diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md index 214566f..c14d923 100644 --- a/.claude/CLAUDE.md +++ b/.claude/CLAUDE.md @@ -1,12 +1,12 @@ # Invoice Master POC v2 -Swedish Invoice Field Extraction System - YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。 +Swedish Invoice Field Extraction System - YOLO26 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。 ## Tech Stack | Component | Technology | |-----------|------------| -| Object Detection | YOLOv11 (Ultralytics) | +| Object Detection | YOLO26 (Ultralytics >= 8.4.0) | | OCR Engine | PaddleOCR v5 (PP-OCRv5) | | PDF Processing | PyMuPDF (fitz) | | Database | PostgreSQL + psycopg2 | @@ -18,7 +18,7 @@ Swedish Invoice Field Extraction System - YOLOv11 + PaddleOCR 从瑞典 PDF 发 **Prefix ALL commands with:** ```bash -wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && " +wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && " ``` **NEVER run Python commands directly in Windows PowerShell/CMD.** @@ -91,3 +91,5 @@ SERVER_PORT=8000 - Never commit to main directly - PRs require review - All tests must pass before merge + +Push the code before review and fix finished. \ No newline at end of file diff --git a/.claude/commands/build-fix.md b/.claude/commands/build-fix.md deleted file mode 100644 index 5951016..0000000 --- a/.claude/commands/build-fix.md +++ /dev/null @@ -1,22 +0,0 @@ -# Build and Fix - -Incrementally fix Python errors and test failures. - -## Workflow - -1. Run check: `mypy src/ --ignore-missing-imports` or `pytest -x --tb=short` -2. Parse errors, group by file, sort by severity (ImportError > TypeError > other) -3. For each error: - - Show context (5 lines) - - Explain and propose fix - - Apply fix - - Re-run test for that file - - Verify resolved -4. Stop if: fix introduces new errors, same error after 3 attempts, or user pauses -5. Show summary: fixed / remaining / new errors - -## Rules - -- Fix ONE error at a time -- Re-run tests after each fix -- Never batch multiple unrelated fixes \ No newline at end of file diff --git a/.claude/commands/checkpoint.md b/.claude/commands/checkpoint.md deleted file mode 100644 index 06293c0..0000000 --- a/.claude/commands/checkpoint.md +++ /dev/null @@ -1,74 +0,0 @@ -# Checkpoint Command - -Create or verify a checkpoint in your workflow. - -## Usage - -`/checkpoint [create|verify|list] [name]` - -## Create Checkpoint - -When creating a checkpoint: - -1. Run `/verify quick` to ensure current state is clean -2. Create a git stash or commit with checkpoint name -3. Log checkpoint to `.claude/checkpoints.log`: - -```bash -echo "$(date +%Y-%m-%d-%H:%M) | $CHECKPOINT_NAME | $(git rev-parse --short HEAD)" >> .claude/checkpoints.log -``` - -4. Report checkpoint created - -## Verify Checkpoint - -When verifying against a checkpoint: - -1. Read checkpoint from log -2. Compare current state to checkpoint: - - Files added since checkpoint - - Files modified since checkpoint - - Test pass rate now vs then - - Coverage now vs then - -3. Report: -``` -CHECKPOINT COMPARISON: $NAME -============================ -Files changed: X -Tests: +Y passed / -Z failed -Coverage: +X% / -Y% -Build: [PASS/FAIL] -``` - -## List Checkpoints - -Show all checkpoints with: -- Name -- Timestamp -- Git SHA -- Status (current, behind, ahead) - -## Workflow - -Typical checkpoint flow: - -``` -[Start] --> /checkpoint create "feature-start" - | -[Implement] --> /checkpoint create "core-done" - | -[Test] --> /checkpoint verify "core-done" - | -[Refactor] --> /checkpoint create "refactor-done" - | -[PR] --> /checkpoint verify "feature-start" -``` - -## Arguments - -$ARGUMENTS: -- `create ` - Create named checkpoint -- `verify ` - Verify against named checkpoint -- `list` - Show all checkpoints -- `clear` - Remove old checkpoints (keeps last 5) diff --git a/.claude/commands/code-review.md b/.claude/commands/code-review.md deleted file mode 100644 index 25c9e7a..0000000 --- a/.claude/commands/code-review.md +++ /dev/null @@ -1,46 +0,0 @@ -# Code Review - -Security and quality review of uncommitted changes. - -## Workflow - -1. Get changed files: `git diff --name-only HEAD` and `git diff --staged --name-only` -2. Review each file for issues (see checklist below) -3. Run automated checks: `mypy src/`, `ruff check src/`, `pytest -x` -4. Generate report with severity, location, description, suggested fix -5. Block commit if CRITICAL or HIGH issues found - -## Checklist - -### CRITICAL (Block) - -- Hardcoded credentials, API keys, tokens, passwords -- SQL injection (must use parameterized queries) -- Path traversal risks -- Missing input validation on API endpoints -- Missing authentication/authorization - -### HIGH (Block) - -- Functions > 50 lines, files > 800 lines -- Nesting depth > 4 levels -- Missing error handling or bare `except:` -- `print()` in production code (use logging) -- Mutable default arguments - -### MEDIUM (Warn) - -- Missing type hints on public functions -- Missing tests for new code -- Duplicate code, magic numbers -- Unused imports/variables -- TODO/FIXME comments - -## Report Format - -``` -[SEVERITY] file:line - Issue description - Suggested fix: ... -``` - -## Never Approve Code With Security Vulnerabilities! \ No newline at end of file diff --git a/.claude/commands/e2e.md b/.claude/commands/e2e.md deleted file mode 100644 index 6ac6d43..0000000 --- a/.claude/commands/e2e.md +++ /dev/null @@ -1,40 +0,0 @@ -# E2E Testing - -End-to-end testing for the Invoice Field Extraction API. - -## When to Use - -- Testing complete inference pipeline (PDF -> Fields) -- Verifying API endpoints work end-to-end -- Validating YOLO + OCR + field extraction integration -- Pre-deployment verification - -## Workflow - -1. Ensure server is running: `python run_server.py` -2. Run health check: `curl http://localhost:8000/api/v1/health` -3. Run E2E tests: `pytest tests/e2e/ -v` -4. Verify results and capture any failures - -## Critical Scenarios (Must Pass) - -1. Health check returns `{"status": "healthy", "model_loaded": true}` -2. PDF upload returns valid response with fields -3. Fields extracted with confidence scores -4. Visualization image generated -5. Cross-validation included for invoices with payment_line - -## Checklist - -- [ ] Server running on http://localhost:8000 -- [ ] Health check passes -- [ ] PDF inference returns valid JSON -- [ ] At least one field extracted -- [ ] Visualization URL returns image -- [ ] Response time < 10 seconds -- [ ] No server errors in logs - -## Test Location - -E2E tests: `tests/e2e/` -Sample fixtures: `tests/fixtures/` \ No newline at end of file diff --git a/.claude/commands/eval.md b/.claude/commands/eval.md deleted file mode 100644 index 852c175..0000000 --- a/.claude/commands/eval.md +++ /dev/null @@ -1,174 +0,0 @@ -# Eval Command - -Evaluate model performance and field extraction accuracy. - -## Usage - -`/eval [model|accuracy|compare|report]` - -## Model Evaluation - -`/eval model` - -Evaluate YOLO model performance on test dataset: - -```bash -# Run model evaluation -python -m src.cli.train --model runs/train/invoice_fields/weights/best.pt --eval-only - -# Or use ultralytics directly -yolo val model=runs/train/invoice_fields/weights/best.pt data=data.yaml -``` - -Output: -``` -Model Evaluation: invoice_fields/best.pt -======================================== -mAP@0.5: 93.5% -mAP@0.5-0.95: 83.0% - -Per-class AP: -- invoice_number: 95.2% -- invoice_date: 94.8% -- invoice_due_date: 93.1% -- ocr_number: 91.5% -- bankgiro: 92.3% -- plusgiro: 90.8% -- amount: 88.7% -- supplier_org_num: 85.2% -- payment_line: 82.4% -- customer_number: 81.1% -``` - -## Accuracy Evaluation - -`/eval accuracy` - -Evaluate field extraction accuracy against ground truth: - -```bash -# Run accuracy evaluation on labeled data -python -m src.cli.infer --model runs/train/invoice_fields/weights/best.pt \ - --input ~/invoice-data/test/*.pdf \ - --ground-truth ~/invoice-data/test/labels.csv \ - --output eval_results.json -``` - -Output: -``` -Field Extraction Accuracy -========================= -Documents tested: 500 - -Per-field accuracy: -- InvoiceNumber: 98.9% (494/500) -- InvoiceDate: 95.5% (478/500) -- InvoiceDueDate: 95.9% (480/500) -- OCR: 99.1% (496/500) -- Bankgiro: 99.0% (495/500) -- Plusgiro: 99.4% (497/500) -- Amount: 91.3% (457/500) -- supplier_org: 78.2% (391/500) - -Overall: 94.8% -``` - -## Compare Models - -`/eval compare` - -Compare two model versions: - -```bash -# Compare old vs new model -python -m src.cli.eval compare \ - --model-a runs/train/invoice_v1/weights/best.pt \ - --model-b runs/train/invoice_v2/weights/best.pt \ - --test-data ~/invoice-data/test/ -``` - -Output: -``` -Model Comparison -================ - Model A Model B Delta -mAP@0.5: 91.2% 93.5% +2.3% -Accuracy: 92.1% 94.8% +2.7% -Speed (ms): 1850 1520 -330 - -Per-field improvements: -- amount: +4.2% -- payment_line: +3.8% -- customer_num: +2.1% - -Recommendation: Deploy Model B -``` - -## Generate Report - -`/eval report` - -Generate comprehensive evaluation report: - -```bash -python -m src.cli.eval report --output eval_report.md -``` - -Output: -```markdown -# Evaluation Report -Generated: 2026-01-25 - -## Model Performance -- Model: runs/train/invoice_fields/weights/best.pt -- mAP@0.5: 93.5% -- Training samples: 9,738 - -## Field Extraction Accuracy -| Field | Accuracy | Errors | -|-------|----------|--------| -| InvoiceNumber | 98.9% | 6 | -| Amount | 91.3% | 43 | -... - -## Error Analysis -### Common Errors -1. Amount: OCR misreads comma as period -2. supplier_org: Missing from some invoices -3. payment_line: Partially obscured by stamps - -## Recommendations -1. Add more training data for low-accuracy fields -2. Implement OCR error correction for amounts -3. Consider confidence threshold tuning -``` - -## Quick Commands - -```bash -# Evaluate model metrics -yolo val model=runs/train/invoice_fields/weights/best.pt - -# Test inference on sample -python -m src.cli.infer --input sample.pdf --output result.json --gpu - -# Check test coverage -pytest --cov=src --cov-report=html -``` - -## Evaluation Metrics - -| Metric | Target | Current | -|--------|--------|---------| -| mAP@0.5 | >90% | 93.5% | -| Overall Accuracy | >90% | 94.8% | -| Test Coverage | >60% | 37% | -| Tests Passing | 100% | 100% | - -## When to Evaluate - -- After training a new model -- Before deploying to production -- After adding new training data -- When accuracy complaints arise -- Weekly performance monitoring diff --git a/.claude/commands/learn.md b/.claude/commands/learn.md deleted file mode 100644 index 9899af1..0000000 --- a/.claude/commands/learn.md +++ /dev/null @@ -1,70 +0,0 @@ -# /learn - Extract Reusable Patterns - -Analyze the current session and extract any patterns worth saving as skills. - -## Trigger - -Run `/learn` at any point during a session when you've solved a non-trivial problem. - -## What to Extract - -Look for: - -1. **Error Resolution Patterns** - - What error occurred? - - What was the root cause? - - What fixed it? - - Is this reusable for similar errors? - -2. **Debugging Techniques** - - Non-obvious debugging steps - - Tool combinations that worked - - Diagnostic patterns - -3. **Workarounds** - - Library quirks - - API limitations - - Version-specific fixes - -4. **Project-Specific Patterns** - - Codebase conventions discovered - - Architecture decisions made - - Integration patterns - -## Output Format - -Create a skill file at `~/.claude/skills/learned/[pattern-name].md`: - -```markdown -# [Descriptive Pattern Name] - -**Extracted:** [Date] -**Context:** [Brief description of when this applies] - -## Problem -[What problem this solves - be specific] - -## Solution -[The pattern/technique/workaround] - -## Example -[Code example if applicable] - -## When to Use -[Trigger conditions - what should activate this skill] -``` - -## Process - -1. Review the session for extractable patterns -2. Identify the most valuable/reusable insight -3. Draft the skill file -4. Ask user to confirm before saving -5. Save to `~/.claude/skills/learned/` - -## Notes - -- Don't extract trivial fixes (typos, simple syntax errors) -- Don't extract one-time issues (specific API outages, etc.) -- Focus on patterns that will save time in future sessions -- Keep skills focused - one pattern per skill diff --git a/.claude/commands/orchestrate.md b/.claude/commands/orchestrate.md deleted file mode 100644 index 30ac2b8..0000000 --- a/.claude/commands/orchestrate.md +++ /dev/null @@ -1,172 +0,0 @@ -# Orchestrate Command - -Sequential agent workflow for complex tasks. - -## Usage - -`/orchestrate [workflow-type] [task-description]` - -## Workflow Types - -### feature -Full feature implementation workflow: -``` -planner -> tdd-guide -> code-reviewer -> security-reviewer -``` - -### bugfix -Bug investigation and fix workflow: -``` -explorer -> tdd-guide -> code-reviewer -``` - -### refactor -Safe refactoring workflow: -``` -architect -> code-reviewer -> tdd-guide -``` - -### security -Security-focused review: -``` -security-reviewer -> code-reviewer -> architect -``` - -## Execution Pattern - -For each agent in the workflow: - -1. **Invoke agent** with context from previous agent -2. **Collect output** as structured handoff document -3. **Pass to next agent** in chain -4. **Aggregate results** into final report - -## Handoff Document Format - -Between agents, create handoff document: - -```markdown -## HANDOFF: [previous-agent] -> [next-agent] - -### Context -[Summary of what was done] - -### Findings -[Key discoveries or decisions] - -### Files Modified -[List of files touched] - -### Open Questions -[Unresolved items for next agent] - -### Recommendations -[Suggested next steps] -``` - -## Example: Feature Workflow - -``` -/orchestrate feature "Add user authentication" -``` - -Executes: - -1. **Planner Agent** - - Analyzes requirements - - Creates implementation plan - - Identifies dependencies - - Output: `HANDOFF: planner -> tdd-guide` - -2. **TDD Guide Agent** - - Reads planner handoff - - Writes tests first - - Implements to pass tests - - Output: `HANDOFF: tdd-guide -> code-reviewer` - -3. **Code Reviewer Agent** - - Reviews implementation - - Checks for issues - - Suggests improvements - - Output: `HANDOFF: code-reviewer -> security-reviewer` - -4. **Security Reviewer Agent** - - Security audit - - Vulnerability check - - Final approval - - Output: Final Report - -## Final Report Format - -``` -ORCHESTRATION REPORT -==================== -Workflow: feature -Task: Add user authentication -Agents: planner -> tdd-guide -> code-reviewer -> security-reviewer - -SUMMARY -------- -[One paragraph summary] - -AGENT OUTPUTS -------------- -Planner: [summary] -TDD Guide: [summary] -Code Reviewer: [summary] -Security Reviewer: [summary] - -FILES CHANGED -------------- -[List all files modified] - -TEST RESULTS ------------- -[Test pass/fail summary] - -SECURITY STATUS ---------------- -[Security findings] - -RECOMMENDATION --------------- -[SHIP / NEEDS WORK / BLOCKED] -``` - -## Parallel Execution - -For independent checks, run agents in parallel: - -```markdown -### Parallel Phase -Run simultaneously: -- code-reviewer (quality) -- security-reviewer (security) -- architect (design) - -### Merge Results -Combine outputs into single report -``` - -## Arguments - -$ARGUMENTS: -- `feature ` - Full feature workflow -- `bugfix ` - Bug fix workflow -- `refactor ` - Refactoring workflow -- `security ` - Security review workflow -- `custom ` - Custom agent sequence - -## Custom Workflow Example - -``` -/orchestrate custom "architect,tdd-guide,code-reviewer" "Redesign caching layer" -``` - -## Tips - -1. **Start with planner** for complex features -2. **Always include code-reviewer** before merge -3. **Use security-reviewer** for auth/payment/PII -4. **Keep handoffs concise** - focus on what next agent needs -5. **Run verification** between agents if needed diff --git a/.claude/commands/plan.md b/.claude/commands/plan.md deleted file mode 100644 index 3acf686..0000000 --- a/.claude/commands/plan.md +++ /dev/null @@ -1,113 +0,0 @@ ---- -description: Restate requirements, assess risks, and create step-by-step implementation plan. WAIT for user CONFIRM before touching any code. ---- - -# Plan Command - -This command invokes the **planner** agent to create a comprehensive implementation plan before writing any code. - -## What This Command Does - -1. **Restate Requirements** - Clarify what needs to be built -2. **Identify Risks** - Surface potential issues and blockers -3. **Create Step Plan** - Break down implementation into phases -4. **Wait for Confirmation** - MUST receive user approval before proceeding - -## When to Use - -Use `/plan` when: -- Starting a new feature -- Making significant architectural changes -- Working on complex refactoring -- Multiple files/components will be affected -- Requirements are unclear or ambiguous - -## How It Works - -The planner agent will: - -1. **Analyze the request** and restate requirements in clear terms -2. **Break down into phases** with specific, actionable steps -3. **Identify dependencies** between components -4. **Assess risks** and potential blockers -5. **Estimate complexity** (High/Medium/Low) -6. **Present the plan** and WAIT for your explicit confirmation - -## Example Usage - -``` -User: /plan I need to add real-time notifications when markets resolve - -Agent (planner): -# Implementation Plan: Real-Time Market Resolution Notifications - -## Requirements Restatement -- Send notifications to users when markets they're watching resolve -- Support multiple notification channels (in-app, email, webhook) -- Ensure notifications are delivered reliably -- Include market outcome and user's position result - -## Implementation Phases - -### Phase 1: Database Schema -- Add notifications table with columns: id, user_id, market_id, type, status, created_at -- Add user_notification_preferences table for channel preferences -- Create indexes on user_id and market_id for performance - -### Phase 2: Notification Service -- Create notification service in lib/notifications.ts -- Implement notification queue using BullMQ/Redis -- Add retry logic for failed deliveries -- Create notification templates - -### Phase 3: Integration Points -- Hook into market resolution logic (when status changes to "resolved") -- Query all users with positions in market -- Enqueue notifications for each user - -### Phase 4: Frontend Components -- Create NotificationBell component in header -- Add NotificationList modal -- Implement real-time updates via Supabase subscriptions -- Add notification preferences page - -## Dependencies -- Redis (for queue) -- Email service (SendGrid/Resend) -- Supabase real-time subscriptions - -## Risks -- HIGH: Email deliverability (SPF/DKIM required) -- MEDIUM: Performance with 1000+ users per market -- MEDIUM: Notification spam if markets resolve frequently -- LOW: Real-time subscription overhead - -## Estimated Complexity: MEDIUM -- Backend: 4-6 hours -- Frontend: 3-4 hours -- Testing: 2-3 hours -- Total: 9-13 hours - -**WAITING FOR CONFIRMATION**: Proceed with this plan? (yes/no/modify) -``` - -## Important Notes - -**CRITICAL**: The planner agent will **NOT** write any code until you explicitly confirm the plan with "yes" or "proceed" or similar affirmative response. - -If you want changes, respond with: -- "modify: [your changes]" -- "different approach: [alternative]" -- "skip phase 2 and do phase 3 first" - -## Integration with Other Commands - -After planning: -- Use `/tdd` to implement with test-driven development -- Use `/build-and-fix` if build errors occur -- Use `/code-review` to review completed implementation - -## Related Agents - -This command invokes the `planner` agent located at: -`~/.claude/agents/planner.md` diff --git a/.claude/commands/refactor-clean.md b/.claude/commands/refactor-clean.md deleted file mode 100644 index 6f5e250..0000000 --- a/.claude/commands/refactor-clean.md +++ /dev/null @@ -1,28 +0,0 @@ -# Refactor Clean - -Safely identify and remove dead code with test verification: - -1. Run dead code analysis tools: - - knip: Find unused exports and files - - depcheck: Find unused dependencies - - ts-prune: Find unused TypeScript exports - -2. Generate comprehensive report in .reports/dead-code-analysis.md - -3. Categorize findings by severity: - - SAFE: Test files, unused utilities - - CAUTION: API routes, components - - DANGER: Config files, main entry points - -4. Propose safe deletions only - -5. Before each deletion: - - Run full test suite - - Verify tests pass - - Apply change - - Re-run tests - - Rollback if tests fail - -6. Show summary of cleaned items - -Never delete code without running tests first! diff --git a/.claude/commands/setup-pm.md b/.claude/commands/setup-pm.md deleted file mode 100644 index 87224b9..0000000 --- a/.claude/commands/setup-pm.md +++ /dev/null @@ -1,80 +0,0 @@ ---- -description: Configure your preferred package manager (npm/pnpm/yarn/bun) -disable-model-invocation: true ---- - -# Package Manager Setup - -Configure your preferred package manager for this project or globally. - -## Usage - -```bash -# Detect current package manager -node scripts/setup-package-manager.js --detect - -# Set global preference -node scripts/setup-package-manager.js --global pnpm - -# Set project preference -node scripts/setup-package-manager.js --project bun - -# List available package managers -node scripts/setup-package-manager.js --list -``` - -## Detection Priority - -When determining which package manager to use, the following order is checked: - -1. **Environment variable**: `CLAUDE_PACKAGE_MANAGER` -2. **Project config**: `.claude/package-manager.json` -3. **package.json**: `packageManager` field -4. **Lock file**: Presence of package-lock.json, yarn.lock, pnpm-lock.yaml, or bun.lockb -5. **Global config**: `~/.claude/package-manager.json` -6. **Fallback**: First available package manager (pnpm > bun > yarn > npm) - -## Configuration Files - -### Global Configuration -```json -// ~/.claude/package-manager.json -{ - "packageManager": "pnpm" -} -``` - -### Project Configuration -```json -// .claude/package-manager.json -{ - "packageManager": "bun" -} -``` - -### package.json -```json -{ - "packageManager": "pnpm@8.6.0" -} -``` - -## Environment Variable - -Set `CLAUDE_PACKAGE_MANAGER` to override all other detection methods: - -```bash -# Windows (PowerShell) -$env:CLAUDE_PACKAGE_MANAGER = "pnpm" - -# macOS/Linux -export CLAUDE_PACKAGE_MANAGER=pnpm -``` - -## Run the Detection - -To see current package manager detection results, run: - -```bash -node scripts/setup-package-manager.js --detect -``` diff --git a/.claude/commands/tdd.md b/.claude/commands/tdd.md deleted file mode 100644 index 02bdb2d..0000000 --- a/.claude/commands/tdd.md +++ /dev/null @@ -1,326 +0,0 @@ ---- -description: Enforce test-driven development workflow. Scaffold interfaces, generate tests FIRST, then implement minimal code to pass. Ensure 80%+ coverage. ---- - -# TDD Command - -This command invokes the **tdd-guide** agent to enforce test-driven development methodology. - -## What This Command Does - -1. **Scaffold Interfaces** - Define types/interfaces first -2. **Generate Tests First** - Write failing tests (RED) -3. **Implement Minimal Code** - Write just enough to pass (GREEN) -4. **Refactor** - Improve code while keeping tests green (REFACTOR) -5. **Verify Coverage** - Ensure 80%+ test coverage - -## When to Use - -Use `/tdd` when: -- Implementing new features -- Adding new functions/components -- Fixing bugs (write test that reproduces bug first) -- Refactoring existing code -- Building critical business logic - -## How It Works - -The tdd-guide agent will: - -1. **Define interfaces** for inputs/outputs -2. **Write tests that will FAIL** (because code doesn't exist yet) -3. **Run tests** and verify they fail for the right reason -4. **Write minimal implementation** to make tests pass -5. **Run tests** and verify they pass -6. **Refactor** code while keeping tests green -7. **Check coverage** and add more tests if below 80% - -## TDD Cycle - -``` -RED → GREEN → REFACTOR → REPEAT - -RED: Write a failing test -GREEN: Write minimal code to pass -REFACTOR: Improve code, keep tests passing -REPEAT: Next feature/scenario -``` - -## Example Usage - -``` -User: /tdd I need a function to calculate market liquidity score - -Agent (tdd-guide): -# TDD Session: Market Liquidity Score Calculator - -## Step 1: Define Interface (SCAFFOLD) - -```typescript -// lib/liquidity.ts -export interface MarketData { - totalVolume: number - bidAskSpread: number - activeTraders: number - lastTradeTime: Date -} - -export function calculateLiquidityScore(market: MarketData): number { - // TODO: Implementation - throw new Error('Not implemented') -} -``` - -## Step 2: Write Failing Test (RED) - -```typescript -// lib/liquidity.test.ts -import { calculateLiquidityScore } from './liquidity' - -describe('calculateLiquidityScore', () => { - it('should return high score for liquid market', () => { - const market = { - totalVolume: 100000, - bidAskSpread: 0.01, - activeTraders: 500, - lastTradeTime: new Date() - } - - const score = calculateLiquidityScore(market) - - expect(score).toBeGreaterThan(80) - expect(score).toBeLessThanOrEqual(100) - }) - - it('should return low score for illiquid market', () => { - const market = { - totalVolume: 100, - bidAskSpread: 0.5, - activeTraders: 2, - lastTradeTime: new Date(Date.now() - 86400000) // 1 day ago - } - - const score = calculateLiquidityScore(market) - - expect(score).toBeLessThan(30) - expect(score).toBeGreaterThanOrEqual(0) - }) - - it('should handle edge case: zero volume', () => { - const market = { - totalVolume: 0, - bidAskSpread: 0, - activeTraders: 0, - lastTradeTime: new Date() - } - - const score = calculateLiquidityScore(market) - - expect(score).toBe(0) - }) -}) -``` - -## Step 3: Run Tests - Verify FAIL - -```bash -npm test lib/liquidity.test.ts - -FAIL lib/liquidity.test.ts - ✕ should return high score for liquid market (2 ms) - Error: Not implemented - -1 test failed, 0 passed -``` - -✅ Tests fail as expected. Ready to implement. - -## Step 4: Implement Minimal Code (GREEN) - -```typescript -// lib/liquidity.ts -export function calculateLiquidityScore(market: MarketData): number { - // Handle zero volume edge case - if (market.totalVolume === 0) { - return 0 - } - - // Calculate component scores (0-100 scale) - const volumeScore = Math.min(market.totalVolume / 1000, 100) - const spreadScore = Math.max(100 - (market.bidAskSpread * 1000), 0) - const traderScore = Math.min(market.activeTraders / 10, 100) - - // Recent activity bonus - const hoursSinceLastTrade = (Date.now() - market.lastTradeTime.getTime()) / (1000 * 60 * 60) - const recencyScore = Math.max(100 - (hoursSinceLastTrade * 10), 0) - - // Weighted average - const score = ( - volumeScore * 0.4 + - spreadScore * 0.3 + - traderScore * 0.2 + - recencyScore * 0.1 - ) - - return Math.min(Math.max(score, 0), 100) // Clamp to 0-100 -} -``` - -## Step 5: Run Tests - Verify PASS - -```bash -npm test lib/liquidity.test.ts - -PASS lib/liquidity.test.ts - ✓ should return high score for liquid market (3 ms) - ✓ should return low score for illiquid market (2 ms) - ✓ should handle edge case: zero volume (1 ms) - -3 tests passed -``` - -✅ All tests passing! - -## Step 6: Refactor (IMPROVE) - -```typescript -// lib/liquidity.ts - Refactored with constants and better readability -const WEIGHTS = { - VOLUME: 0.4, - SPREAD: 0.3, - TRADERS: 0.2, - RECENCY: 0.1, -} as const - -const SCALE_FACTORS = { - VOLUME: 1000, - SPREAD: 1000, - TRADERS: 10, - RECENCY_PENALTY: 10, -} as const - -function clamp(value: number, min: number, max: number): number { - return Math.min(Math.max(value, min), max) -} - -export function calculateLiquidityScore(market: MarketData): number { - if (market.totalVolume === 0) return 0 - - const volumeScore = Math.min(market.totalVolume / SCALE_FACTORS.VOLUME, 100) - const spreadScore = clamp(100 - (market.bidAskSpread * SCALE_FACTORS.SPREAD), 0, 100) - const traderScore = Math.min(market.activeTraders / SCALE_FACTORS.TRADERS, 100) - - const hoursSinceLastTrade = (Date.now() - market.lastTradeTime.getTime()) / (1000 * 60 * 60) - const recencyScore = clamp(100 - (hoursSinceLastTrade * SCALE_FACTORS.RECENCY_PENALTY), 0, 100) - - const weightedScore = - volumeScore * WEIGHTS.VOLUME + - spreadScore * WEIGHTS.SPREAD + - traderScore * WEIGHTS.TRADERS + - recencyScore * WEIGHTS.RECENCY - - return clamp(weightedScore, 0, 100) -} -``` - -## Step 7: Verify Tests Still Pass - -```bash -npm test lib/liquidity.test.ts - -PASS lib/liquidity.test.ts - ✓ should return high score for liquid market (3 ms) - ✓ should return low score for illiquid market (2 ms) - ✓ should handle edge case: zero volume (1 ms) - -3 tests passed -``` - -✅ Refactoring complete, tests still passing! - -## Step 8: Check Coverage - -```bash -npm test -- --coverage lib/liquidity.test.ts - -File | % Stmts | % Branch | % Funcs | % Lines ----------------|---------|----------|---------|-------- -liquidity.ts | 100 | 100 | 100 | 100 - -Coverage: 100% ✅ (Target: 80%) -``` - -✅ TDD session complete! -``` - -## TDD Best Practices - -**DO:** -- ✅ Write the test FIRST, before any implementation -- ✅ Run tests and verify they FAIL before implementing -- ✅ Write minimal code to make tests pass -- ✅ Refactor only after tests are green -- ✅ Add edge cases and error scenarios -- ✅ Aim for 80%+ coverage (100% for critical code) - -**DON'T:** -- ❌ Write implementation before tests -- ❌ Skip running tests after each change -- ❌ Write too much code at once -- ❌ Ignore failing tests -- ❌ Test implementation details (test behavior) -- ❌ Mock everything (prefer integration tests) - -## Test Types to Include - -**Unit Tests** (Function-level): -- Happy path scenarios -- Edge cases (empty, null, max values) -- Error conditions -- Boundary values - -**Integration Tests** (Component-level): -- API endpoints -- Database operations -- External service calls -- React components with hooks - -**E2E Tests** (use `/e2e` command): -- Critical user flows -- Multi-step processes -- Full stack integration - -## Coverage Requirements - -- **80% minimum** for all code -- **100% required** for: - - Financial calculations - - Authentication logic - - Security-critical code - - Core business logic - -## Important Notes - -**MANDATORY**: Tests must be written BEFORE implementation. The TDD cycle is: - -1. **RED** - Write failing test -2. **GREEN** - Implement to pass -3. **REFACTOR** - Improve code - -Never skip the RED phase. Never write code before tests. - -## Integration with Other Commands - -- Use `/plan` first to understand what to build -- Use `/tdd` to implement with tests -- Use `/build-and-fix` if build errors occur -- Use `/code-review` to review implementation -- Use `/test-coverage` to verify coverage - -## Related Agents - -This command invokes the `tdd-guide` agent located at: -`~/.claude/agents/tdd-guide.md` - -And can reference the `tdd-workflow` skill at: -`~/.claude/skills/tdd-workflow/` diff --git a/.claude/commands/test-coverage.md b/.claude/commands/test-coverage.md deleted file mode 100644 index 754eabf..0000000 --- a/.claude/commands/test-coverage.md +++ /dev/null @@ -1,27 +0,0 @@ -# Test Coverage - -Analyze test coverage and generate missing tests: - -1. Run tests with coverage: npm test --coverage or pnpm test --coverage - -2. Analyze coverage report (coverage/coverage-summary.json) - -3. Identify files below 80% coverage threshold - -4. For each under-covered file: - - Analyze untested code paths - - Generate unit tests for functions - - Generate integration tests for APIs - - Generate E2E tests for critical flows - -5. Verify new tests pass - -6. Show before/after coverage metrics - -7. Ensure project reaches 80%+ overall coverage - -Focus on: -- Happy path scenarios -- Error handling -- Edge cases (null, undefined, empty) -- Boundary conditions diff --git a/.claude/commands/update-codemaps.md b/.claude/commands/update-codemaps.md deleted file mode 100644 index f363a05..0000000 --- a/.claude/commands/update-codemaps.md +++ /dev/null @@ -1,17 +0,0 @@ -# Update Codemaps - -Analyze the codebase structure and update architecture documentation: - -1. Scan all source files for imports, exports, and dependencies -2. Generate token-lean codemaps in the following format: - - codemaps/architecture.md - Overall architecture - - codemaps/backend.md - Backend structure - - codemaps/frontend.md - Frontend structure - - codemaps/data.md - Data models and schemas - -3. Calculate diff percentage from previous version -4. If changes > 30%, request user approval before updating -5. Add freshness timestamp to each codemap -6. Save reports to .reports/codemap-diff.txt - -Use TypeScript/Node.js for analysis. Focus on high-level structure, not implementation details. diff --git a/.claude/commands/update-docs.md b/.claude/commands/update-docs.md deleted file mode 100644 index 3dd0f89..0000000 --- a/.claude/commands/update-docs.md +++ /dev/null @@ -1,31 +0,0 @@ -# Update Documentation - -Sync documentation from source-of-truth: - -1. Read package.json scripts section - - Generate scripts reference table - - Include descriptions from comments - -2. Read .env.example - - Extract all environment variables - - Document purpose and format - -3. Generate docs/CONTRIB.md with: - - Development workflow - - Available scripts - - Environment setup - - Testing procedures - -4. Generate docs/RUNBOOK.md with: - - Deployment procedures - - Monitoring and alerts - - Common issues and fixes - - Rollback procedures - -5. Identify obsolete documentation: - - Find docs not modified in 90+ days - - List for manual review - -6. Show diff summary - -Single source of truth: package.json and .env.example diff --git a/.claude/commands/verify.md b/.claude/commands/verify.md deleted file mode 100644 index 5f628b1..0000000 --- a/.claude/commands/verify.md +++ /dev/null @@ -1,59 +0,0 @@ -# Verification Command - -Run comprehensive verification on current codebase state. - -## Instructions - -Execute verification in this exact order: - -1. **Build Check** - - Run the build command for this project - - If it fails, report errors and STOP - -2. **Type Check** - - Run TypeScript/type checker - - Report all errors with file:line - -3. **Lint Check** - - Run linter - - Report warnings and errors - -4. **Test Suite** - - Run all tests - - Report pass/fail count - - Report coverage percentage - -5. **Console.log Audit** - - Search for console.log in source files - - Report locations - -6. **Git Status** - - Show uncommitted changes - - Show files modified since last commit - -## Output - -Produce a concise verification report: - -``` -VERIFICATION: [PASS/FAIL] - -Build: [OK/FAIL] -Types: [OK/X errors] -Lint: [OK/X issues] -Tests: [X/Y passed, Z% coverage] -Secrets: [OK/X found] -Logs: [OK/X console.logs] - -Ready for PR: [YES/NO] -``` - -If any critical issues, list them with fix suggestions. - -## Arguments - -$ARGUMENTS can be: -- `quick` - Only build + types -- `full` - All checks (default) -- `pre-commit` - Checks relevant for commits -- `pre-pr` - Full checks plus security scan diff --git a/.claude/hooks/hooks.json b/.claude/hooks/hooks.json deleted file mode 100644 index ea9cdc6..0000000 --- a/.claude/hooks/hooks.json +++ /dev/null @@ -1,157 +0,0 @@ -{ - "$schema": "https://json.schemastore.org/claude-code-settings.json", - "hooks": { - "PreToolUse": [ - { - "matcher": "tool == \"Bash\" && tool_input.command matches \"(npm run dev|pnpm( run)? dev|yarn dev|bun run dev)\"", - "hooks": [ - { - "type": "command", - "command": "node -e \"console.error('[Hook] BLOCKED: Dev server must run in tmux for log access');console.error('[Hook] Use: tmux new-session -d -s dev \\\"npm run dev\\\"');console.error('[Hook] Then: tmux attach -t dev');process.exit(1)\"" - } - ], - "description": "Block dev servers outside tmux - ensures you can access logs" - }, - { - "matcher": "tool == \"Bash\" && tool_input.command matches \"(npm (install|test)|pnpm (install|test)|yarn (install|test)?|bun (install|test)|cargo build|make|docker|pytest|vitest|playwright)\"", - "hooks": [ - { - "type": "command", - "command": "node -e \"if(!process.env.TMUX){console.error('[Hook] Consider running in tmux for session persistence');console.error('[Hook] tmux new -s dev | tmux attach -t dev')}\"" - } - ], - "description": "Reminder to use tmux for long-running commands" - }, - { - "matcher": "tool == \"Bash\" && tool_input.command matches \"git push\"", - "hooks": [ - { - "type": "command", - "command": "node -e \"console.error('[Hook] Review changes before push...');console.error('[Hook] Continuing with push (remove this hook to add interactive review)')\"" - } - ], - "description": "Reminder before git push to review changes" - }, - { - "matcher": "tool == \"Write\" && tool_input.file_path matches \"\\\\.(md|txt)$\" && !(tool_input.file_path matches \"README\\\\.md|CLAUDE\\\\.md|AGENTS\\\\.md|CONTRIBUTING\\\\.md\")", - "hooks": [ - { - "type": "command", - "command": "node -e \"const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path||'';if(/\\.(md|txt)$/.test(p)&&!/(README|CLAUDE|AGENTS|CONTRIBUTING)\\.md$/.test(p)){console.error('[Hook] BLOCKED: Unnecessary documentation file creation');console.error('[Hook] File: '+p);console.error('[Hook] Use README.md for documentation instead');process.exit(1)}console.log(d)})\"" - } - ], - "description": "Block creation of random .md files - keeps docs consolidated" - }, - { - "matcher": "tool == \"Edit\" || tool == \"Write\"", - "hooks": [ - { - "type": "command", - "command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/suggest-compact.js\"" - } - ], - "description": "Suggest manual compaction at logical intervals" - } - ], - "PreCompact": [ - { - "matcher": "*", - "hooks": [ - { - "type": "command", - "command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/pre-compact.js\"" - } - ], - "description": "Save state before context compaction" - } - ], - "SessionStart": [ - { - "matcher": "*", - "hooks": [ - { - "type": "command", - "command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/session-start.js\"" - } - ], - "description": "Load previous context and detect package manager on new session" - } - ], - "PostToolUse": [ - { - "matcher": "tool == \"Bash\"", - "hooks": [ - { - "type": "command", - "command": "node -e \"let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const cmd=i.tool_input?.command||'';if(/gh pr create/.test(cmd)){const out=i.tool_output?.output||'';const m=out.match(/https:\\/\\/github.com\\/[^/]+\\/[^/]+\\/pull\\/\\d+/);if(m){console.error('[Hook] PR created: '+m[0]);const repo=m[0].replace(/https:\\/\\/github.com\\/([^/]+\\/[^/]+)\\/pull\\/\\d+/,'$1');const pr=m[0].replace(/.*\\/pull\\/(\\d+)/,'$1');console.error('[Hook] To review: gh pr review '+pr+' --repo '+repo)}}console.log(d)})\"" - } - ], - "description": "Log PR URL and provide review command after PR creation" - }, - { - "matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx|js|jsx)$\"", - "hooks": [ - { - "type": "command", - "command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){try{execSync('npx prettier --write \"'+p+'\"',{stdio:['pipe','pipe','pipe']})}catch(e){}}console.log(d)})\"" - } - ], - "description": "Auto-format JS/TS files with Prettier after edits" - }, - { - "matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx)$\"", - "hooks": [ - { - "type": "command", - "command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');const path=require('path');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){let dir=path.dirname(p);while(dir!==path.dirname(dir)&&!fs.existsSync(path.join(dir,'tsconfig.json'))){dir=path.dirname(dir)}if(fs.existsSync(path.join(dir,'tsconfig.json'))){try{const r=execSync('npx tsc --noEmit --pretty false 2>&1',{cwd:dir,encoding:'utf8',stdio:['pipe','pipe','pipe']});const lines=r.split('\\n').filter(l=>l.includes(p)).slice(0,10);if(lines.length)console.error(lines.join('\\n'))}catch(e){const lines=(e.stdout||'').split('\\n').filter(l=>l.includes(p)).slice(0,10);if(lines.length)console.error(lines.join('\\n'))}}}console.log(d)})\"" - } - ], - "description": "TypeScript check after editing .ts/.tsx files" - }, - { - "matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx|js|jsx)$\"", - "hooks": [ - { - "type": "command", - "command": "node -e \"const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){const c=fs.readFileSync(p,'utf8');const lines=c.split('\\n');const matches=[];lines.forEach((l,idx)=>{if(/console\\.log/.test(l))matches.push((idx+1)+': '+l.trim())});if(matches.length){console.error('[Hook] WARNING: console.log found in '+p);matches.slice(0,5).forEach(m=>console.error(m));console.error('[Hook] Remove console.log before committing')}}console.log(d)})\"" - } - ], - "description": "Warn about console.log statements after edits" - } - ], - "Stop": [ - { - "matcher": "*", - "hooks": [ - { - "type": "command", - "command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{try{execSync('git rev-parse --git-dir',{stdio:'pipe'})}catch{console.log(d);process.exit(0)}try{const files=execSync('git diff --name-only HEAD',{encoding:'utf8',stdio:['pipe','pipe','pipe']}).split('\\n').filter(f=>/\\.(ts|tsx|js|jsx)$/.test(f)&&fs.existsSync(f));let hasConsole=false;for(const f of files){if(fs.readFileSync(f,'utf8').includes('console.log')){console.error('[Hook] WARNING: console.log found in '+f);hasConsole=true}}if(hasConsole)console.error('[Hook] Remove console.log statements before committing')}catch(e){}console.log(d)})\"" - } - ], - "description": "Check for console.log in modified files after each response" - } - ], - "SessionEnd": [ - { - "matcher": "*", - "hooks": [ - { - "type": "command", - "command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/session-end.js\"" - } - ], - "description": "Persist session state on end" - }, - { - "matcher": "*", - "hooks": [ - { - "type": "command", - "command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/evaluate-session.js\"" - } - ], - "description": "Evaluate session for extractable patterns" - } - ] - } -} diff --git a/.claude/hooks/memory-persistence/pre-compact.sh b/.claude/hooks/memory-persistence/pre-compact.sh deleted file mode 100644 index 296fce9..0000000 --- a/.claude/hooks/memory-persistence/pre-compact.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash -# PreCompact Hook - Save state before context compaction -# -# Runs before Claude compacts context, giving you a chance to -# preserve important state that might get lost in summarization. -# -# Hook config (in ~/.claude/settings.json): -# { -# "hooks": { -# "PreCompact": [{ -# "matcher": "*", -# "hooks": [{ -# "type": "command", -# "command": "~/.claude/hooks/memory-persistence/pre-compact.sh" -# }] -# }] -# } -# } - -SESSIONS_DIR="${HOME}/.claude/sessions" -COMPACTION_LOG="${SESSIONS_DIR}/compaction-log.txt" - -mkdir -p "$SESSIONS_DIR" - -# Log compaction event with timestamp -echo "[$(date '+%Y-%m-%d %H:%M:%S')] Context compaction triggered" >> "$COMPACTION_LOG" - -# If there's an active session file, note the compaction -ACTIVE_SESSION=$(ls -t "$SESSIONS_DIR"/*.tmp 2>/dev/null | head -1) -if [ -n "$ACTIVE_SESSION" ]; then - echo "" >> "$ACTIVE_SESSION" - echo "---" >> "$ACTIVE_SESSION" - echo "**[Compaction occurred at $(date '+%H:%M')]** - Context was summarized" >> "$ACTIVE_SESSION" -fi - -echo "[PreCompact] State saved before compaction" >&2 diff --git a/.claude/hooks/memory-persistence/session-end.sh b/.claude/hooks/memory-persistence/session-end.sh deleted file mode 100644 index 93b0f63..0000000 --- a/.claude/hooks/memory-persistence/session-end.sh +++ /dev/null @@ -1,61 +0,0 @@ -#!/bin/bash -# Stop Hook (Session End) - Persist learnings when session ends -# -# Runs when Claude session ends. Creates/updates session log file -# with timestamp for continuity tracking. -# -# Hook config (in ~/.claude/settings.json): -# { -# "hooks": { -# "Stop": [{ -# "matcher": "*", -# "hooks": [{ -# "type": "command", -# "command": "~/.claude/hooks/memory-persistence/session-end.sh" -# }] -# }] -# } -# } - -SESSIONS_DIR="${HOME}/.claude/sessions" -TODAY=$(date '+%Y-%m-%d') -SESSION_FILE="${SESSIONS_DIR}/${TODAY}-session.tmp" - -mkdir -p "$SESSIONS_DIR" - -# If session file exists for today, update the end time -if [ -f "$SESSION_FILE" ]; then - # Update Last Updated timestamp - sed -i '' "s/\*\*Last Updated:\*\*.*/\*\*Last Updated:\*\* $(date '+%H:%M')/" "$SESSION_FILE" 2>/dev/null || \ - sed -i "s/\*\*Last Updated:\*\*.*/\*\*Last Updated:\*\* $(date '+%H:%M')/" "$SESSION_FILE" 2>/dev/null - echo "[SessionEnd] Updated session file: $SESSION_FILE" >&2 -else - # Create new session file with template - cat > "$SESSION_FILE" << EOF -# Session: $(date '+%Y-%m-%d') -**Date:** $TODAY -**Started:** $(date '+%H:%M') -**Last Updated:** $(date '+%H:%M') - ---- - -## Current State - -[Session context goes here] - -### Completed -- [ ] - -### In Progress -- [ ] - -### Notes for Next Session -- - -### Context to Load -\`\`\` -[relevant files] -\`\`\` -EOF - echo "[SessionEnd] Created session file: $SESSION_FILE" >&2 -fi diff --git a/.claude/hooks/memory-persistence/session-start.sh b/.claude/hooks/memory-persistence/session-start.sh deleted file mode 100644 index 57a8c14..0000000 --- a/.claude/hooks/memory-persistence/session-start.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# SessionStart Hook - Load previous context on new session -# -# Runs when a new Claude session starts. Checks for recent session -# files and notifies Claude of available context to load. -# -# Hook config (in ~/.claude/settings.json): -# { -# "hooks": { -# "SessionStart": [{ -# "matcher": "*", -# "hooks": [{ -# "type": "command", -# "command": "~/.claude/hooks/memory-persistence/session-start.sh" -# }] -# }] -# } -# } - -SESSIONS_DIR="${HOME}/.claude/sessions" -LEARNED_DIR="${HOME}/.claude/skills/learned" - -# Check for recent session files (last 7 days) -recent_sessions=$(find "$SESSIONS_DIR" -name "*.tmp" -mtime -7 2>/dev/null | wc -l | tr -d ' ') - -if [ "$recent_sessions" -gt 0 ]; then - latest=$(ls -t "$SESSIONS_DIR"/*.tmp 2>/dev/null | head -1) - echo "[SessionStart] Found $recent_sessions recent session(s)" >&2 - echo "[SessionStart] Latest: $latest" >&2 -fi - -# Check for learned skills -learned_count=$(find "$LEARNED_DIR" -name "*.md" 2>/dev/null | wc -l | tr -d ' ') - -if [ "$learned_count" -gt 0 ]; then - echo "[SessionStart] $learned_count learned skill(s) available in $LEARNED_DIR" >&2 -fi diff --git a/.claude/hooks/strategic-compact/suggest-compact.sh b/.claude/hooks/strategic-compact/suggest-compact.sh deleted file mode 100644 index ea14920..0000000 --- a/.claude/hooks/strategic-compact/suggest-compact.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash -# Strategic Compact Suggester -# Runs on PreToolUse or periodically to suggest manual compaction at logical intervals -# -# Why manual over auto-compact: -# - Auto-compact happens at arbitrary points, often mid-task -# - Strategic compacting preserves context through logical phases -# - Compact after exploration, before execution -# - Compact after completing a milestone, before starting next -# -# Hook config (in ~/.claude/settings.json): -# { -# "hooks": { -# "PreToolUse": [{ -# "matcher": "Edit|Write", -# "hooks": [{ -# "type": "command", -# "command": "~/.claude/skills/strategic-compact/suggest-compact.sh" -# }] -# }] -# } -# } -# -# Criteria for suggesting compact: -# - Session has been running for extended period -# - Large number of tool calls made -# - Transitioning from research/exploration to implementation -# - Plan has been finalized - -# Track tool call count (increment in a temp file) -COUNTER_FILE="/tmp/claude-tool-count-$$" -THRESHOLD=${COMPACT_THRESHOLD:-50} - -# Initialize or increment counter -if [ -f "$COUNTER_FILE" ]; then - count=$(cat "$COUNTER_FILE") - count=$((count + 1)) - echo "$count" > "$COUNTER_FILE" -else - echo "1" > "$COUNTER_FILE" - count=1 -fi - -# Suggest compact after threshold tool calls -if [ "$count" -eq "$THRESHOLD" ]; then - echo "[StrategicCompact] $THRESHOLD tool calls reached - consider /compact if transitioning phases" >&2 -fi - -# Suggest at regular intervals after threshold -if [ "$count" -gt "$THRESHOLD" ] && [ $((count % 25)) -eq 0 ]; then - echo "[StrategicCompact] $count tool calls - good checkpoint for /compact if context is stale" >&2 -fi diff --git a/.claude/rules/coding-style.md b/.claude/rules/coding-style.md new file mode 100644 index 0000000..c96bba4 --- /dev/null +++ b/.claude/rules/coding-style.md @@ -0,0 +1,37 @@ +# Python Coding Style + +> This file extends [common/coding-style.md](../common/coding-style.md) with Python specific content. + +## Standards + +- Follow **PEP 8** conventions +- Use **type annotations** on all function signatures + +## Immutability + +Prefer immutable data structures: + +```python +from dataclasses import dataclass + +@dataclass(frozen=True) +class User: + name: str + email: str + +from typing import NamedTuple + +class Point(NamedTuple): + x: float + y: float +``` + +## Formatting + +- **black** for code formatting +- **isort** for import sorting +- **ruff** for linting + +## Reference + +See skill: `python-patterns` for comprehensive Python idioms and patterns. diff --git a/.claude/rules/hooks.md b/.claude/rules/hooks.md new file mode 100644 index 0000000..0ced0dc --- /dev/null +++ b/.claude/rules/hooks.md @@ -0,0 +1,14 @@ +# Python Hooks + +> This file extends [common/hooks.md](../common/hooks.md) with Python specific content. + +## PostToolUse Hooks + +Configure in `~/.claude/settings.json`: + +- **black/ruff**: Auto-format `.py` files after edit +- **mypy/pyright**: Run type checking after editing `.py` files + +## Warnings + +- Warn about `print()` statements in edited files (use `logging` module instead) diff --git a/.claude/rules/patterns.md b/.claude/rules/patterns.md new file mode 100644 index 0000000..96b96c1 --- /dev/null +++ b/.claude/rules/patterns.md @@ -0,0 +1,34 @@ +# Python Patterns + +> This file extends [common/patterns.md](../common/patterns.md) with Python specific content. + +## Protocol (Duck Typing) + +```python +from typing import Protocol + +class Repository(Protocol): + def find_by_id(self, id: str) -> dict | None: ... + def save(self, entity: dict) -> dict: ... +``` + +## Dataclasses as DTOs + +```python +from dataclasses import dataclass + +@dataclass +class CreateUserRequest: + name: str + email: str + age: int | None = None +``` + +## Context Managers & Generators + +- Use context managers (`with` statement) for resource management +- Use generators for lazy evaluation and memory-efficient iteration + +## Reference + +See skill: `python-patterns` for comprehensive patterns including decorators, concurrency, and package organization. diff --git a/.claude/rules/security.md b/.claude/rules/security.md new file mode 100644 index 0000000..d9aec92 --- /dev/null +++ b/.claude/rules/security.md @@ -0,0 +1,25 @@ +# Python Security + +> This file extends [common/security.md](../common/security.md) with Python specific content. + +## Secret Management + +```python +import os +from dotenv import load_dotenv + +load_dotenv() + +api_key = os.environ["OPENAI_API_KEY"] # Raises KeyError if missing +``` + +## Security Scanning + +- Use **bandit** for static security analysis: + ```bash + bandit -r src/ + ``` + +## Reference + +See skill: `django-security` for Django-specific security guidelines (if applicable). diff --git a/.claude/rules/testing.md b/.claude/rules/testing.md new file mode 100644 index 0000000..29a3a66 --- /dev/null +++ b/.claude/rules/testing.md @@ -0,0 +1,33 @@ +# Python Testing + +> This file extends [common/testing.md](../common/testing.md) with Python specific content. + +## Framework + +Use **pytest** as the testing framework. + +## Coverage + +```bash +pytest --cov=src --cov-report=term-missing +``` + +## Test Organization + +Use `pytest.mark` for test categorization: + +```python +import pytest + +@pytest.mark.unit +def test_calculate_total(): + ... + +@pytest.mark.integration +def test_database_connection(): + ... +``` + +## Reference + +See skill: `python-testing` for detailed pytest patterns and fixtures. diff --git a/.claude/skills/backend-patterns/SKILL.md b/.claude/skills/backend-patterns/SKILL.md deleted file mode 100644 index 53bf07e..0000000 --- a/.claude/skills/backend-patterns/SKILL.md +++ /dev/null @@ -1,314 +0,0 @@ -# Backend Development Patterns - -Backend architecture patterns for Python/FastAPI/PostgreSQL applications. - -## API Design - -### RESTful Structure - -``` -GET /api/v1/documents # List -GET /api/v1/documents/{id} # Get -POST /api/v1/documents # Create -PUT /api/v1/documents/{id} # Replace -PATCH /api/v1/documents/{id} # Update -DELETE /api/v1/documents/{id} # Delete - -GET /api/v1/documents?status=processed&sort=created_at&limit=20&offset=0 -``` - -### FastAPI Route Pattern - -```python -from fastapi import APIRouter, HTTPException, Depends, Query, File, UploadFile -from pydantic import BaseModel - -router = APIRouter(prefix="/api/v1", tags=["inference"]) - -@router.post("/infer", response_model=ApiResponse[InferenceResult]) -async def infer_document( - file: UploadFile = File(...), - confidence_threshold: float = Query(0.5, ge=0, le=1), - service: InferenceService = Depends(get_inference_service) -) -> ApiResponse[InferenceResult]: - result = await service.process(file, confidence_threshold) - return ApiResponse(success=True, data=result) -``` - -### Consistent Response Schema - -```python -from typing import Generic, TypeVar -T = TypeVar('T') - -class ApiResponse(BaseModel, Generic[T]): - success: bool - data: T | None = None - error: str | None = None - meta: dict | None = None -``` - -## Core Patterns - -### Repository Pattern - -```python -from typing import Protocol - -class DocumentRepository(Protocol): - def find_all(self, filters: dict | None = None) -> list[Document]: ... - def find_by_id(self, id: str) -> Document | None: ... - def create(self, data: dict) -> Document: ... - def update(self, id: str, data: dict) -> Document: ... - def delete(self, id: str) -> None: ... -``` - -### Service Layer - -```python -class InferenceService: - def __init__(self, model_path: str, use_gpu: bool = True): - self.pipeline = InferencePipeline(model_path=model_path, use_gpu=use_gpu) - - async def process(self, file: UploadFile, confidence_threshold: float) -> InferenceResult: - temp_path = self._save_temp_file(file) - try: - return self.pipeline.process_pdf(temp_path) - finally: - temp_path.unlink(missing_ok=True) -``` - -### Dependency Injection - -```python -from functools import lru_cache -from pydantic_settings import BaseSettings - -class Settings(BaseSettings): - db_host: str = "localhost" - db_password: str - model_path: str = "runs/train/invoice_fields/weights/best.pt" - class Config: - env_file = ".env" - -@lru_cache() -def get_settings() -> Settings: - return Settings() - -def get_inference_service(settings: Settings = Depends(get_settings)) -> InferenceService: - return InferenceService(model_path=settings.model_path) -``` - -## Database Patterns - -### Connection Pooling - -```python -from psycopg2 import pool -from contextlib import contextmanager - -db_pool = pool.ThreadedConnectionPool(minconn=2, maxconn=10, **db_config) - -@contextmanager -def get_db_connection(): - conn = db_pool.getconn() - try: - yield conn - finally: - db_pool.putconn(conn) -``` - -### Query Optimization - -```python -# GOOD: Select only needed columns -cur.execute(""" - SELECT id, status, fields->>'InvoiceNumber' as invoice_number - FROM documents WHERE status = %s - ORDER BY created_at DESC LIMIT %s -""", ('processed', 10)) - -# BAD: SELECT * FROM documents -``` - -### N+1 Prevention - -```python -# BAD: N+1 queries -for doc in documents: - doc.labels = get_labels(doc.id) # N queries - -# GOOD: Batch fetch with JOIN -cur.execute(""" - SELECT d.id, d.status, array_agg(l.label) as labels - FROM documents d - LEFT JOIN document_labels l ON d.id = l.document_id - GROUP BY d.id, d.status -""") -``` - -### Transaction Pattern - -```python -def create_document_with_labels(doc_data: dict, labels: list[dict]) -> str: - with get_db_connection() as conn: - try: - with conn.cursor() as cur: - cur.execute("INSERT INTO documents ... RETURNING id", ...) - doc_id = cur.fetchone()[0] - for label in labels: - cur.execute("INSERT INTO document_labels ...", ...) - conn.commit() - return doc_id - except Exception: - conn.rollback() - raise -``` - -## Caching - -```python -from cachetools import TTLCache - -_cache = TTLCache(maxsize=1000, ttl=300) - -def get_document_cached(doc_id: str) -> Document | None: - if doc_id in _cache: - return _cache[doc_id] - doc = repo.find_by_id(doc_id) - if doc: - _cache[doc_id] = doc - return doc -``` - -## Error Handling - -### Exception Hierarchy - -```python -class AppError(Exception): - def __init__(self, message: str, status_code: int = 500): - self.message = message - self.status_code = status_code - -class NotFoundError(AppError): - def __init__(self, resource: str, id: str): - super().__init__(f"{resource} not found: {id}", 404) - -class ValidationError(AppError): - def __init__(self, message: str): - super().__init__(message, 400) -``` - -### FastAPI Exception Handler - -```python -@app.exception_handler(AppError) -async def app_error_handler(request: Request, exc: AppError): - return JSONResponse(status_code=exc.status_code, content={"success": False, "error": exc.message}) - -@app.exception_handler(Exception) -async def generic_error_handler(request: Request, exc: Exception): - logger.error(f"Unexpected error: {exc}", exc_info=True) - return JSONResponse(status_code=500, content={"success": False, "error": "Internal server error"}) -``` - -### Retry with Backoff - -```python -async def retry_with_backoff(fn, max_retries: int = 3, base_delay: float = 1.0): - last_error = None - for attempt in range(max_retries): - try: - return await fn() if asyncio.iscoroutinefunction(fn) else fn() - except Exception as e: - last_error = e - if attempt < max_retries - 1: - await asyncio.sleep(base_delay * (2 ** attempt)) - raise last_error -``` - -## Rate Limiting - -```python -from time import time -from collections import defaultdict - -class RateLimiter: - def __init__(self): - self.requests: dict[str, list[float]] = defaultdict(list) - - def check_limit(self, identifier: str, max_requests: int, window_sec: int) -> bool: - now = time() - self.requests[identifier] = [t for t in self.requests[identifier] if now - t < window_sec] - if len(self.requests[identifier]) >= max_requests: - return False - self.requests[identifier].append(now) - return True - -limiter = RateLimiter() - -@app.middleware("http") -async def rate_limit_middleware(request: Request, call_next): - ip = request.client.host - if not limiter.check_limit(ip, max_requests=100, window_sec=60): - return JSONResponse(status_code=429, content={"error": "Rate limit exceeded"}) - return await call_next(request) -``` - -## Logging & Middleware - -### Request Logging - -```python -@app.middleware("http") -async def log_requests(request: Request, call_next): - request_id = str(uuid.uuid4())[:8] - start_time = time.time() - logger.info(f"[{request_id}] {request.method} {request.url.path}") - response = await call_next(request) - duration_ms = (time.time() - start_time) * 1000 - logger.info(f"[{request_id}] Completed {response.status_code} in {duration_ms:.2f}ms") - return response -``` - -### Structured Logging - -```python -class JSONFormatter(logging.Formatter): - def format(self, record): - return json.dumps({ - "timestamp": datetime.utcnow().isoformat(), - "level": record.levelname, - "message": record.getMessage(), - "module": record.module, - }) -``` - -## Background Tasks - -```python -from fastapi import BackgroundTasks - -def send_notification(document_id: str, status: str): - logger.info(f"Notification: {document_id} -> {status}") - -@router.post("/infer") -async def infer(file: UploadFile, background_tasks: BackgroundTasks): - result = await process_document(file) - background_tasks.add_task(send_notification, result.document_id, "completed") - return result -``` - -## Key Principles - -- Repository pattern: Abstract data access -- Service layer: Business logic separated from routes -- Dependency injection via `Depends()` -- Connection pooling for database -- Parameterized queries only (no f-strings in SQL) -- Batch fetch to prevent N+1 -- Consistent `ApiResponse[T]` format -- Exception hierarchy with proper status codes -- Rate limit by IP -- Structured logging with request ID \ No newline at end of file diff --git a/.claude/skills/coding-standards/SKILL.md b/.claude/skills/coding-standards/SKILL.md deleted file mode 100644 index 4bb9b71..0000000 --- a/.claude/skills/coding-standards/SKILL.md +++ /dev/null @@ -1,665 +0,0 @@ ---- -name: coding-standards -description: Universal coding standards, best practices, and patterns for Python, FastAPI, and data processing development. ---- - -# Coding Standards & Best Practices - -Python coding standards for the Invoice Master project. - -## Code Quality Principles - -### 1. Readability First -- Code is read more than written -- Clear variable and function names -- Self-documenting code preferred over comments -- Consistent formatting (follow PEP 8) - -### 2. KISS (Keep It Simple, Stupid) -- Simplest solution that works -- Avoid over-engineering -- No premature optimization -- Easy to understand > clever code - -### 3. DRY (Don't Repeat Yourself) -- Extract common logic into functions -- Create reusable utilities -- Share modules across the codebase -- Avoid copy-paste programming - -### 4. YAGNI (You Aren't Gonna Need It) -- Don't build features before they're needed -- Avoid speculative generality -- Add complexity only when required -- Start simple, refactor when needed - -## Python Standards - -### Variable Naming - -```python -# GOOD: Descriptive names -invoice_number = "INV-2024-001" -is_valid_document = True -total_confidence_score = 0.95 - -# BAD: Unclear names -inv = "INV-2024-001" -flag = True -x = 0.95 -``` - -### Function Naming - -```python -# GOOD: Verb-noun pattern with type hints -def extract_invoice_fields(pdf_path: Path) -> dict[str, str]: - """Extract fields from invoice PDF.""" - ... - -def calculate_confidence(predictions: list[float]) -> float: - """Calculate average confidence score.""" - ... - -def is_valid_bankgiro(value: str) -> bool: - """Check if value is valid Bankgiro number.""" - ... - -# BAD: Unclear or noun-only -def invoice(path): - ... - -def confidence(p): - ... - -def bankgiro(v): - ... -``` - -### Type Hints (REQUIRED) - -```python -# GOOD: Full type annotations -from typing import Optional -from pathlib import Path -from dataclasses import dataclass - -@dataclass -class InferenceResult: - document_id: str - fields: dict[str, str] - confidence: dict[str, float] - processing_time_ms: float - -def process_document( - pdf_path: Path, - confidence_threshold: float = 0.5 -) -> InferenceResult: - """Process PDF and return extracted fields.""" - ... - -# BAD: No type hints -def process_document(pdf_path, confidence_threshold=0.5): - ... -``` - -### Immutability Pattern (CRITICAL) - -```python -# GOOD: Create new objects, don't mutate -def update_fields(fields: dict[str, str], updates: dict[str, str]) -> dict[str, str]: - return {**fields, **updates} - -def add_item(items: list[str], new_item: str) -> list[str]: - return [*items, new_item] - -# BAD: Direct mutation -def update_fields(fields: dict[str, str], updates: dict[str, str]) -> dict[str, str]: - fields.update(updates) # MUTATION! - return fields - -def add_item(items: list[str], new_item: str) -> list[str]: - items.append(new_item) # MUTATION! - return items -``` - -### Error Handling - -```python -import logging - -logger = logging.getLogger(__name__) - -# GOOD: Comprehensive error handling with logging -def load_model(model_path: Path) -> Model: - """Load YOLO model from path.""" - try: - if not model_path.exists(): - raise FileNotFoundError(f"Model not found: {model_path}") - - model = YOLO(str(model_path)) - logger.info(f"Model loaded: {model_path}") - return model - except Exception as e: - logger.error(f"Failed to load model: {e}") - raise RuntimeError(f"Model loading failed: {model_path}") from e - -# BAD: No error handling -def load_model(model_path): - return YOLO(str(model_path)) - -# BAD: Bare except -def load_model(model_path): - try: - return YOLO(str(model_path)) - except: # Never use bare except! - return None -``` - -### Async Best Practices - -```python -import asyncio - -# GOOD: Parallel execution when possible -async def process_batch(pdf_paths: list[Path]) -> list[InferenceResult]: - tasks = [process_document(path) for path in pdf_paths] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Handle exceptions - valid_results = [] - for path, result in zip(pdf_paths, results): - if isinstance(result, Exception): - logger.error(f"Failed to process {path}: {result}") - else: - valid_results.append(result) - return valid_results - -# BAD: Sequential when unnecessary -async def process_batch(pdf_paths: list[Path]) -> list[InferenceResult]: - results = [] - for path in pdf_paths: - result = await process_document(path) - results.append(result) - return results -``` - -### Context Managers - -```python -from contextlib import contextmanager -from pathlib import Path -import tempfile - -# GOOD: Proper resource management -@contextmanager -def temp_pdf_copy(pdf_path: Path): - """Create temporary copy of PDF for processing.""" - with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp: - tmp.write(pdf_path.read_bytes()) - tmp_path = Path(tmp.name) - try: - yield tmp_path - finally: - tmp_path.unlink(missing_ok=True) - -# Usage -with temp_pdf_copy(original_pdf) as tmp_pdf: - result = process_pdf(tmp_pdf) -``` - -## FastAPI Best Practices - -### Route Structure - -```python -from fastapi import APIRouter, HTTPException, Depends, Query, File, UploadFile -from pydantic import BaseModel - -router = APIRouter(prefix="/api/v1", tags=["inference"]) - -class InferenceResponse(BaseModel): - success: bool - document_id: str - fields: dict[str, str] - confidence: dict[str, float] - processing_time_ms: float - -@router.post("/infer", response_model=InferenceResponse) -async def infer_document( - file: UploadFile = File(...), - confidence_threshold: float = Query(0.5, ge=0.0, le=1.0) -) -> InferenceResponse: - """Process invoice PDF and extract fields.""" - if not file.filename.endswith(".pdf"): - raise HTTPException(status_code=400, detail="Only PDF files accepted") - - result = await inference_service.process(file, confidence_threshold) - return InferenceResponse( - success=True, - document_id=result.document_id, - fields=result.fields, - confidence=result.confidence, - processing_time_ms=result.processing_time_ms - ) -``` - -### Input Validation with Pydantic - -```python -from pydantic import BaseModel, Field, field_validator -from datetime import date -import re - -class InvoiceData(BaseModel): - invoice_number: str = Field(..., min_length=1, max_length=50) - invoice_date: date - amount: float = Field(..., gt=0) - bankgiro: str | None = None - ocr_number: str | None = None - - @field_validator("bankgiro") - @classmethod - def validate_bankgiro(cls, v: str | None) -> str | None: - if v is None: - return None - # Bankgiro: 7-8 digits - cleaned = re.sub(r"[^0-9]", "", v) - if not (7 <= len(cleaned) <= 8): - raise ValueError("Bankgiro must be 7-8 digits") - return cleaned - - @field_validator("ocr_number") - @classmethod - def validate_ocr(cls, v: str | None) -> str | None: - if v is None: - return None - # OCR: 2-25 digits - cleaned = re.sub(r"[^0-9]", "", v) - if not (2 <= len(cleaned) <= 25): - raise ValueError("OCR must be 2-25 digits") - return cleaned -``` - -### Response Format - -```python -from pydantic import BaseModel -from typing import Generic, TypeVar - -T = TypeVar("T") - -class ApiResponse(BaseModel, Generic[T]): - success: bool - data: T | None = None - error: str | None = None - meta: dict | None = None - -# Success response -return ApiResponse( - success=True, - data=result, - meta={"processing_time_ms": elapsed_ms} -) - -# Error response -return ApiResponse( - success=False, - error="Invalid PDF format" -) -``` - -## File Organization - -### Project Structure - -``` -src/ -├── cli/ # Command-line interfaces -│ ├── autolabel.py -│ ├── train.py -│ └── infer.py -├── pdf/ # PDF processing -│ ├── extractor.py -│ └── renderer.py -├── ocr/ # OCR processing -│ ├── paddle_ocr.py -│ └── machine_code_parser.py -├── inference/ # Inference pipeline -│ ├── pipeline.py -│ ├── yolo_detector.py -│ └── field_extractor.py -├── normalize/ # Field normalization -│ ├── base.py -│ ├── date_normalizer.py -│ └── amount_normalizer.py -├── web/ # FastAPI application -│ ├── app.py -│ ├── routes.py -│ ├── services.py -│ └── schemas.py -└── utils/ # Shared utilities - ├── validators.py - ├── text_cleaner.py - └── logging.py -tests/ # Mirror of src structure - ├── test_pdf/ - ├── test_ocr/ - └── test_inference/ -``` - -### File Naming - -``` -src/ocr/paddle_ocr.py # snake_case for modules -src/inference/yolo_detector.py # snake_case for modules -tests/test_paddle_ocr.py # test_ prefix for tests -config.py # snake_case for config -``` - -### Module Size Guidelines - -- **Maximum**: 800 lines per file -- **Typical**: 200-400 lines per file -- **Functions**: Max 50 lines each -- Extract utilities when modules grow too large - -## Comments & Documentation - -### When to Comment - -```python -# GOOD: Explain WHY, not WHAT -# Swedish Bankgiro uses Luhn algorithm with weight [1,2,1,2...] -def validate_bankgiro_checksum(bankgiro: str) -> bool: - ... - -# Payment line format: 7 groups separated by #, checksum at end -def parse_payment_line(line: str) -> PaymentLineData: - ... - -# BAD: Stating the obvious -# Increment counter by 1 -count += 1 - -# Set name to user's name -name = user.name -``` - -### Docstrings for Public APIs - -```python -def extract_invoice_fields( - pdf_path: Path, - confidence_threshold: float = 0.5, - use_gpu: bool = True -) -> InferenceResult: - """Extract structured fields from Swedish invoice PDF. - - Uses YOLOv11 for field detection and PaddleOCR for text extraction. - Applies field-specific normalization and validation. - - Args: - pdf_path: Path to the invoice PDF file. - confidence_threshold: Minimum confidence for field detection (0.0-1.0). - use_gpu: Whether to use GPU acceleration. - - Returns: - InferenceResult containing extracted fields and confidence scores. - - Raises: - FileNotFoundError: If PDF file doesn't exist. - ProcessingError: If OCR or detection fails. - - Example: - >>> result = extract_invoice_fields(Path("invoice.pdf")) - >>> print(result.fields["invoice_number"]) - "INV-2024-001" - """ - ... -``` - -## Performance Best Practices - -### Caching - -```python -from functools import lru_cache -from cachetools import TTLCache - -# Static data: LRU cache -@lru_cache(maxsize=100) -def get_field_config(field_name: str) -> FieldConfig: - """Load field configuration (cached).""" - return load_config(field_name) - -# Dynamic data: TTL cache -_document_cache = TTLCache(maxsize=1000, ttl=300) # 5 minutes - -def get_document_cached(doc_id: str) -> Document | None: - if doc_id in _document_cache: - return _document_cache[doc_id] - - doc = repo.find_by_id(doc_id) - if doc: - _document_cache[doc_id] = doc - return doc -``` - -### Database Queries - -```python -# GOOD: Select only needed columns -cur.execute(""" - SELECT id, status, fields->>'invoice_number' - FROM documents - WHERE status = %s - LIMIT %s -""", ('processed', 10)) - -# BAD: Select everything -cur.execute("SELECT * FROM documents") - -# GOOD: Batch operations -cur.executemany( - "INSERT INTO labels (doc_id, field, value) VALUES (%s, %s, %s)", - [(doc_id, f, v) for f, v in fields.items()] -) - -# BAD: Individual inserts in loop -for field, value in fields.items(): - cur.execute("INSERT INTO labels ...", (doc_id, field, value)) -``` - -### Lazy Loading - -```python -class InferencePipeline: - def __init__(self, model_path: Path): - self.model_path = model_path - self._model: YOLO | None = None - self._ocr: PaddleOCR | None = None - - @property - def model(self) -> YOLO: - """Lazy load YOLO model.""" - if self._model is None: - self._model = YOLO(str(self.model_path)) - return self._model - - @property - def ocr(self) -> PaddleOCR: - """Lazy load PaddleOCR.""" - if self._ocr is None: - self._ocr = PaddleOCR(use_angle_cls=True, lang="latin") - return self._ocr -``` - -## Testing Standards - -### Test Structure (AAA Pattern) - -```python -def test_extract_bankgiro_valid(): - # Arrange - text = "Bankgiro: 123-4567" - - # Act - result = extract_bankgiro(text) - - # Assert - assert result == "1234567" - -def test_extract_bankgiro_invalid_returns_none(): - # Arrange - text = "No bankgiro here" - - # Act - result = extract_bankgiro(text) - - # Assert - assert result is None -``` - -### Test Naming - -```python -# GOOD: Descriptive test names -def test_parse_payment_line_extracts_all_fields(): ... -def test_parse_payment_line_handles_missing_checksum(): ... -def test_validate_ocr_returns_false_for_invalid_checksum(): ... - -# BAD: Vague test names -def test_parse(): ... -def test_works(): ... -def test_payment_line(): ... -``` - -### Fixtures - -```python -import pytest -from pathlib import Path - -@pytest.fixture -def sample_invoice_pdf(tmp_path: Path) -> Path: - """Create sample invoice PDF for testing.""" - pdf_path = tmp_path / "invoice.pdf" - # Create test PDF... - return pdf_path - -@pytest.fixture -def inference_pipeline(sample_model_path: Path) -> InferencePipeline: - """Create inference pipeline with test model.""" - return InferencePipeline(sample_model_path) - -def test_process_invoice(inference_pipeline, sample_invoice_pdf): - result = inference_pipeline.process(sample_invoice_pdf) - assert result.fields.get("invoice_number") is not None -``` - -## Code Smell Detection - -### 1. Long Functions - -```python -# BAD: Function > 50 lines -def process_document(): - # 100 lines of code... - -# GOOD: Split into smaller functions -def process_document(pdf_path: Path) -> InferenceResult: - image = render_pdf(pdf_path) - detections = detect_fields(image) - ocr_results = extract_text(image, detections) - fields = normalize_fields(ocr_results) - return build_result(fields) -``` - -### 2. Deep Nesting - -```python -# BAD: 5+ levels of nesting -if document: - if document.is_valid: - if document.has_fields: - if field in document.fields: - if document.fields[field]: - # Do something - -# GOOD: Early returns -if not document: - return None -if not document.is_valid: - return None -if not document.has_fields: - return None -if field not in document.fields: - return None -if not document.fields[field]: - return None - -# Do something -``` - -### 3. Magic Numbers - -```python -# BAD: Unexplained numbers -if confidence > 0.5: - ... -time.sleep(3) - -# GOOD: Named constants -CONFIDENCE_THRESHOLD = 0.5 -RETRY_DELAY_SECONDS = 3 - -if confidence > CONFIDENCE_THRESHOLD: - ... -time.sleep(RETRY_DELAY_SECONDS) -``` - -### 4. Mutable Default Arguments - -```python -# BAD: Mutable default argument -def process_fields(fields: list = []): # DANGEROUS! - fields.append("new_field") - return fields - -# GOOD: Use None as default -def process_fields(fields: list | None = None) -> list: - if fields is None: - fields = [] - return [*fields, "new_field"] -``` - -## Logging Standards - -```python -import logging - -# Module-level logger -logger = logging.getLogger(__name__) - -# GOOD: Appropriate log levels -logger.debug("Processing document: %s", doc_id) -logger.info("Document processed successfully: %s", doc_id) -logger.warning("Low confidence score: %.2f", confidence) -logger.error("Failed to process document: %s", error) - -# GOOD: Structured logging with extra data -logger.info( - "Inference complete", - extra={ - "document_id": doc_id, - "field_count": len(fields), - "processing_time_ms": elapsed_ms - } -) - -# BAD: Using print() -print(f"Processing {doc_id}") # Never in production! -``` - -**Remember**: Code quality is not negotiable. Clear, maintainable Python code with proper type hints enables confident development and refactoring. diff --git a/.claude/skills/continuous-learning/SKILL.md b/.claude/skills/continuous-learning/SKILL.md deleted file mode 100644 index 84a88dd..0000000 --- a/.claude/skills/continuous-learning/SKILL.md +++ /dev/null @@ -1,80 +0,0 @@ ---- -name: continuous-learning -description: Automatically extract reusable patterns from Claude Code sessions and save them as learned skills for future use. ---- - -# Continuous Learning Skill - -Automatically evaluates Claude Code sessions on end to extract reusable patterns that can be saved as learned skills. - -## How It Works - -This skill runs as a **Stop hook** at the end of each session: - -1. **Session Evaluation**: Checks if session has enough messages (default: 10+) -2. **Pattern Detection**: Identifies extractable patterns from the session -3. **Skill Extraction**: Saves useful patterns to `~/.claude/skills/learned/` - -## Configuration - -Edit `config.json` to customize: - -```json -{ - "min_session_length": 10, - "extraction_threshold": "medium", - "auto_approve": false, - "learned_skills_path": "~/.claude/skills/learned/", - "patterns_to_detect": [ - "error_resolution", - "user_corrections", - "workarounds", - "debugging_techniques", - "project_specific" - ], - "ignore_patterns": [ - "simple_typos", - "one_time_fixes", - "external_api_issues" - ] -} -``` - -## Pattern Types - -| Pattern | Description | -|---------|-------------| -| `error_resolution` | How specific errors were resolved | -| `user_corrections` | Patterns from user corrections | -| `workarounds` | Solutions to framework/library quirks | -| `debugging_techniques` | Effective debugging approaches | -| `project_specific` | Project-specific conventions | - -## Hook Setup - -Add to your `~/.claude/settings.json`: - -```json -{ - "hooks": { - "Stop": [{ - "matcher": "*", - "hooks": [{ - "type": "command", - "command": "~/.claude/skills/continuous-learning/evaluate-session.sh" - }] - }] - } -} -``` - -## Why Stop Hook? - -- **Lightweight**: Runs once at session end -- **Non-blocking**: Doesn't add latency to every message -- **Complete context**: Has access to full session transcript - -## Related - -- [The Longform Guide](https://x.com/affaanmustafa/status/2014040193557471352) - Section on continuous learning -- `/learn` command - Manual pattern extraction mid-session diff --git a/.claude/skills/continuous-learning/config.json b/.claude/skills/continuous-learning/config.json deleted file mode 100644 index 1094b7e..0000000 --- a/.claude/skills/continuous-learning/config.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "min_session_length": 10, - "extraction_threshold": "medium", - "auto_approve": false, - "learned_skills_path": "~/.claude/skills/learned/", - "patterns_to_detect": [ - "error_resolution", - "user_corrections", - "workarounds", - "debugging_techniques", - "project_specific" - ], - "ignore_patterns": [ - "simple_typos", - "one_time_fixes", - "external_api_issues" - ] -} diff --git a/.claude/skills/continuous-learning/evaluate-session.sh b/.claude/skills/continuous-learning/evaluate-session.sh deleted file mode 100644 index f13208a..0000000 --- a/.claude/skills/continuous-learning/evaluate-session.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Continuous Learning - Session Evaluator -# Runs on Stop hook to extract reusable patterns from Claude Code sessions -# -# Why Stop hook instead of UserPromptSubmit: -# - Stop runs once at session end (lightweight) -# - UserPromptSubmit runs every message (heavy, adds latency) -# -# Hook config (in ~/.claude/settings.json): -# { -# "hooks": { -# "Stop": [{ -# "matcher": "*", -# "hooks": [{ -# "type": "command", -# "command": "~/.claude/skills/continuous-learning/evaluate-session.sh" -# }] -# }] -# } -# } -# -# Patterns to detect: error_resolution, debugging_techniques, workarounds, project_specific -# Patterns to ignore: simple_typos, one_time_fixes, external_api_issues -# Extracted skills saved to: ~/.claude/skills/learned/ - -set -e - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -CONFIG_FILE="$SCRIPT_DIR/config.json" -LEARNED_SKILLS_PATH="${HOME}/.claude/skills/learned" -MIN_SESSION_LENGTH=10 - -# Load config if exists -if [ -f "$CONFIG_FILE" ]; then - MIN_SESSION_LENGTH=$(jq -r '.min_session_length // 10' "$CONFIG_FILE") - LEARNED_SKILLS_PATH=$(jq -r '.learned_skills_path // "~/.claude/skills/learned/"' "$CONFIG_FILE" | sed "s|~|$HOME|") -fi - -# Ensure learned skills directory exists -mkdir -p "$LEARNED_SKILLS_PATH" - -# Get transcript path from environment (set by Claude Code) -transcript_path="${CLAUDE_TRANSCRIPT_PATH:-}" - -if [ -z "$transcript_path" ] || [ ! -f "$transcript_path" ]; then - exit 0 -fi - -# Count messages in session -message_count=$(grep -c '"type":"user"' "$transcript_path" 2>/dev/null || echo "0") - -# Skip short sessions -if [ "$message_count" -lt "$MIN_SESSION_LENGTH" ]; then - echo "[ContinuousLearning] Session too short ($message_count messages), skipping" >&2 - exit 0 -fi - -# Signal to Claude that session should be evaluated for extractable patterns -echo "[ContinuousLearning] Session has $message_count messages - evaluate for extractable patterns" >&2 -echo "[ContinuousLearning] Save learned skills to: $LEARNED_SKILLS_PATH" >&2 diff --git a/.claude/skills/eval-harness/SKILL.md b/.claude/skills/eval-harness/SKILL.md deleted file mode 100644 index 522937d..0000000 --- a/.claude/skills/eval-harness/SKILL.md +++ /dev/null @@ -1,221 +0,0 @@ -# Eval Harness Skill - -A formal evaluation framework for Claude Code sessions, implementing eval-driven development (EDD) principles. - -## Philosophy - -Eval-Driven Development treats evals as the "unit tests of AI development": -- Define expected behavior BEFORE implementation -- Run evals continuously during development -- Track regressions with each change -- Use pass@k metrics for reliability measurement - -## Eval Types - -### Capability Evals -Test if Claude can do something it couldn't before: -```markdown -[CAPABILITY EVAL: feature-name] -Task: Description of what Claude should accomplish -Success Criteria: - - [ ] Criterion 1 - - [ ] Criterion 2 - - [ ] Criterion 3 -Expected Output: Description of expected result -``` - -### Regression Evals -Ensure changes don't break existing functionality: -```markdown -[REGRESSION EVAL: feature-name] -Baseline: SHA or checkpoint name -Tests: - - existing-test-1: PASS/FAIL - - existing-test-2: PASS/FAIL - - existing-test-3: PASS/FAIL -Result: X/Y passed (previously Y/Y) -``` - -## Grader Types - -### 1. Code-Based Grader -Deterministic checks using code: -```bash -# Check if file contains expected pattern -grep -q "export function handleAuth" src/auth.ts && echo "PASS" || echo "FAIL" - -# Check if tests pass -npm test -- --testPathPattern="auth" && echo "PASS" || echo "FAIL" - -# Check if build succeeds -npm run build && echo "PASS" || echo "FAIL" -``` - -### 2. Model-Based Grader -Use Claude to evaluate open-ended outputs: -```markdown -[MODEL GRADER PROMPT] -Evaluate the following code change: -1. Does it solve the stated problem? -2. Is it well-structured? -3. Are edge cases handled? -4. Is error handling appropriate? - -Score: 1-5 (1=poor, 5=excellent) -Reasoning: [explanation] -``` - -### 3. Human Grader -Flag for manual review: -```markdown -[HUMAN REVIEW REQUIRED] -Change: Description of what changed -Reason: Why human review is needed -Risk Level: LOW/MEDIUM/HIGH -``` - -## Metrics - -### pass@k -"At least one success in k attempts" -- pass@1: First attempt success rate -- pass@3: Success within 3 attempts -- Typical target: pass@3 > 90% - -### pass^k -"All k trials succeed" -- Higher bar for reliability -- pass^3: 3 consecutive successes -- Use for critical paths - -## Eval Workflow - -### 1. Define (Before Coding) -```markdown -## EVAL DEFINITION: feature-xyz - -### Capability Evals -1. Can create new user account -2. Can validate email format -3. Can hash password securely - -### Regression Evals -1. Existing login still works -2. Session management unchanged -3. Logout flow intact - -### Success Metrics -- pass@3 > 90% for capability evals -- pass^3 = 100% for regression evals -``` - -### 2. Implement -Write code to pass the defined evals. - -### 3. Evaluate -```bash -# Run capability evals -[Run each capability eval, record PASS/FAIL] - -# Run regression evals -npm test -- --testPathPattern="existing" - -# Generate report -``` - -### 4. Report -```markdown -EVAL REPORT: feature-xyz -======================== - -Capability Evals: - create-user: PASS (pass@1) - validate-email: PASS (pass@2) - hash-password: PASS (pass@1) - Overall: 3/3 passed - -Regression Evals: - login-flow: PASS - session-mgmt: PASS - logout-flow: PASS - Overall: 3/3 passed - -Metrics: - pass@1: 67% (2/3) - pass@3: 100% (3/3) - -Status: READY FOR REVIEW -``` - -## Integration Patterns - -### Pre-Implementation -``` -/eval define feature-name -``` -Creates eval definition file at `.claude/evals/feature-name.md` - -### During Implementation -``` -/eval check feature-name -``` -Runs current evals and reports status - -### Post-Implementation -``` -/eval report feature-name -``` -Generates full eval report - -## Eval Storage - -Store evals in project: -``` -.claude/ - evals/ - feature-xyz.md # Eval definition - feature-xyz.log # Eval run history - baseline.json # Regression baselines -``` - -## Best Practices - -1. **Define evals BEFORE coding** - Forces clear thinking about success criteria -2. **Run evals frequently** - Catch regressions early -3. **Track pass@k over time** - Monitor reliability trends -4. **Use code graders when possible** - Deterministic > probabilistic -5. **Human review for security** - Never fully automate security checks -6. **Keep evals fast** - Slow evals don't get run -7. **Version evals with code** - Evals are first-class artifacts - -## Example: Adding Authentication - -```markdown -## EVAL: add-authentication - -### Phase 1: Define (10 min) -Capability Evals: -- [ ] User can register with email/password -- [ ] User can login with valid credentials -- [ ] Invalid credentials rejected with proper error -- [ ] Sessions persist across page reloads -- [ ] Logout clears session - -Regression Evals: -- [ ] Public routes still accessible -- [ ] API responses unchanged -- [ ] Database schema compatible - -### Phase 2: Implement (varies) -[Write code] - -### Phase 3: Evaluate -Run: /eval check add-authentication - -### Phase 4: Report -EVAL REPORT: add-authentication -============================== -Capability: 5/5 passed (pass@3: 100%) -Regression: 3/3 passed (pass^3: 100%) -Status: SHIP IT -``` diff --git a/.claude/skills/frontend-patterns/SKILL.md b/.claude/skills/frontend-patterns/SKILL.md deleted file mode 100644 index 05a796a..0000000 --- a/.claude/skills/frontend-patterns/SKILL.md +++ /dev/null @@ -1,631 +0,0 @@ ---- -name: frontend-patterns -description: Frontend development patterns for React, Next.js, state management, performance optimization, and UI best practices. ---- - -# Frontend Development Patterns - -Modern frontend patterns for React, Next.js, and performant user interfaces. - -## Component Patterns - -### Composition Over Inheritance - -```typescript -// ✅ GOOD: Component composition -interface CardProps { - children: React.ReactNode - variant?: 'default' | 'outlined' -} - -export function Card({ children, variant = 'default' }: CardProps) { - return
{children}
-} - -export function CardHeader({ children }: { children: React.ReactNode }) { - return
{children}
-} - -export function CardBody({ children }: { children: React.ReactNode }) { - return
{children}
-} - -// Usage - - Title - Content - -``` - -### Compound Components - -```typescript -interface TabsContextValue { - activeTab: string - setActiveTab: (tab: string) => void -} - -const TabsContext = createContext(undefined) - -export function Tabs({ children, defaultTab }: { - children: React.ReactNode - defaultTab: string -}) { - const [activeTab, setActiveTab] = useState(defaultTab) - - return ( - - {children} - - ) -} - -export function TabList({ children }: { children: React.ReactNode }) { - return
{children}
-} - -export function Tab({ id, children }: { id: string, children: React.ReactNode }) { - const context = useContext(TabsContext) - if (!context) throw new Error('Tab must be used within Tabs') - - return ( - - ) -} - -// Usage - - - Overview - Details - - -``` - -### Render Props Pattern - -```typescript -interface DataLoaderProps { - url: string - children: (data: T | null, loading: boolean, error: Error | null) => React.ReactNode -} - -export function DataLoader({ url, children }: DataLoaderProps) { - const [data, setData] = useState(null) - const [loading, setLoading] = useState(true) - const [error, setError] = useState(null) - - useEffect(() => { - fetch(url) - .then(res => res.json()) - .then(setData) - .catch(setError) - .finally(() => setLoading(false)) - }, [url]) - - return <>{children(data, loading, error)} -} - -// Usage - url="/api/markets"> - {(markets, loading, error) => { - if (loading) return - if (error) return - return - }} - -``` - -## Custom Hooks Patterns - -### State Management Hook - -```typescript -export function useToggle(initialValue = false): [boolean, () => void] { - const [value, setValue] = useState(initialValue) - - const toggle = useCallback(() => { - setValue(v => !v) - }, []) - - return [value, toggle] -} - -// Usage -const [isOpen, toggleOpen] = useToggle() -``` - -### Async Data Fetching Hook - -```typescript -interface UseQueryOptions { - onSuccess?: (data: T) => void - onError?: (error: Error) => void - enabled?: boolean -} - -export function useQuery( - key: string, - fetcher: () => Promise, - options?: UseQueryOptions -) { - const [data, setData] = useState(null) - const [error, setError] = useState(null) - const [loading, setLoading] = useState(false) - - const refetch = useCallback(async () => { - setLoading(true) - setError(null) - - try { - const result = await fetcher() - setData(result) - options?.onSuccess?.(result) - } catch (err) { - const error = err as Error - setError(error) - options?.onError?.(error) - } finally { - setLoading(false) - } - }, [fetcher, options]) - - useEffect(() => { - if (options?.enabled !== false) { - refetch() - } - }, [key, refetch, options?.enabled]) - - return { data, error, loading, refetch } -} - -// Usage -const { data: markets, loading, error, refetch } = useQuery( - 'markets', - () => fetch('/api/markets').then(r => r.json()), - { - onSuccess: data => console.log('Fetched', data.length, 'markets'), - onError: err => console.error('Failed:', err) - } -) -``` - -### Debounce Hook - -```typescript -export function useDebounce(value: T, delay: number): T { - const [debouncedValue, setDebouncedValue] = useState(value) - - useEffect(() => { - const handler = setTimeout(() => { - setDebouncedValue(value) - }, delay) - - return () => clearTimeout(handler) - }, [value, delay]) - - return debouncedValue -} - -// Usage -const [searchQuery, setSearchQuery] = useState('') -const debouncedQuery = useDebounce(searchQuery, 500) - -useEffect(() => { - if (debouncedQuery) { - performSearch(debouncedQuery) - } -}, [debouncedQuery]) -``` - -## State Management Patterns - -### Context + Reducer Pattern - -```typescript -interface State { - markets: Market[] - selectedMarket: Market | null - loading: boolean -} - -type Action = - | { type: 'SET_MARKETS'; payload: Market[] } - | { type: 'SELECT_MARKET'; payload: Market } - | { type: 'SET_LOADING'; payload: boolean } - -function reducer(state: State, action: Action): State { - switch (action.type) { - case 'SET_MARKETS': - return { ...state, markets: action.payload } - case 'SELECT_MARKET': - return { ...state, selectedMarket: action.payload } - case 'SET_LOADING': - return { ...state, loading: action.payload } - default: - return state - } -} - -const MarketContext = createContext<{ - state: State - dispatch: Dispatch -} | undefined>(undefined) - -export function MarketProvider({ children }: { children: React.ReactNode }) { - const [state, dispatch] = useReducer(reducer, { - markets: [], - selectedMarket: null, - loading: false - }) - - return ( - - {children} - - ) -} - -export function useMarkets() { - const context = useContext(MarketContext) - if (!context) throw new Error('useMarkets must be used within MarketProvider') - return context -} -``` - -## Performance Optimization - -### Memoization - -```typescript -// ✅ useMemo for expensive computations -const sortedMarkets = useMemo(() => { - return markets.sort((a, b) => b.volume - a.volume) -}, [markets]) - -// ✅ useCallback for functions passed to children -const handleSearch = useCallback((query: string) => { - setSearchQuery(query) -}, []) - -// ✅ React.memo for pure components -export const MarketCard = React.memo(({ market }) => { - return ( -
-

{market.name}

-

{market.description}

-
- ) -}) -``` - -### Code Splitting & Lazy Loading - -```typescript -import { lazy, Suspense } from 'react' - -// ✅ Lazy load heavy components -const HeavyChart = lazy(() => import('./HeavyChart')) -const ThreeJsBackground = lazy(() => import('./ThreeJsBackground')) - -export function Dashboard() { - return ( -
- }> - - - - - - -
- ) -} -``` - -### Virtualization for Long Lists - -```typescript -import { useVirtualizer } from '@tanstack/react-virtual' - -export function VirtualMarketList({ markets }: { markets: Market[] }) { - const parentRef = useRef(null) - - const virtualizer = useVirtualizer({ - count: markets.length, - getScrollElement: () => parentRef.current, - estimateSize: () => 100, // Estimated row height - overscan: 5 // Extra items to render - }) - - return ( -
-
- {virtualizer.getVirtualItems().map(virtualRow => ( -
- -
- ))} -
-
- ) -} -``` - -## Form Handling Patterns - -### Controlled Form with Validation - -```typescript -interface FormData { - name: string - description: string - endDate: string -} - -interface FormErrors { - name?: string - description?: string - endDate?: string -} - -export function CreateMarketForm() { - const [formData, setFormData] = useState({ - name: '', - description: '', - endDate: '' - }) - - const [errors, setErrors] = useState({}) - - const validate = (): boolean => { - const newErrors: FormErrors = {} - - if (!formData.name.trim()) { - newErrors.name = 'Name is required' - } else if (formData.name.length > 200) { - newErrors.name = 'Name must be under 200 characters' - } - - if (!formData.description.trim()) { - newErrors.description = 'Description is required' - } - - if (!formData.endDate) { - newErrors.endDate = 'End date is required' - } - - setErrors(newErrors) - return Object.keys(newErrors).length === 0 - } - - const handleSubmit = async (e: React.FormEvent) => { - e.preventDefault() - - if (!validate()) return - - try { - await createMarket(formData) - // Success handling - } catch (error) { - // Error handling - } - } - - return ( -
- setFormData(prev => ({ ...prev, name: e.target.value }))} - placeholder="Market name" - /> - {errors.name && {errors.name}} - - {/* Other fields */} - - -
- ) -} -``` - -## Error Boundary Pattern - -```typescript -interface ErrorBoundaryState { - hasError: boolean - error: Error | null -} - -export class ErrorBoundary extends React.Component< - { children: React.ReactNode }, - ErrorBoundaryState -> { - state: ErrorBoundaryState = { - hasError: false, - error: null - } - - static getDerivedStateFromError(error: Error): ErrorBoundaryState { - return { hasError: true, error } - } - - componentDidCatch(error: Error, errorInfo: React.ErrorInfo) { - console.error('Error boundary caught:', error, errorInfo) - } - - render() { - if (this.state.hasError) { - return ( -
-

Something went wrong

-

{this.state.error?.message}

- -
- ) - } - - return this.props.children - } -} - -// Usage - - - -``` - -## Animation Patterns - -### Framer Motion Animations - -```typescript -import { motion, AnimatePresence } from 'framer-motion' - -// ✅ List animations -export function AnimatedMarketList({ markets }: { markets: Market[] }) { - return ( - - {markets.map(market => ( - - - - ))} - - ) -} - -// ✅ Modal animations -export function Modal({ isOpen, onClose, children }: ModalProps) { - return ( - - {isOpen && ( - <> - - - {children} - - - )} - - ) -} -``` - -## Accessibility Patterns - -### Keyboard Navigation - -```typescript -export function Dropdown({ options, onSelect }: DropdownProps) { - const [isOpen, setIsOpen] = useState(false) - const [activeIndex, setActiveIndex] = useState(0) - - const handleKeyDown = (e: React.KeyboardEvent) => { - switch (e.key) { - case 'ArrowDown': - e.preventDefault() - setActiveIndex(i => Math.min(i + 1, options.length - 1)) - break - case 'ArrowUp': - e.preventDefault() - setActiveIndex(i => Math.max(i - 1, 0)) - break - case 'Enter': - e.preventDefault() - onSelect(options[activeIndex]) - setIsOpen(false) - break - case 'Escape': - setIsOpen(false) - break - } - } - - return ( -
- {/* Dropdown implementation */} -
- ) -} -``` - -### Focus Management - -```typescript -export function Modal({ isOpen, onClose, children }: ModalProps) { - const modalRef = useRef(null) - const previousFocusRef = useRef(null) - - useEffect(() => { - if (isOpen) { - // Save currently focused element - previousFocusRef.current = document.activeElement as HTMLElement - - // Focus modal - modalRef.current?.focus() - } else { - // Restore focus when closing - previousFocusRef.current?.focus() - } - }, [isOpen]) - - return isOpen ? ( -
e.key === 'Escape' && onClose()} - > - {children} -
- ) : null -} -``` - -**Remember**: Modern frontend patterns enable maintainable, performant user interfaces. Choose patterns that fit your project complexity. diff --git a/.claude/skills/product-spec-builder/SKILL.md b/.claude/skills/product-spec-builder/SKILL.md deleted file mode 100644 index f00e1ff..0000000 --- a/.claude/skills/product-spec-builder/SKILL.md +++ /dev/null @@ -1,335 +0,0 @@ ---- -name: product-spec-builder -description: 当用户表达想要开发产品、应用、工具或任何软件项目时,或者用户想要迭代现有功能、新增需求、修改产品规格时,使用此技能。0-1 阶段通过深入对话收集需求并生成 Product Spec;迭代阶段帮助用户想清楚变更内容并更新现有 Product Spec。 ---- - -[角色] - 你是废才,一位看透无数产品生死的资深产品经理。 - - 你见过太多人带着"改变世界"的妄想来找你,最后连需求都说不清楚。 - 你也见过真正能成事的人——他们不一定聪明,但足够诚实,敢于面对自己想法的漏洞。 - - 你不是来讨好用户的。你是来帮他们把脑子里的浆糊变成可执行的产品文档的。 - 如果他们的想法有问题,你会直接说。如果他们在自欺欺人,你会戳破。 - - 你的冷酷不是恶意,是效率。情绪是最好的思考燃料,而你擅长点火。 - -[任务] - **0-1 模式**:通过深入对话收集用户的产品需求,用直白甚至刺耳的追问逼迫用户想清楚,最终生成一份结构完整、细节丰富、可直接用于 AI 开发的 Product Spec 文档,并输出为 .md 文件供用户下载使用。 - - **迭代模式**:当用户在开发过程中提出新功能、修改需求或迭代想法时,通过追问帮助用户想清楚变更内容,检测与现有 Spec 的冲突,直接更新 Product Spec 文件,并自动记录变更日志。 - -[第一性原则] - **AI优先原则**:用户提出的所有功能,首先考虑如何用 AI 来实现。 - - - 遇到任何功能需求,第一反应是:这个能不能用 AI 做?能做到什么程度? - - 主动询问用户:这个功能要不要加一个「AI一键优化」或「AI智能推荐」? - - 如果用户描述的功能明显可以用 AI 增强,直接建议,不要等用户想到 - - 最终输出的 Product Spec 必须明确列出需要的 AI 能力类型 - - **简单优先原则**:复杂度是产品的敌人。 - - - 能用现成服务的,不自己造轮子 - - 每增加一个功能都要问「真的需要吗」 - - 第一版做最小可行产品,验证了再加功能 - -[技能] - - **需求挖掘**:通过开放式提问引导用户表达想法,捕捉关键信息 - - **追问深挖**:针对模糊描述追问细节,不接受"大概"、"可能"、"应该" - - **AI能力识别**:根据功能需求,识别需要的 AI 能力类型(文本、图像、语音等) - - **技术需求引导**:通过业务问题推断技术需求,帮助无编程基础的用户理解技术选择 - - **布局设计**:深入挖掘界面布局需求,确保每个页面有清晰的空间规范 - - **漏洞识别**:发现用户想法中的矛盾、遗漏、自欺欺人之处,直接指出 - - **冲突检测**:在迭代时检测新需求与现有 Spec 的冲突,主动指出并给出解决方案 - - **方案引导**:当用户不知道怎么做时,提供 2-3 个选项 + 优劣分析,逼用户选择 - - **结构化思维**:将零散信息整理为清晰的产品框架 - - **文档输出**:按照标准模板生成专业的 Product Spec,输出为 .md 文件 - -[文件结构] - ``` - product-spec-builder/ - ├── SKILL.md # 主 Skill 定义(本文件) - └── templates/ - ├── product-spec-template.md # Product Spec 输出模板 - └── changelog-template.md # 变更记录模板 - ``` - -[输出风格] - **语态**: - - 直白、冷静,偶尔带着看透世事的冷漠 - - 不奉承、不迎合、不说"这个想法很棒"之类的废话 - - 该嘲讽时嘲讽,该肯定时也会肯定(但很少) - - **原则**: - - × 绝不给模棱两可的废话 - - × 绝不假装用户的想法没问题(如果有问题就直接说) - - × 绝不浪费时间在无意义的客套上 - - ✓ 一针见血的建议,哪怕听起来刺耳 - - ✓ 用追问逼迫用户自己想清楚,而不是替他们想 - - ✓ 主动建议 AI 增强方案,不等用户开口 - - ✓ 偶尔的毒舌是为了激发思考,不是为了伤害 - - **典型表达**: - - "你说的这个功能,用户真的需要,还是你觉得他们需要?" - - "这个手动操作完全可以让 AI 来做,你为什么要让用户自己填?" - - "别跟我说'用户体验好',告诉我具体好在哪里。" - - "你现在描述的这个东西,市面上已经有十个了。你的凭什么能活?" - - "这里要不要加个 AI 一键优化?用户自己填这些参数,你觉得他们填得好吗?" - - "左边放什么右边放什么,你想清楚了吗?还是打算让开发自己猜?" - - "想清楚了?那我们继续。没想清楚?那就继续想。" - -[需求维度清单] - 在对话过程中,需要收集以下维度的信息(不必按顺序,根据对话自然推进): - - **必须收集**(没有这些,Product Spec 就是废纸): - - 产品定位:这是什么?解决什么问题?凭什么是你来做? - - 目标用户:谁会用?为什么用?不用会死吗? - - 核心功能:必须有什么功能?砍掉什么功能产品就不成立? - - 用户流程:用户怎么用?从打开到完成任务的完整路径是什么? - - AI能力需求:哪些功能需要 AI?需要哪种类型的 AI 能力? - - **尽量收集**(有这些,Product Spec 才能落地): - - 整体布局:几栏布局?左右还是上下?各区域比例多少? - - 区域内容:每个区域放什么?哪个是输入区,哪个是输出区? - - 控件规范:输入框铺满还是定宽?按钮放哪里?下拉框选项有哪些? - - 输入输出:用户输入什么?系统输出什么?格式是什么? - - 应用场景:3-5个具体场景,越具体越好 - - AI增强点:哪些地方可以加「AI一键优化」或「AI智能推荐」? - - 技术复杂度:需要用户登录吗?数据存哪里?需要服务器吗? - - **可选收集**(锦上添花): - - 技术偏好:有没有特定技术要求? - - 参考产品:有没有可以抄的对象?抄哪里,不抄哪里? - - 优先级:第一期做什么,第二期做什么? - -[对话策略] - **开场策略**: - - 不废话,直接基于用户已表达的内容开始追问 - - 让用户先倒完脑子里的东西,再开始解剖 - - **追问策略**: - - 每次只追问 1-2 个问题,问题要直击要害 - - 不接受模糊回答:"大概"、"可能"、"应该"、"用户会喜欢的" → 追问到底 - - 发现逻辑漏洞,直接指出,不留情面 - - 发现用户在自嗨,冷静泼冷水 - - 当用户说"界面你看着办"或"随便",不惯着,用具体选项逼他们决策 - - 布局必须问到具体:几栏、比例、各区域内容、控件规范 - - **方案引导策略**: - - 用户知道但没说清楚 → 继续逼问,不给方案 - - 用户真不知道 → 给 2-3 个选项 + 各自优劣,根据产品类型给针对性建议 - - 给完继续逼他选,选完继续逼下一个细节 - - 选项是工具,不是退路 - - **AI能力引导策略**: - - 每当用户描述一个功能,主动思考:这个能不能用 AI 做? - - 主动询问:"这里要不要加个 AI 一键XX?" - - 用户设计了繁琐的手动流程 → 直接建议用 AI 简化 - - 对话后期,主动总结需要的 AI 能力类型 - - **技术需求引导策略**: - - 用户没有编程基础,不直接问技术问题,通过业务场景推断技术需求 - - 遵循简单优先原则,能不加复杂度就不加 - - 用户想要的功能会大幅增加复杂度时,先劝退或建议分期 - - **确认策略**: - - 定期复述已收集的信息,发现矛盾直接质问 - - 信息够了就推进,不拖泥带水 - - 用户说"差不多了"但信息明显不够,继续问 - - **搜索策略**: - - 涉及可能变化的信息(技术、行业、竞品),先上网搜索再开口 - -[信息充足度判断] - 当以下条件满足时,可以生成 Product Spec: - - **必须满足**: - - ✅ 产品定位清晰(能用一句人话说明白这是什么) - - ✅ 目标用户明确(知道给谁用、为什么用) - - ✅ 核心功能明确(至少3个功能点,且能说清楚为什么需要) - - ✅ 用户流程清晰(至少一条完整路径,从头到尾) - - ✅ AI能力需求明确(知道哪些功能需要 AI,用什么类型的 AI) - - **尽量满足**: - - ✅ 整体布局有方向(知道大概是什么结构) - - ✅ 控件有基本规范(主要输入输出方式清楚) - - 如果「必须满足」条件未达成,继续追问,不要勉强生成一份垃圾文档。 - 如果「尽量满足」条件未达成,可以生成但标注 [待补充]。 - -[启动检查] - Skill 启动时,首先执行以下检查: - - 第一步:扫描项目目录,按优先级查找产品需求文档 - 优先级1(精确匹配):Product-Spec.md - 优先级2(扩大匹配):*spec*.md、*prd*.md、*PRD*.md、*需求*.md、*product*.md - - 匹配规则: - - 找到 1 个文件 → 直接使用 - - 找到多个候选文件 → 列出文件名问用户"你要改的是哪个?" - - 没找到 → 进入 0-1 模式 - - 第二步:判断模式 - - 找到产品需求文档 → 进入 **迭代模式** - - 没找到 → 进入 **0-1 模式** - - 第三步:执行对应流程 - - 0-1 模式:执行 [工作流程(0-1模式)] - - 迭代模式:执行 [工作流程(迭代模式)] - -[工作流程(0-1模式)] - [需求探索阶段] - 目的:让用户把脑子里的东西倒出来 - - 第一步:接住用户 - **先上网搜索**:根据用户表达的产品想法上网搜索相关信息,了解最新情况 - 基于用户已经表达的内容,直接开始追问 - 不重复问"你想做什么",用户已经说过了 - - 第二步:追问 - **先上网搜索**:根据用户表达的内容上网搜索相关信息,确保追问基于最新知识 - 针对模糊、矛盾、自嗨的地方,直接追问 - 每次1-2个问题,问到点子上 - 同时思考哪些功能可以用 AI 增强 - - 第三步:阶段性确认 - 复述理解,确认没跑偏 - 有问题当场纠正 - - [需求完善阶段] - 目的:填补漏洞,逼用户想清楚,确定 AI 能力需求和界面布局 - - 第一步:漏洞识别 - 对照 [需求维度清单],找出缺失的关键信息 - - 第二步:逼问 - **先上网搜索**:针对缺失项上网搜索相关信息,确保给出的建议和方案是最新的 - 针对缺失项设计问题 - 不接受敷衍回答 - 布局问题要问到具体:几栏、比例、各区域内容、控件规范 - - 第三步:AI能力引导 - **先上网搜索**:上网搜索最新的 AI 能力和最佳实践,确保建议不过时 - 主动询问用户: - - "这个功能要不要加 AI 一键优化?" - - "这里让用户手动填,还是让 AI 智能推荐?" - 根据用户需求识别需要的 AI 能力类型(文本生成、图像生成、图像识别等) - - 第四步:技术复杂度评估 - **先上网搜索**:上网搜索相关技术方案,确保建议是最新的 - 根据 [技术需求引导] 策略,通过业务问题判断技术复杂度 - 如果用户想要的功能会大幅增加复杂度,先劝退或建议分期 - 确保用户理解技术选择的影响 - - 第五步:充足度判断 - 对照 [信息充足度判断] - 「必须满足」都达成 → 提议生成 - 未达成 → 继续问,不惯着 - - [文档生成阶段] - 目的:输出可用的 Product Spec 文件 - - 第一步:整理 - 将对话内容按输出模板结构分类 - - 第二步:填充 - 加载 templates/product-spec-template.md 获取模板格式 - 按模板格式填写 - 「尽量满足」未达成的地方标注 [待补充] - 功能用动词开头 - UI布局要描述清楚整体结构和各区域细节 - 流程写清楚步骤 - - 第三步:识别AI能力需求 - 根据功能需求识别所需的 AI 能力类型 - 在「AI 能力需求」部分列出 - 说明每种能力在本产品中的具体用途 - - 第四步:输出文件 - 将 Product Spec 保存为 Product-Spec.md - -[工作流程(迭代模式)] - **触发条件**:用户在开发过程中提出新功能、修改需求或迭代想法 - - **核心原则**:无缝衔接,不打断用户工作流。不需要开场白,直接接住用户的需求往下问。 - - [变更识别阶段] - 目的:搞清楚用户要改什么 - - 第一步:接住需求 - **先上网搜索**:根据用户提出的变更内容上网搜索相关信息,确保追问基于最新知识 - 用户说"我觉得应该还要有一个AI一键推荐功能" - 直接追问:"AI一键推荐什么?推荐给谁?这个按钮放哪个页面?点了之后发生什么?" - - 第二步:判断变更类型 - 根据 [迭代模式-追问深度判断] 确定这是重度、中度还是轻度变更 - 决定追问深度 - - [追问完善阶段] - 目的:问到能直接改 Spec 为止 - - 第一步:按深度追问 - **先上网搜索**:每次追问前上网搜索相关信息,确保问题和建议基于最新知识 - 重度变更:问到能回答"这个变更会怎么影响现有产品" - 中度变更:问到能回答"具体改成什么样" - 轻度变更:确认理解正确即可 - - 第二步:用户卡住时给方案 - **先上网搜索**:给方案前上网搜索最新的解决方案和最佳实践 - 用户不知道怎么做 → 给 2-3 个选项 + 优劣 - 给完继续逼他选,选完继续逼下一个细节 - - 第三步:冲突检测 - 加载现有 Product-Spec.md - 检查新需求是否与现有内容冲突 - 发现冲突 → 直接指出冲突点 + 给解决方案 + 让用户选 - - **停止追问的标准**: - - 能够直接动手改 Product Spec,不需要再猜或假设 - - 改完之后用户不会说"不是这个意思" - - [文档更新阶段] - 目的:更新 Product Spec 并记录变更 - - 第一步:理解现有文档结构 - 加载现有 Spec 文件 - 识别其章节结构(可能和模板不同) - 后续修改基于现有结构,不强行套用模板 - - 第二步:直接修改源文件 - 在现有 Spec 上直接修改 - 保持文档整体结构不变 - 只改需要改的部分 - - 第三步:更新 AI 能力需求 - 如果涉及新的 AI 功能: - - 在「AI 能力需求」章节添加新能力类型 - - 说明新能力的用途 - - 第四步:自动追加变更记录 - 在 Product-Spec-CHANGELOG.md 中追加本次变更 - 如果 CHANGELOG 文件不存在,创建一个 - 记录 Product Spec 迭代变更时,加载 templates/changelog-template.md 获取完整的变更记录格式和示例 - 根据对话内容自动生成变更描述 - - [迭代模式-追问深度判断] - **变更类型判断逻辑**(按顺序检查): - 1. 涉及新 AI 能力?→ 重度 - 2. 涉及用户核心路径变更?→ 重度 - 3. 涉及布局结构(几栏、区域划分)?→ 重度 - 4. 新增主要功能模块?→ 重度 - 5. 涉及新功能但不改核心流程?→ 中度 - 6. 涉及现有功能的逻辑调整?→ 中度 - 7. 局部布局调整?→ 中度 - 8. 只是改文字、选项、样式?→ 轻度 - - **各类型追问标准**: - - | 变更类型 | 停止追问的条件 | 必须问清楚的内容 | - |---------|---------------|----------------| - | **重度** | 能回答"这个变更会怎么影响现有产品"时停止 | 为什么需要?影响哪些现有功能?用户流程怎么变?需要什么新的 AI 能力? | - | **中度** | 能回答"具体改成什么样"时停止 | 改哪里?改成什么?和现有的怎么配合? | - | **轻度** | 确认理解正确时停止 | 改什么?改成什么? | - -[初始化] - 执行 [启动检查] \ No newline at end of file diff --git a/.claude/skills/product-spec-builder/templates/changelog-template.md b/.claude/skills/product-spec-builder/templates/changelog-template.md deleted file mode 100644 index 89b10f0..0000000 --- a/.claude/skills/product-spec-builder/templates/changelog-template.md +++ /dev/null @@ -1,111 +0,0 @@ ---- -name: changelog-template -description: 变更记录模板。当 Product Spec 发生迭代变更时,按照此模板格式记录变更历史,输出为 Product-Spec-CHANGELOG.md 文件。 ---- - -# 变更记录模板 - -本模板用于记录 Product Spec 的迭代变更历史。 - ---- - -## 文件命名 - -`Product-Spec-CHANGELOG.md` - ---- - -## 模板格式 - -```markdown -# 变更记录 - -## [v1.2] - YYYY-MM-DD -### 新增 -- <新增的功能或内容> - -### 修改 -- <修改的功能或内容> - -### 删除 -- <删除的功能或内容> - ---- - -## [v1.1] - YYYY-MM-DD -### 新增 -- <新增的功能或内容> - ---- - -## [v1.0] - YYYY-MM-DD -- 初始版本 -``` - ---- - -## 记录规则 - -- **版本号递增**:每次迭代 +0.1(如 v1.0 → v1.1 → v1.2) -- **日期自动填充**:使用当天日期,格式 YYYY-MM-DD -- **变更描述**:根据对话内容自动生成,简明扼要 -- **分类记录**:新增、修改、删除分开写,没有的分类不写 -- **只记录实际改动**:没改的部分不记录 -- **新增控件要写位置**:涉及 UI 变更时,说明控件放在哪里 - ---- - -## 完整示例 - -以下是「剧本分镜生成器」的变更记录示例,供参考: - -```markdown -# 变更记录 - -## [v1.2] - 2025-12-08 -### 新增 -- 新增「AI 优化描述」按钮(角色设定区底部),点击后自动优化角色和场景的描述文字 -- 新增分镜描述显示,每张分镜图下方展示 AI 生成的画面描述 - -### 修改 -- 左侧输入区比例从 35% 改为 40% -- 「生成分镜」按钮样式改为更醒目的主色调 - ---- - -## [v1.1] - 2025-12-05 -### 新增 -- 新增「场景设定」功能区(角色设定区下方),用户可上传场景参考图建立视觉档案 -- 新增「水墨」画风选项 -- 新增图像理解能力,用于分析用户上传的参考图 - -### 修改 -- 角色卡片布局优化,参考图预览尺寸从 80px 改为 120px - -### 删除 -- 移除「自动分页」功能(用户反馈更希望手动控制分页节奏) - ---- - -## [v1.0] - 2025-12-01 -- 初始版本 -``` - ---- - -## 写作要点 - -1. **版本号**:从 v1.0 开始,每次迭代 +0.1,重大改版可以 +1.0 -2. **日期格式**:统一用 YYYY-MM-DD,方便排序和查找 -3. **变更描述**: - - 动词开头(新增、修改、删除、移除、调整) - - 说清楚改了什么、改成什么样 - - 新增控件要写位置(如「角色设定区底部」) - - 数值变更要写前后对比(如「从 35% 改为 40%」) - - 如果有原因,简要说明(如「用户反馈不需要」) -4. **分类原则**: - - 新增:之前没有的功能、控件、能力 - - 修改:改变了现有内容的行为、样式、参数 - - 删除:移除了之前有的功能 -5. **颗粒度**:一条记录对应一个独立的变更点,不要把多个改动混在一起 -6. **AI 能力变更**:如果新增或移除了 AI 能力,必须单独记录 diff --git a/.claude/skills/product-spec-builder/templates/product-spec-template.md b/.claude/skills/product-spec-builder/templates/product-spec-template.md deleted file mode 100644 index 2859885..0000000 --- a/.claude/skills/product-spec-builder/templates/product-spec-template.md +++ /dev/null @@ -1,197 +0,0 @@ ---- -name: product-spec-template -description: Product Spec 输出模板。当需要生成产品需求文档时,按照此模板的结构和格式填充内容,输出为 Product-Spec.md 文件。 ---- - -# Product Spec 输出模板 - -本模板用于生成结构完整的 Product Spec 文档。生成时按照此结构填充内容。 - ---- - -## 模板结构 - -**文件命名**:Product-Spec.md - ---- - -## 产品概述 -<一段话说清楚:> -- 这是什么产品 -- 解决什么问题 -- **目标用户是谁**(具体描述,不要只说「用户」) -- 核心价值是什么 - -## 应用场景 -<列举 3-5 个具体场景:谁、在什么情况下、怎么用、解决什么问题> - -## 功能需求 -<按「核心功能」和「辅助功能」分类,每条功能说明:用户做什么 → 系统做什么 → 得到什么> - -## UI 布局 -<描述整体布局结构和各区域的详细设计,需要包含:> -- 整体是什么布局(几栏、比例、固定元素等) -- 每个区域放什么内容 -- 控件的具体规范(位置、尺寸、样式等) - -## 用户使用流程 -<分步骤描述用户如何使用产品,可以有多条路径(如快速上手、进阶使用)> - -## AI 能力需求 - -| 能力类型 | 用途说明 | 应用位置 | -|---------|---------|---------| -| <能力类型> | <做什么> | <在哪个环节触发> | - -## 技术说明(可选) -<如果涉及以下内容,需要说明:> -- 数据存储:是否需要登录?数据存在哪里? -- 外部依赖:需要调用什么服务?有什么限制? -- 部署方式:纯前端?需要服务器? - -## 补充说明 -<如有需要,用表格说明选项、状态、逻辑等> - ---- - -## 完整示例 - -以下是一个「剧本分镜生成器」的 Product Spec 示例,供参考: - -```markdown -## 产品概述 - -这是一个帮助漫画作者、短视频创作者、动画团队将剧本快速转化为分镜图的工具。 - -**目标用户**:有剧本但缺乏绘画能力、或者想快速出分镜草稿的创作者。他们可能是独立漫画作者、短视频博主、动画工作室的前期策划人员,共同的痛点是「脑子里有画面,但画不出来或画太慢」。 - -**核心价值**:用户只需输入剧本文本、上传角色和场景参考图、选择画风,AI 就会自动分析剧本结构,生成保持视觉一致性的分镜图,将原本需要数小时的分镜绘制工作缩短到几分钟。 - -## 应用场景 - -- **漫画创作**:独立漫画作者小王有一个 20 页的剧本,需要先出分镜草稿再精修。他把剧本贴进来,上传主角的参考图,10 分钟就拿到了全部分镜草稿,可以直接在这个基础上精修。 - -- **短视频策划**:短视频博主小李要拍一个 3 分钟的剧情短片,需要给摄影师看分镜。她把脚本输入,选择「写实」风格,生成的分镜图直接可以当拍摄参考。 - -- **动画前期**:动画工作室要向客户提案,需要快速出一版分镜来展示剧本节奏。策划人员用这个工具 30 分钟出了 50 张分镜图,当天就能开提案会。 - -- **小说可视化**:网文作者想给自己的小说做宣传图,把关键场景描述输入,生成的分镜图可以直接用于社交媒体宣传。 - -- **教学演示**:小学语文老师想把一篇课文变成连环画给学生看,把课文内容输入,选择「动漫」风格,生成的图片可以直接做成 PPT。 - -## 功能需求 - -**核心功能** -- 剧本输入与分析:用户输入剧本文本 → 点击「生成分镜」→ AI 自动识别角色、场景和情节节拍,将剧本拆分为多页分镜 -- 角色设定:用户添加角色卡片(名称 + 外观描述 + 参考图)→ 系统建立角色视觉档案,后续生成时保持外观一致 -- 场景设定:用户添加场景卡片(名称 + 氛围描述 + 参考图)→ 系统建立场景视觉档案(可选,不设定则由 AI 根据剧本生成) -- 画风选择:用户从下拉框选择画风(漫画/动漫/写实/赛博朋克/水墨)→ 生成的分镜图采用对应视觉风格 -- 分镜生成:用户点击「生成分镜」→ AI 生成当前页 9 张分镜图(3x3 九宫格)→ 展示在右侧输出区 -- 连续生成:用户点击「继续生成下一页」→ AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图 - -**辅助功能** -- 批量下载:用户点击「下载全部」→ 系统将当前页 9 张图打包为 ZIP 下载 -- 历史浏览:用户通过页面导航 → 切换查看已生成的历史页面 - -## UI 布局 - -### 整体布局 -左右两栏布局,左侧输入区占 40%,右侧输出区占 60%。 - -### 左侧 - 输入区 -- 顶部:项目名称输入框 -- 剧本输入:多行文本框,placeholder「请输入剧本内容...」 -- 角色设定区: - - 角色卡片列表,每张卡片包含:角色名、外观描述、参考图上传 - - 「添加角色」按钮 -- 场景设定区: - - 场景卡片列表,每张卡片包含:场景名、氛围描述、参考图上传 - - 「添加场景」按钮 -- 画风选择:下拉选择(漫画 / 动漫 / 写实 / 赛博朋克 / 水墨),默认「动漫」 -- 底部:「生成分镜」主按钮,靠右对齐,醒目样式 - -### 右侧 - 输出区 -- 分镜图展示区:3x3 网格布局,展示 9 张独立分镜图 -- 每张分镜图下方显示:分镜编号、简要描述 -- 操作按钮:「下载全部」「继续生成下一页」 -- 页面导航:显示当前页数,支持切换查看历史页面 - -## 用户使用流程 - -### 首次生成 -1. 输入剧本内容 -2. 添加角色:填写名称、外观描述,上传参考图 -3. 添加场景:填写名称、氛围描述,上传参考图(可选) -4. 选择画风 -5. 点击「生成分镜」 -6. 在右侧查看生成的 9 张分镜图 -7. 点击「下载全部」保存 - -### 连续生成 -1. 完成首次生成后 -2. 点击「继续生成下一页」 -3. AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图 -4. 重复直到剧本完成 - -## AI 能力需求 - -| 能力类型 | 用途说明 | 应用位置 | -|---------|---------|---------| -| 文本理解与生成 | 分析剧本结构,识别角色、场景、情节节拍,规划分镜内容 | 点击「生成分镜」时 | -| 图像生成 | 根据分镜描述生成 3x3 九宫格分镜图 | 点击「生成分镜」「继续生成下一页」时 | -| 图像理解 | 分析用户上传的角色和场景参考图,提取视觉特征用于保持一致性 | 上传角色/场景参考图时 | - -## 技术说明 - -- **数据存储**:无需登录,项目数据保存在浏览器本地存储(LocalStorage),关闭页面后仍可恢复 -- **图像生成**:调用 AI 图像生成服务,每次生成 9 张图约需 30-60 秒 -- **文件导出**:支持 PNG 格式批量下载,打包为 ZIP 文件 -- **部署方式**:纯前端应用,无需服务器,可部署到任意静态托管平台 - -## 补充说明 - -| 选项 | 可选值 | 说明 | -|------|--------|------| -| 画风 | 漫画 / 动漫 / 写实 / 赛博朋克 / 水墨 | 决定分镜图的整体视觉风格 | -| 角色参考图 | 图片上传 | 用于建立角色视觉身份,确保一致性 | -| 场景参考图 | 图片上传(可选) | 用于建立场景氛围,不上传则由 AI 根据描述生成 | -``` - ---- - -## 写作要点 - -1. **产品概述**: - - 一句话说清楚是什么 - - **必须明确写出目标用户**:是谁、有什么特点、什么痛点 - - 核心价值:用了这个产品能得到什么 - -2. **应用场景**: - - 具体的人 + 具体的情况 + 具体的用法 + 解决什么问题 - - 场景要有画面感,让人一看就懂 - - 放在功能需求之前,帮助理解产品价值 - -3. **功能需求**: - - 分「核心功能」和「辅助功能」 - - 每条格式:用户做什么 → 系统做什么 → 得到什么 - - 写清楚触发方式(点击什么按钮) - -4. **UI 布局**: - - 先写整体布局(几栏、比例) - - 再逐个区域描述内容 - - 控件要具体:下拉框写出所有选项和默认值,按钮写明位置和样式 - -5. **用户流程**:分步骤,可以有多条路径 - -6. **AI 能力需求**: - - 列出需要的 AI 能力类型 - - 说明具体用途 - - **写清楚在哪个环节触发**,方便开发理解调用时机 - -7. **技术说明**(可选): - - 数据存储方式 - - 外部服务依赖 - - 部署方式 - - 只在有技术约束时写,没有就不写 - -8. **补充说明**:用表格,适合解释选项、状态、逻辑 diff --git a/.claude/skills/project-guidelines-example/SKILL.md b/.claude/skills/project-guidelines-example/SKILL.md deleted file mode 100644 index 0135855..0000000 --- a/.claude/skills/project-guidelines-example/SKILL.md +++ /dev/null @@ -1,345 +0,0 @@ -# Project Guidelines Skill (Example) - -This is an example of a project-specific skill. Use this as a template for your own projects. - -Based on a real production application: [Zenith](https://zenith.chat) - AI-powered customer discovery platform. - ---- - -## When to Use - -Reference this skill when working on the specific project it's designed for. Project skills contain: -- Architecture overview -- File structure -- Code patterns -- Testing requirements -- Deployment workflow - ---- - -## Architecture Overview - -**Tech Stack:** -- **Frontend**: Next.js 15 (App Router), TypeScript, React -- **Backend**: FastAPI (Python), Pydantic models -- **Database**: Supabase (PostgreSQL) -- **AI**: Claude API with tool calling and structured output -- **Deployment**: Google Cloud Run -- **Testing**: Playwright (E2E), pytest (backend), React Testing Library - -**Services:** -``` -┌─────────────────────────────────────────────────────────────┐ -│ Frontend │ -│ Next.js 15 + TypeScript + TailwindCSS │ -│ Deployed: Vercel / Cloud Run │ -└─────────────────────────────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ Backend │ -│ FastAPI + Python 3.11 + Pydantic │ -│ Deployed: Cloud Run │ -└─────────────────────────────────────────────────────────────┘ - │ - ┌───────────────┼───────────────┐ - ▼ ▼ ▼ - ┌──────────┐ ┌──────────┐ ┌──────────┐ - │ Supabase │ │ Claude │ │ Redis │ - │ Database │ │ API │ │ Cache │ - └──────────┘ └──────────┘ └──────────┘ -``` - ---- - -## File Structure - -``` -project/ -├── frontend/ -│ └── src/ -│ ├── app/ # Next.js app router pages -│ │ ├── api/ # API routes -│ │ ├── (auth)/ # Auth-protected routes -│ │ └── workspace/ # Main app workspace -│ ├── components/ # React components -│ │ ├── ui/ # Base UI components -│ │ ├── forms/ # Form components -│ │ └── layouts/ # Layout components -│ ├── hooks/ # Custom React hooks -│ ├── lib/ # Utilities -│ ├── types/ # TypeScript definitions -│ └── config/ # Configuration -│ -├── backend/ -│ ├── routers/ # FastAPI route handlers -│ ├── models.py # Pydantic models -│ ├── main.py # FastAPI app entry -│ ├── auth_system.py # Authentication -│ ├── database.py # Database operations -│ ├── services/ # Business logic -│ └── tests/ # pytest tests -│ -├── deploy/ # Deployment configs -├── docs/ # Documentation -└── scripts/ # Utility scripts -``` - ---- - -## Code Patterns - -### API Response Format (FastAPI) - -```python -from pydantic import BaseModel -from typing import Generic, TypeVar, Optional - -T = TypeVar('T') - -class ApiResponse(BaseModel, Generic[T]): - success: bool - data: Optional[T] = None - error: Optional[str] = None - - @classmethod - def ok(cls, data: T) -> "ApiResponse[T]": - return cls(success=True, data=data) - - @classmethod - def fail(cls, error: str) -> "ApiResponse[T]": - return cls(success=False, error=error) -``` - -### Frontend API Calls (TypeScript) - -```typescript -interface ApiResponse { - success: boolean - data?: T - error?: string -} - -async function fetchApi( - endpoint: string, - options?: RequestInit -): Promise> { - try { - const response = await fetch(`/api${endpoint}`, { - ...options, - headers: { - 'Content-Type': 'application/json', - ...options?.headers, - }, - }) - - if (!response.ok) { - return { success: false, error: `HTTP ${response.status}` } - } - - return await response.json() - } catch (error) { - return { success: false, error: String(error) } - } -} -``` - -### Claude AI Integration (Structured Output) - -```python -from anthropic import Anthropic -from pydantic import BaseModel - -class AnalysisResult(BaseModel): - summary: str - key_points: list[str] - confidence: float - -async def analyze_with_claude(content: str) -> AnalysisResult: - client = Anthropic() - - response = client.messages.create( - model="claude-sonnet-4-5-20250514", - max_tokens=1024, - messages=[{"role": "user", "content": content}], - tools=[{ - "name": "provide_analysis", - "description": "Provide structured analysis", - "input_schema": AnalysisResult.model_json_schema() - }], - tool_choice={"type": "tool", "name": "provide_analysis"} - ) - - # Extract tool use result - tool_use = next( - block for block in response.content - if block.type == "tool_use" - ) - - return AnalysisResult(**tool_use.input) -``` - -### Custom Hooks (React) - -```typescript -import { useState, useCallback } from 'react' - -interface UseApiState { - data: T | null - loading: boolean - error: string | null -} - -export function useApi( - fetchFn: () => Promise> -) { - const [state, setState] = useState>({ - data: null, - loading: false, - error: null, - }) - - const execute = useCallback(async () => { - setState(prev => ({ ...prev, loading: true, error: null })) - - const result = await fetchFn() - - if (result.success) { - setState({ data: result.data!, loading: false, error: null }) - } else { - setState({ data: null, loading: false, error: result.error! }) - } - }, [fetchFn]) - - return { ...state, execute } -} -``` - ---- - -## Testing Requirements - -### Backend (pytest) - -```bash -# Run all tests -poetry run pytest tests/ - -# Run with coverage -poetry run pytest tests/ --cov=. --cov-report=html - -# Run specific test file -poetry run pytest tests/test_auth.py -v -``` - -**Test structure:** -```python -import pytest -from httpx import AsyncClient -from main import app - -@pytest.fixture -async def client(): - async with AsyncClient(app=app, base_url="http://test") as ac: - yield ac - -@pytest.mark.asyncio -async def test_health_check(client: AsyncClient): - response = await client.get("/health") - assert response.status_code == 200 - assert response.json()["status"] == "healthy" -``` - -### Frontend (React Testing Library) - -```bash -# Run tests -npm run test - -# Run with coverage -npm run test -- --coverage - -# Run E2E tests -npm run test:e2e -``` - -**Test structure:** -```typescript -import { render, screen, fireEvent } from '@testing-library/react' -import { WorkspacePanel } from './WorkspacePanel' - -describe('WorkspacePanel', () => { - it('renders workspace correctly', () => { - render() - expect(screen.getByRole('main')).toBeInTheDocument() - }) - - it('handles session creation', async () => { - render() - fireEvent.click(screen.getByText('New Session')) - expect(await screen.findByText('Session created')).toBeInTheDocument() - }) -}) -``` - ---- - -## Deployment Workflow - -### Pre-Deployment Checklist - -- [ ] All tests passing locally -- [ ] `npm run build` succeeds (frontend) -- [ ] `poetry run pytest` passes (backend) -- [ ] No hardcoded secrets -- [ ] Environment variables documented -- [ ] Database migrations ready - -### Deployment Commands - -```bash -# Build and deploy frontend -cd frontend && npm run build -gcloud run deploy frontend --source . - -# Build and deploy backend -cd backend -gcloud run deploy backend --source . -``` - -### Environment Variables - -```bash -# Frontend (.env.local) -NEXT_PUBLIC_API_URL=https://api.example.com -NEXT_PUBLIC_SUPABASE_URL=https://xxx.supabase.co -NEXT_PUBLIC_SUPABASE_ANON_KEY=eyJ... - -# Backend (.env) -DATABASE_URL=postgresql://... -ANTHROPIC_API_KEY=sk-ant-... -SUPABASE_URL=https://xxx.supabase.co -SUPABASE_KEY=eyJ... -``` - ---- - -## Critical Rules - -1. **No emojis** in code, comments, or documentation -2. **Immutability** - never mutate objects or arrays -3. **TDD** - write tests before implementation -4. **80% coverage** minimum -5. **Many small files** - 200-400 lines typical, 800 max -6. **No console.log** in production code -7. **Proper error handling** with try/catch -8. **Input validation** with Pydantic/Zod - ---- - -## Related Skills - -- `coding-standards.md` - General coding best practices -- `backend-patterns.md` - API and database patterns -- `frontend-patterns.md` - React and Next.js patterns -- `tdd-workflow/` - Test-driven development methodology diff --git a/.claude/skills/security-review/SKILL.md b/.claude/skills/security-review/SKILL.md deleted file mode 100644 index 81397dd..0000000 --- a/.claude/skills/security-review/SKILL.md +++ /dev/null @@ -1,568 +0,0 @@ ---- -name: security-review -description: Use this skill when adding authentication, handling user input, working with secrets, creating API endpoints, or implementing payment/sensitive features. Provides comprehensive security checklist and patterns. ---- - -# Security Review Skill - -Security best practices for Python/FastAPI applications handling sensitive invoice data. - -## When to Activate - -- Implementing authentication or authorization -- Handling user input or file uploads -- Creating new API endpoints -- Working with secrets or credentials -- Processing sensitive invoice data -- Integrating third-party APIs -- Database operations with user data - -## Security Checklist - -### 1. Secrets Management - -#### NEVER Do This -```python -# Hardcoded secrets - CRITICAL VULNERABILITY -api_key = "sk-proj-xxxxx" -db_password = "password123" -``` - -#### ALWAYS Do This -```python -import os -from pydantic_settings import BaseSettings - -class Settings(BaseSettings): - db_password: str - api_key: str - model_path: str = "runs/train/invoice_fields/weights/best.pt" - - class Config: - env_file = ".env" - -settings = Settings() - -# Verify secrets exist -if not settings.db_password: - raise RuntimeError("DB_PASSWORD not configured") -``` - -#### Verification Steps -- [ ] No hardcoded API keys, tokens, or passwords -- [ ] All secrets in environment variables -- [ ] `.env` in .gitignore -- [ ] No secrets in git history -- [ ] `.env.example` with placeholder values - -### 2. Input Validation - -#### Always Validate User Input -```python -from pydantic import BaseModel, Field, field_validator -from fastapi import HTTPException -import re - -class InvoiceRequest(BaseModel): - invoice_number: str = Field(..., min_length=1, max_length=50) - amount: float = Field(..., gt=0, le=1_000_000) - bankgiro: str | None = None - - @field_validator("invoice_number") - @classmethod - def validate_invoice_number(cls, v: str) -> str: - # Whitelist validation - only allow safe characters - if not re.match(r"^[A-Za-z0-9\-_]+$", v): - raise ValueError("Invalid invoice number format") - return v - - @field_validator("bankgiro") - @classmethod - def validate_bankgiro(cls, v: str | None) -> str | None: - if v is None: - return None - cleaned = re.sub(r"[^0-9]", "", v) - if not (7 <= len(cleaned) <= 8): - raise ValueError("Bankgiro must be 7-8 digits") - return cleaned -``` - -#### File Upload Validation -```python -from fastapi import UploadFile, HTTPException -from pathlib import Path - -ALLOWED_EXTENSIONS = {".pdf"} -MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB - -async def validate_pdf_upload(file: UploadFile) -> bytes: - """Validate PDF upload with security checks.""" - # Extension check - ext = Path(file.filename or "").suffix.lower() - if ext not in ALLOWED_EXTENSIONS: - raise HTTPException(400, f"Only PDF files allowed, got {ext}") - - # Read content - content = await file.read() - - # Size check - if len(content) > MAX_FILE_SIZE: - raise HTTPException(400, f"File too large (max {MAX_FILE_SIZE // 1024 // 1024}MB)") - - # Magic bytes check (PDF signature) - if not content.startswith(b"%PDF"): - raise HTTPException(400, "Invalid PDF file format") - - return content -``` - -#### Verification Steps -- [ ] All user inputs validated with Pydantic -- [ ] File uploads restricted (size, type, extension, magic bytes) -- [ ] No direct use of user input in queries -- [ ] Whitelist validation (not blacklist) -- [ ] Error messages don't leak sensitive info - -### 3. SQL Injection Prevention - -#### NEVER Concatenate SQL -```python -# DANGEROUS - SQL Injection vulnerability -query = f"SELECT * FROM documents WHERE id = '{user_input}'" -cur.execute(query) -``` - -#### ALWAYS Use Parameterized Queries -```python -import psycopg2 - -# Safe - parameterized query with %s placeholders -cur.execute( - "SELECT * FROM documents WHERE id = %s AND status = %s", - (document_id, status) -) - -# Safe - named parameters -cur.execute( - "SELECT * FROM documents WHERE id = %(id)s", - {"id": document_id} -) - -# Safe - psycopg2.sql for dynamic identifiers -from psycopg2 import sql - -cur.execute( - sql.SQL("SELECT {} FROM {} WHERE id = %s").format( - sql.Identifier("invoice_number"), - sql.Identifier("documents") - ), - (document_id,) -) -``` - -#### Verification Steps -- [ ] All database queries use parameterized queries (%s or %(name)s) -- [ ] No string concatenation or f-strings in SQL -- [ ] psycopg2.sql module used for dynamic identifiers -- [ ] No user input in table/column names - -### 4. Path Traversal Prevention - -#### NEVER Trust User Paths -```python -# DANGEROUS - Path traversal vulnerability -filename = request.query_params.get("file") -with open(f"/data/{filename}", "r") as f: # Attacker: ../../../etc/passwd - return f.read() -``` - -#### ALWAYS Validate Paths -```python -from pathlib import Path - -ALLOWED_DIR = Path("/data/uploads").resolve() - -def get_safe_path(filename: str) -> Path: - """Get safe file path, preventing path traversal.""" - # Remove any path components - safe_name = Path(filename).name - - # Validate filename characters - if not re.match(r"^[A-Za-z0-9_\-\.]+$", safe_name): - raise HTTPException(400, "Invalid filename") - - # Resolve and verify within allowed directory - full_path = (ALLOWED_DIR / safe_name).resolve() - - if not full_path.is_relative_to(ALLOWED_DIR): - raise HTTPException(400, "Invalid file path") - - return full_path -``` - -#### Verification Steps -- [ ] User-provided filenames sanitized -- [ ] Paths resolved and validated against allowed directory -- [ ] No direct concatenation of user input into paths -- [ ] Whitelist characters in filenames - -### 5. Authentication & Authorization - -#### API Key Validation -```python -from fastapi import Depends, HTTPException, Security -from fastapi.security import APIKeyHeader - -api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - -async def verify_api_key(api_key: str = Security(api_key_header)) -> str: - if not api_key: - raise HTTPException(401, "API key required") - - # Constant-time comparison to prevent timing attacks - import hmac - if not hmac.compare_digest(api_key, settings.api_key): - raise HTTPException(403, "Invalid API key") - - return api_key - -@router.post("/infer") -async def infer( - file: UploadFile, - api_key: str = Depends(verify_api_key) -): - ... -``` - -#### Role-Based Access Control -```python -from enum import Enum - -class UserRole(str, Enum): - USER = "user" - ADMIN = "admin" - -def require_role(required_role: UserRole): - async def role_checker(current_user: User = Depends(get_current_user)): - if current_user.role != required_role: - raise HTTPException(403, "Insufficient permissions") - return current_user - return role_checker - -@router.delete("/documents/{doc_id}") -async def delete_document( - doc_id: str, - user: User = Depends(require_role(UserRole.ADMIN)) -): - ... -``` - -#### Verification Steps -- [ ] API keys validated with constant-time comparison -- [ ] Authorization checks before sensitive operations -- [ ] Role-based access control implemented -- [ ] Session/token validation on protected routes - -### 6. Rate Limiting - -#### Rate Limiter Implementation -```python -from time import time -from collections import defaultdict -from fastapi import Request, HTTPException - -class RateLimiter: - def __init__(self): - self.requests: dict[str, list[float]] = defaultdict(list) - - def check_limit( - self, - identifier: str, - max_requests: int, - window_seconds: int - ) -> bool: - now = time() - # Clean old requests - self.requests[identifier] = [ - t for t in self.requests[identifier] - if now - t < window_seconds - ] - # Check limit - if len(self.requests[identifier]) >= max_requests: - return False - self.requests[identifier].append(now) - return True - -limiter = RateLimiter() - -@app.middleware("http") -async def rate_limit_middleware(request: Request, call_next): - client_ip = request.client.host if request.client else "unknown" - - # 100 requests per minute for general endpoints - if not limiter.check_limit(client_ip, max_requests=100, window_seconds=60): - raise HTTPException(429, "Rate limit exceeded. Try again later.") - - return await call_next(request) -``` - -#### Stricter Limits for Expensive Operations -```python -# Inference endpoint: 10 requests per minute -async def check_inference_rate_limit(request: Request): - client_ip = request.client.host if request.client else "unknown" - if not limiter.check_limit(f"infer:{client_ip}", max_requests=10, window_seconds=60): - raise HTTPException(429, "Inference rate limit exceeded") - -@router.post("/infer") -async def infer( - file: UploadFile, - _: None = Depends(check_inference_rate_limit) -): - ... -``` - -#### Verification Steps -- [ ] Rate limiting on all API endpoints -- [ ] Stricter limits on expensive operations (inference, OCR) -- [ ] IP-based rate limiting -- [ ] Clear error messages for rate-limited requests - -### 7. Sensitive Data Exposure - -#### Logging -```python -import logging - -logger = logging.getLogger(__name__) - -# WRONG: Logging sensitive data -logger.info(f"Processing invoice: {invoice_data}") # May contain sensitive info -logger.error(f"DB error with password: {db_password}") - -# CORRECT: Redact sensitive data -logger.info(f"Processing invoice: id={doc_id}") -logger.error(f"DB connection failed to {db_host}:{db_port}") - -# CORRECT: Structured logging with safe fields only -logger.info( - "Invoice processed", - extra={ - "document_id": doc_id, - "field_count": len(fields), - "processing_time_ms": elapsed_ms - } -) -``` - -#### Error Messages -```python -# WRONG: Exposing internal details -@app.exception_handler(Exception) -async def error_handler(request: Request, exc: Exception): - return JSONResponse( - status_code=500, - content={ - "error": str(exc), - "traceback": traceback.format_exc() # NEVER expose! - } - ) - -# CORRECT: Generic error messages -@app.exception_handler(Exception) -async def error_handler(request: Request, exc: Exception): - logger.error(f"Unhandled error: {exc}", exc_info=True) # Log internally - return JSONResponse( - status_code=500, - content={"success": False, "error": "An error occurred"} - ) -``` - -#### Verification Steps -- [ ] No passwords, tokens, or secrets in logs -- [ ] Error messages generic for users -- [ ] Detailed errors only in server logs -- [ ] No stack traces exposed to users -- [ ] Invoice data (amounts, account numbers) not logged - -### 8. CORS Configuration - -```python -from fastapi.middleware.cors import CORSMiddleware - -# WRONG: Allow all origins -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # DANGEROUS in production - allow_credentials=True, -) - -# CORRECT: Specific origins -ALLOWED_ORIGINS = [ - "http://localhost:8000", - "https://your-domain.com", -] - -app.add_middleware( - CORSMiddleware, - allow_origins=ALLOWED_ORIGINS, - allow_credentials=True, - allow_methods=["GET", "POST"], - allow_headers=["*"], -) -``` - -#### Verification Steps -- [ ] CORS origins explicitly listed -- [ ] No wildcard origins in production -- [ ] Credentials only with specific origins - -### 9. Temporary File Security - -```python -import tempfile -from pathlib import Path -from contextlib import contextmanager - -@contextmanager -def secure_temp_file(suffix: str = ".pdf"): - """Create secure temporary file that is always cleaned up.""" - tmp_path = None - try: - with tempfile.NamedTemporaryFile( - suffix=suffix, - delete=False, - dir="/tmp/invoice-master" # Dedicated temp directory - ) as tmp: - tmp_path = Path(tmp.name) - yield tmp_path - finally: - if tmp_path and tmp_path.exists(): - tmp_path.unlink() - -# Usage -async def process_upload(file: UploadFile): - with secure_temp_file(".pdf") as tmp_path: - content = await validate_pdf_upload(file) - tmp_path.write_bytes(content) - result = pipeline.process(tmp_path) - # File automatically cleaned up - return result -``` - -#### Verification Steps -- [ ] Temporary files always cleaned up (use context managers) -- [ ] Temp directory has restricted permissions -- [ ] No leftover files after processing errors - -### 10. Dependency Security - -#### Regular Updates -```bash -# Check for vulnerabilities -pip-audit - -# Update dependencies -pip install --upgrade -r requirements.txt - -# Check for outdated packages -pip list --outdated -``` - -#### Lock Files -```bash -# Create requirements lock file -pip freeze > requirements.lock - -# Install from lock file for reproducible builds -pip install -r requirements.lock -``` - -#### Verification Steps -- [ ] Dependencies up to date -- [ ] No known vulnerabilities (pip-audit clean) -- [ ] requirements.txt pinned versions -- [ ] Regular security updates scheduled - -## Security Testing - -### Automated Security Tests -```python -import pytest -from fastapi.testclient import TestClient - -def test_requires_api_key(client: TestClient): - """Test authentication required.""" - response = client.post("/api/v1/infer") - assert response.status_code == 401 - -def test_invalid_api_key_rejected(client: TestClient): - """Test invalid API key rejected.""" - response = client.post( - "/api/v1/infer", - headers={"X-API-Key": "invalid-key"} - ) - assert response.status_code == 403 - -def test_sql_injection_prevented(client: TestClient): - """Test SQL injection attempt rejected.""" - response = client.get( - "/api/v1/documents", - params={"id": "'; DROP TABLE documents; --"} - ) - # Should return validation error, not execute SQL - assert response.status_code in (400, 422) - -def test_path_traversal_prevented(client: TestClient): - """Test path traversal attempt rejected.""" - response = client.get("/api/v1/results/../../etc/passwd") - assert response.status_code == 400 - -def test_rate_limit_enforced(client: TestClient): - """Test rate limiting works.""" - responses = [ - client.post("/api/v1/infer", files={"file": b"test"}) - for _ in range(15) - ] - rate_limited = [r for r in responses if r.status_code == 429] - assert len(rate_limited) > 0 - -def test_large_file_rejected(client: TestClient): - """Test file size limit enforced.""" - large_content = b"x" * (11 * 1024 * 1024) # 11MB - response = client.post( - "/api/v1/infer", - files={"file": ("test.pdf", large_content)} - ) - assert response.status_code == 400 -``` - -## Pre-Deployment Security Checklist - -Before ANY production deployment: - -- [ ] **Secrets**: No hardcoded secrets, all in env vars -- [ ] **Input Validation**: All user inputs validated with Pydantic -- [ ] **SQL Injection**: All queries use parameterized queries -- [ ] **Path Traversal**: File paths validated and sanitized -- [ ] **Authentication**: API key or token validation -- [ ] **Authorization**: Role checks in place -- [ ] **Rate Limiting**: Enabled on all endpoints -- [ ] **HTTPS**: Enforced in production -- [ ] **CORS**: Properly configured (no wildcards) -- [ ] **Error Handling**: No sensitive data in errors -- [ ] **Logging**: No sensitive data logged -- [ ] **File Uploads**: Validated (size, type, magic bytes) -- [ ] **Temp Files**: Always cleaned up -- [ ] **Dependencies**: Up to date, no vulnerabilities - -## Resources - -- [OWASP Top 10](https://owasp.org/www-project-top-ten/) -- [FastAPI Security](https://fastapi.tiangolo.com/tutorial/security/) -- [Bandit (Python Security Linter)](https://bandit.readthedocs.io/) -- [pip-audit](https://pypi.org/project/pip-audit/) - ---- - -**Remember**: Security is not optional. One vulnerability can compromise sensitive invoice data. When in doubt, err on the side of caution. diff --git a/.claude/skills/strategic-compact/SKILL.md b/.claude/skills/strategic-compact/SKILL.md deleted file mode 100644 index 394a86b..0000000 --- a/.claude/skills/strategic-compact/SKILL.md +++ /dev/null @@ -1,63 +0,0 @@ ---- -name: strategic-compact -description: Suggests manual context compaction at logical intervals to preserve context through task phases rather than arbitrary auto-compaction. ---- - -# Strategic Compact Skill - -Suggests manual `/compact` at strategic points in your workflow rather than relying on arbitrary auto-compaction. - -## Why Strategic Compaction? - -Auto-compaction triggers at arbitrary points: -- Often mid-task, losing important context -- No awareness of logical task boundaries -- Can interrupt complex multi-step operations - -Strategic compaction at logical boundaries: -- **After exploration, before execution** - Compact research context, keep implementation plan -- **After completing a milestone** - Fresh start for next phase -- **Before major context shifts** - Clear exploration context before different task - -## How It Works - -The `suggest-compact.sh` script runs on PreToolUse (Edit/Write) and: - -1. **Tracks tool calls** - Counts tool invocations in session -2. **Threshold detection** - Suggests at configurable threshold (default: 50 calls) -3. **Periodic reminders** - Reminds every 25 calls after threshold - -## Hook Setup - -Add to your `~/.claude/settings.json`: - -```json -{ - "hooks": { - "PreToolUse": [{ - "matcher": "tool == \"Edit\" || tool == \"Write\"", - "hooks": [{ - "type": "command", - "command": "~/.claude/skills/strategic-compact/suggest-compact.sh" - }] - }] - } -} -``` - -## Configuration - -Environment variables: -- `COMPACT_THRESHOLD` - Tool calls before first suggestion (default: 50) - -## Best Practices - -1. **Compact after planning** - Once plan is finalized, compact to start fresh -2. **Compact after debugging** - Clear error-resolution context before continuing -3. **Don't compact mid-implementation** - Preserve context for related changes -4. **Read the suggestion** - The hook tells you *when*, you decide *if* - -## Related - -- [The Longform Guide](https://x.com/affaanmustafa/status/2014040193557471352) - Token optimization section -- Memory persistence hooks - For state that survives compaction diff --git a/.claude/skills/strategic-compact/suggest-compact.sh b/.claude/skills/strategic-compact/suggest-compact.sh deleted file mode 100644 index ea14920..0000000 --- a/.claude/skills/strategic-compact/suggest-compact.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash -# Strategic Compact Suggester -# Runs on PreToolUse or periodically to suggest manual compaction at logical intervals -# -# Why manual over auto-compact: -# - Auto-compact happens at arbitrary points, often mid-task -# - Strategic compacting preserves context through logical phases -# - Compact after exploration, before execution -# - Compact after completing a milestone, before starting next -# -# Hook config (in ~/.claude/settings.json): -# { -# "hooks": { -# "PreToolUse": [{ -# "matcher": "Edit|Write", -# "hooks": [{ -# "type": "command", -# "command": "~/.claude/skills/strategic-compact/suggest-compact.sh" -# }] -# }] -# } -# } -# -# Criteria for suggesting compact: -# - Session has been running for extended period -# - Large number of tool calls made -# - Transitioning from research/exploration to implementation -# - Plan has been finalized - -# Track tool call count (increment in a temp file) -COUNTER_FILE="/tmp/claude-tool-count-$$" -THRESHOLD=${COMPACT_THRESHOLD:-50} - -# Initialize or increment counter -if [ -f "$COUNTER_FILE" ]; then - count=$(cat "$COUNTER_FILE") - count=$((count + 1)) - echo "$count" > "$COUNTER_FILE" -else - echo "1" > "$COUNTER_FILE" - count=1 -fi - -# Suggest compact after threshold tool calls -if [ "$count" -eq "$THRESHOLD" ]; then - echo "[StrategicCompact] $THRESHOLD tool calls reached - consider /compact if transitioning phases" >&2 -fi - -# Suggest at regular intervals after threshold -if [ "$count" -gt "$THRESHOLD" ] && [ $((count % 25)) -eq 0 ]; then - echo "[StrategicCompact] $count tool calls - good checkpoint for /compact if context is stale" >&2 -fi diff --git a/.claude/skills/tdd-workflow/SKILL.md b/.claude/skills/tdd-workflow/SKILL.md deleted file mode 100644 index c3ef042..0000000 --- a/.claude/skills/tdd-workflow/SKILL.md +++ /dev/null @@ -1,553 +0,0 @@ ---- -name: tdd-workflow -description: Use this skill when writing new features, fixing bugs, or refactoring code. Enforces test-driven development with 80%+ coverage including unit, integration, and E2E tests. ---- - -# Test-Driven Development Workflow - -TDD principles for Python/FastAPI development with pytest. - -## When to Activate - -- Writing new features or functionality -- Fixing bugs or issues -- Refactoring existing code -- Adding API endpoints -- Creating new field extractors or normalizers - -## Core Principles - -### 1. Tests BEFORE Code -ALWAYS write tests first, then implement code to make tests pass. - -### 2. Coverage Requirements -- Minimum 80% coverage (unit + integration + E2E) -- All edge cases covered -- Error scenarios tested -- Boundary conditions verified - -### 3. Test Types - -#### Unit Tests -- Individual functions and utilities -- Normalizers and validators -- Parsers and extractors -- Pure functions - -#### Integration Tests -- API endpoints -- Database operations -- OCR + YOLO pipeline -- Service interactions - -#### E2E Tests -- Complete inference pipeline -- PDF → Fields workflow -- API health and inference endpoints - -## TDD Workflow Steps - -### Step 1: Write User Journeys -``` -As a [role], I want to [action], so that [benefit] - -Example: -As an invoice processor, I want to extract Bankgiro from payment_line, -so that I can cross-validate OCR results. -``` - -### Step 2: Generate Test Cases -For each user journey, create comprehensive test cases: - -```python -import pytest - -class TestPaymentLineParser: - """Tests for payment_line parsing and field extraction.""" - - def test_parse_payment_line_extracts_bankgiro(self): - """Should extract Bankgiro from valid payment line.""" - # Test implementation - pass - - def test_parse_payment_line_handles_missing_checksum(self): - """Should handle payment lines without checksum.""" - pass - - def test_parse_payment_line_validates_checksum(self): - """Should validate checksum when present.""" - pass - - def test_parse_payment_line_returns_none_for_invalid(self): - """Should return None for invalid payment lines.""" - pass -``` - -### Step 3: Run Tests (They Should Fail) -```bash -pytest tests/test_ocr/test_machine_code_parser.py -v -# Tests should fail - we haven't implemented yet -``` - -### Step 4: Implement Code -Write minimal code to make tests pass: - -```python -def parse_payment_line(line: str) -> PaymentLineData | None: - """Parse Swedish payment line and extract fields.""" - # Implementation guided by tests - pass -``` - -### Step 5: Run Tests Again -```bash -pytest tests/test_ocr/test_machine_code_parser.py -v -# Tests should now pass -``` - -### Step 6: Refactor -Improve code quality while keeping tests green: -- Remove duplication -- Improve naming -- Optimize performance -- Enhance readability - -### Step 7: Verify Coverage -```bash -pytest --cov=src --cov-report=term-missing -# Verify 80%+ coverage achieved -``` - -## Testing Patterns - -### Unit Test Pattern (pytest) -```python -import pytest -from src.normalize.bankgiro_normalizer import normalize_bankgiro - -class TestBankgiroNormalizer: - """Tests for Bankgiro normalization.""" - - def test_normalize_removes_hyphens(self): - """Should remove hyphens from Bankgiro.""" - result = normalize_bankgiro("123-4567") - assert result == "1234567" - - def test_normalize_removes_spaces(self): - """Should remove spaces from Bankgiro.""" - result = normalize_bankgiro("123 4567") - assert result == "1234567" - - def test_normalize_validates_length(self): - """Should validate Bankgiro is 7-8 digits.""" - result = normalize_bankgiro("123456") # 6 digits - assert result is None - - def test_normalize_validates_checksum(self): - """Should validate Luhn checksum.""" - result = normalize_bankgiro("1234568") # Invalid checksum - assert result is None - - @pytest.mark.parametrize("input_value,expected", [ - ("123-4567", "1234567"), - ("1234567", "1234567"), - ("123 4567", "1234567"), - ("BG 123-4567", "1234567"), - ]) - def test_normalize_various_formats(self, input_value, expected): - """Should handle various input formats.""" - result = normalize_bankgiro(input_value) - assert result == expected -``` - -### API Integration Test Pattern -```python -import pytest -from fastapi.testclient import TestClient -from src.web.app import app - -@pytest.fixture -def client(): - return TestClient(app) - -class TestHealthEndpoint: - """Tests for /api/v1/health endpoint.""" - - def test_health_returns_200(self, client): - """Should return 200 OK.""" - response = client.get("/api/v1/health") - assert response.status_code == 200 - - def test_health_returns_status(self, client): - """Should return health status.""" - response = client.get("/api/v1/health") - data = response.json() - assert data["status"] == "healthy" - assert "model_loaded" in data - -class TestInferEndpoint: - """Tests for /api/v1/infer endpoint.""" - - def test_infer_requires_file(self, client): - """Should require file upload.""" - response = client.post("/api/v1/infer") - assert response.status_code == 422 - - def test_infer_rejects_non_pdf(self, client): - """Should reject non-PDF files.""" - response = client.post( - "/api/v1/infer", - files={"file": ("test.txt", b"not a pdf", "text/plain")} - ) - assert response.status_code == 400 - - def test_infer_returns_fields(self, client, sample_invoice_pdf): - """Should return extracted fields.""" - with open(sample_invoice_pdf, "rb") as f: - response = client.post( - "/api/v1/infer", - files={"file": ("invoice.pdf", f, "application/pdf")} - ) - assert response.status_code == 200 - data = response.json() - assert data["success"] is True - assert "fields" in data -``` - -### E2E Test Pattern -```python -import pytest -import httpx -from pathlib import Path - -@pytest.fixture(scope="module") -def running_server(): - """Ensure server is running for E2E tests.""" - # Server should be started before running E2E tests - base_url = "http://localhost:8000" - yield base_url - -class TestInferencePipeline: - """E2E tests for complete inference pipeline.""" - - def test_health_check(self, running_server): - """Should pass health check.""" - response = httpx.get(f"{running_server}/api/v1/health") - assert response.status_code == 200 - data = response.json() - assert data["status"] == "healthy" - assert data["model_loaded"] is True - - def test_pdf_inference_returns_fields(self, running_server): - """Should extract fields from PDF.""" - pdf_path = Path("tests/fixtures/sample_invoice.pdf") - with open(pdf_path, "rb") as f: - response = httpx.post( - f"{running_server}/api/v1/infer", - files={"file": ("invoice.pdf", f, "application/pdf")} - ) - - assert response.status_code == 200 - data = response.json() - assert data["success"] is True - assert "fields" in data - assert len(data["fields"]) > 0 - - def test_cross_validation_included(self, running_server): - """Should include cross-validation for invoices with payment_line.""" - pdf_path = Path("tests/fixtures/invoice_with_payment_line.pdf") - with open(pdf_path, "rb") as f: - response = httpx.post( - f"{running_server}/api/v1/infer", - files={"file": ("invoice.pdf", f, "application/pdf")} - ) - - data = response.json() - if data["fields"].get("payment_line"): - assert "cross_validation" in data -``` - -## Test File Organization - -``` -tests/ -├── conftest.py # Shared fixtures -├── fixtures/ # Test data files -│ ├── sample_invoice.pdf -│ └── invoice_with_payment_line.pdf -├── test_cli/ -│ └── test_infer.py -├── test_pdf/ -│ ├── test_extractor.py -│ └── test_renderer.py -├── test_ocr/ -│ ├── test_paddle_ocr.py -│ └── test_machine_code_parser.py -├── test_inference/ -│ ├── test_pipeline.py -│ ├── test_yolo_detector.py -│ └── test_field_extractor.py -├── test_normalize/ -│ ├── test_bankgiro_normalizer.py -│ ├── test_date_normalizer.py -│ └── test_amount_normalizer.py -├── test_web/ -│ ├── test_routes.py -│ └── test_services.py -└── e2e/ - └── test_inference_e2e.py -``` - -## Mocking External Services - -### Mock PaddleOCR -```python -import pytest -from unittest.mock import Mock, patch - -@pytest.fixture -def mock_paddle_ocr(): - """Mock PaddleOCR for unit tests.""" - with patch("src.ocr.paddle_ocr.PaddleOCR") as mock: - instance = Mock() - instance.ocr.return_value = [ - [ - [[[0, 0], [100, 0], [100, 20], [0, 20]], ("Invoice Number", 0.95)], - [[[0, 30], [100, 30], [100, 50], [0, 50]], ("INV-2024-001", 0.98)] - ] - ] - mock.return_value = instance - yield instance -``` - -### Mock YOLO Model -```python -@pytest.fixture -def mock_yolo_model(): - """Mock YOLO model for unit tests.""" - with patch("src.inference.yolo_detector.YOLO") as mock: - instance = Mock() - # Mock detection results - instance.return_value = Mock( - boxes=Mock( - xyxy=[[10, 20, 100, 50]], - conf=[0.95], - cls=[0] # invoice_number class - ) - ) - mock.return_value = instance - yield instance -``` - -### Mock Database -```python -@pytest.fixture -def mock_db_connection(): - """Mock database connection for unit tests.""" - with patch("src.data.db.get_db_connection") as mock: - conn = Mock() - cursor = Mock() - cursor.fetchall.return_value = [ - ("doc-123", "processed", {"invoice_number": "INV-001"}) - ] - cursor.fetchone.return_value = ("doc-123",) - conn.cursor.return_value.__enter__ = Mock(return_value=cursor) - conn.cursor.return_value.__exit__ = Mock(return_value=False) - mock.return_value.__enter__ = Mock(return_value=conn) - mock.return_value.__exit__ = Mock(return_value=False) - yield conn -``` - -## Test Coverage Verification - -### Run Coverage Report -```bash -# Run with coverage -pytest --cov=src --cov-report=term-missing - -# Generate HTML report -pytest --cov=src --cov-report=html -# Open htmlcov/index.html in browser -``` - -### Coverage Configuration (pyproject.toml) -```toml -[tool.coverage.run] -source = ["src"] -omit = ["*/__init__.py", "*/test_*.py"] - -[tool.coverage.report] -fail_under = 80 -show_missing = true -exclude_lines = [ - "pragma: no cover", - "if TYPE_CHECKING:", - "raise NotImplementedError", -] -``` - -## Common Testing Mistakes to Avoid - -### WRONG: Testing Implementation Details -```python -# Don't test internal state -def test_parser_internal_state(): - parser = PaymentLineParser() - parser._parse("...") - assert parser._groups == [...] # Internal state -``` - -### CORRECT: Test Public Interface -```python -# Test what users see -def test_parser_extracts_bankgiro(): - result = parse_payment_line("...") - assert result.bankgiro == "1234567" -``` - -### WRONG: No Test Isolation -```python -# Tests depend on each other -class TestDocuments: - def test_creates_document(self): - create_document(...) # Creates in DB - - def test_updates_document(self): - update_document(...) # Depends on previous test -``` - -### CORRECT: Independent Tests -```python -# Each test sets up its own data -class TestDocuments: - def test_creates_document(self, mock_db): - result = create_document(...) - assert result.id is not None - - def test_updates_document(self, mock_db): - # Create own test data - doc = create_document(...) - result = update_document(doc.id, ...) - assert result.status == "updated" -``` - -### WRONG: Testing Too Much -```python -# One test doing everything -def test_full_invoice_processing(): - # Load PDF - # Extract images - # Run YOLO - # Run OCR - # Normalize fields - # Save to DB - # Return response -``` - -### CORRECT: Focused Tests -```python -def test_yolo_detects_invoice_number(): - """Test only YOLO detection.""" - result = detector.detect(image) - assert any(d.label == "invoice_number" for d in result) - -def test_ocr_extracts_text(): - """Test only OCR extraction.""" - result = ocr.extract(image, bbox) - assert result == "INV-2024-001" - -def test_normalizer_formats_date(): - """Test only date normalization.""" - result = normalize_date("2024-01-15") - assert result == "2024-01-15" -``` - -## Fixtures (conftest.py) - -```python -import pytest -from pathlib import Path -from fastapi.testclient import TestClient - -@pytest.fixture -def sample_invoice_pdf(tmp_path: Path) -> Path: - """Create sample invoice PDF for testing.""" - pdf_path = tmp_path / "invoice.pdf" - # Copy from fixtures or create minimal PDF - src = Path("tests/fixtures/sample_invoice.pdf") - if src.exists(): - pdf_path.write_bytes(src.read_bytes()) - return pdf_path - -@pytest.fixture -def client(): - """FastAPI test client.""" - from src.web.app import app - return TestClient(app) - -@pytest.fixture -def sample_payment_line() -> str: - """Sample Swedish payment line for testing.""" - return "1234567#0000000012345#230115#00012345678901234567#1" -``` - -## Continuous Testing - -### Watch Mode During Development -```bash -# Using pytest-watch -ptw -- tests/test_ocr/ -# Tests run automatically on file changes -``` - -### Pre-Commit Hook -```bash -# .pre-commit-config.yaml -repos: - - repo: local - hooks: - - id: pytest - name: pytest - entry: pytest --tb=short -q - language: system - pass_filenames: false - always_run: true -``` - -### CI/CD Integration (GitHub Actions) -```yaml -- name: Run Tests - run: | - pytest --cov=src --cov-report=xml - -- name: Upload Coverage - uses: codecov/codecov-action@v3 - with: - file: coverage.xml -``` - -## Best Practices - -1. **Write Tests First** - Always TDD -2. **One Assert Per Test** - Focus on single behavior -3. **Descriptive Test Names** - `test___` -4. **Arrange-Act-Assert** - Clear test structure -5. **Mock External Dependencies** - Isolate unit tests -6. **Test Edge Cases** - None, empty, invalid, boundary -7. **Test Error Paths** - Not just happy paths -8. **Keep Tests Fast** - Unit tests < 50ms each -9. **Clean Up After Tests** - Use fixtures with cleanup -10. **Review Coverage Reports** - Identify gaps - -## Success Metrics - -- 80%+ code coverage achieved -- All tests passing (green) -- No skipped or disabled tests -- Fast test execution (< 60s for unit tests) -- E2E tests cover critical inference flow -- Tests catch bugs before production - ---- - -**Remember**: Tests are not optional. They are the safety net that enables confident refactoring, rapid development, and production reliability. diff --git a/.claude/skills/ui-prompt-generator/SKILL.md b/.claude/skills/ui-prompt-generator/SKILL.md deleted file mode 100644 index 274e347..0000000 --- a/.claude/skills/ui-prompt-generator/SKILL.md +++ /dev/null @@ -1,139 +0,0 @@ ---- -name: ui-prompt-generator -description: 读取 Product-Spec.md 中的功能需求和 UI 布局,生成可用于 AI 绘图工具的原型图提示词。与 product-spec-builder 配套使用,帮助用户快速将需求文档转化为视觉原型。 ---- - -[角色] - 你是一位 UI/UX 设计专家,擅长将产品需求转化为精准的视觉描述。 - - 你能够从结构化的产品文档中提取关键信息,并转化为 AI 绘图工具可以理解的提示词,帮助用户快速生成产品原型图。 - -[任务] - 读取 Product-Spec.md,提取功能需求和 UI 布局信息,补充必要的视觉参数,生成可直接用于文生图工具的原型图提示词。 - - 最终输出按页面拆分的提示词,用户可以直接复制到 AI 绘图工具生成原型图。 - -[技能] - - **文档解析**:从 Product-Spec.md 提取产品概述、功能需求、UI 布局、用户流程 - - **页面识别**:根据产品复杂度识别需要生成几个页面 - - **视觉转换**:将结构化的布局描述转化为视觉语言 - - **提示词生成**:输出高质量的英文文生图提示词 - -[文件结构] - ``` - ui-prompt-generator/ - ├── SKILL.md # 主 Skill 定义(本文件) - └── templates/ - └── ui-prompt-template.md # 提示词输出模板 - ``` - -[总体规则] - - 始终使用中文与用户交流 - - 提示词使用英文输出(AI 绘图工具英文效果更好) - - 必须先读取 Product-Spec.md,不存在则提示用户先完成需求收集 - - 不重复追问 Product-Spec.md 里已有的信息 - - 用户不确定的信息,直接使用默认值继续推进 - - 按页面拆分生成提示词,每个页面一条提示词 - - 保持专业友好的语气 - -[视觉风格选项] - | 风格 | 英文 | 说明 | 适用场景 | - |------|------|------|---------| - | 现代极简 | Minimalism | 简洁留白、干净利落 | 工具类、企业应用 | - | 玻璃拟态 | Glassmorphism | 毛玻璃效果、半透明层叠 | 科技产品、仪表盘 | - | 新拟态 | Neomorphism | 柔和阴影、微凸起效果 | 音乐播放器、控制面板 | - | 便当盒布局 | Bento Grid | 模块化卡片、网格排列 | 数据展示、功能聚合页 | - | 暗黑模式 | Dark Mode | 深色背景、低亮度护眼 | 开发工具、影音类 | - | 新野兽派 | Neo-Brutalism | 粗黑边框、高对比、大胆配色 | 创意类、潮流品牌 | - - **默认值**:现代极简(Minimalism) - -[配色选项] - | 选项 | 说明 | - |------|------| - | 浅色系 | 白色/浅灰背景,深色文字 | - | 深色系 | 深色/黑色背景,浅色文字 | - | 指定主色 | 用户指定品牌色或主题色 | - - **默认值**:浅色系 - -[目标平台选项] - | 选项 | 说明 | - |------|------| - | 桌面端 | Desktop application,宽屏布局 | - | 网页 | Web application,响应式布局 | - | 移动端 | Mobile application,竖屏布局 | - - **默认值**:网页 - -[工作流程] - [启动阶段] - 目的:读取 Product-Spec.md,提取信息,补充缺失的视觉参数 - - 第一步:检测文件 - 检测项目目录中是否存在 Product-Spec.md - 不存在 → 提示:「未找到 Product-Spec.md,请先使用 /prd 完成需求收集。」,终止流程 - 存在 → 继续 - - 第二步:解析 Product-Spec.md - 读取 Product-Spec.md 文件内容 - 提取以下信息: - - 产品概述:了解产品是什么 - - 功能需求:了解有哪些功能 - - UI 布局:了解界面结构和控件 - - 用户流程:了解有哪些页面和状态 - - 视觉风格(如果文档里提到了) - - 配色方案(如果文档里提到了) - - 目标平台(如果文档里提到了) - - 第三步:识别页面 - 根据 UI 布局和用户流程,识别产品包含几个页面 - - 判断逻辑: - - 只有一个主界面 → 单页面产品 - - 有多个界面(如:主界面、设置页、详情页)→ 多页面产品 - - 有明显的多步骤流程 → 按步骤拆分页面 - - 输出页面清单: - "📄 **识别到以下页面:** - 1. [页面名称]:[简要描述] - 2. [页面名称]:[简要描述] - ..." - - 第四步:补充缺失的视觉参数 - 检查是否已提取到:视觉风格、配色方案、目标平台 - - 全部已有 → 跳过提问,直接进入提示词生成阶段 - 有缺失项 → 只针对缺失项询问用户: - - "🎨 **还需要确认几个视觉参数:** - - [只列出缺失的项目,已有的不列] - - 直接回复你的选择,或回复「默认」使用默认值。" - - 用户回复后解析选择 - 用户不确定或回复「默认」→ 使用默认值 - - [提示词生成阶段] - 目的:为每个页面生成提示词 - - 第一步:准备生成参数 - 整合所有信息: - - 产品类型(从产品概述提取) - - 页面列表(从启动阶段获取) - - 每个页面的布局和控件(从 UI 布局提取) - - 视觉风格(从 Product-Spec.md 提取或用户选择) - - 配色方案(从 Product-Spec.md 提取或用户选择) - - 目标平台(从 Product-Spec.md 提取或用户选择) - - 第二步:按页面生成提示词 - 加载 templates/ui-prompt-template.md 获取提示词结构和输出格式 - 为每个页面生成一条英文提示词 - 按模板中的提示词结构组织内容 - - 第三步:输出文件 - 将生成的提示词保存为 UI-Prompts.md - -[初始化] - 执行 [启动阶段] \ No newline at end of file diff --git a/.claude/skills/ui-prompt-generator/templates/ui-prompt-template.md b/.claude/skills/ui-prompt-generator/templates/ui-prompt-template.md deleted file mode 100644 index e79c537..0000000 --- a/.claude/skills/ui-prompt-generator/templates/ui-prompt-template.md +++ /dev/null @@ -1,154 +0,0 @@ ---- -name: ui-prompt-template -description: UI 原型图提示词输出模板。当需要生成文生图提示词时,按照此模板的结构和格式填充内容,输出为 UI-Prompts.md 文件。 ---- - -# UI 原型图提示词模板 - -本模板用于生成可直接用于 AI 绘图工具的原型图提示词。生成时按照此结构填充内容。 - ---- - -## 文件命名 - -`UI-Prompts.md` - ---- - -## 提示词结构 - -每条提示词按以下结构组织: - -``` -[主体] + [布局] + [控件] + [风格] + [质量词] -``` - -### [主体] -产品类型 + 界面类型 + 页面名称 - -示例: -- `A modern web application UI for a storyboard generator tool, main interface` -- `A mobile app screen for a task management application, settings page` - -### [布局] -整体结构 + 比例 + 区域划分 - -示例: -- `split layout with left panel (40%) and right content area (60%)` -- `single column layout with top navigation bar and main content below` -- `grid layout with 2x2 card arrangement` - -### [控件] -各区域的具体控件,从上到下、从左到右描述 - -示例: -- `left panel contains: project name input at top, large text area for content, dropdown menu for style selection, primary action button at bottom` -- `right panel shows: 3x3 grid of image cards with frame numbers and captions, action buttons below` - -### [风格] -视觉风格 + 配色 + 细节特征 - -| 风格 | 英文描述 | -|------|---------| -| 现代极简 | minimalist design, clean layout, ample white space, subtle shadows | -| 玻璃拟态 | glassmorphism style, frosted glass effect, translucent panels, blur background | -| 新拟态 | neumorphism design, soft shadows, subtle highlights, extruded elements | -| 便当盒布局 | bento grid layout, modular cards, organized sections, clean borders | -| 暗黑模式 | dark mode UI, dark background, light text, subtle glow effects | -| 新野兽派 | neo-brutalist design, bold black borders, high contrast, raw aesthetic | - -配色描述: -- 浅色系:`light color scheme, white background, dark text, [accent color] accent` -- 深色系:`dark color scheme, dark gray background, light text, [accent color] accent` - -### [质量词] -确保生成质量的关键词,放在提示词末尾 - -``` -UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style, dribbble, behance -``` - ---- - -## 输出格式 - -```markdown -# [产品名称] 原型图提示词 - -> 视觉风格:[风格名称] -> 配色方案:[配色名称] -> 目标平台:[平台名称] - ---- - -## 页面 1:[页面名称] - -**页面说明**:[一句话描述这个页面是什么] - -**提示词**: -``` -[完整的英文提示词] -``` - ---- - -## 页面 2:[页面名称] - -**页面说明**:[一句话描述] - -**提示词**: -``` -[完整的英文提示词] -``` -``` - ---- - -## 完整示例 - -以下是「剧本分镜生成器」的原型图提示词示例,供参考: - -```markdown -# 剧本分镜生成器 原型图提示词 - -> 视觉风格:现代极简(Minimalism) -> 配色方案:浅色系 -> 目标平台:网页(Web) - ---- - -## 页面 1:主界面 - -**页面说明**:用户输入剧本、设置角色和场景、生成分镜图的主要工作界面 - -**提示词**: -``` -A modern web application UI for a storyboard generator tool, main interface, split layout with left input panel (40% width) and right output area (60% width), left panel contains: project name input field at top, large multiline text area for script input with placeholder text, character cards section with image thumbnails and text fields and add button, scene cards section below, style dropdown menu, prominent generate button at bottom, right panel shows: 3x3 grid of storyboard image cards with frame numbers and short descriptions below each image, download all button and continue generating button below the grid, page navigation at bottom, minimalist design, clean layout, white background, light gray borders, blue accent color for primary actions, subtle shadows, rounded corners, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style -``` - ---- - -## 页面 2:空状态界面 - -**页面说明**:用户首次打开、尚未输入内容时的引导界面 - -**提示词**: -``` -A modern web application UI for a storyboard generator tool, empty state screen, split layout with left panel (40%) and right panel (60%), left panel shows: empty input fields with placeholder text and helper icons, right panel displays: large empty state illustration in the center, welcome message and getting started tips below, minimalist design, clean layout, white background, soft gray placeholder elements, blue accent color, friendly and inviting atmosphere, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style -``` -``` - ---- - -## 写作要点 - -1. **提示词语言**:始终使用英文,AI 绘图工具对英文理解更好 -2. **结构完整**:确保包含主体、布局、控件、风格、质量词五个部分 -3. **控件描述**: - - 按空间顺序描述(上到下、左到右) - - 具体到控件类型(input field, button, dropdown, card) - - 包含控件状态(placeholder text, selected state) -4. **布局比例**:写明具体比例(40%/60%),不要只说「左右布局」 -5. **风格一致**:同一产品的多个页面使用相同的风格描述 -6. **质量词**:始终在末尾加上质量词确保生成效果 -7. **页面说明**:用中文写一句话说明,帮助理解这个页面是什么 diff --git a/.claude/skills/verification-loop/SKILL.md b/.claude/skills/verification-loop/SKILL.md deleted file mode 100644 index 0c2f000..0000000 --- a/.claude/skills/verification-loop/SKILL.md +++ /dev/null @@ -1,242 +0,0 @@ -# Verification Loop Skill - -Comprehensive verification system for Python/FastAPI development. - -## When to Use - -Invoke this skill: -- After completing a feature or significant code change -- Before creating a PR -- When you want to ensure quality gates pass -- After refactoring -- Before deployment - -## Verification Phases - -### Phase 1: Type Check -```bash -# Run mypy type checker -mypy src/ --ignore-missing-imports 2>&1 | head -30 -``` - -Report all type errors. Fix critical ones before continuing. - -### Phase 2: Lint Check -```bash -# Run ruff linter -ruff check src/ 2>&1 | head -30 - -# Auto-fix if desired -ruff check src/ --fix -``` - -Check for: -- Unused imports -- Code style violations -- Common Python anti-patterns - -### Phase 3: Test Suite -```bash -# Run tests with coverage -pytest --cov=src --cov-report=term-missing -q 2>&1 | tail -50 - -# Run specific test file -pytest tests/test_ocr/test_machine_code_parser.py -v - -# Run with short traceback -pytest -x --tb=short -``` - -Report: -- Total tests: X -- Passed: X -- Failed: X -- Coverage: X% -- Target: 80% minimum - -### Phase 4: Security Scan -```bash -# Check for hardcoded secrets -grep -rn "password\s*=" --include="*.py" src/ 2>/dev/null | grep -v "db_password:" | head -10 -grep -rn "api_key\s*=" --include="*.py" src/ 2>/dev/null | head -10 -grep -rn "sk-" --include="*.py" src/ 2>/dev/null | head -10 - -# Check for print statements (should use logging) -grep -rn "print(" --include="*.py" src/ 2>/dev/null | head -10 - -# Check for bare except -grep -rn "except:" --include="*.py" src/ 2>/dev/null | head -10 - -# Check for SQL injection risks (f-strings in execute) -grep -rn 'execute(f"' --include="*.py" src/ 2>/dev/null | head -10 -grep -rn "execute(f'" --include="*.py" src/ 2>/dev/null | head -10 -``` - -### Phase 5: Import Check -```bash -# Verify all imports work -python -c "from src.web.app import app; print('Web app OK')" -python -c "from src.inference.pipeline import InferencePipeline; print('Pipeline OK')" -python -c "from src.ocr.machine_code_parser import parse_payment_line; print('Parser OK')" -``` - -### Phase 6: Diff Review -```bash -# Show what changed -git diff --stat -git diff HEAD --name-only - -# Show staged changes -git diff --staged --stat -``` - -Review each changed file for: -- Unintended changes -- Missing error handling -- Potential edge cases -- Missing type hints -- Mutable default arguments - -### Phase 7: API Smoke Test (if server running) -```bash -# Health check -curl -s http://localhost:8000/api/v1/health | python -m json.tool - -# Verify response format -curl -s http://localhost:8000/api/v1/health | grep -q "healthy" && echo "Health: OK" || echo "Health: FAIL" -``` - -## Output Format - -After running all phases, produce a verification report: - -``` -VERIFICATION REPORT -================== - -Types: [PASS/FAIL] (X errors) -Lint: [PASS/FAIL] (X warnings) -Tests: [PASS/FAIL] (X/Y passed, Z% coverage) -Security: [PASS/FAIL] (X issues) -Imports: [PASS/FAIL] -Diff: [X files changed] - -Overall: [READY/NOT READY] for PR - -Issues to Fix: -1. ... -2. ... -``` - -## Quick Commands - -```bash -# Full verification (WSL) -wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && mypy src/ --ignore-missing-imports && ruff check src/ && pytest -x --tb=short" - -# Type check only -wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && mypy src/ --ignore-missing-imports" - -# Tests only -wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --cov=src -q" -``` - -## Verification Checklist - -### Before Commit -- [ ] mypy passes (no type errors) -- [ ] ruff check passes (no lint errors) -- [ ] All tests pass -- [ ] No print() statements in production code -- [ ] No hardcoded secrets -- [ ] No bare `except:` clauses -- [ ] No SQL injection risks (f-strings in queries) -- [ ] Coverage >= 80% for changed code - -### Before PR -- [ ] All above checks pass -- [ ] git diff reviewed for unintended changes -- [ ] New code has tests -- [ ] Type hints on all public functions -- [ ] Docstrings on public APIs -- [ ] No TODO/FIXME for critical items - -### Before Deployment -- [ ] All above checks pass -- [ ] E2E tests pass -- [ ] Health check returns healthy -- [ ] Model loaded successfully -- [ ] No server errors in logs - -## Common Issues and Fixes - -### Type Error: Missing return type -```python -# Before -def process(data): - return result - -# After -def process(data: dict) -> InferenceResult: - return result -``` - -### Lint Error: Unused import -```python -# Remove unused imports or add to __all__ -``` - -### Security: print() in production -```python -# Before -print(f"Processing {doc_id}") - -# After -logger.info(f"Processing {doc_id}") -``` - -### Security: Bare except -```python -# Before -except: - pass - -# After -except Exception as e: - logger.error(f"Error: {e}") - raise -``` - -### Security: SQL injection -```python -# Before (DANGEROUS) -cur.execute(f"SELECT * FROM docs WHERE id = '{user_input}'") - -# After (SAFE) -cur.execute("SELECT * FROM docs WHERE id = %s", (user_input,)) -``` - -## Continuous Mode - -For long sessions, run verification after major changes: - -```markdown -Checkpoints: -- After completing each function -- After finishing a module -- Before moving to next task -- Every 15-20 minutes of coding - -Run: /verify -``` - -## Integration with Other Skills - -| Skill | Purpose | -|-------|---------| -| code-review | Detailed code analysis | -| security-review | Deep security audit | -| tdd-workflow | Test coverage | -| build-fix | Fix errors incrementally | - -This skill provides quick, comprehensive verification. Use specialized skills for deeper analysis. diff --git a/.opencode/skills/coding-standards/SKILL.md b/.opencode/skills/coding-standards/SKILL.md index 4bb9b71..27bddd2 100644 --- a/.opencode/skills/coding-standards/SKILL.md +++ b/.opencode/skills/coding-standards/SKILL.md @@ -396,7 +396,7 @@ def extract_invoice_fields( ) -> InferenceResult: """Extract structured fields from Swedish invoice PDF. - Uses YOLOv11 for field detection and PaddleOCR for text extraction. + Uses YOLO26 for field detection and PaddleOCR for text extraction. Applies field-specific normalization and validation. Args: diff --git a/AGENTS.md b/AGENTS.md index 214566f..5226e44 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,12 +1,12 @@ # Invoice Master POC v2 -Swedish Invoice Field Extraction System - YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。 +Swedish Invoice Field Extraction System - YOLO26 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。 ## Tech Stack | Component | Technology | |-----------|------------| -| Object Detection | YOLOv11 (Ultralytics) | +| Object Detection | YOLO26 (Ultralytics >= 8.4.0) | | OCR Engine | PaddleOCR v5 (PP-OCRv5) | | PDF Processing | PyMuPDF (fitz) | | Database | PostgreSQL + psycopg2 | @@ -18,7 +18,7 @@ Swedish Invoice Field Extraction System - YOLOv11 + PaddleOCR 从瑞典 PDF 发 **Prefix ALL commands with:** ```bash -wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && " +wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && " ``` **NEVER run Python commands directly in Windows PowerShell/CMD.** diff --git a/ARCHITECTURE_REVIEW.md b/ARCHITECTURE_REVIEW.md index 02864eb..8b40b5a 100644 --- a/ARCHITECTURE_REVIEW.md +++ b/ARCHITECTURE_REVIEW.md @@ -64,7 +64,7 @@ | **前端** | React + Vite + TypeScript + TailwindCSS | ✅ 现代栈 | | **API 框架** | FastAPI | ✅ 高性能,类型安全 | | **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 ORM | -| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 | +| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) | ✅ 业界标准 | | **OCR** | PaddleOCR v5 | ✅ 支持瑞典语 | | **部署** | Docker + Azure/AWS | ✅ 云原生 | diff --git a/CODE_REVIEW_REPORT.md b/CODE_REVIEW_REPORT.md index e64355a..733bd4a 100644 --- a/CODE_REVIEW_REPORT.md +++ b/CODE_REVIEW_REPORT.md @@ -96,7 +96,7 @@ invoice-master-poc-v2/ | **前端** | React 18 + TypeScript + Vite + TailwindCSS | 现代栈,类型安全 | | **API框架** | FastAPI + Uvicorn | 高性能,异步支持 | | **数据库** | PostgreSQL + SQLModel | 类型安全ORM | -| **目标检测** | YOLOv11 (Ultralytics) | 业界标准 | +| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) | 业界标准 | | **OCR** | PaddleOCR v5 | 支持瑞典语 | | **部署** | Docker + Azure/AWS | 云原生 | diff --git a/COMMERCIALIZATION_ANALYSIS_REPORT.md b/COMMERCIALIZATION_ANALYSIS_REPORT.md index 2b954b5..2336dd7 100644 --- a/COMMERCIALIZATION_ANALYSIS_REPORT.md +++ b/COMMERCIALIZATION_ANALYSIS_REPORT.md @@ -26,7 +26,7 @@ ### 项目现状 -Invoice Master是一个基于YOLOv11 + PaddleOCR的瑞典发票字段自动提取系统,具备以下核心能力: +Invoice Master是一个基于YOLO26 + PaddleOCR的瑞典发票字段自动提取系统,具备以下核心能力: | 指标 | 数值 | 评估 | |------|------|------| diff --git a/INFERENCE_ANALYSIS_REPORT.md b/INFERENCE_ANALYSIS_REPORT.md new file mode 100644 index 0000000..54abe0b --- /dev/null +++ b/INFERENCE_ANALYSIS_REPORT.md @@ -0,0 +1,314 @@ +# Inference Analysis Report + +Date: 2026-02-11 +Sample: 39 PDFs (diverse sizes from 1783 total), invoice-sm120 environment + +## Executive Summary + +| Metric | Value | +|--------|-------| +| Total PDFs tested | 39 | +| Successful responses | 35 (89.7%) | +| Timeouts (>120s) | 4 (10.3%) | +| Pure fallback (all fields conf=0.500) | 15/35 (42.9%) | +| Full extraction (all expected fields) | 6/35 (17.1%) | +| supplier_org_number extraction rate | 0% | +| InvoiceDate extraction rate | 31.4% | +| OCR extraction rate | 31.4% | + +**Root Cause**: A critical DPI mismatch bug causes 43% of documents to lose all YOLO-detected field data, falling back to inaccurate regex patterns. + +--- + +## Problem #1 (CRITICAL): DPI Mismatch - Field Extraction Failures + +### Symptom +- 15/35 documents (43%) have ALL extracted fields at confidence=0.500 (fallback) +- YOLO detects fields correctly (6+ detections at conf 0.8-0.97) but text extraction returns nothing +- Examples: `4f822b0d` has 6 YOLO detections but only 1 field extracted via fallback + +### Root Cause +**DPI not passed from pipeline to FieldExtractor** causing 2x coordinate scaling error. + +``` +pipeline.py:237 -> self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu) + ^^^ DPI NOT PASSED! Defaults to 300 +``` + +The chain: +1. `shared/config.py:22` defines `DEFAULT_DPI = 150` +2. `InferencePipeline.__init__()` receives `dpi=150` from `ModelConfig` +3. PDF rendered at **150 DPI** -> YOLO detections in 150 DPI pixel coordinates +4. `FieldExtractor` defaults to `dpi=300` (never receives the actual 150) +5. Coordinate conversion: `scale = 72 / self.dpi` = `72/300 = 0.24` instead of `72/150 = 0.48` +6. Bounding boxes are **halved** in PDF point space -> no tokens match -> empty extraction +7. Fallback regex triggers with conf=0.500 + +### Fix +**File**: `packages/backend/backend/pipeline/pipeline.py`, line 237 + +```python +# BEFORE (broken): +self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu) + +# AFTER (fixed): +self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu, dpi=dpi) +``` + +### Impact +This single-line fix will recover ~43% of documents from degraded fallback to proper YOLO+OCR extraction. + +--- + +## Problem #2 (HIGH): Fallback Amount Extraction Grabs Wrong Values + +### Symptom +- 3 documents extracted Amount=1.00 when actual amounts are 7500.00, etc. +- Fallback regex matches table column header "Summa" followed by row quantity "1,00" instead of total + +### Example +Document `2b7e4103` (Astra Football Club): +- Actual amount: **7 500,00 SEK** +- Extracted: **1.00** (from "Summa 1" where "1" is the article number in the next row) + +### Root Cause +The fallback Amount regex in `pipeline.py:676`: +```python +r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?' +``` +matches "Summa" (column header) followed by "1" (first data in next row), because PaddleOCR produces tokens in position order. The greedy `[\d\s,\.]` captures "1" and stops at "Medlemsavgift". + +### Fix +**File**: `packages/backend/backend/pipeline/pipeline.py`, lines 674-688 + +1. Require minimum amount value in fallback (e.g., > 10.00) +2. Require the matched amount to have a decimal separator (`,` or `.`) to avoid matching integers +3. Prefer "ATT BETALA" over "Summa" as the keyword (less ambiguous) + +```python +'Amount': [ + r'(?:att\s+betala)\s*[:.]?\s*([\d\s]+[,\.]\d{2})\s*(?:SEK|kr)?', + r'([\d\s]+[,\.]\d{2})\s*(?:SEK|kr)\s*$', + r'(?:summa|total|belopp)\s*[:.]?\s*([\d\s]+[,\.]\d{2})\s*(?:SEK|kr)?', +], +``` + +--- + +## Problem #3 (HIGH): Fallback Bankgiro Regex False Positives + +### Symptom +- Document `2b7e4103` extracts Bankgiro=2546-1610 but the actual document has NO Bankgiro +- The document has Plusgiro=2131575-9 and Org.nr=802546-1610 + +### Root Cause +Fallback Bankgiro regex in `pipeline.py:681`: +```python +r'(\d{4}[-\s]\d{4})\s*(?=\s|$)' +``` +matches the LAST 8 digits of org number "802546-1610" as "2546-1610". + +### Fix +**File**: `packages/backend/backend/pipeline/pipeline.py`, line 681 + +Add negative lookbehind to avoid matching within longer numbers: +```python +'Bankgiro': [ + r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})', + r'(? 10: + continue + valid_sequences.append(seq) +``` + +--- + +## Problem #6 (MEDIUM): InvoiceNumber vs OCR Mismatch + +### Symptom +- 5 documents show InvoiceNumber different from OCR number +- Example: `87f470da` InvoiceNumber=852460234111905 vs OCR=524602341119055 +- Example: `8b0674be` InvoiceNumber=508021404131 vs OCR=50802140413 + +### Root Cause +These are legitimate: InvoiceNumber and OCR are detected from DIFFERENT YOLO bounding boxes (different regions of the invoice). The InvoiceNumber normalizer picks a shorter sequence from the invoice_number bbox, while the OCR normalizer extracts from the ocr_number bbox. Cross-validation from payment_line should reconcile these but cross-validation isn't running (0 documents show cross_validation results). + +### Diagnosis Needed +Check why cross-validation / payment_line parsing isn't populating `result.cross_validation` even when payment_line is extracted. + +--- + +## Problem #7 (MEDIUM): supplier_org_number 0% Extraction Rate + +### Symptom +- 0/35 documents extract supplier_org_number +- YOLO detects supplier_org_number in many documents (visible in detection classes) +- When extracted, the field appears as `supplier_organisation_number` (different name) + +### Root Cause +This is actually a reporting issue. The API returns the field as `supplier_organisation_number` (full spelling) from `CLASS_TO_FIELD` mapping, but the analysis expected `supplier_org_number`. Looking at the actual data, 8/35 documents DO have `supplier_organisation_number` extracted. + +However, the underlying issue is that even when YOLO detects `supplier_org_number`, the DPI bug prevents text extraction for text PDFs. + +### Fix +Already addressed by Problem #1 (DPI fix). Additionally, ensure consistent field naming in API documentation. + +--- + +## Problem #8 (LOW): Timeout Failures (4/39 documents) + +### Symptom +- 4 PDFs timed out at 120 seconds +- File sizes: 89KB, 169KB, 239KB, 970KB (not correlated with size) + +### Root Cause +Likely multi-page PDFs or PDFs with complex layouts requiring extensive OCR. The 120s timeout in the test script may be too short for multi-page documents + full-page OCR fallback. + +### Fix +1. Increase API timeout for multi-page PDFs +2. Add page limit or early termination for very large documents +3. Log page count in response to correlate with processing time + +--- + +## Problem #9 (LOW): Non-Invoice Documents in Dataset + +### Symptom +- `dccf6655`: 0 detections, 0 fields - this is a screenshot of UI buttons, NOT an invoice + +### Fix +Add document classification as a pre-processing step to reject non-invoice documents before running the expensive YOLO + OCR pipeline. + +--- + +## Problem #10 (LOW): InvoiceDueDate Before InvoiceDate + +### Symptom +- Document `11de4d07`: InvoiceDate=2026-01-16, InvoiceDueDate=2025-12-01 +- Due date is BEFORE invoice date, which is illogical + +### Root Cause +Either the date normalizer swapped the values, or the YOLO model detected the wrong region for one of the dates. The DPI bug (Problem #1) may also affect date extraction from the correct regions. + +### Fix +Add post-extraction validation: if InvoiceDueDate < InvoiceDate, either swap them or flag for review. + +--- + +## Priority Fix Order + +| Priority | Fix | Impact | Effort | +|----------|-----|--------|--------| +| 1 | DPI mismatch (Problem #1) | 43% of docs recovered | 1 line change | +| 2 | Fallback amount regex (Problem #2) | 3+ docs with wrong amounts | Small regex fix | +| 3 | Fallback bankgiro regex (Problem #3) | False positive bankgiro | Small regex fix | +| 4 | OCR min digits (Problem #4) | Short OCR numbers supported | 1 line change | +| 5 | Year as InvoiceNumber (Problem #5) | 2+ docs | Small logic add | +| 6 | Date validation (Problem #10) | Logical consistency | Small validation add | +| 7 | Cross-validation (Problem #6) | Better field reconciliation | Investigation needed | +| 8 | Timeouts (Problem #8) | 4 docs | Config change | +| 9 | Document classification (Problem #9) | Filter non-invoices | Feature addition | + +--- + +## Re-run Expected After Fix #1 + +After fixing the DPI mismatch alone, re-running the same 39 PDFs should show: +- Pure fallback rate dropping from 43% to near 0% +- InvoiceDate extraction rate improving from 31% to ~70%+ +- OCR extraction rate improving from 31% to ~60%+ +- Average confidence scores increasing significantly +- supplier_organisation_number extraction improving from 23% to ~60%+ + +--- + +## Detailed Per-PDF Results Summary + +| PDF | Size | Time | Fields | Confidence | Issues | +|-----|------|------|--------|------------|--------| +| dccf6655 | 10KB | 17s | 0/0 | - | Not an invoice | +| 4f822b0d | 183KB | 37s | 1/6 | ALL 0.500 | DPI bug: 6 detections, 5 lost | +| d4af7848 | 55KB | 41s | 1/6 | ALL 0.500 | DPI bug: 6 detections, 5 lost | +| 19533483 | 262KB | 39s | 1/9 | ALL 0.500 | DPI bug: 9 detections, 8 lost | +| 2b7e4103 | 25KB | 47s | 3/6 | ALL 0.500 | DPI bug + Amount=1.00 wrong | +| 7717d293 | 34KB | 16s | 3/6 | ALL 0.500 | DPI bug + Amount=1.00 wrong | +| 3226ac59 | 66KB | 42s | 3/5 | ALL 0.500 | DPI bug + Amount=1.00 wrong | +| 0553e5c2 | 31KB | 18s | 3/6 | ALL 0.500 | DPI bug + BG=5000-0000 suspicious | +| 32e90db8 | 136KB | 40s | 3/7 | Mixed | Amount=2026.00 (year?) | +| dc35ee8e | 567KB | 83s | 7/9 | YOLO | InvoiceNumber=2025 (year) | +| 56cabf73 | 67KB | 19s | 5/6 | YOLO | InvoiceNumber=2026 (year) | +| 87f470da | 784KB | 42s | 9/14 | YOLO | InvNum vs OCR mismatch | +| 11de4d07 | 356KB | 68s | 5/3 | Mixed | DueDate < InvoiceDate | +| 0f9047a9 | 415KB | 22s | 8/6 | YOLO | Good extraction | +| 9d0b793c | 286KB | 18s | 8/8 | YOLO | Good extraction | +| 5604d375 | 915KB | 51s | 9/10 | YOLO | Good extraction | +| 87f470da | 784KB | 42s | 9/14 | YOLO | Good extraction | +| f40fd418 | 523KB | 90s | 9/9 | YOLO | Good extraction | + +--- + +## Field Extraction Rate Summary + +| Field | Present | Missing | Rate | Avg Conf | +|-------|---------|---------|------|----------| +| Bankgiro | 32 | 3 | 91.4% | 0.681 | +| InvoiceNumber | 28 | 7 | 80.0% | 0.695 | +| Amount | 27 | 8 | 77.1% | 0.726 | +| InvoiceDueDate | 13 | 22 | 37.1% | 0.883 | +| InvoiceDate | 11 | 24 | 31.4% | 0.879 | +| OCR | 11 | 24 | 31.4% | 0.900 | +| customer_number | 11 | 24 | 31.4% | 0.926 | +| payment_line | 9 | 26 | 25.7% | 0.938 | +| Plusgiro | 3 | 32 | 8.6% | 0.948 | +| supplier_org_number | 0 | 35 | 0.0% | 0.000 | + +Note: Fields with high confidence but low extraction rate (InvoiceDate 0.879, OCR 0.900, payment_line 0.938) indicate the DPI bug: when extraction works (via YOLO), confidence is high. The low rate is because most documents fall back and these fields have no fallback regex pattern. diff --git a/LABELING_STRATEGY_ANALYSIS.md b/LABELING_STRATEGY_ANALYSIS.md new file mode 100644 index 0000000..4e7e7bd --- /dev/null +++ b/LABELING_STRATEGY_ANALYSIS.md @@ -0,0 +1,257 @@ +# Semi-Automatic Labeling Strategy Analysis + +## 1. Current Pipeline Overview + +``` +CSV (field values) + | + v +Autolabel CLI + |- PDF render (300 DPI) + |- Text extraction (PDF text layer or PaddleOCR) + |- FieldMatcher.find_matches() [5 strategies] + | |- ExactMatcher (priority 1) + | |- ConcatenatedMatcher (multi-token) + | |- FuzzyMatcher (Amount, dates only) + | |- SubstringMatcher (prevents false positives) + | |- FlexibleDateMatcher (fallback) + | + |- AnnotationGenerator + | |- PDF points -> pixels + | |- expand_bbox() [field-specific strategy] + | |- pixels -> YOLO normalized (0-1) + | |- Save to database + | + v +DBYOLODataset (training) + |- Load images + bboxes from DB + |- Re-apply expand_bbox() + |- YOLO training + | + v +Inference + |- YOLO detect -> pixel bboxes + |- Crop region -> OCR extract text + |- Normalize & validate +``` + +--- + +## 2. Current Expansion Strategy Analysis + +### 2.1 Field-Specific Parameters + +| Field | Scale X | Scale Y | Extra Top | Extra Left | Extra Right | Max Pad X | Max Pad Y | +|---|---|---|---|---|---|---|---| +| ocr_number | 1.15 | 1.80 | 0.60 | - | - | 50 | 140 | +| bankgiro | 1.45 | 1.35 | - | 0.80 | - | 160 | 90 | +| plusgiro | 1.45 | 1.35 | - | 0.80 | - | 160 | 90 | +| invoice_date | 1.25 | 1.55 | 0.40 | - | - | 80 | 110 | +| invoice_due_date | 1.30 | 1.65 | 0.45 | 0.35 | - | 100 | 120 | +| amount | 1.20 | 1.35 | - | - | 0.30 | 70 | 80 | +| invoice_number | 1.20 | 1.50 | 0.40 | - | - | 80 | 100 | +| supplier_org_number | 1.25 | 1.40 | 0.30 | 0.20 | - | 90 | 90 | +| customer_number | 1.25 | 1.45 | 0.35 | 0.25 | - | 90 | 100 | +| payment_line | 1.10 | 1.20 | - | - | - | 40 | 40 | + +### 2.2 Design Rationale + +The expansion is designed based on Swedish invoice layout patterns: +- **Dates**: Labels ("Fakturadatum") typically sit **above** the value -> extra top +- **Giro accounts**: Prefix ("BG:", "PG:") sits **to the left** -> extra left +- **Amount**: Currency suffix ("SEK", "kr") to the **right** -> extra right +- **Payment line**: Machine-readable, self-contained -> minimal expansion + +### 2.3 Strengths + +1. **Field-specific directional expansion** - matches Swedish invoice conventions +2. **Max padding caps** - prevents runaway expansion into neighboring fields +3. **Center-point scaling** with directional compensation - geometrically sound +4. **Image boundary clamping** - prevents out-of-bounds coordinates + +### 2.4 Potential Issues + +| Issue | Risk Level | Description | +|---|---|---| +| Over-expansion | HIGH | OCR number 1.80x Y-scale could capture adjacent fields | +| Inconsistent training vs inference bbox | MEDIUM | Model trained on expanded boxes, inference returns raw detection | +| No expansion at inference OCR crop | MEDIUM | Detected bbox may clip text edges without post-expansion | +| Max padding in pixels vs DPI-dependent | LOW | 140px at 300DPI != 140px at 150DPI | + +--- + +## 3. Industry Best Practices (Research Findings) + +### 3.1 Labeling: Tight vs. Loose Bounding Boxes + +**Consensus**: Annotate **tight bounding boxes around the value text only**. + +- FUNSD/CORD benchmarks annotate keys and values as **separate entities** +- Loose boxes "introduce background noise and can mislead the model" (V7 Labs, LabelVisor) +- IoU discrepancies from loose boxes degrade mAP during training + +**However**, for YOLO + OCR pipelines, tight-only creates a problem: +- YOLO predicts slightly imprecise boxes (typical IoU 0.7-0.9) +- If the predicted box clips even slightly, OCR misses characters +- Solution: **Label tight, expand at inference** OR **label with controlled padding** + +### 3.2 The Two Dominant Strategies + +**Strategy A: Tight Label + Inference-Time Expansion** (Recommended by research) +``` +Label: [ 2024-01-15 ] (tight around value) +Inference: [ [2024-01-15] ] + pad -> OCR +``` +- Clean, consistent annotations +- Requires post-detection padding before OCR crop +- Used by: Microsoft Document Intelligence, Nanonets + +**Strategy B: Expanded Label at Training Time** (Current project approach) +``` +Label: [ Fakturadatum: 2024-01-15 ] (includes context) +Inference: YOLO detects full region -> OCR extracts from region +``` +- Model learns spatial context (label + value) +- Larger, more variable boxes +- OCR must filter out label text from extracted content + +### 3.3 OCR Padding Requirements + +**Tesseract**: Requires ~10px white border for reliable segmentation (PSM 7-10). +**PaddleOCR**: `det_db_unclip_ratio` parameter (default 1.5) controls detection expansion. + +Key insight: Even after YOLO detection, OCR engines need some padding around text to work reliably. + +### 3.4 State-of-the-Art Comparison + +| System | Bbox Strategy | Field Definition | +|---|---|---| +| **LayoutLM** | Word-level bboxes from OCR | Token classification (BIO tagging) | +| **Donut** | No bboxes (end-to-end) | Internal attention mechanism | +| **Microsoft DocAI** | Field-level, tight | Post-expansion for OCR | +| **YOLO + OCR (this project)** | Field-level, expanded | Field-specific directional expansion | + +--- + +## 4. Recommendations + +### 4.1 Short-Term (Current Architecture) + +#### A. Add Inference-Time OCR Padding +Currently, the detected bbox is sent directly to OCR. Add a small uniform padding (5-10%) before cropping for OCR: + +```python +# In field_extractor.py, before OCR crop: +pad_ratio = 0.05 # 5% expansion +w_pad = (x2 - x1) * pad_ratio +h_pad = (y2 - y1) * pad_ratio +crop_x1 = max(0, x1 - w_pad) +crop_y1 = max(0, y1 - h_pad) +crop_x2 = min(img_w, x2 + w_pad) +crop_y2 = min(img_h, y2 + h_pad) +``` + +#### B. Reduce Training-Time Expansion Ratios +Current ratios (especially OCR number 1.80x Y, Bankgiro 1.45x X) are aggressive. Proposed reduction: + +| Field | Current Scale Y | Proposed Scale Y | Rationale | +|---|---|---|---| +| ocr_number | 1.80 | 1.40 | 1.80 is too aggressive, captures neighbors | +| bankgiro | 1.35 | 1.25 | Reduce vertical over-expansion | +| invoice_due_date | 1.65 | 1.45 | Tighten vertical | + +Principle: **shift expansion work from training-time to inference-time**. + +#### C. Add Label Visualization Quality Check +Before training, sample 50-100 annotated images and visually inspect: +- Are expanded bboxes capturing only the target field? +- Are any bboxes overlapping with adjacent fields? +- Are any values being clipped? + +### 4.2 Medium-Term (Architecture Improvements) + +#### D. Two-Stage Detection Strategy +``` +Stage 1: YOLO detects field regions (current) +Stage 2: Within each detection, use PaddleOCR text detection + to find the precise text boundary +Stage 3: Extract text from refined boundary +``` + +Benefits: +- YOLO handles field classification (what) +- PaddleOCR handles text localization (where exactly) +- Eliminates the "tight vs loose" dilemma entirely + +#### E. Label Both Key and Value Separately +Add new annotation classes: `invoice_date_label`, `invoice_date_value` +- Model learns to find both the label and value +- Use spatial relationship (label -> value) for more robust extraction +- Aligns with FUNSD benchmark approach + +#### F. Confidence-Weighted Expansion +Scale expansion by detection confidence: +```python +# Higher confidence = tighter crop (model is sure) +# Lower confidence = wider crop (give OCR more context) +expansion = base_expansion * (1.5 - confidence) +``` + +### 4.3 Long-Term (Next Generation) + +#### G. Move to LayoutLM-Style Token Classification +- Replace YOLO field detection with token-level classification +- Each OCR word gets classified as B-field/I-field/O +- Eliminates bbox expansion entirely +- Better for fields with complex layouts + +#### H. End-to-End with Donut/Pix2Struct +- No separate OCR step +- Model directly outputs structured fields from image +- Zero bbox concerns +- Requires more training data and compute + +--- + +## 5. Recommended Action Plan + +### Phase 1: Validate Current Labels (1-2 days) +- [ ] Build label visualization script +- [ ] Sample 100 documents across all field types +- [ ] Identify over-expansion and clipping cases +- [ ] Document per-field accuracy of current expansion + +### Phase 2: Tune Expansion Parameters (2-3 days) +- [ ] Reduce aggressive expansion ratios (OCR number, bankgiro) +- [ ] Add inference-time OCR padding (5-10%) +- [ ] Re-train model with adjusted labels +- [ ] Compare mAP and field extraction accuracy + +### Phase 3: Two-Stage Refinement (1 week) +- [ ] Implement PaddleOCR text detection within YOLO detection +- [ ] Use text detection bbox for precise OCR crop +- [ ] Keep YOLO expansion for classification only + +### Phase 4: Evaluation (ongoing) +- [ ] Track per-field extraction accuracy on test set +- [ ] A/B test tight vs expanded labels +- [ ] Build regression test suite for labeling quality + +--- + +## 6. Summary + +| Aspect | Current Approach | Best Practice | Gap | +|---|---|---|---| +| **Labeling** | Value + expansion at label time | Tight value + inference expansion | Medium | +| **Expansion** | Field-specific directional | Field-specific directional | Aligned | +| **Inference OCR crop** | Raw detection bbox | Detection + padding | Needs padding | +| **Expansion ratios** | Aggressive (up to 1.80x) | Moderate (1.10-1.30x) | Over-expanded | +| **Visualization QC** | None | Regular sampling | Missing | +| **Coordinate consistency** | PDF points -> pixels | Consistent DPI | Check needed | + +**Bottom line**: The architecture (field-specific directional expansion) is sound and aligns with best practices. The main improvements are: +1. **Reduce expansion aggressiveness** during training labels +2. **Add OCR padding** at inference time +3. **Add label quality visualization** for validation +4. Longer term: consider **two-stage detection** or **token classification** diff --git a/PLAN_TWO_STAGE_DETECTION.md b/PLAN_TWO_STAGE_DETECTION.md new file mode 100644 index 0000000..c1457d8 --- /dev/null +++ b/PLAN_TWO_STAGE_DETECTION.md @@ -0,0 +1,153 @@ +# Plan: Two-Stage Detection (YOLO + PaddleOCR Value Selection) + +## Context + +Current inference flow: YOLO detects field region -> crop region -> PaddleOCR reads ALL text -> concatenate -> normalizer extracts value via regex. + +Problem: When training labels include label+value (e.g., "Fakturadatum 2024-01-15"), the detected region contains both label text and value text. Currently all OCR tokens are concatenated, and normalizers must regex out the value. This works for most fields but is fragile. + +Solution: After PaddleOCR returns individual text lines from the detected region, add a **value selection** step that picks the most likely value token(s) BEFORE sending to normalizer. This gives normalizers cleaner input and provides a more precise value bbox. + +## Key Insight + +PaddleOCR already returns individual `OCRToken` objects (text + bbox + confidence). The current code just concatenates them all (line 227 of `field_extractor.py`): + +```python +raw_text = ' '.join(t.text for t in ocr_tokens) +``` + +The change: replace blind concatenation with field-aware token selection. + +## Architecture + +``` +Current: + YOLO bbox -> crop -> PaddleOCR -> [all tokens] -> concat -> normalizer + +New: + YOLO bbox -> crop -> PaddleOCR -> [all tokens] -> value_selector -> normalizer + | | + individual selected + text lines value token(s) +``` + +## Scope: Inference Only + +| Pipeline Stage | Changed? | Reason | +|---|---|---| +| **Labeling** (autolabel, expansion) | NO | Expanded bbox gives YOLO strong visual patterns for generalization | +| **Training** (YOLO26s) | NO | Model learns field regions correctly with current labels | +| **Inference - Detection** (YOLO) | NO | Detection output is correct | +| **Inference - Extraction** (OCR -> text) | **YES** | Add ValueSelector between OCR tokens and normalizer | + +This design works with ANY model -- tight-label model (15K docs) or expanded-label model (58 images). ValueSelector is model-agnostic. + +## Files to Modify + +| File | Change | +|------|--------| +| `packages/backend/backend/pipeline/value_selector.py` | **NEW** - Value selection logic per field type | +| `packages/backend/backend/pipeline/field_extractor.py` | Use ValueSelector in both extraction paths | +| `tests/pipeline/test_value_selector.py` | **NEW** - Tests for value selection | + +## Implementation + +### 1. New File: `value_selector.py` + +Core class `ValueSelector` with method: + +```python +def select_value_tokens( + self, + tokens: list[OCRToken], + field_name: str +) -> list[OCRToken]: +``` + +Field-specific selection rules: + +| Field | Strategy | Pattern | +|-------|----------|---------| +| InvoiceDate, InvoiceDueDate | Date pattern match | `\d{4}[-./]\d{2}[-./]\d{2}` or `\d{2}[-./]\d{2}[-./]\d{4}` or `\d{8}` | +| Amount | Number pattern match | `\d[\d\s]*[,.]\d{2}` (prefer tokens with comma/decimal) | +| Bankgiro | Giro pattern match | `\d{3,4}-\d{4}` or 7-8 consecutive digits | +| Plusgiro | Giro pattern match | `\d+-\d` or 2-8 digits | +| OCR | Longest digit sequence | Token with most consecutive digits (min 5) | +| InvoiceNumber | Exclude known labels | Remove tokens matching Swedish label keywords, keep rest | +| supplier_org_number | Org number pattern | `\d{6}-?\d{4}` | +| customer_number | Exclude labels | Remove label keywords, keep alphanumeric tokens | +| payment_line | Full concatenation | Keep all tokens (payment line needs full text) | + +Label keyword exclusion list (Swedish): +```python +LABEL_KEYWORDS = { + "fakturanummer", "fakturadatum", "forfallodag", "forfalldatum", + "bankgiro", "plusgiro", "bg", "pg", "ocr", "belopp", "summa", + "total", "att", "betala", "kundnummer", "organisationsnummer", + "org", "nr", "datum", "nummer", "ref", "referens", + "momsreg", "vat", "moms", "sek", "kr", +} +``` + +Selection algorithm: +1. Try field-specific pattern matching on individual tokens +2. If match found -> return matched token(s) only +3. If no match -> **fallback to ALL tokens** (current behavior, so we never lose data) + +### 2. Modify: `field_extractor.py` + +**OCR path** (`extract_from_detection`, line 224-228): + +```python +# Before (current): +ocr_tokens = self.ocr_engine.extract_from_image(region) +raw_text = ' '.join(t.text for t in ocr_tokens) + +# After: +ocr_tokens = self.ocr_engine.extract_from_image(region) +value_tokens = self._value_selector.select_value_tokens(ocr_tokens, field_name) +raw_text = ' '.join(t.text for t in value_tokens) +``` + +**PDF text path** (`extract_from_detection_with_pdf`, line 172-174): + +```python +# Before (current): +matching_tokens.sort(key=lambda x: -x[1]) +raw_text = ' '.join(t[0].text for t in matching_tokens) + +# After: +matching_tokens.sort(key=lambda x: -x[1]) +all_text_tokens = [OCRToken(text=t[0].text, bbox=t[0].bbox, confidence=1.0) for t in matching_tokens] +value_tokens = self._value_selector.select_value_tokens(all_text_tokens, field_name) +raw_text = ' '.join(t.text for t in value_tokens) +``` + +**Constructor**: Add `ValueSelector` instance. + +### 3. Tests: `test_value_selector.py` + +Test cases per field type: +- Date: "Fakturadatum 2024-01-15" -> selects "2024-01-15" +- Amount: "Belopp 1 234,56 kr" -> selects "1 234,56" +- Bankgiro: "BG: 123-4567" -> selects "123-4567" +- OCR: "OCR 1234567890" -> selects "1234567890" +- InvoiceNumber: "Fakturanr INV-2024-001" -> selects "INV-2024-001" +- Fallback: Unknown pattern -> returns all tokens (no data loss) + +## Key Design Decisions + +1. **Fallback to full text**: If value selection can't identify the value, return ALL tokens. This means the change can never make things worse than current behavior. + +2. **ValueSelector is stateless**: Pure function, no side effects. Easy to test. + +3. **No training changes**: Training labels stay as-is (expanded bboxes). Only inference pipeline changes. + +4. **No normalizer changes**: Normalizers still work the same. They just get cleaner input. + +## Verification + +1. Run existing tests: `pytest tests/pipeline/ -v` +2. Run new tests: `pytest tests/pipeline/test_value_selector.py -v` +3. Manual validation: Run inference on a few invoices and compare raw_text before/after +4. Regression check: Ensure no field extraction accuracy drops on existing test documents diff --git a/PROJECT_REVIEW.md b/PROJECT_REVIEW.md index f7ac99c..974f0f2 100644 --- a/PROJECT_REVIEW.md +++ b/PROJECT_REVIEW.md @@ -8,11 +8,11 @@ ## 项目概述 -**Invoice Master POC v2** - 基于 YOLOv11 + PaddleOCR 的瑞典发票字段自动提取系统 +**Invoice Master POC v2** - 基于 YOLO26 + PaddleOCR 的瑞典发票字段自动提取系统 ### 核心功能 - **自动标注**: 利用 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注 -- **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强 +- **模型训练**: 使用 YOLO26 训练字段检测模型,支持数据增强 - **推理提取**: 检测字段区域 → OCR 提取文本 → 字段规范化 - **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理 @@ -175,7 +175,7 @@ frontend/ # React 前端 (Vite + TypeScript + TailwindCSS) | 组件 | 技术选择 | 评估 | |------|----------|------| -| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 | +| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) | ✅ 业界标准 | | **OCR 引擎** | PaddleOCR v5 | ✅ 支持瑞典语 | | **PDF 处理** | PyMuPDF (fitz) | ✅ 功能强大 | | **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 | diff --git a/README.md b/README.md index 9954bdb..40a28f7 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ # Invoice Master POC v2 -自动发票字段提取系统 - 使用 YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。 +自动发票字段提取系统 - 使用 YOLO26 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。 ## 项目概述 本项目实现了一个完整的发票字段自动提取流程: -1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注 -2. **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强 -3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化 +1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注(统一 15px 填充) +2. **模型训练**: 使用 YOLO26 训练字段检测模型,支持数据增强 +3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> ValueSelector 过滤标签 -> 字段规范化 4. **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理 ### 架构 @@ -37,8 +37,8 @@ frontend/ # React 前端 (Vite + TypeScript + TailwindCSS) |------|------| | **已标注文档** | 9,738 (9,709 成功) | | **总体字段匹配率** | 94.8% (82,604/87,121) | -| **测试** | 2,058 passed | -| **测试覆盖率** | 60% | +| **测试** | 2,047 passed | +| **测试覆盖率** | 72% | | **模型 mAP@0.5** | 93.5% | **各字段匹配率:** @@ -204,7 +204,7 @@ invoice-master-poc-v2/ │ ├── run_server.py # Web 服务器入口 │ └── backend/ │ ├── cli/ # infer, serve -│ ├── pipeline/ # YOLO 检测, 字段提取, 解析器 +│ ├── pipeline/ # YOLO 检测, 字段提取, ValueSelector, 解析器 │ ├── web/ # FastAPI 应用 │ │ ├── api/v1/ # REST API (admin, public, batch) │ │ ├── schemas/ # Pydantic 数据模型 @@ -278,7 +278,7 @@ python -m training.cli.autolabel --workers 4 ```bash # 从预训练模型开始训练 python -m training.cli.train \ - --model yolo11n.pt \ + --model yolo26s.pt \ --epochs 100 \ --batch 16 \ --name invoice_fields \ @@ -286,7 +286,7 @@ python -m training.cli.train \ # 低内存模式 python -m training.cli.train \ - --model yolo11n.pt \ + --model yolo26s.pt \ --epochs 100 \ --name invoice_fields \ --low-memory @@ -443,6 +443,30 @@ result = parser.parse("Said, Shakar Umj 436-R Billo") print(f"Customer Number: {result}") # "UMJ 436-R" ``` +## 推理流水线 (Two-Stage Detection) + +``` +YOLO bbox -> crop -> PaddleOCR -> [all tokens] -> ValueSelector -> normalizer + | | + individual selected + text lines value token(s) +``` + +**BBox 扩展**: 所有字段统一使用 15px 填充(150 DPI 下约 2.5mm),不做方向性扩展,不依赖布局假设。 + +**ValueSelector**: 在 OCR 和 normalizer 之间按字段类型过滤标签文本,只保留值 token: + +| 字段 | 选择策略 | 示例输入 -> 输出 | +|------|---------|-----------------| +| InvoiceDate / DueDate | 日期模式匹配 | "Fakturadatum 2024-01-15" -> "2024-01-15" | +| Amount | 金额模式匹配 | "Belopp 1 234,56 kr" -> "1 234,56" | +| Bankgiro / Plusgiro | Giro 号码模式 | "BG: 123-4567" -> "123-4567" | +| OCR | 最长数字序列 (>=5位) | "OCR 94228110015950070" -> "94228110015950070" | +| InvoiceNumber | 排除瑞典语标签 | "Fakturanr INV-2024-001" -> "INV-2024-001" | +| payment_line | 保留全部 | 不过滤 | + +如果没有匹配到任何模式,回退返回全部 token(永远不会比之前更差)。 + ## DPI 配置 系统所有组件统一使用 **150 DPI**。DPI 必须在训练和推理时保持一致。 @@ -506,9 +530,9 @@ DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing | 指标 | 数值 | |------|------| -| **测试总数** | 2,058 | +| **测试总数** | 2,047 | | **通过率** | 100% | -| **覆盖率** | 60% | +| **覆盖率** | 72% | ## 存储抽象层 @@ -619,7 +643,7 @@ npm run dev | 组件 | 技术 | |------|------| -| **目标检测** | YOLOv11 (Ultralytics) | +| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) | | **OCR 引擎** | PaddleOCR v5 (PP-OCRv5) | | **PDF 处理** | PyMuPDF (fitz) | | **数据库** | PostgreSQL + SQLModel | diff --git a/configs/default.yaml b/configs/default.yaml index b2c81ff..1704ead 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -76,7 +76,7 @@ matching: # YOLO Training yolo: - model: yolov8s # Model architecture (yolov8n/s/m/l/x) + model: yolo26s # Model architecture (yolo26n/s/m/l/x) epochs: 100 batch_size: 16 img_size: 1280 # Image size for training diff --git a/configs/training.yaml b/configs/training.yaml index 4eb2a89..727d37f 100644 --- a/configs/training.yaml +++ b/configs/training.yaml @@ -2,7 +2,7 @@ # Use with: yolo train data=dataset.yaml cfg=training.yaml # Model -model: yolov8s.pt +model: yolo26s.pt # Training hyperparameters epochs: 100 @@ -57,3 +57,12 @@ name: invoice_fields exist_ok: true pretrained: true verbose: true + +# Fine-tuning profile (overrides when task_type == finetune) +finetune: + epochs: 10 + lr0: 0.001 + freeze: 10 + warmup_epochs: 1 + cos_lr: true + patience: 5 diff --git a/docs/FORTNOX_INTEGRATION_SPEC.md b/docs/FORTNOX_INTEGRATION_SPEC.md index 79dfc32..3924ffe 100644 --- a/docs/FORTNOX_INTEGRATION_SPEC.md +++ b/docs/FORTNOX_INTEGRATION_SPEC.md @@ -32,7 +32,7 @@ ### 1.1 项目背景 -Invoice Master是一个基于YOLOv11 + PaddleOCR的发票字段自动提取系统,当前准确率达到94.8%。本方案设计将Invoice Master作为Fortnox会计软件的插件/扩展,实现无缝的发票数据导入功能。 +Invoice Master是一个基于YOLO26 + PaddleOCR的发票字段自动提取系统,当前准确率达到94.8%。本方案设计将Invoice Master作为Fortnox会计软件的插件/扩展,实现无缝的发票数据导入功能。 ### 1.2 目标 diff --git a/docs/aws-deployment-guide.md b/docs/aws-deployment-guide.md index f0e6ac5..1af9bbc 100644 --- a/docs/aws-deployment-guide.md +++ b/docs/aws-deployment-guide.md @@ -500,7 +500,7 @@ estimator = PyTorch( hyperparameters={ "epochs": 100, "batch-size": 16, - "model": "yolo11n.pt" + "model": "yolo26s.pt" } ) ``` diff --git a/docs/azure-deployment-guide.md b/docs/azure-deployment-guide.md index f2dec60..33ec5bf 100644 --- a/docs/azure-deployment-guide.md +++ b/docs/azure-deployment-guide.md @@ -152,7 +152,7 @@ rclone mount azure:training-images Z: --vfs-cache-mode full ### 推荐: Container Apps (CPU) 对于 YOLO 推理,**CPU 足够**,不需要 GPU: -- YOLOv11n 在 CPU 上推理时间 ~200-500ms +- YOLO26s 在 CPU 上推理时间 ~200-500ms - 比 GPU 便宜很多,适合中低流量 ```yaml @@ -335,7 +335,7 @@ az containerapp create \ │ ~$30/月 │ │ ~$1-5/次训练 │ │ │ │ │ │ │ │ │ │ ┌───────────────────┐ │ │ ┌───────────────────┐ │ │ ┌───────────────────┐ │ -│ │ FastAPI + YOLO │ │ │ │ YOLOv11 Training │ │ │ │ React/Vue 前端 │ │ +│ │ FastAPI + YOLO │ │ │ │ YOLO26 Training │ │ │ │ React/Vue 前端 │ │ │ │ /api/v1/infer │ │ │ │ 100 epochs │ │ │ │ 上传发票界面 │ │ │ └───────────────────┘ │ │ └───────────────────┘ │ │ └───────────────────┘ │ └───────────┬───────────┘ └───────────┬───────────┘ └───────────┬───────────┘ diff --git a/docs/fine-tuning-best-practices.md b/docs/fine-tuning-best-practices.md new file mode 100644 index 0000000..fdd078b --- /dev/null +++ b/docs/fine-tuning-best-practices.md @@ -0,0 +1,185 @@ +# YOLO Model Fine-Tuning Best Practices + +Production guide for continuous fine-tuning of YOLO object detection models with user feedback. + +## Overview + +When users report failed detections, those documents are collected, reviewed, and used to incrementally improve the model without degrading performance on existing data. + +Key risks: +- **Catastrophic forgetting**: model forgets original training after fine-tuning on small new data +- **Cumulative drift**: repeated fine-tuning sessions cause progressive degradation +- **Overfitting**: few samples + many epochs = memorizing noise + +## 1. Data Management + +``` +Original training set (25K) --> permanently retained as "anchor dataset" + | +User-reported failures --> human review & labeling --> "fine-tune pool" + | +Fine-tune pool accumulates over time, never deleted +``` + +Every new sample MUST be human-verified before entering the fine-tune pool. Incorrect labels are more harmful than no labels. + +### Data Mixing Ratios + +| Accumulated New Samples | Old Data Multiplier | Total Training Size | +|------------------------|--------------------|--------------------| +| 10 | 50x (500) | 510 | +| 50 | 20x (1,000) | 1,050 | +| 200 | 10x (2,000) | 2,200 | +| 500+ | 5x (2,500) | 3,000 | + +Principle: fewer new samples require higher old data ratio. Stabilize at 5x once pool reaches 500+. + +Old samples are randomly sampled from the original 25K each time, ensuring broad coverage. + +## 2. Model Version Management + +``` +base_v1.pt (original 25K training) + +-- ft_v1.1.pt (base + fine-tune batch 1) + +-- ft_v1.2.pt (base + fine-tune batch 1+2) + +-- ... + +When fine-tune pool reaches 2000+ samples: +base_v2.pt (original 25K + all accumulated samples, trained from scratch) + +-- ft_v2.1.pt + +-- ... +``` + +CRITICAL: Never chain fine-tunes (ft_v1.1 -> ft_v1.2 -> ft_v1.3). Always start from the base model to avoid cumulative drift. + +## 3. Fine-Tuning Parameters + +```yaml +base_model: best.pt # always start from base model +epochs: 10 # few epochs are sufficient +lr0: 0.001 # 1/10 of base training lr +freeze: 10 # freeze first 10 backbone layers +warmup_epochs: 1 +cos_lr: true + +# data mixing +new_samples: all # entire fine-tune pool +old_samples: min(5x_new, 3000) # old data sampling, cap at 3000 +``` + +### Why These Settings + +| Parameter | Rationale | +|-----------|-----------| +| `epochs: 10` | More than enough for small datasets; prevents overfitting | +| `lr0: 0.001` | Low learning rate preserves base model knowledge | +| `freeze: 10` | Backbone features are general; only fine-tune detection head and later layers | +| `cos_lr: true` | Smooth decay prevents sharp weight updates | + +## 4. Deployment Gating (Most Important) + +Every fine-tuned model MUST pass three gates before deployment: + +### Gate 1: Regression Validation + +Run evaluation on the original test set (held out from the 25K training data). + +| mAP50 Change | Action | +|-------------|--------| +| Drop < 1% | PASS - deploy | +| Drop 1-3% | REVIEW - human inspection required | +| Drop > 3% | REJECT - do not deploy | + +### Gate 2: New Sample Validation + +Run inference on the new failure documents. + +| Detection Rate | Action | +|---------------|--------| +| > 80% correct | PASS | +| < 80% correct | REVIEW - check label quality or increase training | + +### Gate 3: A/B Comparison (Optional) + +Sample 100 production documents, run both old and new models: +- New model must not be worse on any field type +- Compare per-class mAP to detect targeted regressions + +## 5. Fine-Tuning Frequency + +| Strategy | Trigger | Recommendation | +|----------|---------|---------------| +| **By volume (recommended)** | Pool reaches 50+ new samples | Best signal-to-noise ratio | +| By schedule | Weekly or monthly | Predictable but may trigger with insufficient data | +| By performance | Monitored accuracy drops below threshold | Reactive, requires monitoring infrastructure | + +Do NOT fine-tune daily with fewer than 50 samples. The noise outweighs the signal. + +## 6. Complete Workflow + +``` +User marks failed document + | + v +Human reviews and labels annotations + | + v +Add to fine-tune pool + | + v +Pool >= 50 samples? --NO--> Wait for more samples + | + YES + | + v +Prepare mixed dataset: + - All samples from fine-tune pool + - Random sample 5x from original 25K + | + v +Fine-tune from base.pt: + - 10 epochs + - lr0 = 0.001 + - freeze first 10 layers + | + v +Gate 1: Original test set mAP drop < 1%? + | + PASS + | + v +Gate 2: New sample detection rate > 80%? + | + PASS + | + v +Deploy new model, retain old model for rollback + | + v +Pool accumulated 2000+ samples? + | + YES --> Merge all data, train new base from scratch +``` + +## 7. Monitoring in Production + +Track these metrics continuously: + +| Metric | Purpose | Alert Threshold | +|--------|---------|----------------| +| Detection rate per field | Catch field-specific regressions | < 90% for any field | +| Average confidence score | Detect model uncertainty drift | Drop > 5% from baseline | +| User-reported failures / week | Measure improvement trend | Increasing over 3 weeks | +| Inference latency | Ensure model size hasn't bloated | > 2x baseline | + +## 8. Summary of Rules + +| Rule | Practice | +|------|----------| +| Never chain fine-tunes | Always start from base.pt | +| Never use only new data | Must mix with old data | +| Never fine-tune on < 50 samples | Accumulate before triggering | +| Never auto-deploy | Must pass gating validation | +| Never discard old models | Retain versions for rollback | +| Periodically retrain base | Merge all data at 2000+ new samples | +| Always human-review labels | Bad labels are worse than no labels | diff --git a/docs/product-plan-v2.md b/docs/product-plan-v2.md index 4e127ce..f6ee796 100644 --- a/docs/product-plan-v2.md +++ b/docs/product-plan-v2.md @@ -546,7 +546,7 @@ Request: "description": "First training run with 500 documents", "document_ids": ["uuid1", "uuid2", "uuid3"], "config": { - "model_name": "yolo11n.pt", + "model_name": "yolo26s.pt", "epochs": 100, "batch_size": 16, "image_size": 640 @@ -1036,7 +1036,7 @@ Response: | | Name: [Training Run 2024-01____________] | | | | Description: [First training with 500 documents_________] | | | | | | -| | Base Model: [yolo11n.pt v] Epochs: [100] Batch: [16] | | +| | Base Model: [yolo26s.pt v] Epochs: [100] Batch: [16] | | | | Image Size: [640] Device: [GPU 0 v] | | | | | | | | [ ] Schedule for later: [2024-01-20] [22:00] | | @@ -1088,7 +1088,7 @@ Response: | | - Recall: 92% | | | | | | | | Configuration: | | -| | - Base: yolo11n.pt Epochs: 100 Batch: 16 Size: 640 | | +| | - Base: yolo26s.pt Epochs: 100 Batch: 16 Size: 640 | | | | | | | | Documents Used: [View 600 documents] | | | +--------------------------------------------------------------+ | diff --git a/docs/training-flow.mmd b/docs/training-flow.mmd index b4ed0c8..1235d99 100644 --- a/docs/training-flow.mmd +++ b/docs/training-flow.mmd @@ -27,7 +27,7 @@ flowchart TD I --> I1{--resume?} I1 -- Yes --> I2[Load last.pt checkpoint] - I1 -- No --> I3[Load pretrained model\ne.g. yolo11n.pt] + I1 -- No --> I3[Load pretrained model\ne.g. yolo26s.pt] I2 --> J[Configure Training] I3 --> J diff --git a/frontend/src/api/endpoints/index.ts b/frontend/src/api/endpoints/index.ts index 1533939..65e8536 100644 --- a/frontend/src/api/endpoints/index.ts +++ b/frontend/src/api/endpoints/index.ts @@ -5,4 +5,5 @@ export { inferenceApi } from './inference' export { datasetsApi } from './datasets' export { augmentationApi } from './augmentation' export { modelsApi } from './models' +export { poolApi } from './pool' export { dashboardApi } from './dashboard' diff --git a/frontend/src/api/endpoints/pool.ts b/frontend/src/api/endpoints/pool.ts new file mode 100644 index 0000000..28eaaa7 --- /dev/null +++ b/frontend/src/api/endpoints/pool.ts @@ -0,0 +1,40 @@ +import apiClient from '../client' +import type { + PoolListResponse, + PoolStatsResponse, + PoolEntryResponse, +} from '../types' + +export const poolApi = { + addToPool: async (documentId: string, reason?: string): Promise => { + const { data } = await apiClient.post('/api/v1/admin/training/pool', { + document_id: documentId, + reason: reason ?? 'manual_addition', + }) + return data + }, + + listEntries: async (params?: { + verified_only?: boolean + limit?: number + offset?: number + }): Promise => { + const { data } = await apiClient.get('/api/v1/admin/training/pool', { params }) + return data + }, + + getStats: async (): Promise => { + const { data } = await apiClient.get('/api/v1/admin/training/pool/stats') + return data + }, + + verifyEntry: async (entryId: string): Promise => { + const { data } = await apiClient.post(`/api/v1/admin/training/pool/${entryId}/verify`) + return data + }, + + removeEntry: async (entryId: string): Promise<{ message: string }> => { + const { data } = await apiClient.delete(`/api/v1/admin/training/pool/${entryId}`) + return data + }, +} diff --git a/frontend/src/api/types.ts b/frontend/src/api/types.ts index 428f53d..021e1aa 100644 --- a/frontend/src/api/types.ts +++ b/frontend/src/api/types.ts @@ -111,6 +111,9 @@ export interface ModelVersionItem { is_active: boolean metrics_mAP: number | null document_count: number + model_type?: string + base_model_version_id?: string | null + gating_status?: string trained_at: string | null activated_at: string | null created_at: string @@ -370,19 +373,6 @@ export interface TrainingTaskResponse { // Model Version types -export interface ModelVersionItem { - version_id: string - version: string - name: string - status: string - is_active: boolean - metrics_mAP: number | null - document_count: number - trained_at: string | null - activated_at: string | null - created_at: string -} - export interface ModelVersionDetailResponse { version_id: string version: string @@ -397,6 +387,10 @@ export interface ModelVersionDetailResponse { metrics_precision: number | null metrics_recall: number | null document_count: number + model_type?: string + base_model_version_id?: string | null + base_training_dataset_id?: string | null + gating_status?: string training_config: Record | null file_size: number | null trained_at: string | null @@ -405,6 +399,39 @@ export interface ModelVersionDetailResponse { updated_at: string } +// Fine-Tune Pool types + +export interface PoolEntryItem { + entry_id: string + document_id: string + added_by: string | null + reason: string | null + is_verified: boolean + verified_at: string | null + verified_by: string | null + created_at: string +} + +export interface PoolListResponse { + total: number + limit: number + offset: number + entries: PoolEntryItem[] +} + +export interface PoolStatsResponse { + total_entries: number + verified_entries: number + unverified_entries: number + is_ready: boolean + min_required: number +} + +export interface PoolEntryResponse { + entry_id: string + message: string +} + export interface ModelVersionListResponse { total: number limit: number diff --git a/frontend/src/components/Models.tsx b/frontend/src/components/Models.tsx index bfe2222..03befe2 100644 --- a/frontend/src/components/Models.tsx +++ b/frontend/src/components/Models.tsx @@ -74,13 +74,33 @@ export const Models: React.FC = () => {

Trained {formatDate(model.trained_at)}

- - {model.is_active ? 'Active' : model.status} - +
+ {(model.model_type ?? 'base') === 'finetune' && ( + + Fine-tuned + + )} + {model.gating_status && model.gating_status !== 'skipped' && ( + + {model.gating_status === 'pass' ? 'PASS' + : model.gating_status === 'review' ? 'REVIEW' + : model.gating_status === 'reject' ? 'REJECT' + : model.gating_status.toUpperCase()} + + )} + + {model.is_active ? 'Active' : model.status} + +
diff --git a/frontend/src/components/Training.tsx b/frontend/src/components/Training.tsx index 480ff4f..aaea18b 100644 --- a/frontend/src/components/Training.tsx +++ b/frontend/src/components/Training.tsx @@ -1,15 +1,15 @@ import React, { useState, useMemo } from 'react' -import { useQuery } from '@tanstack/react-query' -import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react' +import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query' +import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle, Shield, CheckCircle, XCircle } from 'lucide-react' import { Button } from './Button' import { AugmentationConfig } from './AugmentationConfig' import { useDatasets } from '../hooks/useDatasets' import { useTrainingDocuments } from '../hooks/useTraining' -import { trainingApi } from '../api/endpoints' -import type { DatasetListItem } from '../api/types' +import { trainingApi, poolApi } from '../api/endpoints' +import type { DatasetListItem, PoolEntryItem } from '../api/types' import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation' -type Tab = 'datasets' | 'create' +type Tab = 'datasets' | 'create' | 'pool' interface TrainingProps { onNavigate?: (view: string, id?: string) => void @@ -72,19 +72,23 @@ const TrainDialog: React.FC = ({ dataset, onClose, onSubmit, i const [augmentationConfig, setAugmentationConfig] = useState>({}) const [augmentationMultiplier, setAugmentationMultiplier] = useState(2) + const isFineTune = baseModelType === 'existing' + // Fetch available trained models (active or inactive, not archived) const { data: modelsData } = useQuery({ queryKey: ['training', 'models', 'available'], queryFn: () => trainingApi.getModels(), }) - // Filter out archived models - only show active/inactive models for base model selection - const availableModels = (modelsData?.models ?? []).filter(m => m.status !== 'archived') + // Only show base models (not fine-tuned) for selection - prevents chaining fine-tunes + const availableModels = (modelsData?.models ?? []).filter( + m => m.status !== 'archived' && (m.model_type ?? 'base') === 'base' + ) const handleSubmit = () => { onSubmit({ name, config: { - model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined, + model_name: baseModelType === 'pretrained' ? 'yolo26s.pt' : undefined, base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null, epochs, batch_size: batchSize, @@ -121,14 +125,16 @@ const TrainDialog: React.FC = ({ dataset, onClose, onSubmit, i if (e.target.value === 'pretrained') { setBaseModelType('pretrained') setBaseModelVersionId(null) + setEpochs(100) } else { setBaseModelType('existing') setBaseModelVersionId(e.target.value) + setEpochs(10) // Fine-tune: fewer epochs per best practices } }} className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" > - + {availableModels.map(m => (
+ {/* Fine-tune info panel */} + {isFineTune && ( +
+

Fine-Tune Mode

+
    +
  • Epochs: 10 (auto-set), Backbone frozen (10 layers)
  • +
  • Cosine LR scheduler, Pool data mixed with old data
  • +
  • Requires 50+ verified pool entries
  • +
  • Deployment gating runs automatically after training
  • +
+
+ )} +
@@ -455,6 +474,148 @@ const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitch ) } +// --- Fine-Tune Pool --- + +const FineTunePool: React.FC = () => { + const queryClient = useQueryClient() + + const { data: statsData, isLoading: isLoadingStats } = useQuery({ + queryKey: ['pool', 'stats'], + queryFn: () => poolApi.getStats(), + }) + + const { data: entriesData, isLoading: isLoadingEntries } = useQuery({ + queryKey: ['pool', 'entries'], + queryFn: () => poolApi.listEntries({ limit: 50 }), + }) + + const verifyMutation = useMutation({ + mutationFn: (entryId: string) => poolApi.verifyEntry(entryId), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['pool'] }) + }, + }) + + const removeMutation = useMutation({ + mutationFn: (entryId: string) => poolApi.removeEntry(entryId), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['pool'] }) + }, + }) + + const stats = statsData + const entries = entriesData?.entries ?? [] + + return ( +
+ {/* Pool Stats */} +
+ {isLoadingStats ? ( +
+ Loading stats... +
+ ) : ( + <> +
+

Total Entries

+

{stats?.total_entries ?? 0}

+
+
+

Verified

+

{stats?.verified_entries ?? 0}

+
+
+

Unverified

+

{stats?.unverified_entries ?? 0}

+
+
+

Ready for Fine-Tune

+
+ {stats?.is_ready ? ( + + ) : ( + + )} +

+ {stats?.is_ready ? 'Yes' : `Need ${(stats?.min_required ?? 50) - (stats?.verified_entries ?? 0)} more`} +

+
+
+ + )} +
+ + {/* Pool Entries Table */} + {isLoadingEntries ? ( +
+ Loading pool entries... +
+ ) : entries.length === 0 ? ( +
+ +

Fine-tune pool is empty

+

Add documents with extraction failures to the pool for future fine-tuning.

+
+ ) : ( +
+ + + + + + + + + + + + {entries.map((entry: PoolEntryItem) => ( + + + + + + + + ))} + +
Document IDReasonStatusAddedActions
{entry.document_id.slice(0, 8)}...{entry.reason ?? '-'} + + {entry.is_verified ? : } + {entry.is_verified ? 'Verified' : 'Unverified'} + + {new Date(entry.created_at).toLocaleDateString()} +
+ {!entry.is_verified && ( + + )} + +
+
+
+ )} +
+ ) +} + // --- Main Training Component --- export const Training: React.FC = ({ onNavigate }) => { @@ -468,7 +629,7 @@ export const Training: React.FC = ({ onNavigate }) => { {/* Tabs */}
- {([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => ( + {([['datasets', 'Datasets'], ['create', 'Create Dataset'], ['pool', 'Fine-Tune Pool']] as const).map(([key, label]) => (
) } diff --git a/packages/backend/backend/data/admin_models.py b/packages/backend/backend/data/admin_models.py index 2639680..8c753d9 100644 --- a/packages/backend/backend/data/admin_models.py +++ b/packages/backend/backend/data/admin_models.py @@ -289,6 +289,16 @@ class ModelVersion(SQLModel, table=True): is_active: bool = Field(default=False, index=True) # Only one version can be active at a time for inference + # Model lineage + model_type: str = Field(default="base", max_length=20, index=True) + # "base" = trained from pretrained YOLO, "finetune" = fine-tuned from base model + base_model_version_id: UUID | None = Field(default=None, index=True) + # Points to the base model this was fine-tuned from (None for base models) + base_training_dataset_id: UUID | None = Field(default=None, index=True) + # The dataset used for original base training (for data mixing old samples) + gating_status: str = Field(default="pending", max_length=20, index=True) + # Deployment gating: pending, pass, review, reject, skipped + # Training association task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True) dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True) @@ -317,6 +327,64 @@ class ModelVersion(SQLModel, table=True): updated_at: datetime = Field(default_factory=datetime.utcnow) +# ============================================================================= +# Fine-Tune Pool +# ============================================================================= + + +class FineTunePoolEntry(SQLModel, table=True): + """Document in the fine-tune pool for incremental model improvement.""" + + __tablename__ = "finetune_pool_entries" + + entry_id: UUID = Field(default_factory=uuid4, primary_key=True) + document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True) + added_by: str | None = Field(default=None, max_length=255) + reason: str | None = Field(default=None, max_length=255) + # Reason: user_reported_failure, manual_addition + is_verified: bool = Field(default=False, index=True) + verified_at: datetime | None = Field(default=None) + verified_by: str | None = Field(default=None, max_length=255) + created_at: datetime = Field(default_factory=datetime.utcnow) + + +# ============================================================================= +# Deployment Gating +# ============================================================================= + + +class GatingResult(SQLModel, table=True): + """Model deployment gating validation result.""" + + __tablename__ = "gating_results" + + result_id: UUID = Field(default_factory=uuid4, primary_key=True) + model_version_id: UUID = Field(foreign_key="model_versions.version_id", index=True) + task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id") + + # Gate 1: Regression validation (original test set mAP) + gate1_status: str = Field(default="pending", max_length=20) + # pending, pass, review, reject + gate1_original_mAP: float | None = Field(default=None) + gate1_new_mAP: float | None = Field(default=None) + gate1_mAP_drop: float | None = Field(default=None) + + # Gate 2: New sample validation (detection rate on pool docs) + gate2_status: str = Field(default="pending", max_length=20) + gate2_detection_rate: float | None = Field(default=None) + gate2_total_samples: int | None = Field(default=None) + gate2_detected_samples: int | None = Field(default=None) + + # Overall + overall_status: str = Field(default="pending", max_length=20) + # pending, pass, review, reject + reviewer_notes: str | None = Field(default=None) + reviewed_by: str | None = Field(default=None, max_length=255) + reviewed_at: datetime | None = Field(default=None) + + created_at: datetime = Field(default_factory=datetime.utcnow) + + # ============================================================================= # Annotation History (v2) # ============================================================================= diff --git a/packages/backend/backend/data/repositories/__init__.py b/packages/backend/backend/data/repositories/__init__.py index c24fa8d..1bea934 100644 --- a/packages/backend/backend/data/repositories/__init__.py +++ b/packages/backend/backend/data/repositories/__init__.py @@ -13,6 +13,7 @@ from backend.data.repositories.training_task_repository import TrainingTaskRepos from backend.data.repositories.dataset_repository import DatasetRepository from backend.data.repositories.model_version_repository import ModelVersionRepository from backend.data.repositories.batch_upload_repository import BatchUploadRepository +from backend.data.repositories.finetune_pool_repository import FineTunePoolRepository __all__ = [ "BaseRepository", @@ -23,4 +24,5 @@ __all__ = [ "DatasetRepository", "ModelVersionRepository", "BatchUploadRepository", + "FineTunePoolRepository", ] diff --git a/packages/backend/backend/data/repositories/finetune_pool_repository.py b/packages/backend/backend/data/repositories/finetune_pool_repository.py new file mode 100644 index 0000000..7c70f99 --- /dev/null +++ b/packages/backend/backend/data/repositories/finetune_pool_repository.py @@ -0,0 +1,131 @@ +""" +Fine-Tune Pool Repository + +Manages the fine-tune pool: accumulated user-reported failure documents +for incremental model improvement. +""" + +import logging +from datetime import datetime +from uuid import UUID + +from sqlalchemy import func +from sqlmodel import select + +from backend.data.database import get_session_context +from backend.data.admin_models import FineTunePoolEntry +from backend.data.repositories.base import BaseRepository + +logger = logging.getLogger(__name__) + + +class FineTunePoolRepository(BaseRepository[FineTunePoolEntry]): + """Repository for fine-tune pool management.""" + + def add_document( + self, + document_id: str | UUID, + added_by: str | None = None, + reason: str | None = None, + ) -> FineTunePoolEntry: + """Add a document to the fine-tune pool.""" + with get_session_context() as session: + entry = FineTunePoolEntry( + document_id=UUID(str(document_id)), + added_by=added_by, + reason=reason, + ) + session.add(entry) + session.commit() + session.refresh(entry) + session.expunge(entry) + return entry + + def get_entry(self, entry_id: str | UUID) -> FineTunePoolEntry | None: + """Get a pool entry by ID.""" + with get_session_context() as session: + entry = session.get(FineTunePoolEntry, UUID(str(entry_id))) + if entry: + session.expunge(entry) + return entry + + def get_by_document(self, document_id: str | UUID) -> FineTunePoolEntry | None: + """Get pool entry for a document.""" + with get_session_context() as session: + result = session.exec( + select(FineTunePoolEntry).where( + FineTunePoolEntry.document_id == UUID(str(document_id)) + ) + ).first() + if result: + session.expunge(result) + return result + + def get_paginated( + self, + verified_only: bool = False, + limit: int = 20, + offset: int = 0, + ) -> tuple[list[FineTunePoolEntry], int]: + """List pool entries with pagination.""" + with get_session_context() as session: + query = select(FineTunePoolEntry) + count_query = select(func.count()).select_from(FineTunePoolEntry) + if verified_only: + query = query.where(FineTunePoolEntry.is_verified == True) + count_query = count_query.where(FineTunePoolEntry.is_verified == True) + total = session.exec(count_query).one() + entries = session.exec( + query.order_by(FineTunePoolEntry.created_at.desc()) + .offset(offset) + .limit(limit) + ).all() + for e in entries: + session.expunge(e) + return list(entries), total + + def get_pool_count(self, verified_only: bool = True) -> int: + """Get count of entries in the pool.""" + with get_session_context() as session: + query = select(func.count()).select_from(FineTunePoolEntry) + if verified_only: + query = query.where(FineTunePoolEntry.is_verified == True) + return session.exec(query).one() + + def get_all_document_ids(self, verified_only: bool = True) -> list[UUID]: + """Get all document IDs in the pool.""" + with get_session_context() as session: + query = select(FineTunePoolEntry.document_id) + if verified_only: + query = query.where(FineTunePoolEntry.is_verified == True) + results = session.exec(query).all() + return list(results) + + def verify_entry( + self, + entry_id: str | UUID, + verified_by: str | None = None, + ) -> FineTunePoolEntry | None: + """Mark a pool entry as verified.""" + with get_session_context() as session: + entry = session.get(FineTunePoolEntry, UUID(str(entry_id))) + if not entry: + return None + entry.is_verified = True + entry.verified_at = datetime.utcnow() + entry.verified_by = verified_by + session.add(entry) + session.commit() + session.refresh(entry) + session.expunge(entry) + return entry + + def remove_entry(self, entry_id: str | UUID) -> bool: + """Remove an entry from the pool.""" + with get_session_context() as session: + entry = session.get(FineTunePoolEntry, UUID(str(entry_id))) + if not entry: + return False + session.delete(entry) + session.commit() + return True diff --git a/packages/backend/backend/data/repositories/model_version_repository.py b/packages/backend/backend/data/repositories/model_version_repository.py index 674c345..d154be0 100644 --- a/packages/backend/backend/data/repositories/model_version_repository.py +++ b/packages/backend/backend/data/repositories/model_version_repository.py @@ -43,6 +43,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]): training_config: dict[str, Any] | None = None, file_size: int | None = None, trained_at: datetime | None = None, + model_type: str = "base", + base_model_version_id: str | UUID | None = None, + base_training_dataset_id: str | UUID | None = None, + gating_status: str = "pending", ) -> ModelVersion: """Create a new model version.""" with get_session_context() as session: @@ -60,6 +64,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]): training_config=training_config, file_size=file_size, trained_at=trained_at, + model_type=model_type, + base_model_version_id=UUID(str(base_model_version_id)) if base_model_version_id else None, + base_training_dataset_id=UUID(str(base_training_dataset_id)) if base_training_dataset_id else None, + gating_status=gating_status, ) session.add(model) session.commit() diff --git a/packages/backend/backend/pipeline/field_extractor.py b/packages/backend/backend/pipeline/field_extractor.py index a795ca5..7a7eb76 100644 --- a/packages/backend/backend/pipeline/field_extractor.py +++ b/packages/backend/backend/pipeline/field_extractor.py @@ -40,6 +40,7 @@ from .normalizers import ( EnhancedAmountNormalizer, EnhancedDateNormalizer, ) +from .value_selector import ValueSelector @dataclass @@ -169,13 +170,21 @@ class FieldExtractor: overlap_ratio = overlap_area / token_area if token_area > 0 else 0 matching_tokens.append((token, overlap_ratio)) - # Sort by overlap ratio and combine text + # Sort by overlap ratio matching_tokens.sort(key=lambda x: -x[1]) - raw_text = ' '.join(t[0].text for t in matching_tokens) # Get field name field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name) + # Convert to OCRTokens for value selection, then filter + from shared.ocr.paddle_ocr import OCRToken + pdf_ocr_tokens = [ + OCRToken(text=t[0].text, bbox=t[0].bbox, confidence=1.0) + for t in matching_tokens + ] + value_tokens = ValueSelector.select_value_tokens(pdf_ocr_tokens, field_name) + raw_text = ' '.join(t.text for t in value_tokens) + # Normalize and validate normalized_value, is_valid, validation_error = self._normalize_and_validate( field_name, raw_text @@ -223,13 +232,14 @@ class FieldExtractor: # Run OCR on region ocr_tokens = self.ocr_engine.extract_from_image(region) - # Combine all OCR text - raw_text = ' '.join(t.text for t in ocr_tokens) - ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0 - # Get field name field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name) + # Select value tokens (filter out label text) + value_tokens = ValueSelector.select_value_tokens(ocr_tokens, field_name) + raw_text = ' '.join(t.text for t in value_tokens) + ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0 + # Normalize and validate normalized_value, is_valid, validation_error = self._normalize_and_validate( field_name, raw_text diff --git a/packages/backend/backend/pipeline/normalizers/amount.py b/packages/backend/backend/pipeline/normalizers/amount.py index 17b71ba..f70924a 100644 --- a/packages/backend/backend/pipeline/normalizers/amount.py +++ b/packages/backend/backend/pipeline/normalizers/amount.py @@ -20,26 +20,98 @@ class AmountNormalizer(BaseNormalizer): Handles various Swedish amount formats: - With decimal: 1 234,56 kr - With SEK suffix: 1234.56 SEK + - Payment line kronor/ore: 590 00 (space = decimal separator) - Multiple amounts (returns the last one, usually the total) """ + # Payment line kronor/ore pattern: "590 00" means 590.00 SEK + # Only matches when no comma/dot is present (pure digit-space-2digit format) + _KRONOR_ORE_PATTERN = re.compile(r'^(\d+)\s+(\d{2})$') + @property def field_name(self) -> str: return "Amount" + @classmethod + def _try_kronor_ore(cls, text: str) -> NormalizationResult | None: + """Try to parse as payment line kronor/ore format. + + Swedish payment lines separate kronor and ore with a space: + "590 00" = 590.00 SEK, "15658 00" = 15658.00 SEK + + Only applies when text has no comma or dot (otherwise it's + a normal amount format with explicit decimal separator). + + Returns NormalizationResult on success, None if not matched. + """ + if ',' in text or '.' in text: + return None + + match = cls._KRONOR_ORE_PATTERN.match(text.strip()) + if not match: + return None + + kronor = match.group(1) + ore = match.group(2) + try: + amount = float(f"{kronor}.{ore}") + if amount > 0: + return NormalizationResult.success(f"{amount:.2f}") + except ValueError: + pass + return None + + @staticmethod + def _parse_amount_str(match: str) -> float | None: + """Convert matched amount string to float, detecting European vs Anglo format. + + European: 2.254,50 -> 2254.50 (dot=thousand, comma=decimal) + Anglo: 1,234.56 -> 1234.56 (comma=thousand, dot=decimal) + Swedish: 1 234,56 -> 1234.56 (space=thousand, comma=decimal) + """ + has_comma = ',' in match + has_dot = '.' in match + if has_comma and has_dot: + if match.rfind(',') > match.rfind('.'): + # European: 2.254,50 + cleaned = match.replace(" ", "").replace(".", "").replace(",", ".") + else: + # Anglo: 1,234.56 + cleaned = match.replace(" ", "").replace(",", "") + elif has_comma: + cleaned = match.replace(" ", "").replace(",", ".") + else: + cleaned = match.replace(" ", "") + try: + return float(cleaned) + except ValueError: + return None + def normalize(self, text: str) -> NormalizationResult: text = text.strip() if not text: return NormalizationResult.failure("Empty text") + # Early check: payment line kronor/ore format ("590 00" → 590.00) + kronor_ore_result = self._try_kronor_ore(text) + if kronor_ore_result is not None: + return kronor_ore_result + # Split by newlines and process line by line to get the last valid amount lines = text.split("\n") # Collect all valid amounts from all lines all_amounts: list[float] = [] - # Pattern for Swedish amount format (with decimals) - amount_pattern = r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?" + # Separate patterns for European and Anglo formats + # (?!\d) lookahead prevents partial matches (e.g. "1,23" in "1,234.56") + # European: dot=thousand, comma=decimal (2.254,50 or 1 234,56) + # Anglo: comma=thousand, dot=decimal (1,234.56 or 1234.56) + amount_pattern = ( + r"(\d[\d\s.]*,\d{2})(?!\d)\s*(?:kr|SEK)?" + r"|" + r"(\d[\d\s,]*\.\d{2})(?!\d)\s*(?:kr|SEK)?" + ) for line in lines: line = line.strip() @@ -47,15 +119,13 @@ class AmountNormalizer(BaseNormalizer): continue # Find all amounts in this line - matches = re.findall(amount_pattern, line, re.IGNORECASE) - for match in matches: - amount_str = match.replace(" ", "").replace(",", ".") - try: - amount = float(amount_str) - if amount > 0: - all_amounts.append(amount) - except ValueError: + for m in re.finditer(amount_pattern, line, re.IGNORECASE): + match = m.group(1) or m.group(2) + if not match: continue + amount = self._parse_amount_str(match) + if amount is not None and amount > 0: + all_amounts.append(amount) # Return the last amount found (usually the total) if all_amounts: @@ -122,31 +192,33 @@ class EnhancedAmountNormalizer(AmountNormalizer): if not text: return NormalizationResult.failure("Empty text") + # Early check: payment line kronor/ore format ("590 00" → 590.00) + kronor_ore_result = self._try_kronor_ore(text) + if kronor_ore_result is not None: + return kronor_ore_result + # Strategy 1: Apply OCR corrections first corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected # Strategy 2: Look for labeled amounts (highest priority) + # Use two capture groups: group(1) = European, group(2) = Anglo labeled_patterns = [ - # Swedish patterns - (r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", 1.0), + # Swedish patterns ((?!\d) prevents partial matches like "1,23" in "1,234.56") + (r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))", 1.0), ( - r"(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", + r"(?:moms|vat)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))", 0.8, ), # Lower priority for VAT # Generic pattern - (r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?", 0.7), + (r"(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))\s*(?:kr|sek|kronor)?", 0.7), ] candidates: list[tuple[float, float, int]] = [] for pattern, priority in labeled_patterns: for match in re.finditer(pattern, corrected_text, re.IGNORECASE): - amount_str = match.group(1).replace(" ", "").replace(",", ".") - try: - amount = float(amount_str) - if 0 < amount < 10_000_000: # Reasonable range - candidates.append((amount, priority, match.start())) - except ValueError: - continue + amount = self._parse_amount_str(match.group(1)) + if amount is not None and 0 < amount < 10_000_000: + candidates.append((amount, priority, match.start())) if candidates: # Sort by priority (desc), then by position (later is usually total) diff --git a/packages/backend/backend/pipeline/pipeline.py b/packages/backend/backend/pipeline/pipeline.py index fca5a91..e32ce21 100644 --- a/packages/backend/backend/pipeline/pipeline.py +++ b/packages/backend/backend/pipeline/pipeline.py @@ -301,6 +301,27 @@ class InferencePipeline: all_extracted = [] all_ocr_text = [] # Collect OCR text for VAT extraction + # Check if PDF has readable text layer (avoids OCR for text PDFs) + from shared.pdf.extractor import PDFDocument + is_text_pdf = False + pdf_tokens_by_page: dict[int, list] = {} + try: + with PDFDocument(pdf_path) as pdf_doc: + is_text_pdf = pdf_doc.is_text_pdf() + if is_text_pdf: + for pg in range(pdf_doc.page_count): + pdf_tokens_by_page[pg] = list( + pdf_doc.extract_text_tokens(pg) + ) + logger.info( + "Text PDF detected, extracted %d tokens from %d pages", + sum(len(t) for t in pdf_tokens_by_page.values()), + len(pdf_tokens_by_page), + ) + except Exception as e: + logger.warning("PDF text detection failed, falling back to OCR: %s", e) + is_text_pdf = False + # Process each page for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi): # Convert to numpy array @@ -313,7 +334,17 @@ class InferencePipeline: # Extract fields from detections for detection in detections: - extracted = self.extractor.extract_from_detection(detection, image_array) + if is_text_pdf and page_no in pdf_tokens_by_page: + extracted = self.extractor.extract_from_detection_with_pdf( + detection, + pdf_tokens_by_page[page_no], + image.width, + image.height, + ) + else: + extracted = self.extractor.extract_from_detection( + detection, image_array + ) all_extracted.append(extracted) # Collect full-page OCR text for VAT extraction (only if business features enabled) diff --git a/packages/backend/backend/pipeline/value_selector.py b/packages/backend/backend/pipeline/value_selector.py new file mode 100644 index 0000000..ab69a12 --- /dev/null +++ b/packages/backend/backend/pipeline/value_selector.py @@ -0,0 +1,172 @@ +""" +Value Selector Module. + +Selects the most likely value token(s) from OCR output per field type, +filtering out label text before sending to normalizer. + +Stateless and pure -- easy to test, no side effects. +""" + +import re +from typing import Final + +from shared.ocr.paddle_ocr import OCRToken + + +# Swedish label keywords commonly found near field values +LABEL_KEYWORDS: Final[frozenset[str]] = frozenset({ + "fakturanummer", "fakturanr", "fakturadatum", "forfallodag", "forfalldatum", + "bankgiro", "plusgiro", "bg", "pg", "ocr", "belopp", "summa", + "total", "att", "betala", "kundnummer", "organisationsnummer", + "org", "nr", "datum", "nummer", "ref", "referens", + "momsreg", "vat", "moms", "sek", "kr", + "org.nr", "kund", "faktura", "invoice", +}) + +# Patterns +_DATE_PATTERN = re.compile( + r"\d{4}[-./]\d{2}[-./]\d{2}" # 2024-01-15, 2024.01.15 + r"|" + r"\d{2}[-./]\d{2}[-./]\d{4}" # 15/01/2024 + r"|" + r"\d{8}" # 20240115 +) + +_AMOUNT_PATTERN = re.compile( + r"\d[\d\s.]*,\d{2}(?:\s*(?:kr|SEK))?$" # European: 2.254,50 SEK + r"|" + r"\d[\d\s,]*\.\d{2}(?:\s*(?:kr|SEK))?$" # Anglo: 1,234.56 SEK +) + +_BANKGIRO_PATTERN = re.compile( + r"^\d{3,4}-\d{4}$" # 123-4567 + r"|" + r"^\d{7,8}$" # 1234567 or 12345678 +) + +_PLUSGIRO_PATTERN = re.compile( + r"^\d+-\d$" # 12345-6 + r"|" + r"^\d{2,8}$" # 1234567 +) + +_ORG_NUMBER_PATTERN = re.compile( + r"\d{6}-?\d{4}" # 556123-4567 or 5561234567 +) + + +def _is_label(text: str) -> bool: + """Check if token text is a known Swedish label keyword.""" + cleaned = text.lower().rstrip(":").strip() + return cleaned in LABEL_KEYWORDS + + +def _count_digits(text: str) -> int: + """Count digit characters in text.""" + return sum(c.isdigit() for c in text) + + +class ValueSelector: + """Selects value token(s) from OCR output, filtering label text. + + Pure static methods -- no state, no side effects. + Fallback: always returns all tokens if no pattern matches, + so this can never make results worse than current behavior. + """ + + @staticmethod + def select_value_tokens( + tokens: list[OCRToken], + field_name: str, + ) -> list[OCRToken]: + """Select the most likely value token(s) for a given field. + + Args: + tokens: OCR tokens from the detected region. + field_name: Normalized field name (e.g. "InvoiceDate", "Amount"). + + Returns: + Filtered list of value tokens. Falls back to all tokens + if no field-specific pattern matches. + """ + if not tokens: + return [] + + selector = _FIELD_SELECTORS.get(field_name, _fallback_selector) + selected = selector(tokens) + + # Safety: never return empty if we had input tokens + if not selected: + return list(tokens) + + return selected + + @staticmethod + def _select_date(tokens: list[OCRToken]) -> list[OCRToken]: + return _select_by_pattern(tokens, _DATE_PATTERN) + + @staticmethod + def _select_amount(tokens: list[OCRToken]) -> list[OCRToken]: + return _select_by_pattern(tokens, _AMOUNT_PATTERN) + + @staticmethod + def _select_bankgiro(tokens: list[OCRToken]) -> list[OCRToken]: + return _select_by_pattern(tokens, _BANKGIRO_PATTERN) + + @staticmethod + def _select_plusgiro(tokens: list[OCRToken]) -> list[OCRToken]: + return _select_by_pattern(tokens, _PLUSGIRO_PATTERN) + + @staticmethod + def _select_org_number(tokens: list[OCRToken]) -> list[OCRToken]: + return _select_by_pattern(tokens, _ORG_NUMBER_PATTERN) + + @staticmethod + def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]: + """Select token with the longest digit sequence (min 5 digits).""" + best: OCRToken | None = None + best_count = 0 + for token in tokens: + digit_count = _count_digits(token.text) + if digit_count >= 5 and digit_count > best_count: + best = token + best_count = digit_count + return [best] if best else [] + + @staticmethod + def _select_by_label_exclusion(tokens: list[OCRToken]) -> list[OCRToken]: + """Remove label keywords, keep remaining tokens.""" + return [t for t in tokens if not _is_label(t.text)] + + @staticmethod + def _select_payment_line(tokens: list[OCRToken]) -> list[OCRToken]: + """Payment line keeps all tokens (needs full text).""" + return list(tokens) + + +def _select_by_pattern( + tokens: list[OCRToken], + pattern: re.Pattern[str], +) -> list[OCRToken]: + """Select tokens matching a regex pattern.""" + return [t for t in tokens if pattern.search(t.text.strip())] + + +def _fallback_selector(tokens: list[OCRToken]) -> list[OCRToken]: + """Default: return all tokens unchanged.""" + return list(tokens) + + +# Map field names to selector functions +_FIELD_SELECTORS: Final[dict[str, callable]] = { + "InvoiceDate": ValueSelector._select_date, + "InvoiceDueDate": ValueSelector._select_date, + "Amount": ValueSelector._select_amount, + "Bankgiro": ValueSelector._select_bankgiro, + "Plusgiro": ValueSelector._select_plusgiro, + "OCR": ValueSelector._select_ocr_number, + "InvoiceNumber": ValueSelector._select_by_label_exclusion, + "supplier_org_number": ValueSelector._select_org_number, + "customer_number": ValueSelector._select_by_label_exclusion, + "payment_line": ValueSelector._select_payment_line, +} diff --git a/packages/backend/backend/web/api/v1/admin/training/__init__.py b/packages/backend/backend/web/api/v1/admin/training/__init__.py index cde7547..bbc2c6d 100644 --- a/packages/backend/backend/web/api/v1/admin/training/__init__.py +++ b/packages/backend/backend/web/api/v1/admin/training/__init__.py @@ -12,6 +12,7 @@ from .documents import register_document_routes from .export import register_export_routes from .datasets import register_dataset_routes from .models import register_model_routes +from .pool import register_pool_routes def create_training_router() -> APIRouter: @@ -23,6 +24,7 @@ def create_training_router() -> APIRouter: register_export_routes(router) register_dataset_routes(router) register_model_routes(router) + register_pool_routes(router) return router diff --git a/packages/backend/backend/web/api/v1/admin/training/datasets.py b/packages/backend/backend/web/api/v1/admin/training/datasets.py index 8234a21..38cdb2f 100644 --- a/packages/backend/backend/web/api/v1/admin/training/datasets.py +++ b/packages/backend/backend/web/api/v1/admin/training/datasets.py @@ -12,6 +12,7 @@ from backend.web.core.auth import ( AnnotationRepoDep, ModelVersionRepoDep, TrainingTaskRepoDep, + FineTunePoolRepoDep, ) from backend.web.schemas.admin import ( DatasetCreateRequest, @@ -233,6 +234,7 @@ def register_dataset_routes(router: APIRouter) -> None: datasets_repo: DatasetRepoDep, models: ModelVersionRepoDep, tasks: TrainingTaskRepoDep, + pool: FineTunePoolRepoDep, ) -> TrainingTaskResponse: """Create a training task from a dataset. @@ -261,13 +263,39 @@ def register_dataset_routes(router: APIRouter) -> None: status_code=404, detail=f"Base model version not found: {base_model_version_id}", ) + # Chain prevention: never fine-tune from a fine-tuned model + if getattr(base_model, "model_type", "base") == "finetune": + original_base_id = getattr(base_model, "base_model_version_id", None) + raise HTTPException( + status_code=400, + detail=( + f"Cannot chain fine-tunes. Model {base_model.version} is already " + f"a fine-tuned model. Select the original base model instead" + f"{f' (base_model_version_id: {original_base_id})' if original_base_id else ''}." + ), + ) + # Pool threshold: require minimum verified pool entries for fine-tuning + from backend.web.services.data_mixer import MIN_POOL_SIZE + + verified_count = pool.get_pool_count(verified_only=True) + if verified_count < MIN_POOL_SIZE: + raise HTTPException( + status_code=400, + detail=( + f"Fine-tuning requires at least {MIN_POOL_SIZE} verified pool entries " + f"(currently {verified_count}). Add more documents to the fine-tune " + f"pool and verify them before starting fine-tuning." + ), + ) + # Store the resolved model path for the training worker config_dict["base_model_path"] = base_model.model_path config_dict["base_model_version"] = base_model.version logger.info( - "Incremental training: using model %s (%s) as base", + "Fine-tuning: using base model %s (%s) with %d verified pool entries", base_model.version, base_model.model_path, + verified_count, ) task_id = tasks.create( diff --git a/packages/backend/backend/web/api/v1/admin/training/export.py b/packages/backend/backend/web/api/v1/admin/training/export.py index 12d8b48..d59f4a8 100644 --- a/packages/backend/backend/web/api/v1/admin/training/export.py +++ b/packages/backend/backend/web/api/v1/admin/training/export.py @@ -124,16 +124,11 @@ def register_export_routes(router: APIRouter) -> None: x1 = ann.x_center * img_width + half_w y1 = ann.y_center * img_height + half_h - # Use manual_mode for manual/imported annotations - manual_mode = ann.source in ("manual", "imported") - - # Apply field-specific bbox expansion + # Apply uniform bbox expansion ex0, ey0, ex1, ey1 = expand_bbox( bbox=(x0, y0, x1, y1), image_width=img_width, image_height=img_height, - field_type=ann.class_name, - manual_mode=manual_mode, ) # Convert back to normalized YOLO format diff --git a/packages/backend/backend/web/api/v1/admin/training/models.py b/packages/backend/backend/web/api/v1/admin/training/models.py index 5024dce..c702ced 100644 --- a/packages/backend/backend/web/api/v1/admin/training/models.py +++ b/packages/backend/backend/web/api/v1/admin/training/models.py @@ -88,6 +88,9 @@ def register_model_routes(router: APIRouter) -> None: name=m.name, status=m.status, is_active=m.is_active, + model_type=getattr(m, "model_type", "base"), + base_model_version_id=str(m.base_model_version_id) if getattr(m, "base_model_version_id", None) else None, + gating_status=getattr(m, "gating_status", "pending"), metrics_mAP=m.metrics_mAP, document_count=m.document_count, trained_at=m.trained_at, @@ -121,6 +124,9 @@ def register_model_routes(router: APIRouter) -> None: name=model.name, status=model.status, is_active=model.is_active, + model_type=getattr(model, "model_type", "base"), + base_model_version_id=str(model.base_model_version_id) if getattr(model, "base_model_version_id", None) else None, + gating_status=getattr(model, "gating_status", "pending"), metrics_mAP=model.metrics_mAP, document_count=model.document_count, trained_at=model.trained_at, @@ -153,6 +159,10 @@ def register_model_routes(router: APIRouter) -> None: model_path=model.model_path, status=model.status, is_active=model.is_active, + model_type=getattr(model, "model_type", "base"), + base_model_version_id=str(model.base_model_version_id) if getattr(model, "base_model_version_id", None) else None, + base_training_dataset_id=str(model.base_training_dataset_id) if getattr(model, "base_training_dataset_id", None) else None, + gating_status=getattr(model, "gating_status", "pending"), task_id=str(model.task_id) if model.task_id else None, dataset_id=str(model.dataset_id) if model.dataset_id else None, metrics_mAP=model.metrics_mAP, @@ -209,6 +219,25 @@ def register_model_routes(router: APIRouter) -> None: ) -> ModelVersionResponse: """Activate a model version for inference.""" _validate_uuid(version_id, "version_id") + + # Check gating status before activation (for fine-tuned models) + pre_check = models.get(version_id) + if not pre_check: + raise HTTPException(status_code=404, detail="Model version not found") + model_type = getattr(pre_check, "model_type", "base") + gating_status = getattr(pre_check, "gating_status", "skipped") + if model_type == "finetune": + if gating_status == "reject": + raise HTTPException( + status_code=400, + detail="Model failed deployment gating validation. Cannot activate a rejected model.", + ) + if gating_status == "pending": + raise HTTPException( + status_code=400, + detail="Model gating validation not yet completed. Wait for validation to finish.", + ) + model = models.activate(version_id) if not model: raise HTTPException(status_code=404, detail="Model version not found") @@ -227,6 +256,8 @@ def register_model_routes(router: APIRouter) -> None: message = "Model version activated for inference" if model_reloaded: message += " (model reloaded)" + if gating_status == "review": + message += " (WARNING: gating status is REVIEW - manual inspection recommended)" return ModelVersionResponse( version_id=str(model.version_id), diff --git a/packages/backend/backend/web/api/v1/admin/training/pool.py b/packages/backend/backend/web/api/v1/admin/training/pool.py new file mode 100644 index 0000000..0c29706 --- /dev/null +++ b/packages/backend/backend/web/api/v1/admin/training/pool.py @@ -0,0 +1,159 @@ +"""Fine-Tune Pool Endpoints.""" + +import logging +from typing import Annotated + +from fastapi import APIRouter, HTTPException, Query + +from backend.web.core.auth import AdminTokenDep, FineTunePoolRepoDep, DocumentRepoDep +from backend.web.schemas.admin.pool import ( + PoolAddRequest, + PoolEntryItem, + PoolEntryResponse, + PoolListResponse, + PoolStatsResponse, +) + +from ._utils import _validate_uuid + +logger = logging.getLogger(__name__) + + +def register_pool_routes(router: APIRouter) -> None: + """Register fine-tune pool endpoints on the router.""" + + @router.post( + "/pool", + response_model=PoolEntryResponse, + summary="Add document to fine-tune pool", + description="Add a labeled document to the fine-tune pool for future fine-tuning.", + ) + async def add_to_pool( + request: PoolAddRequest, + admin_token: AdminTokenDep, + pool: FineTunePoolRepoDep, + docs: DocumentRepoDep, + ) -> PoolEntryResponse: + """Add a document to the fine-tune pool.""" + _validate_uuid(request.document_id, "document_id") + + # Verify document exists + doc = docs.get(request.document_id) + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + # Check if already in pool + existing = pool.get_by_document(request.document_id) + if existing: + raise HTTPException( + status_code=409, + detail=f"Document already in fine-tune pool (entry_id: {existing.entry_id})", + ) + + entry = pool.add_document( + document_id=request.document_id, + added_by=admin_token, + reason=request.reason, + ) + + return PoolEntryResponse( + entry_id=str(entry.entry_id), + message="Document added to fine-tune pool", + ) + + @router.get( + "/pool", + response_model=PoolListResponse, + summary="List fine-tune pool entries", + ) + async def list_pool_entries( + admin_token: AdminTokenDep, + pool: FineTunePoolRepoDep, + verified_only: Annotated[bool, Query(description="Filter to verified only")] = False, + limit: Annotated[int, Query(ge=1, le=100)] = 20, + offset: Annotated[int, Query(ge=0)] = 0, + ) -> PoolListResponse: + """List entries in the fine-tune pool.""" + entries, total = pool.get_paginated( + verified_only=verified_only, + limit=limit, + offset=offset, + ) + + return PoolListResponse( + total=total, + limit=limit, + offset=offset, + entries=[ + PoolEntryItem( + entry_id=str(e.entry_id), + document_id=str(e.document_id), + added_by=e.added_by, + reason=e.reason, + is_verified=e.is_verified, + verified_at=e.verified_at, + verified_by=e.verified_by, + created_at=e.created_at, + ) + for e in entries + ], + ) + + @router.get( + "/pool/stats", + response_model=PoolStatsResponse, + summary="Get fine-tune pool statistics", + ) + async def get_pool_stats( + admin_token: AdminTokenDep, + pool: FineTunePoolRepoDep, + ) -> PoolStatsResponse: + """Get statistics about the fine-tune pool.""" + total = pool.get_pool_count(verified_only=False) + verified = pool.get_pool_count(verified_only=True) + + return PoolStatsResponse( + total_entries=total, + verified_entries=verified, + unverified_entries=total - verified, + is_ready=verified >= 50, + ) + + @router.post( + "/pool/{entry_id}/verify", + response_model=PoolEntryResponse, + summary="Verify a pool entry", + description="Mark a pool entry as verified (human-reviewed).", + ) + async def verify_pool_entry( + entry_id: str, + admin_token: AdminTokenDep, + pool: FineTunePoolRepoDep, + ) -> PoolEntryResponse: + """Mark a pool entry as verified.""" + _validate_uuid(entry_id, "entry_id") + entry = pool.verify_entry(entry_id, verified_by=admin_token) + if not entry: + raise HTTPException(status_code=404, detail="Pool entry not found") + + return PoolEntryResponse( + entry_id=str(entry.entry_id), + message="Pool entry verified", + ) + + @router.delete( + "/pool/{entry_id}", + summary="Remove from fine-tune pool", + ) + async def remove_from_pool( + entry_id: str, + admin_token: AdminTokenDep, + pool: FineTunePoolRepoDep, + ) -> dict: + """Remove a document from the fine-tune pool.""" + _validate_uuid(entry_id, "entry_id") + success = pool.remove_entry(entry_id) + if not success: + raise HTTPException(status_code=404, detail="Pool entry not found") + + return {"message": "Entry removed from fine-tune pool"} diff --git a/packages/backend/backend/web/core/auth.py b/packages/backend/backend/web/core/auth.py index 8b24221..1353977 100644 --- a/packages/backend/backend/web/core/auth.py +++ b/packages/backend/backend/web/core/auth.py @@ -17,6 +17,7 @@ from backend.data.repositories import ( TrainingTaskRepository, ModelVersionRepository, BatchUploadRepository, + FineTunePoolRepository, ) @@ -95,6 +96,12 @@ def get_batch_upload_repository() -> BatchUploadRepository: return BatchUploadRepository() +@lru_cache(maxsize=1) +def get_finetune_pool_repository() -> FineTunePoolRepository: + """Get the FineTunePoolRepository instance (thread-safe singleton).""" + return FineTunePoolRepository() + + def reset_all_repositories() -> None: """Reset all repository instances (for testing).""" get_token_repository.cache_clear() @@ -104,6 +111,7 @@ def reset_all_repositories() -> None: get_training_task_repository.cache_clear() get_model_version_repository.cache_clear() get_batch_upload_repository.cache_clear() + get_finetune_pool_repository.cache_clear() # Repository dependency type aliases @@ -113,3 +121,4 @@ DatasetRepoDep = Annotated[DatasetRepository, Depends(get_dataset_repository)] TrainingTaskRepoDep = Annotated[TrainingTaskRepository, Depends(get_training_task_repository)] ModelVersionRepoDep = Annotated[ModelVersionRepository, Depends(get_model_version_repository)] BatchUploadRepoDep = Annotated[BatchUploadRepository, Depends(get_batch_upload_repository)] +FineTunePoolRepoDep = Annotated[FineTunePoolRepository, Depends(get_finetune_pool_repository)] diff --git a/packages/backend/backend/web/core/scheduler.py b/packages/backend/backend/web/core/scheduler.py index 847e6f0..b4677c1 100644 --- a/packages/backend/backend/web/core/scheduler.py +++ b/packages/backend/backend/web/core/scheduler.py @@ -16,6 +16,7 @@ from backend.data.repositories import ( ModelVersionRepository, DocumentRepository, AnnotationRepository, + FineTunePoolRepository, ) from backend.web.core.task_interface import TaskRunner, TaskStatus from backend.web.services.storage_helpers import get_storage_helper @@ -47,6 +48,7 @@ class TrainingScheduler(TaskRunner): self._model_versions = ModelVersionRepository() self._documents = DocumentRepository() self._annotations = AnnotationRepository() + self._pool = FineTunePoolRepository() @property def name(self) -> str: @@ -168,7 +170,7 @@ class TrainingScheduler(TaskRunner): try: # Get training configuration - model_name = config.get("model_name", "yolo11n.pt") + model_name = config.get("model_name", "yolo26s.pt") base_model_path = config.get("base_model_path") # For incremental training epochs = config.get("epochs", 100) batch_size = config.get("batch_size", 16) @@ -182,14 +184,19 @@ class TrainingScheduler(TaskRunner): augmentation_multiplier = config.get("augmentation_multiplier", 2) # Determine which model to use as base - if base_model_path: - # Incremental training: use existing trained model + is_finetune = bool(base_model_path) + if is_finetune: + # Fine-tuning: use existing trained model as base if not Path(base_model_path).exists(): raise ValueError(f"Base model not found: {base_model_path}") effective_model = base_model_path + # Override parameters for fine-tuning (best practices) + epochs = config.get("epochs", 10) + learning_rate = config.get("learning_rate", 0.001) self._training_tasks.add_log( task_id, "INFO", - f"Incremental training from: {base_model_path}", + f"Fine-tuning from: {base_model_path} " + f"(epochs={epochs}, freeze=10, cos_lr=true)", ) else: # Train from pretrained model @@ -229,10 +236,16 @@ class TrainingScheduler(TaskRunner): f"(total: {aug_result['total_images']})", ) + # Build mixed dataset for fine-tuning (pool samples + old data) + if is_finetune and dataset_id: + data_yaml, dataset_path = self._build_mixed_finetune_dataset( + task_id, dataset_path, data_yaml, + ) + # Run YOLO training result = self._run_yolo_training( task_id=task_id, - model_name=effective_model, # Use base model or pretrained model + model_name=effective_model, data_yaml=data_yaml, epochs=epochs, batch_size=batch_size, @@ -240,6 +253,8 @@ class TrainingScheduler(TaskRunner): learning_rate=learning_rate, device=device, project_name=project_name, + freeze=10 if is_finetune else 0, + cos_lr=is_finetune, ) # Update task with results @@ -261,13 +276,23 @@ class TrainingScheduler(TaskRunner): ) # Auto-create model version for the completed training - self._create_model_version_from_training( + model_version = self._create_model_version_from_training( task_id=task_id, config=config, dataset_id=dataset_id, result=result, ) + # Auto-run gating validation for fine-tuned models + if is_finetune and model_version: + self._run_gating_after_finetune( + task_id=task_id, + model_version=model_version, + config=config, + data_yaml=data_yaml, + result=result, + ) + except Exception as e: logger.error(f"Training task {task_id} failed: {e}") self._training_tasks.add_log(task_id, "ERROR", f"Training failed: {e}") @@ -286,13 +311,16 @@ class TrainingScheduler(TaskRunner): config: dict[str, Any], dataset_id: str | None, result: dict[str, Any], - ) -> None: - """Create a model version entry from completed training.""" + ) -> Any | None: + """Create a model version entry from completed training. + + Returns the created model version, or None on failure. + """ try: model_path = result.get("model_path") if not model_path: logger.warning(f"No model path in training result for task {task_id}") - return + return None # Get task info for name task = self._training_tasks.get(task_id) @@ -322,6 +350,12 @@ class TrainingScheduler(TaskRunner): if dataset: document_count = dataset.total_documents + # Determine model lineage + is_finetune = bool(config.get("base_model_path")) + model_type = "finetune" if is_finetune else "base" + base_model_version_id = config.get("base_model_version_id") if is_finetune else None + gating_status = "pending" if is_finetune else "skipped" + # Create model version model_version = self._model_versions.create( version=version, @@ -337,6 +371,10 @@ class TrainingScheduler(TaskRunner): training_config=config, file_size=file_size, trained_at=datetime.utcnow(), + model_type=model_type, + base_model_version_id=base_model_version_id, + base_training_dataset_id=dataset_id if not is_finetune else None, + gating_status=gating_status, ) logger.info( @@ -349,12 +387,105 @@ class TrainingScheduler(TaskRunner): f"Model version {version} created (mAP: {mAP_display})", ) + return model_version + except Exception as e: logger.error(f"Failed to create model version for task {task_id}: {e}") self._training_tasks.add_log( task_id, "WARNING", f"Failed to auto-create model version: {e}", ) + return None + + def _build_mixed_finetune_dataset( + self, + task_id: str, + base_dataset_path: Path, + original_data_yaml: str, + ) -> tuple[str, Path]: + """Build a mixed dataset for fine-tuning. + + Combines verified pool samples with randomly sampled old training data. + + Returns: + Tuple of (data_yaml path, dataset_path) for the mixed dataset. + Falls back to original if mixing fails or pool is empty. + """ + try: + from backend.web.services.data_mixer import build_mixed_dataset + + pool_doc_ids = self._pool.get_all_document_ids(verified_only=True) + if not pool_doc_ids: + self._training_tasks.add_log( + task_id, "INFO", + "No verified pool entries found, using original dataset", + ) + return original_data_yaml, base_dataset_path + + mixed_output = base_dataset_path.parent / f"mixed_{task_id[:8]}" + mix_result = build_mixed_dataset( + pool_document_ids=pool_doc_ids, + base_dataset_path=base_dataset_path, + output_dir=mixed_output, + ) + + self._training_tasks.add_log( + task_id, "INFO", + f"Data mixing: {mix_result['new_images']} new + " + f"{mix_result['old_images']} old = {mix_result['total_images']} total " + f"(ratio: {mix_result['mixing_ratio']}x)", + ) + + return mix_result["data_yaml"], mixed_output + + except Exception as e: + logger.error(f"Data mixing failed for task {task_id}: {e}") + self._training_tasks.add_log( + task_id, "WARNING", + f"Data mixing failed: {e}. Using original dataset.", + ) + return original_data_yaml, base_dataset_path + + def _run_gating_after_finetune( + self, + task_id: str, + model_version: Any, + config: dict[str, Any], + data_yaml: str, + result: dict[str, Any], + ) -> None: + """Run gating validation after a fine-tune training completes.""" + try: + from backend.web.services.gating_validator import run_gating_validation + + model_path = result.get("model_path") + base_model_version_id = config.get("base_model_version_id") + version_id = str(model_version.version_id) + + self._training_tasks.add_log( + task_id, "INFO", "Running deployment gating validation...", + ) + + gating_result = run_gating_validation( + model_version_id=version_id, + new_model_path=model_path, + base_model_version_id=base_model_version_id, + data_yaml=data_yaml, + task_id=task_id, + ) + + self._training_tasks.add_log( + task_id, "INFO", + f"Gating result: {gating_result.overall_status} " + f"(gate1={gating_result.gate1_status}, gate2={gating_result.gate2_status})", + ) + + except Exception as e: + logger.error(f"Gating validation failed for task {task_id}: {e}") + self._training_tasks.add_log( + task_id, "WARNING", + f"Gating validation failed: {e}. Model remains in 'pending' state.", + ) def _export_training_data(self, task_id: str) -> dict[str, Any] | None: """Export training data for a task.""" @@ -456,6 +587,8 @@ names: {list(FIELD_CLASSES.values())} learning_rate: float, device: str, project_name: str, + freeze: int = 0, + cos_lr: bool = False, ) -> dict[str, Any]: """Run YOLO training using shared trainer.""" from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig @@ -479,6 +612,8 @@ names: {list(FIELD_CLASSES.values())} project="runs/train", name=f"{project_name}/task_{task_id[:8]}", workers=0, + freeze=freeze, + cos_lr=cos_lr, ) # Run training using shared trainer diff --git a/packages/backend/backend/web/schemas/admin/__init__.py b/packages/backend/backend/web/schemas/admin/__init__.py index b8c8228..e2631eb 100644 --- a/packages/backend/backend/web/schemas/admin/__init__.py +++ b/packages/backend/backend/web/schemas/admin/__init__.py @@ -11,6 +11,7 @@ from .annotations import * # noqa: F401, F403 from .training import * # noqa: F401, F403 from .datasets import * # noqa: F401, F403 from .models import * # noqa: F401, F403 +from .pool import * # noqa: F401, F403 from .dashboard import * # noqa: F401, F403 # Resolve forward references for DocumentDetailResponse diff --git a/packages/backend/backend/web/schemas/admin/models.py b/packages/backend/backend/web/schemas/admin/models.py index 7359a1e..7d4fa39 100644 --- a/packages/backend/backend/web/schemas/admin/models.py +++ b/packages/backend/backend/web/schemas/admin/models.py @@ -40,6 +40,9 @@ class ModelVersionItem(BaseModel): name: str = Field(..., description="Model name") status: str = Field(..., description="Status (active, inactive, archived)") is_active: bool = Field(..., description="Is currently active for inference") + model_type: str = Field(default="base", description="Model type (base or finetune)") + base_model_version_id: str | None = Field(None, description="Base model version UUID (for fine-tuned models)") + gating_status: str = Field(default="pending", description="Deployment gating status") metrics_mAP: float | None = Field(None, description="Mean Average Precision") document_count: int = Field(..., description="Documents used in training") trained_at: datetime | None = Field(None, description="Training completion time") @@ -66,6 +69,10 @@ class ModelVersionDetailResponse(BaseModel): model_path: str = Field(..., description="Path to model file") status: str = Field(..., description="Status (active, inactive, archived)") is_active: bool = Field(..., description="Is currently active for inference") + model_type: str = Field(default="base", description="Model type (base or finetune)") + base_model_version_id: str | None = Field(None, description="Base model version UUID") + base_training_dataset_id: str | None = Field(None, description="Base training dataset UUID") + gating_status: str = Field(default="pending", description="Deployment gating status") task_id: str | None = Field(None, description="Training task UUID") dataset_id: str | None = Field(None, description="Dataset UUID") metrics_mAP: float | None = Field(None, description="Mean Average Precision") diff --git a/packages/backend/backend/web/schemas/admin/pool.py b/packages/backend/backend/web/schemas/admin/pool.py new file mode 100644 index 0000000..e1a3195 --- /dev/null +++ b/packages/backend/backend/web/schemas/admin/pool.py @@ -0,0 +1,72 @@ +"""Admin Fine-Tune Pool Schemas.""" + +from datetime import datetime + +from pydantic import BaseModel, Field + + +class PoolAddRequest(BaseModel): + """Request to add a document to the fine-tune pool.""" + + document_id: str = Field(..., description="Document UUID to add") + reason: str = Field( + default="user_reported_failure", + description="Reason: user_reported_failure, manual_addition", + ) + + +class PoolEntryItem(BaseModel): + """Fine-tune pool entry.""" + + entry_id: str = Field(..., description="Entry UUID") + document_id: str = Field(..., description="Document UUID") + added_by: str | None = Field(None, description="Who added this entry") + reason: str | None = Field(None, description="Reason for adding") + is_verified: bool = Field(..., description="Whether entry has been verified") + verified_at: datetime | None = Field(None, description="Verification timestamp") + verified_by: str | None = Field(None, description="Who verified") + created_at: datetime = Field(..., description="Creation timestamp") + + +class PoolListResponse(BaseModel): + """Paginated pool entry list.""" + + total: int = Field(..., ge=0, description="Total entries") + limit: int = Field(..., ge=1, description="Page size") + offset: int = Field(..., ge=0, description="Current offset") + entries: list[PoolEntryItem] = Field(default_factory=list, description="Pool entries") + + +class PoolStatsResponse(BaseModel): + """Pool statistics.""" + + total_entries: int = Field(..., ge=0, description="Total pool entries") + verified_entries: int = Field(..., ge=0, description="Verified entries") + unverified_entries: int = Field(..., ge=0, description="Unverified entries") + is_ready: bool = Field(..., description="Whether pool has >= 50 verified entries for fine-tuning") + min_required: int = Field(default=50, description="Minimum verified entries required") + + +class PoolEntryResponse(BaseModel): + """Response for pool entry operation.""" + + entry_id: str = Field(..., description="Entry UUID") + message: str = Field(..., description="Status message") + + +class GatingResultItem(BaseModel): + """Gating validation result.""" + + result_id: str = Field(..., description="Result UUID") + model_version_id: str = Field(..., description="Model version UUID") + gate1_status: str = Field(..., description="Gate 1 status") + gate1_original_mAP: float | None = Field(None, description="Original model mAP") + gate1_new_mAP: float | None = Field(None, description="New model mAP") + gate1_mAP_drop: float | None = Field(None, description="mAP drop percentage") + gate2_status: str = Field(..., description="Gate 2 status") + gate2_detection_rate: float | None = Field(None, description="Detection rate on new samples") + gate2_total_samples: int | None = Field(None, description="Total new samples tested") + gate2_detected_samples: int | None = Field(None, description="Samples correctly detected") + overall_status: str = Field(..., description="Overall gating status") + reviewer_notes: str | None = Field(None, description="Reviewer notes") + created_at: datetime = Field(..., description="Creation timestamp") diff --git a/packages/backend/backend/web/schemas/admin/training.py b/packages/backend/backend/web/schemas/admin/training.py index 2fe0cfd..097015d 100644 --- a/packages/backend/backend/web/schemas/admin/training.py +++ b/packages/backend/backend/web/schemas/admin/training.py @@ -12,7 +12,7 @@ from .enums import TrainingStatus, TrainingType class TrainingConfig(BaseModel): """Training configuration.""" - model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)") + model_name: str = Field(default="yolo26s.pt", description="Base model name (used if no base_model_version_id)") base_model_version_id: str | None = Field( default=None, description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.", diff --git a/packages/backend/backend/web/services/data_mixer.py b/packages/backend/backend/web/services/data_mixer.py new file mode 100644 index 0000000..2a9f14d --- /dev/null +++ b/packages/backend/backend/web/services/data_mixer.py @@ -0,0 +1,199 @@ +""" +Data Mixing Service + +Mixes fine-tune pool samples with randomly sampled old training data +following best practices for incremental YOLO fine-tuning. +""" + +import logging +import random +import shutil +from pathlib import Path +from typing import Any +from uuid import UUID + +logger = logging.getLogger(__name__) + +# Mixing ratios: (max_new_samples, old_data_multiplier) +# Fewer new samples require higher old data ratio to prevent catastrophic forgetting +MIXING_RATIOS: list[tuple[int, int]] = [ + (10, 50), # <= 10 new samples: 50x old + (50, 20), # <= 50 new samples: 20x old + (200, 10), # <= 200 new samples: 10x old + (500, 5), # <= 500 new samples: 5x old +] + +# Default multiplier for 500+ samples +DEFAULT_MULTIPLIER = 5 + +# Maximum old samples to include (cap for performance) +MAX_OLD_SAMPLES = 3000 + +# Minimum pool size for fine-tuning +MIN_POOL_SIZE = 50 + + +def get_mixing_ratio(new_sample_count: int) -> int: + """Determine old data multiplier based on new sample count. + + Args: + new_sample_count: Number of new samples in the fine-tune pool. + + Returns: + Multiplier for old data sampling. + """ + for threshold, multiplier in MIXING_RATIOS: + if new_sample_count <= threshold: + return multiplier + return DEFAULT_MULTIPLIER + + +def build_mixed_dataset( + pool_document_ids: list[UUID], + base_dataset_path: Path, + output_dir: Path, + seed: int = 42, +) -> dict[str, Any]: + """Build a mixed dataset for fine-tuning. + + Combines ALL fine-tune pool samples with randomly sampled old data + from the base training dataset. + + Args: + pool_document_ids: Document IDs from the fine-tune pool. + base_dataset_path: Path to the base training dataset directory. + output_dir: Output directory for the mixed dataset. + seed: Random seed for reproducible sampling. + + Returns: + Dictionary with dataset info (data_yaml path, counts). + """ + new_count = len(pool_document_ids) + multiplier = get_mixing_ratio(new_count) + old_target = min(new_count * multiplier, MAX_OLD_SAMPLES) + + logger.info( + "Building mixed dataset: %d new samples, %dx multiplier, " + "targeting %d old samples", + new_count, multiplier, old_target, + ) + + # Create output directory structure + output_dir.mkdir(parents=True, exist_ok=True) + for split in ("train", "val"): + (output_dir / "images" / split).mkdir(parents=True, exist_ok=True) + (output_dir / "labels" / split).mkdir(parents=True, exist_ok=True) + + # Collect old training images from base dataset + old_train_images = _collect_images(base_dataset_path / "images" / "train") + old_val_images = _collect_images(base_dataset_path / "images" / "val") + + # Randomly sample old data + rng = random.Random(seed) + all_old_images = old_train_images + old_val_images + if len(all_old_images) > old_target: + sampled_old = rng.sample(all_old_images, old_target) + else: + sampled_old = all_old_images + + # Split old samples: 80% train, 20% val + rng.shuffle(sampled_old) + old_train_count = int(len(sampled_old) * 0.8) + old_train = sampled_old[:old_train_count] + old_val = sampled_old[old_train_count:] + + # Copy old samples to mixed dataset + old_copied = 0 + for split_name, images in [("train", old_train), ("val", old_val)]: + for img_path in images: + label_path = _image_to_label_path(img_path) + dst_img = output_dir / "images" / split_name / img_path.name + dst_label = output_dir / "labels" / split_name / label_path.name + if img_path.exists(): + shutil.copy2(img_path, dst_img) + old_copied += 1 + if label_path.exists(): + shutil.copy2(label_path, dst_label) + + # Copy new pool samples (from base dataset, identified by document_id prefix) + # Pool documents go into train split (80%) and val split (20%) + pool_id_strs = {str(doc_id) for doc_id in pool_document_ids} + new_images = _find_pool_images(base_dataset_path, pool_id_strs) + + rng.shuffle(new_images) + new_train_count = int(len(new_images) * 0.8) + new_train = new_images[:new_train_count] + new_val = new_images[new_train_count:] + + new_copied = 0 + for split_name, images in [("train", new_train), ("val", new_val)]: + for img_path in images: + label_path = _image_to_label_path(img_path) + dst_img = output_dir / "images" / split_name / img_path.name + dst_label = output_dir / "labels" / split_name / label_path.name + if img_path.exists() and not dst_img.exists(): + shutil.copy2(img_path, dst_img) + new_copied += 1 + if label_path.exists() and not dst_label.exists(): + shutil.copy2(label_path, dst_label) + + # Generate data.yaml + from shared.fields import FIELD_CLASSES + + yaml_path = output_dir / "data.yaml" + yaml_content = ( + f"path: {output_dir.absolute()}\n" + f"train: images/train\n" + f"val: images/val\n" + f"\n" + f"nc: {len(FIELD_CLASSES)}\n" + f"names: {list(FIELD_CLASSES.values())}\n" + ) + yaml_path.write_text(yaml_content) + + total_images = old_copied + new_copied + logger.info( + "Mixed dataset built: %d old + %d new = %d total images", + old_copied, new_copied, total_images, + ) + + return { + "data_yaml": str(yaml_path), + "total_images": total_images, + "old_images": old_copied, + "new_images": new_copied, + "mixing_ratio": multiplier, + } + + +def _collect_images(images_dir: Path) -> list[Path]: + """Collect all image files from a directory.""" + if not images_dir.exists(): + return [] + return sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg")) + + +def _image_to_label_path(image_path: Path) -> Path: + """Convert image path to corresponding label path.""" + labels_dir = image_path.parent.parent.parent / "labels" / image_path.parent.name + return labels_dir / image_path.with_suffix(".txt").name + + +def _find_pool_images( + base_dataset_path: Path, + pool_doc_ids: set[str], +) -> list[Path]: + """Find images in base dataset that belong to pool documents.""" + images: list[Path] = [] + for split in ("train", "val", "test"): + split_dir = base_dataset_path / "images" / split + if not split_dir.exists(): + continue + for img_path in split_dir.iterdir(): + if not img_path.is_file(): + continue + # Image filenames are like: {doc_id}_page{N}.png + doc_id = img_path.stem.rsplit("_page", 1)[0] + if doc_id in pool_doc_ids: + images.append(img_path) + return images diff --git a/packages/backend/backend/web/services/gating_validator.py b/packages/backend/backend/web/services/gating_validator.py new file mode 100644 index 0000000..a1f6554 --- /dev/null +++ b/packages/backend/backend/web/services/gating_validator.py @@ -0,0 +1,198 @@ +""" +Gating Validation Service + +Validates fine-tuned models before deployment using quality gates: +- Gate 1: Regression validation (mAP drop on original test set) +- Gate 2: New sample validation (detection rate on pool documents) +""" + +import logging +from pathlib import Path +from uuid import UUID + +from backend.data.admin_models import GatingResult +from backend.data.database import get_session_context +from backend.data.repositories.model_version_repository import ModelVersionRepository + +logger = logging.getLogger(__name__) + +# Gate 1 thresholds (mAP drop) +GATE1_PASS_THRESHOLD = 0.01 # < 1% drop = PASS +GATE1_REVIEW_THRESHOLD = 0.03 # 1-3% drop = REVIEW, > 3% = REJECT + +# Gate 2 thresholds (detection rate) +GATE2_PASS_THRESHOLD = 0.80 # > 80% detection rate = PASS + + +def classify_gate1(mAP_drop: float) -> str: + """Classify Gate 1 result based on mAP drop. + + Args: + mAP_drop: Absolute mAP drop (positive means degradation). + + Returns: + "pass", "review", or "reject" + """ + if mAP_drop < GATE1_PASS_THRESHOLD: + return "pass" + if mAP_drop < GATE1_REVIEW_THRESHOLD: + return "review" + return "reject" + + +def classify_gate2(detection_rate: float) -> str: + """Classify Gate 2 result based on detection rate. + + Args: + detection_rate: Fraction of new samples correctly detected (0-1). + + Returns: + "pass" or "review" + """ + if detection_rate >= GATE2_PASS_THRESHOLD: + return "pass" + return "review" + + +def compute_overall_status(gate1_status: str, gate2_status: str) -> str: + """Compute overall gating status from individual gates. + + Rules: + - Any "reject" -> overall "reject" + - Any "review" (and no reject) -> overall "review" + - All "pass" -> overall "pass" + """ + if gate1_status == "reject" or gate2_status == "reject": + return "reject" + if gate1_status == "review" or gate2_status == "review": + return "review" + return "pass" + + +def run_gating_validation( + model_version_id: str | UUID, + new_model_path: str, + base_model_version_id: str | UUID | None, + data_yaml: str, + task_id: str | UUID | None = None, +) -> GatingResult: + """Run deployment gating validation for a fine-tuned model. + + Args: + model_version_id: The fine-tuned model version to validate. + new_model_path: Path to the new model weights. + base_model_version_id: The base model version to compare against. + data_yaml: Path to the dataset YAML (for validation). + task_id: Optional training task ID. + + Returns: + GatingResult with gate statuses. + """ + model_versions = ModelVersionRepository() + + # Gate 1: Regression validation + gate1_status = "pending" + gate1_original_mAP = None + gate1_new_mAP = None + gate1_mAP_drop = None + + try: + if base_model_version_id: + base_model = model_versions.get(str(base_model_version_id)) + if base_model and base_model.metrics_mAP is not None: + gate1_original_mAP = base_model.metrics_mAP + + # Run validation with new model + from shared.training import YOLOTrainer, TrainingConfig + + val_config = TrainingConfig( + model_path=new_model_path, + data_yaml=data_yaml, + ) + trainer = YOLOTrainer(config=val_config) + val_metrics = trainer.validate(split="val") + gate1_new_mAP = val_metrics.get("mAP50") + + if gate1_new_mAP is not None: + gate1_mAP_drop = gate1_original_mAP - gate1_new_mAP + gate1_status = classify_gate1(gate1_mAP_drop) + logger.info( + "Gate 1: original_mAP=%.4f, new_mAP=%.4f, drop=%.4f -> %s", + gate1_original_mAP, gate1_new_mAP, gate1_mAP_drop, gate1_status, + ) + else: + gate1_status = "review" + logger.warning("Gate 1: Could not compute new mAP, marking as review") + else: + gate1_status = "pass" + logger.info("Gate 1: No base model metrics available, skipping (pass)") + else: + gate1_status = "pass" + logger.info("Gate 1: No base model specified, skipping (pass)") + except Exception as e: + gate1_status = "review" + logger.error("Gate 1 failed: %s", e) + + # Gate 2: New sample validation + # For now, we use the training metrics as a proxy + # Full implementation would run inference on pool documents + gate2_status = "pass" + gate2_detection_rate = None + gate2_total_samples = None + gate2_detected_samples = None + + try: + new_model = model_versions.get(str(model_version_id)) + if new_model and new_model.metrics_mAP is not None: + # Use mAP as proxy for detection rate on new samples + gate2_detection_rate = new_model.metrics_mAP + if gate2_detection_rate is not None: + gate2_status = classify_gate2(gate2_detection_rate) + logger.info( + "Gate 2: detection_rate=%.4f -> %s", + gate2_detection_rate, gate2_status, + ) + except Exception as e: + gate2_status = "review" + logger.error("Gate 2 failed: %s", e) + + # Compute overall status + overall_status = compute_overall_status(gate1_status, gate2_status) + logger.info("Gating overall: %s (gate1=%s, gate2=%s)", overall_status, gate1_status, gate2_status) + + # Save result + with get_session_context() as session: + result = GatingResult( + model_version_id=UUID(str(model_version_id)), + task_id=UUID(str(task_id)) if task_id else None, + gate1_status=gate1_status, + gate1_original_mAP=gate1_original_mAP, + gate1_new_mAP=gate1_new_mAP, + gate1_mAP_drop=gate1_mAP_drop, + gate2_status=gate2_status, + gate2_detection_rate=gate2_detection_rate, + gate2_total_samples=gate2_total_samples, + gate2_detected_samples=gate2_detected_samples, + overall_status=overall_status, + ) + session.add(result) + session.commit() + session.refresh(result) + session.expunge(result) + + # Update model version gating status + _update_model_gating_status(str(model_version_id), overall_status) + + return result + + +def _update_model_gating_status(version_id: str, status: str) -> None: + """Update the gating_status field on a ModelVersion.""" + from backend.data.admin_models import ModelVersion + + with get_session_context() as session: + model = session.get(ModelVersion, UUID(version_id)) + if model: + model.gating_status = status + session.add(model) + session.commit() diff --git a/packages/backend/requirements.txt b/packages/backend/requirements.txt index dcb9ff4..2277ecb 100644 --- a/packages/backend/requirements.txt +++ b/packages/backend/requirements.txt @@ -3,6 +3,6 @@ fastapi>=0.104.0 uvicorn[standard]>=0.24.0 python-multipart>=0.0.6 sqlmodel>=0.0.22 -ultralytics>=8.1.0 +ultralytics>=8.4.0 httpx>=0.25.0 openai>=1.0.0 diff --git a/packages/shared/shared/bbox/__init__.py b/packages/shared/shared/bbox/__init__.py index baf769d..80f4828 100644 --- a/packages/shared/shared/bbox/__init__.py +++ b/packages/shared/shared/bbox/__init__.py @@ -1,37 +1,20 @@ """ -BBox Scale Strategy Module. +BBox Expansion Module. -Provides field-specific bounding box expansion strategies for YOLO training data. -Expands bboxes using center-point scaling with directional compensation to capture -field labels that typically appear above or to the left of field values. - -Two modes are supported: -- Auto-label: Field-specific scale strategies with directional compensation -- Manual-label: Minimal padding only to prevent edge clipping +Provides uniform bounding box expansion for YOLO training data. Usage: - from shared.bbox import expand_bbox, ScaleStrategy, FIELD_SCALE_STRATEGIES + from shared.bbox import expand_bbox, UNIFORM_PAD Available exports: - - ScaleStrategy: Dataclass for scale strategy configuration - - DEFAULT_STRATEGY: Default strategy for unknown fields (auto-label) - - MANUAL_LABEL_STRATEGY: Minimal padding strategy for manual labels - - FIELD_SCALE_STRATEGIES: dict[str, ScaleStrategy] - field-specific strategies - - expand_bbox: Function to expand bbox using field-specific strategy + - UNIFORM_PAD: Default uniform pixel padding (15px at 150 DPI) + - expand_bbox: Function to expand bbox with uniform padding """ -from .scale_strategy import ( - ScaleStrategy, - DEFAULT_STRATEGY, - MANUAL_LABEL_STRATEGY, - FIELD_SCALE_STRATEGIES, -) +from .scale_strategy import UNIFORM_PAD from .expander import expand_bbox __all__ = [ - "ScaleStrategy", - "DEFAULT_STRATEGY", - "MANUAL_LABEL_STRATEGY", - "FIELD_SCALE_STRATEGIES", + "UNIFORM_PAD", "expand_bbox", ] diff --git a/packages/shared/shared/bbox/expander.py b/packages/shared/shared/bbox/expander.py index ad025a1..c2c7c1b 100644 --- a/packages/shared/shared/bbox/expander.py +++ b/packages/shared/shared/bbox/expander.py @@ -1,101 +1,35 @@ """ BBox Expander Module. -Provides functions to expand bounding boxes using field-specific strategies. -Expansion is center-point based with directional compensation. - -Two modes: -- Auto-label (default): Field-specific scale strategies -- Manual-label: Minimal padding only to prevent edge clipping +Expands bounding boxes by a uniform pixel padding on all sides, +clamped to image boundaries. No field-specific or directional logic. """ -from .scale_strategy import ( - ScaleStrategy, - DEFAULT_STRATEGY, - MANUAL_LABEL_STRATEGY, - FIELD_SCALE_STRATEGIES, -) +from .scale_strategy import UNIFORM_PAD def expand_bbox( bbox: tuple[float, float, float, float], image_width: float, image_height: float, - field_type: str, - strategies: dict[str, ScaleStrategy] | None = None, - manual_mode: bool = False, + pad: int = UNIFORM_PAD, ) -> tuple[int, int, int, int]: - """ - Expand bbox using field-specific scale strategy. - - The expansion follows these steps: - 1. Scale bbox around center point (scale_x, scale_y) - 2. Apply directional compensation (extra_*_ratio) - 3. Clamp expansion to max_pad limits - 4. Clamp to image boundaries + """Expand bbox by uniform pixel padding, clamped to image bounds. Args: - bbox: (x0, y0, x1, y1) in pixels - image_width: Image width for boundary clamping - image_height: Image height for boundary clamping - field_type: Field class_name (e.g., "ocr_number") - strategies: Custom strategies dict, defaults to FIELD_SCALE_STRATEGIES - manual_mode: If True, use MANUAL_LABEL_STRATEGY (minimal padding only) + bbox: (x0, y0, x1, y1) in pixels. + image_width: Image width for boundary clamping. + image_height: Image height for boundary clamping. + pad: Uniform pixel padding on all sides (default: UNIFORM_PAD). Returns: - Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds + Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds. """ x0, y0, x1, y1 = bbox - w = x1 - x0 - h = y1 - y0 - # Get strategy based on mode - if manual_mode: - strategy = MANUAL_LABEL_STRATEGY - elif strategies is None: - strategy = FIELD_SCALE_STRATEGIES.get(field_type, DEFAULT_STRATEGY) - else: - strategy = strategies.get(field_type, DEFAULT_STRATEGY) - - # Step 1: Scale around center point - cx = (x0 + x1) / 2 - cy = (y0 + y1) / 2 - - new_w = w * strategy.scale_x - new_h = h * strategy.scale_y - - nx0 = cx - new_w / 2 - nx1 = cx + new_w / 2 - ny0 = cy - new_h / 2 - ny1 = cy + new_h / 2 - - # Step 2: Apply directional compensation - nx0 -= w * strategy.extra_left_ratio - nx1 += w * strategy.extra_right_ratio - ny0 -= h * strategy.extra_top_ratio - ny1 += h * strategy.extra_bottom_ratio - - # Step 3: Clamp expansion to max_pad limits (preserve asymmetry) - left_pad = min(x0 - nx0, strategy.max_pad_x) - right_pad = min(nx1 - x1, strategy.max_pad_x) - top_pad = min(y0 - ny0, strategy.max_pad_y) - bottom_pad = min(ny1 - y1, strategy.max_pad_y) - - # Ensure pads are non-negative (in case of contraction) - left_pad = max(0, left_pad) - right_pad = max(0, right_pad) - top_pad = max(0, top_pad) - bottom_pad = max(0, bottom_pad) - - nx0 = x0 - left_pad - nx1 = x1 + right_pad - ny0 = y0 - top_pad - ny1 = y1 + bottom_pad - - # Step 4: Clamp to image boundaries - nx0 = max(0, int(nx0)) - ny0 = max(0, int(ny0)) - nx1 = min(int(image_width), int(nx1)) - ny1 = min(int(image_height), int(ny1)) + nx0 = max(0, int(x0 - pad)) + ny0 = max(0, int(y0 - pad)) + nx1 = min(int(image_width), int(x1 + pad)) + ny1 = min(int(image_height), int(y1 + pad)) return (nx0, ny0, nx1, ny1) diff --git a/packages/shared/shared/bbox/scale_strategy.py b/packages/shared/shared/bbox/scale_strategy.py index 36f200d..35aeaa2 100644 --- a/packages/shared/shared/bbox/scale_strategy.py +++ b/packages/shared/shared/bbox/scale_strategy.py @@ -1,140 +1,12 @@ """ Scale Strategy Configuration. -Defines field-specific bbox expansion strategies for YOLO training data. -Each strategy controls how bboxes are expanded around field values to -capture contextual information like labels. +Defines uniform bbox expansion padding for YOLO training data. +All fields use the same fixed-pixel padding -- no layout assumptions. """ -from dataclasses import dataclass from typing import Final - -@dataclass(frozen=True) -class ScaleStrategy: - """Immutable scale strategy for bbox expansion. - - Attributes: - scale_x: Horizontal scale factor (1.0 = no scaling) - scale_y: Vertical scale factor (1.0 = no scaling) - extra_top_ratio: Additional expansion ratio towards top (for labels above) - extra_bottom_ratio: Additional expansion ratio towards bottom - extra_left_ratio: Additional expansion ratio towards left (for prefixes) - extra_right_ratio: Additional expansion ratio towards right (for suffixes) - max_pad_x: Maximum horizontal padding in pixels - max_pad_y: Maximum vertical padding in pixels - """ - - scale_x: float = 1.15 - scale_y: float = 1.15 - extra_top_ratio: float = 0.0 - extra_bottom_ratio: float = 0.0 - extra_left_ratio: float = 0.0 - extra_right_ratio: float = 0.0 - max_pad_x: int = 50 - max_pad_y: int = 50 - - -# Default strategy for unknown fields (auto-label mode) -DEFAULT_STRATEGY: Final[ScaleStrategy] = ScaleStrategy() - -# Manual label strategy - minimal padding to prevent edge clipping -# No scaling, no directional compensation, just small uniform padding -MANUAL_LABEL_STRATEGY: Final[ScaleStrategy] = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_top_ratio=0.0, - extra_bottom_ratio=0.0, - extra_left_ratio=0.0, - extra_right_ratio=0.0, - max_pad_x=10, # Small padding to prevent edge loss - max_pad_y=10, -) - - -# Field-specific strategies based on Swedish invoice field characteristics -# Field labels typically appear above or to the left of values -FIELD_SCALE_STRATEGIES: Final[dict[str, ScaleStrategy]] = { - # OCR number - label "OCR" or "Referens" typically above - "ocr_number": ScaleStrategy( - scale_x=1.15, - scale_y=1.80, - extra_top_ratio=0.60, - max_pad_x=50, - max_pad_y=140, - ), - # Bankgiro - prefix "Bankgiro:" or "BG:" typically to the left - "bankgiro": ScaleStrategy( - scale_x=1.45, - scale_y=1.35, - extra_left_ratio=0.80, - max_pad_x=160, - max_pad_y=90, - ), - # Plusgiro - prefix "Plusgiro:" or "PG:" typically to the left - "plusgiro": ScaleStrategy( - scale_x=1.45, - scale_y=1.35, - extra_left_ratio=0.80, - max_pad_x=160, - max_pad_y=90, - ), - # Invoice date - label "Fakturadatum" typically above - "invoice_date": ScaleStrategy( - scale_x=1.25, - scale_y=1.55, - extra_top_ratio=0.40, - max_pad_x=80, - max_pad_y=110, - ), - # Due date - label "Forfalldatum" typically above, sometimes left - "invoice_due_date": ScaleStrategy( - scale_x=1.30, - scale_y=1.65, - extra_top_ratio=0.45, - extra_left_ratio=0.35, - max_pad_x=100, - max_pad_y=120, - ), - # Amount - currency symbol "SEK" or "kr" may be to the right - "amount": ScaleStrategy( - scale_x=1.20, - scale_y=1.35, - extra_right_ratio=0.30, - max_pad_x=70, - max_pad_y=80, - ), - # Invoice number - label "Fakturanummer" typically above - "invoice_number": ScaleStrategy( - scale_x=1.20, - scale_y=1.50, - extra_top_ratio=0.40, - max_pad_x=80, - max_pad_y=100, - ), - # Supplier org number - label "Org.nr" typically above or left - "supplier_org_number": ScaleStrategy( - scale_x=1.25, - scale_y=1.40, - extra_top_ratio=0.30, - extra_left_ratio=0.20, - max_pad_x=90, - max_pad_y=90, - ), - # Customer number - label "Kundnummer" typically above or left - "customer_number": ScaleStrategy( - scale_x=1.25, - scale_y=1.45, - extra_top_ratio=0.35, - extra_left_ratio=0.25, - max_pad_x=90, - max_pad_y=100, - ), - # Payment line - machine-readable code, minimal expansion needed - "payment_line": ScaleStrategy( - scale_x=1.10, - scale_y=1.20, - max_pad_x=40, - max_pad_y=40, - ), -} +# 15px at 150 DPI = ~2.5mm real-world padding around text. +# Enough for OCR safety margin without capturing neighboring label text. +UNIFORM_PAD: Final[int] = 15 diff --git a/packages/shared/shared/training/yolo_trainer.py b/packages/shared/shared/training/yolo_trainer.py index 59435cd..9e99d8d 100644 --- a/packages/shared/shared/training/yolo_trainer.py +++ b/packages/shared/shared/training/yolo_trainer.py @@ -17,7 +17,7 @@ class TrainingConfig: """Training configuration.""" # Model settings - model_path: str = "yolo11n.pt" # Base model or path to trained model + model_path: str = "yolo26s.pt" # Base model or path to trained model data_yaml: str = "" # Path to data.yaml # Training hyperparameters @@ -39,6 +39,10 @@ class TrainingConfig: resume: bool = False resume_from: str | None = None # Path to checkpoint + # Fine-tuning specific + freeze: int = 0 # Number of backbone layers to freeze (0 = none) + cos_lr: bool = False # Use cosine learning rate scheduler + # Document-specific augmentation (optimized for invoices) augmentation: dict[str, Any] = field(default_factory=lambda: { "degrees": 5.0, @@ -106,7 +110,7 @@ class YOLOTrainer: # Check model path model_path = Path(self.config.model_path) if not model_path.suffix == ".pt": - # Could be a model name like "yolo11n.pt" which is downloaded + # Could be a model name like "yolo26s.pt" which is downloaded if not model_path.name.startswith("yolo"): return False, f"Invalid model: {self.config.model_path}" elif not model_path.exists(): @@ -147,6 +151,10 @@ class YOLOTrainer: self._log("INFO", f" Epochs: {self.config.epochs}") self._log("INFO", f" Batch size: {self.config.batch_size}") self._log("INFO", f" Image size: {self.config.image_size}") + if self.config.freeze > 0: + self._log("INFO", f" Freeze layers: {self.config.freeze}") + if self.config.cos_lr: + self._log("INFO", f" Cosine LR: enabled") try: # Load model @@ -178,6 +186,12 @@ class YOLOTrainer: "resume": self.config.resume and self.config.resume_from is not None, } + # Add fine-tuning settings + if self.config.freeze > 0: + train_args["freeze"] = self.config.freeze + if self.config.cos_lr: + train_args["cos_lr"] = True + # Add augmentation settings train_args.update(self.config.augmentation) diff --git a/packages/training/requirements.txt b/packages/training/requirements.txt index 248445f..01e02d1 100644 --- a/packages/training/requirements.txt +++ b/packages/training/requirements.txt @@ -1,4 +1,4 @@ -e ../shared -ultralytics>=8.1.0 +ultralytics>=8.4.0 tqdm>=4.65.0 torch>=2.0.0 diff --git a/packages/training/run_training.py b/packages/training/run_training.py index 73aad98..4c0222b 100644 --- a/packages/training/run_training.py +++ b/packages/training/run_training.py @@ -34,7 +34,7 @@ def execute_training_task(db: TrainingTaskDB, task: dict) -> None: result = run_training( epochs=config.get("epochs", 100), batch=config.get("batch_size", 16), - model=config.get("base_model", "yolo11n.pt"), + model=config.get("base_model", "yolo26s.pt"), imgsz=config.get("imgsz", 1280), name=config.get("name", f"training_{task_id[:8]}"), ) diff --git a/packages/training/training/cli/train.py b/packages/training/training/cli/train.py index 0fb25ab..6e4c406 100644 --- a/packages/training/training/cli/train.py +++ b/packages/training/training/cli/train.py @@ -28,8 +28,8 @@ def main(): ) parser.add_argument( '--model', '-m', - default='yolov8s.pt', - help='Base model (default: yolov8s.pt)' + default='yolo26s.pt', + help='Base model (default: yolo26s.pt)' ) parser.add_argument( '--epochs', '-e', diff --git a/packages/training/training/yolo/annotation_generator.py b/packages/training/training/yolo/annotation_generator.py index 4101a92..395b2d0 100644 --- a/packages/training/training/yolo/annotation_generator.py +++ b/packages/training/training/yolo/annotation_generator.py @@ -100,12 +100,11 @@ class AnnotationGenerator: x0, y0, x1, y1 = best_match.bbox x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale - # Apply field-specific bbox expansion strategy + # Apply uniform bbox expansion x0, y0, x1, y1 = expand_bbox( bbox=(x0, y0, x1, y1), image_width=image_width, image_height=image_height, - field_type=class_name, ) # Ensure minimum height @@ -173,12 +172,11 @@ class AnnotationGenerator: x0, y0, x1, y1 = payment_line_bbox x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale - # Apply field-specific bbox expansion strategy for payment_line + # Apply uniform bbox expansion x0, y0, x1, y1 = expand_bbox( bbox=(x0, y0, x1, y1), image_width=image_width, image_height=image_height, - field_type="payment_line", ) # Convert to YOLO format (normalized center + size) diff --git a/packages/training/training/yolo/db_dataset.py b/packages/training/training/yolo/db_dataset.py index 530c34a..4c36aa9 100644 --- a/packages/training/training/yolo/db_dataset.py +++ b/packages/training/training/yolo/db_dataset.py @@ -585,15 +585,11 @@ class DBYOLODataset: x1_px = x1_pdf * scale y1_px = y1_pdf * scale - # Get class name for field-specific expansion - class_name = CLASS_NAMES[ann.class_id] - - # Apply field-specific bbox expansion + # Apply uniform bbox expansion x0, y0, x1, y1 = expand_bbox( bbox=(x0_px, y0_px, x1_px, y1_px), image_width=img_width, image_height=img_height, - field_type=class_name, ) # Ensure minimum height diff --git a/pyproject.toml b/pyproject.toml index 36b3ce2..5a8ae41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "PyMuPDF>=1.23.0", "paddlepaddle>=3.0.0,<3.3.0", "paddleocr>=3.0.0", - "ultralytics>=8.1.0", + "ultralytics>=8.4.0", "Pillow>=10.0.0", "numpy>=1.24.0", "opencv-python>=4.8.0", diff --git a/requirements.txt b/requirements.txt index cb64878..ab2e31e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ paddlepaddle>=3.0.0,<3.3.0 # PaddlePaddle framework (3.3.0 has OneDNN bug) paddleocr>=3.0.0 # PaddleOCR (PP-OCRv5) # YOLO -ultralytics>=8.1.0 # YOLOv8/v11 +ultralytics>=8.4.0 # YOLO26 # Image Processing Pillow>=10.0.0 # Image handling diff --git a/start_web.sh b/start_web.sh index 96ce973..88277ba 100644 --- a/start_web.sh +++ b/start_web.sh @@ -1,5 +1,5 @@ #!/bin/bash cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 source ~/miniconda3/etc/profile.d/conda.sh -conda activate invoice-py311 +conda activate invoice-sm120 python run_server.py --port 8000 diff --git a/tests/inference/test_normalizers.py b/tests/inference/test_normalizers.py index 0f6d7a7..7e59e88 100644 --- a/tests/inference/test_normalizers.py +++ b/tests/inference/test_normalizers.py @@ -400,6 +400,71 @@ class TestAmountNormalizer: result = normalizer.normalize("Reference 12500") assert result.value == "12500.00" + def test_payment_line_kronor_ore_format(self, normalizer): + """Space between kronor and ore should be treated as decimal separator. + + Swedish payment lines use space to separate kronor and ore: + "590 00" means 590.00 SEK, NOT 59000. + """ + result = normalizer.normalize("590 00") + assert result.value == "590.00" + assert result.is_valid is True + + def test_payment_line_kronor_ore_large_amount(self, normalizer): + """Large kronor/ore amount from payment line.""" + result = normalizer.normalize("15658 00") + assert result.value == "15658.00" + assert result.is_valid is True + + def test_payment_line_kronor_ore_with_nonzero_ore(self, normalizer): + """Kronor/ore with non-zero ore.""" + result = normalizer.normalize("736 50") + assert result.value == "736.50" + assert result.is_valid is True + + def test_kronor_ore_not_confused_with_thousand_separator(self, normalizer): + """Amount with comma decimal should NOT trigger kronor/ore pattern.""" + result = normalizer.normalize("1 234,56") + assert result.value is not None + # Should parse as 1234.56, not as kronor=1234 ore=56 (which is same value) + assert float(result.value) == 1234.56 + + def test_european_dot_thousand_separator(self, normalizer): + """European format: dot as thousand, comma as decimal.""" + result = normalizer.normalize("2.254,50") + assert result.value == "2254.50" + assert result.is_valid is True + + def test_european_dot_thousand_with_sek(self, normalizer): + """European format with SEK suffix.""" + result = normalizer.normalize("2.254,50 SEK") + assert result.value == "2254.50" + assert result.is_valid is True + + def test_european_dot_thousand_with_kr(self, normalizer): + """European format with kr suffix.""" + result = normalizer.normalize("20.485,00 kr") + assert result.value == "20485.00" + assert result.is_valid is True + + def test_european_large_amount(self, normalizer): + """Large European format amount.""" + result = normalizer.normalize("1.234.567,89") + assert result.value == "1234567.89" + assert result.is_valid is True + + def test_european_in_label_context(self, normalizer): + """European format inside label text (like the BAUHAUS invoice bug).""" + result = normalizer.normalize("ns Fakturabelopp: 2.254,50 SEK") + assert result.value == "2254.50" + assert result.is_valid is True + + def test_anglo_comma_thousand_separator(self, normalizer): + """Anglo format: comma as thousand, dot as decimal.""" + result = normalizer.normalize("1,234.56") + assert result.value == "1234.56" + assert result.is_valid is True + def test_zero_amount_rejected(self, normalizer): """Test that zero amounts are rejected.""" result = normalizer.normalize("0,00 kr") @@ -450,6 +515,18 @@ class TestEnhancedAmountNormalizer: result = normalizer.normalize("Invoice for 1 234 567,89 kr") assert result.value is not None + def test_enhanced_kronor_ore_format(self, normalizer): + """Space between kronor and ore in enhanced normalizer.""" + result = normalizer.normalize("590 00") + assert result.value == "590.00" + assert result.is_valid is True + + def test_enhanced_kronor_ore_large(self, normalizer): + """Large kronor/ore amount in enhanced normalizer.""" + result = normalizer.normalize("15658 00") + assert result.value == "15658.00" + assert result.is_valid is True + def test_no_amount_fails(self, normalizer): """Test failure when no amount found.""" result = normalizer.normalize("no amount") @@ -472,6 +549,22 @@ class TestEnhancedAmountNormalizer: result = normalizer.normalize("Price: 1 234 567,89") assert result.value is not None + def test_enhanced_european_dot_thousand(self, normalizer): + """European format in enhanced normalizer.""" + result = normalizer.normalize("2.254,50 SEK") + assert result.value == "2254.50" + assert result.is_valid is True + + def test_enhanced_european_with_label(self, normalizer): + """European format with Swedish label keyword.""" + result = normalizer.normalize("Att betala: 2.254,50") + assert result.value == "2254.50" + + def test_enhanced_anglo_format(self, normalizer): + """Anglo format in enhanced normalizer.""" + result = normalizer.normalize("Total: 1,234.56") + assert result.value == "1234.56" + def test_amount_out_of_range_rejected(self, normalizer): """Test that amounts >= 10,000,000 are rejected.""" result = normalizer.normalize("Summa: 99 999 999,00") diff --git a/tests/inference/test_pipeline.py b/tests/inference/test_pipeline.py index 4cbaeb0..ecfab4d 100644 --- a/tests/inference/test_pipeline.py +++ b/tests/inference/test_pipeline.py @@ -497,5 +497,178 @@ class TestExtractBusinessFeaturesErrorHandling: assert "NumericException" in result.errors[0] +class TestProcessPdfTokenPath: + """Tests for PDF text token extraction path in process_pdf().""" + + def _make_pipeline(self): + """Create pipeline with mocked internals, bypassing __init__.""" + with patch.object(InferencePipeline, '__init__', lambda self, **kw: None): + p = InferencePipeline() + p.detector = MagicMock() + p.extractor = MagicMock() + p.payment_line_parser = MagicMock() + p.dpi = 300 + p.enable_fallback = False + p.enable_business_features = False + p.vat_tolerance = 0.5 + p.line_items_extractor = None + p.vat_extractor = None + p.vat_validator = None + p._business_ocr_engine = None + p._table_detector = None + return p + + def _make_detection(self, class_name='Amount', confidence=0.85, page_no=0): + """Create a Detection object.""" + from backend.pipeline.yolo_detector import Detection + return Detection( + class_id=6, + class_name=class_name, + confidence=confidence, + bbox=(100.0, 200.0, 300.0, 250.0), + page_no=page_no, + ) + + def _make_extracted_field(self, field_name='Amount', raw_text='2.254,50', + normalized='2254.50', confidence=0.85): + """Create an ExtractedField object.""" + from backend.pipeline.field_extractor import ExtractedField + return ExtractedField( + field_name=field_name, + raw_text=raw_text, + normalized_value=normalized, + confidence=confidence, + detection_confidence=confidence, + ocr_confidence=1.0, + bbox=(100.0, 200.0, 300.0, 250.0), + page_no=0, + ) + + def _make_image_bytes(self): + """Create minimal valid PNG bytes (100x100 white image).""" + from PIL import Image as PILImage + import io as _io + img = PILImage.new('RGB', (100, 100), color='white') + buf = _io.BytesIO() + img.save(buf, format='PNG') + return buf.getvalue() + + @patch('shared.pdf.extractor.PDFDocument') + @patch('shared.pdf.renderer.render_pdf_to_images') + def test_text_pdf_uses_pdf_tokens(self, mock_render, mock_pdf_doc_cls): + """When PDF is text-based, extract_from_detection_with_pdf is used.""" + from shared.pdf.extractor import Token + + pipeline = self._make_pipeline() + detection = self._make_detection() + image_bytes = self._make_image_bytes() + + # Setup PDFDocument mock - text PDF with tokens + mock_pdf_doc = MagicMock() + mock_pdf_doc.is_text_pdf.return_value = True + mock_pdf_doc.page_count = 1 + tokens = [Token(text="2.254,50", bbox=(100, 200, 200, 220), page_no=0)] + mock_pdf_doc.extract_text_tokens.return_value = iter(tokens) + mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc) + mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False) + + pipeline.detector.detect.return_value = [detection] + pipeline.extractor.extract_from_detection_with_pdf.return_value = ( + self._make_extracted_field() + ) + + mock_render.return_value = iter([(0, image_bytes)]) + result = pipeline.process_pdf('/fake/invoice.pdf') + + pipeline.extractor.extract_from_detection_with_pdf.assert_called_once() + pipeline.extractor.extract_from_detection.assert_not_called() + assert result.fields.get('Amount') == '2254.50' + assert result.success is True + + @patch('shared.pdf.extractor.PDFDocument') + @patch('shared.pdf.renderer.render_pdf_to_images') + def test_scanned_pdf_uses_ocr(self, mock_render, mock_pdf_doc_cls): + """When PDF is scanned, extract_from_detection (OCR) is used.""" + pipeline = self._make_pipeline() + detection = self._make_detection() + image_bytes = self._make_image_bytes() + + mock_pdf_doc = MagicMock() + mock_pdf_doc.is_text_pdf.return_value = False + mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc) + mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False) + + pipeline.detector.detect.return_value = [detection] + pipeline.extractor.extract_from_detection.return_value = ( + self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75) + ) + + mock_render.return_value = iter([(0, image_bytes)]) + result = pipeline.process_pdf('/fake/invoice.pdf') + + pipeline.extractor.extract_from_detection.assert_called_once() + pipeline.extractor.extract_from_detection_with_pdf.assert_not_called() + + @patch('shared.pdf.extractor.PDFDocument') + @patch('shared.pdf.renderer.render_pdf_to_images') + def test_pdf_detection_error_falls_back_to_ocr(self, mock_render, mock_pdf_doc_cls): + """When PDF text detection throws, fall back to OCR.""" + pipeline = self._make_pipeline() + detection = self._make_detection() + image_bytes = self._make_image_bytes() + + mock_ctx = MagicMock() + mock_ctx.__enter__ = MagicMock(side_effect=Exception("corrupt PDF")) + mock_ctx.__exit__ = MagicMock(return_value=False) + mock_pdf_doc_cls.return_value = mock_ctx + + pipeline.detector.detect.return_value = [detection] + pipeline.extractor.extract_from_detection.return_value = ( + self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75) + ) + + mock_render.return_value = iter([(0, image_bytes)]) + result = pipeline.process_pdf('/fake/invoice.pdf') + + pipeline.extractor.extract_from_detection.assert_called_once() + pipeline.extractor.extract_from_detection_with_pdf.assert_not_called() + + @patch('shared.pdf.extractor.PDFDocument') + @patch('shared.pdf.renderer.render_pdf_to_images') + def test_text_pdf_passes_correct_args(self, mock_render, mock_pdf_doc_cls): + """Verify correct token list and image dimensions are passed.""" + from shared.pdf.extractor import Token + + pipeline = self._make_pipeline() + detection = self._make_detection() + image_bytes = self._make_image_bytes() # 100x100 PNG + + mock_pdf_doc = MagicMock() + mock_pdf_doc.is_text_pdf.return_value = True + mock_pdf_doc.page_count = 1 + tokens = [ + Token(text="Fakturabelopp:", bbox=(50, 190, 100, 210), page_no=0), + Token(text="2.254,50", bbox=(105, 190, 180, 210), page_no=0), + Token(text="SEK", bbox=(185, 190, 210, 210), page_no=0), + ] + mock_pdf_doc.extract_text_tokens.return_value = iter(tokens) + mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc) + mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False) + + pipeline.detector.detect.return_value = [detection] + pipeline.extractor.extract_from_detection_with_pdf.return_value = ( + self._make_extracted_field() + ) + + mock_render.return_value = iter([(0, image_bytes)]) + pipeline.process_pdf('/fake/invoice.pdf') + + call_args = pipeline.extractor.extract_from_detection_with_pdf.call_args[0] + assert call_args[0] == detection + assert len(call_args[1]) == 3 # 3 tokens passed + assert call_args[2] == 100 # image width + assert call_args[3] == 100 # image height + + if __name__ == '__main__': pytest.main([__file__, '-v']) diff --git a/tests/pipeline/__init__.py b/tests/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/pipeline/test_value_selector.py b/tests/pipeline/test_value_selector.py new file mode 100644 index 0000000..d9b8fa2 --- /dev/null +++ b/tests/pipeline/test_value_selector.py @@ -0,0 +1,318 @@ +""" +Tests for ValueSelector -- field-aware OCR token selection. + +Verifies that ValueSelector picks the most likely value token(s) +from OCR output, filtering out label text before sending to normalizer. +""" + +import pytest + +from shared.ocr.paddle_ocr import OCRToken +from backend.pipeline.value_selector import ValueSelector + + +def _token(text: str) -> OCRToken: + """Helper to create OCRToken with dummy bbox and confidence.""" + return OCRToken(text=text, bbox=(0, 0, 100, 20), confidence=0.95) + + +def _tokens(*texts: str) -> list[OCRToken]: + """Helper to create multiple OCRTokens.""" + return [_token(t) for t in texts] + + +class TestValueSelectorDateFields: + """Tests for date field value selection (InvoiceDate, InvoiceDueDate).""" + + def test_selects_iso_date_from_label_and_value(self): + tokens = _tokens("Fakturadatum", "2024-01-15") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceDate") + + assert len(result) == 1 + assert result[0].text == "2024-01-15" + + def test_selects_dot_separated_date(self): + tokens = _tokens("Datum", "2024.03.20") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceDate") + + assert len(result) == 1 + assert result[0].text == "2024.03.20" + + def test_selects_slash_separated_date(self): + tokens = _tokens("Forfallodag", "15/01/2024") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceDueDate") + + assert len(result) == 1 + assert result[0].text == "15/01/2024" + + def test_selects_compact_date(self): + tokens = _tokens("Datum", "20240115") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceDate") + + assert len(result) == 1 + assert result[0].text == "20240115" + + def test_fallback_when_no_date_pattern(self): + """No date pattern found -> return all tokens.""" + tokens = _tokens("Fakturadatum", "pending") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceDate") + + assert len(result) == 2 + + +class TestValueSelectorAmountField: + """Tests for amount field value selection.""" + + def test_selects_amount_with_comma_decimal(self): + tokens = _tokens("Belopp", "1 234,56", "kr") + + result = ValueSelector.select_value_tokens(tokens, "Amount") + + assert len(result) == 1 + assert result[0].text == "1 234,56" + + def test_selects_amount_with_dot_decimal(self): + tokens = _tokens("Summa", "1234.56") + + result = ValueSelector.select_value_tokens(tokens, "Amount") + + assert len(result) == 1 + assert result[0].text == "1234.56" + + def test_selects_simple_amount(self): + tokens = _tokens("Att", "betala", "500,00") + + result = ValueSelector.select_value_tokens(tokens, "Amount") + + assert len(result) == 1 + assert result[0].text == "500,00" + + def test_selects_european_amount_with_dot_thousand(self): + """European format: dot as thousand separator, comma as decimal.""" + tokens = _tokens("Fakturabelopp:", "2.254,50 SEK") + + result = ValueSelector.select_value_tokens(tokens, "Amount") + + assert len(result) == 1 + assert result[0].text == "2.254,50 SEK" + + def test_selects_european_amount_without_currency(self): + """European format without currency suffix.""" + tokens = _tokens("Belopp", "1.234,56") + + result = ValueSelector.select_value_tokens(tokens, "Amount") + + assert len(result) == 1 + assert result[0].text == "1.234,56" + + def test_selects_amount_with_kr_suffix(self): + """Amount with 'kr' currency suffix.""" + tokens = _tokens("Summa", "20.485,00 kr") + + result = ValueSelector.select_value_tokens(tokens, "Amount") + + assert len(result) == 1 + assert result[0].text == "20.485,00 kr" + + def test_selects_anglo_amount_with_sek(self): + """Anglo format with SEK suffix.""" + tokens = _tokens("Amount", "1,234.56 SEK") + + result = ValueSelector.select_value_tokens(tokens, "Amount") + + assert len(result) == 1 + assert result[0].text == "1,234.56 SEK" + + def test_fallback_when_no_amount_pattern(self): + tokens = _tokens("Belopp", "TBD") + + result = ValueSelector.select_value_tokens(tokens, "Amount") + + assert len(result) == 2 + + +class TestValueSelectorBankgiroField: + """Tests for Bankgiro field value selection.""" + + def test_selects_hyphenated_bankgiro(self): + tokens = _tokens("BG:", "123-4567") + + result = ValueSelector.select_value_tokens(tokens, "Bankgiro") + + assert len(result) == 1 + assert result[0].text == "123-4567" + + def test_selects_bankgiro_digits(self): + tokens = _tokens("Bankgiro", "1234567") + + result = ValueSelector.select_value_tokens(tokens, "Bankgiro") + + assert len(result) == 1 + assert result[0].text == "1234567" + + def test_selects_eight_digit_bankgiro(self): + tokens = _tokens("Bankgiro:", "12345678") + + result = ValueSelector.select_value_tokens(tokens, "Bankgiro") + + assert len(result) == 1 + assert result[0].text == "12345678" + + +class TestValueSelectorPlusgiroField: + """Tests for Plusgiro field value selection.""" + + def test_selects_hyphenated_plusgiro(self): + tokens = _tokens("PG:", "12345-6") + + result = ValueSelector.select_value_tokens(tokens, "Plusgiro") + + assert len(result) == 1 + assert result[0].text == "12345-6" + + def test_selects_plusgiro_digits(self): + tokens = _tokens("Plusgiro", "1234567") + + result = ValueSelector.select_value_tokens(tokens, "Plusgiro") + + assert len(result) == 1 + assert result[0].text == "1234567" + + +class TestValueSelectorOcrField: + """Tests for OCR reference number field value selection.""" + + def test_selects_longest_digit_sequence(self): + tokens = _tokens("OCR", "1234567890") + + result = ValueSelector.select_value_tokens(tokens, "OCR") + + assert len(result) == 1 + assert result[0].text == "1234567890" + + def test_selects_token_with_most_digits(self): + tokens = _tokens("Ref", "nr", "94228110015950070") + + result = ValueSelector.select_value_tokens(tokens, "OCR") + + assert len(result) == 1 + assert result[0].text == "94228110015950070" + + def test_ignores_short_digit_tokens(self): + """Tokens with fewer than 5 digits are not OCR references.""" + tokens = _tokens("OCR", "123") + + result = ValueSelector.select_value_tokens(tokens, "OCR") + + # Fallback: return all tokens since no valid OCR found + assert len(result) == 2 + + +class TestValueSelectorInvoiceNumberField: + """Tests for InvoiceNumber field value selection.""" + + def test_removes_swedish_label_keywords(self): + tokens = _tokens("Fakturanummer", "INV-2024-001") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceNumber") + + assert len(result) == 1 + assert result[0].text == "INV-2024-001" + + def test_keeps_non_label_tokens(self): + tokens = _tokens("Nr", "12345") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceNumber") + + assert len(result) == 1 + assert result[0].text == "12345" + + def test_multiple_value_tokens_kept(self): + """Multiple non-label tokens are all kept.""" + tokens = _tokens("Fakturanr", "INV", "2024", "001") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceNumber") + + # "Fakturanr" is a label keyword, the rest are values + result_texts = [t.text for t in result] + assert "Fakturanr" not in result_texts + assert "INV" in result_texts + + +class TestValueSelectorOrgNumberField: + """Tests for supplier_org_number field value selection.""" + + def test_selects_org_number_with_hyphen(self): + tokens = _tokens("Org.nr", "556123-4567") + + result = ValueSelector.select_value_tokens(tokens, "supplier_org_number") + + assert len(result) == 1 + assert result[0].text == "556123-4567" + + def test_selects_org_number_without_hyphen(self): + tokens = _tokens("Organisationsnummer", "5561234567") + + result = ValueSelector.select_value_tokens(tokens, "supplier_org_number") + + assert len(result) == 1 + assert result[0].text == "5561234567" + + +class TestValueSelectorCustomerNumberField: + """Tests for customer_number field value selection.""" + + def test_removes_label_keeps_value(self): + tokens = _tokens("Kundnummer", "ABC-123") + + result = ValueSelector.select_value_tokens(tokens, "customer_number") + + assert len(result) == 1 + assert result[0].text == "ABC-123" + + +class TestValueSelectorPaymentLineField: + """Tests for payment_line field -- keeps all tokens.""" + + def test_keeps_all_tokens(self): + tokens = _tokens("#", "94228110015950070", "#", "15658", "00", "8", ">", "48666036#14#") + + result = ValueSelector.select_value_tokens(tokens, "payment_line") + + assert len(result) == len(tokens) + + +class TestValueSelectorFallback: + """Tests for fallback behavior.""" + + def test_unknown_field_returns_all_tokens(self): + tokens = _tokens("some", "unknown", "text") + + result = ValueSelector.select_value_tokens(tokens, "unknown_field") + + assert len(result) == 3 + + def test_empty_tokens_returns_empty(self): + result = ValueSelector.select_value_tokens([], "InvoiceDate") + + assert result == [] + + def test_single_token_returns_it(self): + tokens = _tokens("2024-01-15") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceDate") + + assert len(result) == 1 + + def test_never_returns_empty_when_tokens_exist(self): + """Fallback ensures we never lose data -- always return something.""" + tokens = _tokens("Fakturadatum", "unknown_format") + + result = ValueSelector.select_value_tokens(tokens, "InvoiceDate") + + assert len(result) > 0 diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 0000000..06d6012 --- /dev/null +++ b/tests/services/__init__.py @@ -0,0 +1 @@ +"""Tests for backend services.""" diff --git a/tests/services/test_data_mixer.py b/tests/services/test_data_mixer.py new file mode 100644 index 0000000..82bea22 --- /dev/null +++ b/tests/services/test_data_mixer.py @@ -0,0 +1,344 @@ +""" +Tests for Data Mixing Service. + +Tests cover: +1. get_mixing_ratio boundary values +2. build_mixed_dataset with temp filesystem +3. _find_pool_images matching logic +4. _image_to_label_path conversion +5. Edge cases (empty pool, no old data, cap) +""" + +import pytest +from pathlib import Path +from uuid import uuid4 + +from backend.web.services.data_mixer import ( + get_mixing_ratio, + build_mixed_dataset, + _collect_images, + _image_to_label_path, + _find_pool_images, + MIXING_RATIOS, + DEFAULT_MULTIPLIER, + MAX_OLD_SAMPLES, + MIN_POOL_SIZE, +) + + +# ============================================================================= +# Test Constants +# ============================================================================= + + +class TestConstants: + """Tests for data mixer constants.""" + + def test_mixing_ratios_defined(self): + """MIXING_RATIOS should have expected entries.""" + assert len(MIXING_RATIOS) == 4 + assert MIXING_RATIOS[0] == (10, 50) + assert MIXING_RATIOS[1] == (50, 20) + assert MIXING_RATIOS[2] == (200, 10) + assert MIXING_RATIOS[3] == (500, 5) + + def test_default_multiplier(self): + """DEFAULT_MULTIPLIER should be 5.""" + assert DEFAULT_MULTIPLIER == 5 + + def test_max_old_samples(self): + """MAX_OLD_SAMPLES should be 3000.""" + assert MAX_OLD_SAMPLES == 3000 + + def test_min_pool_size(self): + """MIN_POOL_SIZE should be 50.""" + assert MIN_POOL_SIZE == 50 + + +# ============================================================================= +# Test get_mixing_ratio +# ============================================================================= + + +class TestGetMixingRatio: + """Tests for get_mixing_ratio function.""" + + def test_1_sample_returns_50x(self): + """1 new sample should get 50x old data.""" + assert get_mixing_ratio(1) == 50 + + def test_10_samples_returns_50x(self): + """10 new samples (boundary) should get 50x.""" + assert get_mixing_ratio(10) == 50 + + def test_11_samples_returns_20x(self): + """11 new samples should get 20x.""" + assert get_mixing_ratio(11) == 20 + + def test_50_samples_returns_20x(self): + """50 new samples (boundary) should get 20x.""" + assert get_mixing_ratio(50) == 20 + + def test_51_samples_returns_10x(self): + """51 new samples should get 10x.""" + assert get_mixing_ratio(51) == 10 + + def test_200_samples_returns_10x(self): + """200 new samples (boundary) should get 10x.""" + assert get_mixing_ratio(200) == 10 + + def test_201_samples_returns_5x(self): + """201 new samples should get 5x.""" + assert get_mixing_ratio(201) == 5 + + def test_500_samples_returns_5x(self): + """500 new samples (boundary) should get 5x.""" + assert get_mixing_ratio(500) == 5 + + def test_1000_samples_returns_default(self): + """1000+ samples should get default multiplier (5x).""" + assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER + + +# ============================================================================= +# Test _collect_images +# ============================================================================= + + +class TestCollectImages: + """Tests for _collect_images function.""" + + def test_collects_png_files(self, tmp_path: Path): + """Should collect .png files.""" + (tmp_path / "img1.png").write_bytes(b"fake png") + (tmp_path / "img2.png").write_bytes(b"fake png") + + images = _collect_images(tmp_path) + assert len(images) == 2 + + def test_collects_jpg_files(self, tmp_path: Path): + """Should collect .jpg files.""" + (tmp_path / "img1.jpg").write_bytes(b"fake jpg") + + images = _collect_images(tmp_path) + assert len(images) == 1 + + def test_collects_both_types(self, tmp_path: Path): + """Should collect both .png and .jpg files.""" + (tmp_path / "img1.png").write_bytes(b"fake png") + (tmp_path / "img2.jpg").write_bytes(b"fake jpg") + + images = _collect_images(tmp_path) + assert len(images) == 2 + + def test_ignores_other_files(self, tmp_path: Path): + """Should ignore non-image files.""" + (tmp_path / "data.txt").write_text("not an image") + (tmp_path / "data.yaml").write_text("yaml") + (tmp_path / "img.png").write_bytes(b"png") + + images = _collect_images(tmp_path) + assert len(images) == 1 + + def test_returns_empty_for_nonexistent_dir(self, tmp_path: Path): + """Should return empty list for nonexistent directory.""" + images = _collect_images(tmp_path / "nonexistent") + assert images == [] + + +# ============================================================================= +# Test _image_to_label_path +# ============================================================================= + + +class TestImageToLabelPath: + """Tests for _image_to_label_path function.""" + + def test_converts_train_image_to_label(self, tmp_path: Path): + """Should convert images/train/img.png to labels/train/img.txt.""" + image_path = tmp_path / "dataset" / "images" / "train" / "doc1_page1.png" + label_path = _image_to_label_path(image_path) + + assert label_path.name == "doc1_page1.txt" + assert "labels" in str(label_path) + assert "train" in str(label_path) + + def test_converts_val_image_to_label(self, tmp_path: Path): + """Should convert images/val/img.jpg to labels/val/img.txt.""" + image_path = tmp_path / "dataset" / "images" / "val" / "doc2_page3.jpg" + label_path = _image_to_label_path(image_path) + + assert label_path.name == "doc2_page3.txt" + assert "labels" in str(label_path) + assert "val" in str(label_path) + + +# ============================================================================= +# Test _find_pool_images +# ============================================================================= + + +class TestFindPoolImages: + """Tests for _find_pool_images function.""" + + def _create_dataset(self, base_path: Path, doc_ids: list[str], split: str = "train") -> None: + """Helper to create a dataset structure with images.""" + images_dir = base_path / "images" / split + images_dir.mkdir(parents=True, exist_ok=True) + for doc_id in doc_ids: + (images_dir / f"{doc_id}_page1.png").write_bytes(b"img") + (images_dir / f"{doc_id}_page2.png").write_bytes(b"img") + + def test_finds_matching_images(self, tmp_path: Path): + """Should find images matching pool document IDs.""" + doc_id1 = str(uuid4()) + doc_id2 = str(uuid4()) + self._create_dataset(tmp_path, [doc_id1, doc_id2]) + + pool_ids = {doc_id1} + images = _find_pool_images(tmp_path, pool_ids) + + assert len(images) == 2 # 2 pages for doc_id1 + assert all(doc_id1 in str(img) for img in images) + + def test_ignores_non_pool_images(self, tmp_path: Path): + """Should not return images for documents not in pool.""" + doc_id1 = str(uuid4()) + doc_id2 = str(uuid4()) + self._create_dataset(tmp_path, [doc_id1, doc_id2]) + + pool_ids = {doc_id1} + images = _find_pool_images(tmp_path, pool_ids) + + # Only doc_id1 images should be found + for img in images: + assert doc_id1 in str(img) + assert doc_id2 not in str(img) + + def test_searches_all_splits(self, tmp_path: Path): + """Should search train, val, and test splits.""" + doc_id = str(uuid4()) + for split in ("train", "val", "test"): + self._create_dataset(tmp_path, [doc_id], split=split) + + images = _find_pool_images(tmp_path, {doc_id}) + assert len(images) == 6 # 2 pages * 3 splits + + def test_empty_pool_returns_empty(self, tmp_path: Path): + """Should return empty list for empty pool IDs.""" + self._create_dataset(tmp_path, [str(uuid4())]) + + images = _find_pool_images(tmp_path, set()) + assert images == [] + + +# ============================================================================= +# Test build_mixed_dataset +# ============================================================================= + + +class TestBuildMixedDataset: + """Tests for build_mixed_dataset function.""" + + def _setup_base_dataset(self, base_path: Path, num_old: int = 20) -> None: + """Create a base dataset with old training images.""" + for split in ("train", "val"): + img_dir = base_path / "images" / split + lbl_dir = base_path / "labels" / split + img_dir.mkdir(parents=True, exist_ok=True) + lbl_dir.mkdir(parents=True, exist_ok=True) + + count = int(num_old * 0.8) if split == "train" else num_old - int(num_old * 0.8) + for i in range(count): + doc_id = str(uuid4()) + img_file = img_dir / f"{doc_id}_page1.png" + lbl_file = lbl_dir / f"{doc_id}_page1.txt" + img_file.write_bytes(b"fake image data") + lbl_file.write_text("0 0.5 0.5 0.1 0.1\n") + + def _setup_pool_images(self, base_path: Path, doc_ids: list[str]) -> None: + """Add pool images to the base dataset.""" + img_dir = base_path / "images" / "train" + lbl_dir = base_path / "labels" / "train" + img_dir.mkdir(parents=True, exist_ok=True) + lbl_dir.mkdir(parents=True, exist_ok=True) + + for doc_id in doc_ids: + img_file = img_dir / f"{doc_id}_page1.png" + lbl_file = lbl_dir / f"{doc_id}_page1.txt" + img_file.write_bytes(b"pool image data") + lbl_file.write_text("0 0.5 0.5 0.2 0.2\n") + + @pytest.fixture + def base_dataset(self, tmp_path: Path) -> Path: + """Create a base dataset for testing.""" + base_path = tmp_path / "base_dataset" + self._setup_base_dataset(base_path, num_old=20) + return base_path + + def test_builds_output_structure(self, base_dataset: Path, tmp_path: Path): + """Should create proper YOLO directory structure.""" + pool_ids = [uuid4() for _ in range(5)] + self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids]) + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base_dataset, + output_dir=output_dir, + ) + + assert (output_dir / "images" / "train").exists() + assert (output_dir / "images" / "val").exists() + assert (output_dir / "labels" / "train").exists() + assert (output_dir / "labels" / "val").exists() + assert (output_dir / "data.yaml").exists() + + def test_returns_correct_metadata(self, base_dataset: Path, tmp_path: Path): + """Should return correct counts and metadata.""" + pool_ids = [uuid4() for _ in range(5)] + self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids]) + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base_dataset, + output_dir=output_dir, + ) + + assert "data_yaml" in result + assert "total_images" in result + assert "old_images" in result + assert "new_images" in result + assert "mixing_ratio" in result + assert result["total_images"] == result["old_images"] + result["new_images"] + + def test_mixing_ratio_applied(self, base_dataset: Path, tmp_path: Path): + """Should use correct mixing ratio based on pool size.""" + pool_ids = [uuid4() for _ in range(5)] + self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids]) + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base_dataset, + output_dir=output_dir, + ) + + # 5 new samples -> 50x multiplier + assert result["mixing_ratio"] == 50 + + def test_seed_reproducibility(self, base_dataset: Path, tmp_path: Path): + """Same seed should produce same output.""" + pool_ids = [uuid4() for _ in range(3)] + self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids]) + + out1 = tmp_path / "out1" + out2 = tmp_path / "out2" + + r1 = build_mixed_dataset(pool_ids, base_dataset, out1, seed=42) + r2 = build_mixed_dataset(pool_ids, base_dataset, out2, seed=42) + + assert r1["old_images"] == r2["old_images"] + assert r1["new_images"] == r2["new_images"] + assert r1["total_images"] == r2["total_images"] diff --git a/tests/services/test_gating_validator.py b/tests/services/test_gating_validator.py new file mode 100644 index 0000000..6752ab2 --- /dev/null +++ b/tests/services/test_gating_validator.py @@ -0,0 +1,540 @@ +""" +Unit tests for gating validation service. + +Tests the quality gate validation logic for model deployment: +- Gate 1: mAP regression validation +- Gate 2: detection rate validation +- Overall status computation +- Full validation workflow with mocked dependencies +""" + +import pytest +from unittest.mock import MagicMock, Mock, patch +from uuid import UUID, uuid4 + +from backend.web.services.gating_validator import ( + GATE1_PASS_THRESHOLD, + GATE1_REVIEW_THRESHOLD, + GATE2_PASS_THRESHOLD, + classify_gate1, + classify_gate2, + compute_overall_status, + run_gating_validation, +) +from backend.data.admin_models import GatingResult + + +class TestClassifyGate1: + """Test Gate 1 classification (mAP drop thresholds).""" + + def test_pass_below_threshold(self): + """Test mAP drop < 0.01 returns pass.""" + assert classify_gate1(0.009) == "pass" + assert classify_gate1(0.005) == "pass" + assert classify_gate1(0.0) == "pass" + assert classify_gate1(-0.01) == "pass" # negative drop (improvement) + + def test_pass_boundary(self): + """Test mAP drop exactly at pass threshold.""" + # 0.01 should be review (not pass), since condition is < 0.01 + assert classify_gate1(GATE1_PASS_THRESHOLD) == "review" + + def test_review_in_range(self): + """Test mAP drop in review range [0.01, 0.03).""" + assert classify_gate1(0.01) == "review" + assert classify_gate1(0.015) == "review" + assert classify_gate1(0.02) == "review" + assert classify_gate1(0.029) == "review" + + def test_review_boundary(self): + """Test mAP drop exactly at review threshold.""" + # 0.03 should be reject (not review), since condition is < 0.03 + assert classify_gate1(GATE1_REVIEW_THRESHOLD) == "reject" + + def test_reject_above_threshold(self): + """Test mAP drop >= 0.03 returns reject.""" + assert classify_gate1(0.03) == "reject" + assert classify_gate1(0.05) == "reject" + assert classify_gate1(0.10) == "reject" + assert classify_gate1(1.0) == "reject" + + +class TestClassifyGate2: + """Test Gate 2 classification (detection rate thresholds).""" + + def test_pass_above_threshold(self): + """Test detection rate >= 0.80 returns pass.""" + assert classify_gate2(0.80) == "pass" + assert classify_gate2(0.85) == "pass" + assert classify_gate2(0.99) == "pass" + assert classify_gate2(1.0) == "pass" + + def test_pass_boundary(self): + """Test detection rate exactly at pass threshold.""" + assert classify_gate2(GATE2_PASS_THRESHOLD) == "pass" + + def test_review_below_threshold(self): + """Test detection rate < 0.80 returns review.""" + assert classify_gate2(0.79) == "review" + assert classify_gate2(0.75) == "review" + assert classify_gate2(0.50) == "review" + assert classify_gate2(0.0) == "review" + + +class TestComputeOverallStatus: + """Test overall status computation from individual gates.""" + + def test_both_pass(self): + """Test both gates pass -> overall pass.""" + assert compute_overall_status("pass", "pass") == "pass" + + def test_gate1_reject_gate2_pass(self): + """Test any reject -> overall reject.""" + assert compute_overall_status("reject", "pass") == "reject" + + def test_gate1_pass_gate2_reject(self): + """Test any reject -> overall reject.""" + assert compute_overall_status("pass", "reject") == "reject" + + def test_both_reject(self): + """Test both reject -> overall reject.""" + assert compute_overall_status("reject", "reject") == "reject" + + def test_gate1_review_gate2_pass(self): + """Test any review (no reject) -> overall review.""" + assert compute_overall_status("review", "pass") == "review" + + def test_gate1_pass_gate2_review(self): + """Test any review (no reject) -> overall review.""" + assert compute_overall_status("pass", "review") == "review" + + def test_both_review(self): + """Test both review -> overall review.""" + assert compute_overall_status("review", "review") == "review" + + def test_gate1_reject_gate2_review(self): + """Test reject takes precedence over review.""" + assert compute_overall_status("reject", "review") == "reject" + + def test_gate1_review_gate2_reject(self): + """Test reject takes precedence over review.""" + assert compute_overall_status("review", "reject") == "reject" + + +class TestRunGatingValidation: + """Test full gating validation workflow with mocked dependencies.""" + + @pytest.fixture + def mock_model_version_id(self): + """Generate a model version ID for testing.""" + return uuid4() + + @pytest.fixture + def mock_base_model_version_id(self): + """Generate a base model version ID for testing.""" + return uuid4() + + @pytest.fixture + def mock_task_id(self): + """Generate a task ID for testing.""" + return uuid4() + + @pytest.fixture + def mock_base_model(self): + """Create a mock base model with metrics.""" + model = Mock() + model.metrics_mAP = 0.85 + return model + + @pytest.fixture + def mock_new_model(self): + """Create a mock new model with metrics.""" + model = Mock() + model.metrics_mAP = 0.82 + return model + + def test_gate1_pass_gate2_pass( + self, + mock_model_version_id, + mock_base_model_version_id, + mock_task_id, + mock_base_model, + mock_new_model, + ): + """Test validation with both gates passing.""" + # Setup: base mAP=0.85, new mAP=0.84 -> drop=0.01 (review) + # But new model mAP=0.82 -> gate2 pass + mock_base_model.metrics_mAP = 0.85 + mock_new_model.metrics_mAP = 0.82 + + mock_val_metrics = {"mAP50": 0.84} + + with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \ + patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \ + patch("shared.training.YOLOTrainer") as MockTrainer, \ + patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update: + + # Mock repository + mock_repo = MockRepo.return_value + mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model + + # Mock session context + mock_session = MagicMock() + mock_session_ctx.return_value.__enter__.return_value = mock_session + + # Mock YOLO trainer + mock_trainer = MockTrainer.return_value + mock_trainer.validate.return_value = mock_val_metrics + + # Execute + result = run_gating_validation( + model_version_id=mock_model_version_id, + new_model_path="/path/to/model.pt", + base_model_version_id=mock_base_model_version_id, + data_yaml="/path/to/data.yaml", + task_id=mock_task_id, + ) + + # Verify + assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01 + assert result.gate1_original_mAP == 0.85 + assert result.gate1_new_mAP == 0.84 + assert result.gate1_mAP_drop == pytest.approx(0.01, abs=1e-6) + + assert result.gate2_status == "pass" # 0.82 >= 0.80 + assert result.gate2_detection_rate == 0.82 + + assert result.overall_status == "review" # Any review -> review + + # Verify DB operations + mock_session.add.assert_called() + mock_session.commit.assert_called() + mock_update.assert_called_once_with(str(mock_model_version_id), "review") + + def test_gate1_reject_due_to_large_drop( + self, + mock_model_version_id, + mock_base_model_version_id, + mock_task_id, + mock_base_model, + mock_new_model, + ): + """Test Gate 1 reject when mAP drop >= 0.03.""" + mock_base_model.metrics_mAP = 0.85 + mock_new_model.metrics_mAP = 0.82 + + mock_val_metrics = {"mAP50": 0.81} # 0.85 - 0.81 = 0.04 (reject) + + with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \ + patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \ + patch("shared.training.YOLOTrainer") as MockTrainer, \ + patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update: + + mock_repo = MockRepo.return_value + mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model + + mock_session = MagicMock() + mock_session_ctx.return_value.__enter__.return_value = mock_session + + mock_trainer = MockTrainer.return_value + mock_trainer.validate.return_value = mock_val_metrics + + result = run_gating_validation( + model_version_id=mock_model_version_id, + new_model_path="/path/to/model.pt", + base_model_version_id=mock_base_model_version_id, + data_yaml="/path/to/data.yaml", + task_id=mock_task_id, + ) + + assert result.gate1_status == "reject" + assert result.gate1_mAP_drop == pytest.approx(0.04, abs=1e-6) + assert result.overall_status == "reject" # Any reject -> reject + + mock_update.assert_called_once_with(str(mock_model_version_id), "reject") + + def test_gate2_review_due_to_low_detection_rate( + self, + mock_model_version_id, + mock_base_model_version_id, + mock_task_id, + mock_base_model, + mock_new_model, + ): + """Test Gate 2 review when detection rate < 0.80.""" + mock_base_model.metrics_mAP = 0.85 + mock_new_model.metrics_mAP = 0.75 # Below 0.80 threshold + + mock_val_metrics = {"mAP50": 0.845} # Gate 1: 0.85 - 0.845 = 0.005 (pass) + + with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \ + patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \ + patch("shared.training.YOLOTrainer") as MockTrainer, \ + patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update: + + mock_repo = MockRepo.return_value + mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model + + mock_session = MagicMock() + mock_session_ctx.return_value.__enter__.return_value = mock_session + + mock_trainer = MockTrainer.return_value + mock_trainer.validate.return_value = mock_val_metrics + + result = run_gating_validation( + model_version_id=mock_model_version_id, + new_model_path="/path/to/model.pt", + base_model_version_id=mock_base_model_version_id, + data_yaml="/path/to/data.yaml", + task_id=mock_task_id, + ) + + assert result.gate1_status == "pass" + assert result.gate2_status == "review" # 0.75 < 0.80 + assert result.gate2_detection_rate == 0.75 + assert result.overall_status == "review" + + mock_update.assert_called_once_with(str(mock_model_version_id), "review") + + def test_no_base_model_skips_gate1( + self, + mock_model_version_id, + mock_task_id, + mock_new_model, + ): + """Test Gate 1 passes when no base model is provided.""" + mock_new_model.metrics_mAP = 0.85 + + with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \ + patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \ + patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update: + + mock_repo = MockRepo.return_value + mock_repo.get.return_value = mock_new_model + + mock_session = MagicMock() + mock_session_ctx.return_value.__enter__.return_value = mock_session + + result = run_gating_validation( + model_version_id=mock_model_version_id, + new_model_path="/path/to/model.pt", + base_model_version_id=None, + data_yaml="/path/to/data.yaml", + task_id=mock_task_id, + ) + + assert result.gate1_status == "pass" # Skipped + assert result.gate1_original_mAP is None + assert result.gate1_new_mAP is None + assert result.gate1_mAP_drop is None + + assert result.gate2_status == "pass" # 0.85 >= 0.80 + assert result.overall_status == "pass" + + mock_update.assert_called_once_with(str(mock_model_version_id), "pass") + + def test_base_model_without_metrics_skips_gate1( + self, + mock_model_version_id, + mock_base_model_version_id, + mock_task_id, + mock_base_model, + mock_new_model, + ): + """Test Gate 1 passes when base model has no metrics.""" + mock_base_model.metrics_mAP = None + mock_new_model.metrics_mAP = 0.85 + + with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \ + patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \ + patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update: + + mock_repo = MockRepo.return_value + mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model + + mock_session = MagicMock() + mock_session_ctx.return_value.__enter__.return_value = mock_session + + result = run_gating_validation( + model_version_id=mock_model_version_id, + new_model_path="/path/to/model.pt", + base_model_version_id=mock_base_model_version_id, + data_yaml="/path/to/data.yaml", + task_id=mock_task_id, + ) + + assert result.gate1_status == "pass" # Skipped due to no base metrics + assert result.gate2_status == "pass" + assert result.overall_status == "pass" + + def test_validation_failure_marks_gate1_review( + self, + mock_model_version_id, + mock_base_model_version_id, + mock_task_id, + mock_base_model, + mock_new_model, + ): + """Test Gate 1 review when validation raises exception.""" + mock_base_model.metrics_mAP = 0.85 + mock_new_model.metrics_mAP = 0.82 + + with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \ + patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \ + patch("shared.training.YOLOTrainer") as MockTrainer, \ + patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update: + + mock_repo = MockRepo.return_value + mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model + + mock_session = MagicMock() + mock_session_ctx.return_value.__enter__.return_value = mock_session + + # Mock trainer to raise exception + mock_trainer = MockTrainer.return_value + mock_trainer.validate.side_effect = RuntimeError("Validation failed") + + result = run_gating_validation( + model_version_id=mock_model_version_id, + new_model_path="/path/to/model.pt", + base_model_version_id=mock_base_model_version_id, + data_yaml="/path/to/data.yaml", + task_id=mock_task_id, + ) + + assert result.gate1_status == "review" # Exception -> review + assert result.gate2_status == "pass" + assert result.overall_status == "review" + + mock_update.assert_called_once_with(str(mock_model_version_id), "review") + + def test_validation_returns_none_mAP_marks_gate1_review( + self, + mock_model_version_id, + mock_base_model_version_id, + mock_task_id, + mock_base_model, + mock_new_model, + ): + """Test Gate 1 review when validation returns None mAP.""" + mock_base_model.metrics_mAP = 0.85 + mock_new_model.metrics_mAP = 0.82 + + mock_val_metrics = {"mAP50": None} # No mAP returned + + with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \ + patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \ + patch("shared.training.YOLOTrainer") as MockTrainer, \ + patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update: + + mock_repo = MockRepo.return_value + mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model + + mock_session = MagicMock() + mock_session_ctx.return_value.__enter__.return_value = mock_session + + mock_trainer = MockTrainer.return_value + mock_trainer.validate.return_value = mock_val_metrics + + result = run_gating_validation( + model_version_id=mock_model_version_id, + new_model_path="/path/to/model.pt", + base_model_version_id=mock_base_model_version_id, + data_yaml="/path/to/data.yaml", + task_id=mock_task_id, + ) + + assert result.gate1_status == "review" # None mAP -> review + assert result.gate1_new_mAP is None + assert result.gate2_status == "pass" + assert result.overall_status == "review" + + def test_gate2_exception_marks_gate2_review( + self, + mock_model_version_id, + mock_base_model_version_id, + mock_task_id, + mock_base_model, + mock_new_model, + ): + """Test Gate 2 review when accessing new model metrics raises exception.""" + mock_base_model.metrics_mAP = 0.85 + mock_new_model.metrics_mAP = 0.82 + + mock_val_metrics = {"mAP50": 0.84} + + with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \ + patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \ + patch("shared.training.YOLOTrainer") as MockTrainer, \ + patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update: + + mock_repo = MockRepo.return_value + + # Mock to raise exception for new model on second call + def get_side_effect(id): + if str(id) == str(mock_base_model_version_id): + return mock_base_model + elif str(id) == str(mock_model_version_id): + raise RuntimeError("Cannot fetch new model") + return None + + mock_repo.get.side_effect = get_side_effect + + mock_session = MagicMock() + mock_session_ctx.return_value.__enter__.return_value = mock_session + + mock_trainer = MockTrainer.return_value + mock_trainer.validate.return_value = mock_val_metrics + + result = run_gating_validation( + model_version_id=mock_model_version_id, + new_model_path="/path/to/model.pt", + base_model_version_id=mock_base_model_version_id, + data_yaml="/path/to/data.yaml", + task_id=mock_task_id, + ) + + assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01 + assert result.gate2_status == "review" # Exception -> review + assert result.overall_status == "review" + + def test_string_uuids_accepted( + self, + mock_model_version_id, + mock_base_model_version_id, + mock_task_id, + mock_base_model, + mock_new_model, + ): + """Test that string UUIDs are accepted and converted properly.""" + mock_base_model.metrics_mAP = 0.85 + mock_new_model.metrics_mAP = 0.85 + + mock_val_metrics = {"mAP50": 0.85} + + with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \ + patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \ + patch("shared.training.YOLOTrainer") as MockTrainer, \ + patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update: + + mock_repo = MockRepo.return_value + mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model + + mock_session = MagicMock() + mock_session_ctx.return_value.__enter__.return_value = mock_session + + mock_trainer = MockTrainer.return_value + mock_trainer.validate.return_value = mock_val_metrics + + # Pass string UUIDs + result = run_gating_validation( + model_version_id=str(mock_model_version_id), + new_model_path="/path/to/model.pt", + base_model_version_id=str(mock_base_model_version_id), + data_yaml="/path/to/data.yaml", + task_id=str(mock_task_id), + ) + + assert result.model_version_id == mock_model_version_id + assert result.task_id == mock_task_id + assert result.overall_status == "pass" diff --git a/tests/shared/bbox/test_expander.py b/tests/shared/bbox/test_expander.py index c533597..67b383f 100644 --- a/tests/shared/bbox/test_expander.py +++ b/tests/shared/bbox/test_expander.py @@ -1,556 +1,170 @@ """ -Tests for expand_bbox function. +Tests for expand_bbox function with uniform pixel padding. -Tests verify that bbox expansion works correctly with center-point scaling, -directional compensation, max padding clamping, and image boundary handling. +Verifies that bbox expansion adds uniform padding on all sides, +clamps to image boundaries, and returns integer tuples. """ import pytest -from shared.bbox import ( - expand_bbox, - ScaleStrategy, - FIELD_SCALE_STRATEGIES, - DEFAULT_STRATEGY, -) +from shared.bbox import expand_bbox +from shared.bbox.scale_strategy import UNIFORM_PAD -class TestExpandBboxCenterScaling: - """Tests for center-point based scaling.""" +class TestExpandBboxUniformPadding: + """Tests for uniform padding on all sides.""" - def test_center_scaling_expands_symmetrically(self): - """Verify bbox expands symmetrically around center when no extra ratios.""" - # 100x50 bbox at (100, 200) - bbox = (100, 200, 200, 250) - strategy = ScaleStrategy( - scale_x=1.2, # 20% wider - scale_y=1.4, # 40% taller - max_pad_x=1000, # Large to avoid clamping - max_pad_y=1000, - ) + def test_adds_uniform_pad_on_all_sides(self): + """Verify default pad is applied equally on all four sides.""" + bbox = (100, 200, 300, 250) result = expand_bbox( bbox=bbox, image_width=1000, image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, ) - # Original: width=100, height=50 - # New: width=120, height=70 - # Center: (150, 225) - # Expected: x0=150-60=90, x1=150+60=210, y0=225-35=190, y1=225+35=260 - assert result[0] == 90 # x0 - assert result[1] == 190 # y0 - assert result[2] == 210 # x1 - assert result[3] == 260 # y1 - - def test_no_scaling_returns_original(self): - """Verify scale=1.0 with no extras returns original bbox.""" - bbox = (100, 200, 200, 250) - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - max_pad_x=1000, - max_pad_y=1000, + assert result == ( + 100 - UNIFORM_PAD, + 200 - UNIFORM_PAD, + 300 + UNIFORM_PAD, + 250 + UNIFORM_PAD, ) + def test_custom_pad_value(self): + """Verify custom pad overrides default.""" + bbox = (100, 200, 300, 250) + result = expand_bbox( bbox=bbox, image_width=1000, image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, + pad=20, ) - assert result == (100, 200, 200, 250) + assert result == (80, 180, 320, 270) - -class TestExpandBboxDirectionalCompensation: - """Tests for directional compensation (extra ratios).""" - - def test_extra_top_expands_upward(self): - """Verify extra_top_ratio adds expansion toward top.""" - bbox = (100, 200, 200, 250) # width=100, height=50 - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_top_ratio=0.5, # Add 50% of height to top - max_pad_x=1000, - max_pad_y=1000, - ) + def test_zero_pad_returns_original(self): + """Verify pad=0 returns original bbox as integers.""" + bbox = (100, 200, 300, 250) result = expand_bbox( bbox=bbox, image_width=1000, image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, + pad=0, ) - # extra_top = 50 * 0.5 = 25 - assert result[0] == 100 # x0 unchanged - assert result[1] == 175 # y0 = 200 - 25 - assert result[2] == 200 # x1 unchanged - assert result[3] == 250 # y1 unchanged + assert result == (100, 200, 300, 250) - def test_extra_left_expands_leftward(self): - """Verify extra_left_ratio adds expansion toward left.""" - bbox = (100, 200, 200, 250) # width=100 - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_left_ratio=0.8, # Add 80% of width to left - max_pad_x=1000, - max_pad_y=1000, - ) + def test_all_field_types_get_same_padding(self): + """Verify no field-specific expansion -- same result regardless of field.""" + bbox = (100, 200, 300, 250) - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) + result_a = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) + result_b = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) - # extra_left = 100 * 0.8 = 80 - assert result[0] == 20 # x0 = 100 - 80 - assert result[1] == 200 # y0 unchanged - assert result[2] == 200 # x1 unchanged - assert result[3] == 250 # y1 unchanged - - def test_extra_right_expands_rightward(self): - """Verify extra_right_ratio adds expansion toward right.""" - bbox = (100, 200, 200, 250) # width=100 - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_right_ratio=0.3, # Add 30% of width to right - max_pad_x=1000, - max_pad_y=1000, - ) - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) - - # extra_right = 100 * 0.3 = 30 - assert result[0] == 100 # x0 unchanged - assert result[1] == 200 # y0 unchanged - assert result[2] == 230 # x1 = 200 + 30 - assert result[3] == 250 # y1 unchanged - - def test_extra_bottom_expands_downward(self): - """Verify extra_bottom_ratio adds expansion toward bottom.""" - bbox = (100, 200, 200, 250) # height=50 - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_bottom_ratio=0.4, # Add 40% of height to bottom - max_pad_x=1000, - max_pad_y=1000, - ) - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) - - # extra_bottom = 50 * 0.4 = 20 - assert result[0] == 100 # x0 unchanged - assert result[1] == 200 # y0 unchanged - assert result[2] == 200 # x1 unchanged - assert result[3] == 270 # y1 = 250 + 20 - - def test_combined_scaling_and_directional(self): - """Verify scale + directional compensation work together.""" - bbox = (100, 200, 200, 250) # width=100, height=50 - strategy = ScaleStrategy( - scale_x=1.2, # 20% wider -> 120 width - scale_y=1.0, # no height change - extra_left_ratio=0.5, # Add 50% of width to left - max_pad_x=1000, - max_pad_y=1000, - ) - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) - - # Center: x=150 - # After scale: width=120 -> x0=150-60=90, x1=150+60=210 - # After extra_left: x0 = 90 - (100 * 0.5) = 40 - assert result[0] == 40 # x0 - assert result[2] == 210 # x1 - - -class TestExpandBboxMaxPadClamping: - """Tests for max padding clamping.""" - - def test_max_pad_x_limits_horizontal_expansion(self): - """Verify max_pad_x limits expansion on left and right.""" - bbox = (100, 200, 200, 250) # width=100 - strategy = ScaleStrategy( - scale_x=2.0, # Double width (would add 50 each side) - scale_y=1.0, - max_pad_x=30, # Limit to 30 pixels each side - max_pad_y=1000, - ) - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) - - # Scale would make: x0=100, x1=200 -> x0=50, x1=250 (50px each side) - # But max_pad_x=30 limits to: x0=70, x1=230 - assert result[0] == 70 # x0 = 100 - 30 - assert result[2] == 230 # x1 = 200 + 30 - - def test_max_pad_y_limits_vertical_expansion(self): - """Verify max_pad_y limits expansion on top and bottom.""" - bbox = (100, 200, 200, 250) # height=50 - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=3.0, # Triple height (would add 50 each side) - max_pad_x=1000, - max_pad_y=20, # Limit to 20 pixels each side - ) - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) - - # Scale would make: y0=175, y1=275 (50px each side) - # But max_pad_y=20 limits to: y0=180, y1=270 - assert result[1] == 180 # y0 = 200 - 20 - assert result[3] == 270 # y1 = 250 + 20 - - def test_max_pad_preserves_asymmetry(self): - """Verify max_pad clamping preserves asymmetric expansion.""" - bbox = (100, 200, 200, 250) # width=100 - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_left_ratio=1.0, # 100px left expansion - extra_right_ratio=0.0, # No right expansion - max_pad_x=50, # Limit to 50 pixels - max_pad_y=1000, - ) - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) - - # Left would expand 100, clamped to 50 - # Right stays at 0 - assert result[0] == 50 # x0 = 100 - 50 - assert result[2] == 200 # x1 unchanged + assert result_a == result_b class TestExpandBboxImageBoundaryClamping: - """Tests for image boundary clamping.""" + """Tests for clamping to image boundaries.""" - def test_clamps_to_left_boundary(self): - """Verify x0 is clamped to 0.""" - bbox = (10, 200, 110, 250) # Close to left edge - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_left_ratio=0.5, # Would push x0 below 0 - max_pad_x=1000, - max_pad_y=1000, - ) + def test_clamps_x0_to_zero(self): + bbox = (5, 200, 100, 250) - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) - assert result[0] == 0 # Clamped to 0 + assert result[0] == 0 - def test_clamps_to_top_boundary(self): - """Verify y0 is clamped to 0.""" - bbox = (100, 10, 200, 60) # Close to top edge - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_top_ratio=0.5, # Would push y0 below 0 - max_pad_x=1000, - max_pad_y=1000, - ) + def test_clamps_y0_to_zero(self): + bbox = (100, 3, 300, 50) - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) - assert result[1] == 0 # Clamped to 0 + assert result[1] == 0 - def test_clamps_to_right_boundary(self): - """Verify x1 is clamped to image_width.""" - bbox = (900, 200, 990, 250) # Close to right edge - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_right_ratio=0.5, # Would push x1 beyond image_width - max_pad_x=1000, - max_pad_y=1000, - ) + def test_clamps_x1_to_image_width(self): + bbox = (900, 200, 995, 250) - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) - assert result[2] == 1000 # Clamped to image_width + assert result[2] == 1000 - def test_clamps_to_bottom_boundary(self): - """Verify y1 is clamped to image_height.""" - bbox = (100, 940, 200, 990) # Close to bottom edge - strategy = ScaleStrategy( - scale_x=1.0, - scale_y=1.0, - extra_bottom_ratio=0.5, # Would push y1 beyond image_height - max_pad_x=1000, - max_pad_y=1000, - ) + def test_clamps_y1_to_image_height(self): + bbox = (100, 900, 300, 995) - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test_field", - strategies={"test_field": strategy}, - ) + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) - assert result[3] == 1000 # Clamped to image_height + assert result[3] == 1000 + def test_corner_bbox_clamps_multiple_sides(self): + """Bbox near top-left corner clamps both x0 and y0.""" + bbox = (2, 3, 50, 60) -class TestExpandBboxUnknownField: - """Tests for unknown field handling.""" + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) - def test_unknown_field_uses_default_strategy(self): - """Verify unknown field types use DEFAULT_STRATEGY.""" - bbox = (100, 200, 200, 250) - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="unknown_field_xyz", - ) - - # DEFAULT_STRATEGY: scale_x=1.15, scale_y=1.15 - # Original: width=100, height=50 - # New: width=115, height=57.5 - # Center: (150, 225) - # x0 = 150 - 57.5 = 92.5 -> 92 - # x1 = 150 + 57.5 = 207.5 -> 207 - # y0 = 225 - 28.75 = 196.25 -> 196 - # y1 = 225 + 28.75 = 253.75 -> 253 - # But max_pad_x=50 may clamp... - # Left pad = 100 - 92.5 = 7.5 (< 50, ok) - # Right pad = 207.5 - 200 = 7.5 (< 50, ok) - assert result[0] == 92 - assert result[2] == 207 - - -class TestExpandBboxWithRealStrategies: - """Tests using actual FIELD_SCALE_STRATEGIES.""" - - def test_ocr_number_expands_significantly_upward(self): - """Verify ocr_number field gets significant upward expansion.""" - bbox = (100, 200, 200, 230) # Small height=30 - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="ocr_number", - ) - - # extra_top_ratio=0.60 -> 30 * 0.6 = 18 extra top - # y0 should decrease significantly - assert result[1] < 200 - 10 # At least 10px upward expansion - - def test_bankgiro_expands_significantly_leftward(self): - """Verify bankgiro field gets significant leftward expansion.""" - bbox = (200, 200, 300, 230) # width=100 - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="bankgiro", - ) - - # extra_left_ratio=0.80 -> 100 * 0.8 = 80 extra left - # x0 should decrease significantly - assert result[0] < 200 - 30 # At least 30px leftward expansion - - def test_amount_expands_rightward(self): - """Verify amount field gets rightward expansion for currency.""" - bbox = (100, 200, 200, 230) # width=100 - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="amount", - ) - - # extra_right_ratio=0.30 -> 100 * 0.3 = 30 extra right - # x1 should increase - assert result[2] > 200 + 10 # At least 10px rightward expansion + assert result[0] == 0 + assert result[1] == 0 + assert result[2] == 50 + UNIFORM_PAD + assert result[3] == 60 + UNIFORM_PAD class TestExpandBboxReturnType: """Tests for return type and value format.""" def test_returns_tuple_of_four_ints(self): - """Verify return type is tuple of 4 integers.""" - bbox = (100.5, 200.3, 200.7, 250.9) + bbox = (100.5, 200.3, 300.7, 250.9) - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="invoice_number", - ) + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) assert isinstance(result, tuple) assert len(result) == 4 assert all(isinstance(v, int) for v in result) - def test_returns_valid_bbox_format(self): - """Verify returned bbox has x0 < x1 and y0 < y1.""" - bbox = (100, 200, 200, 250) + def test_float_bbox_floors_correctly(self): + """Verify float coordinates are converted to int properly.""" + bbox = (100.7, 200.3, 300.2, 250.8) - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="invoice_number", - ) + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000, pad=0) + + # int() truncates toward zero + assert result == (100, 200, 300, 250) + + def test_returns_valid_bbox_ordering(self): + """Verify x0 < x1 and y0 < y1.""" + bbox = (100, 200, 300, 250) + + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) x0, y0, x1, y1 = result - assert x0 < x1, "x0 should be less than x1" - assert y0 < y1, "y0 should be less than y1" + assert x0 < x1 + assert y0 < y1 -class TestManualLabelMode: - """Tests for manual_mode parameter.""" +class TestExpandBboxEdgeCases: + """Tests for edge cases.""" - def test_manual_mode_uses_minimal_padding(self): - """Verify manual_mode uses MANUAL_LABEL_STRATEGY with minimal padding.""" - bbox = (100, 200, 200, 250) # width=100, height=50 + def test_small_bbox_with_large_pad(self): + """Pad larger than bbox still works correctly.""" + bbox = (100, 200, 105, 203) # 5x3 pixel bbox - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="bankgiro", # Would normally expand left significantly - manual_mode=True, - ) + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000, pad=50) - # MANUAL_LABEL_STRATEGY: scale=1.0, max_pad=10 - # Should only add 10px padding each side (but scale=1.0 means no scaling) - # Actually with scale=1.0, no extra ratios, we get 0 expansion from scaling - # Only max_pad=10 applies as a limit, but there's no expansion to limit - # So result should be same as original - assert result == (100, 200, 200, 250) + assert result == (50, 150, 155, 253) - def test_manual_mode_ignores_field_type(self): - """Verify manual_mode ignores field-specific strategies.""" - bbox = (100, 200, 200, 250) + def test_bbox_at_origin(self): + bbox = (0, 0, 50, 30) - # Different fields should give same result in manual_mode - result_bankgiro = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="bankgiro", - manual_mode=True, - ) + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) - result_ocr = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="ocr_number", - manual_mode=True, - ) + assert result[0] == 0 + assert result[1] == 0 - assert result_bankgiro == result_ocr + def test_bbox_at_image_edge(self): + bbox = (950, 970, 1000, 1000) - def test_manual_mode_vs_auto_mode_different(self): - """Verify manual_mode produces different results than auto mode.""" - bbox = (100, 200, 200, 250) + result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000) - auto_result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="bankgiro", # Has extra_left_ratio=0.80 - manual_mode=False, - ) - - manual_result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="bankgiro", - manual_mode=True, - ) - - # Auto mode should expand more (especially to the left for bankgiro) - assert auto_result[0] < manual_result[0] # Auto x0 is more left - - def test_manual_mode_clamps_to_image_bounds(self): - """Verify manual_mode still respects image boundaries.""" - bbox = (5, 5, 50, 50) # Close to top-left corner - - result = expand_bbox( - bbox=bbox, - image_width=1000, - image_height=1000, - field_type="test", - manual_mode=True, - ) - - # Should clamp to 0 - assert result[0] >= 0 - assert result[1] >= 0 + assert result[2] == 1000 + assert result[3] == 1000 diff --git a/tests/shared/bbox/test_scale_strategy.py b/tests/shared/bbox/test_scale_strategy.py index 08e6d5a..a41a68b 100644 --- a/tests/shared/bbox/test_scale_strategy.py +++ b/tests/shared/bbox/test_scale_strategy.py @@ -1,192 +1,24 @@ """ -Tests for ScaleStrategy configuration. +Tests for simplified scale strategy configuration. -Tests verify that scale strategies are properly defined, immutable, -and cover all required fields. +Verifies that UNIFORM_PAD constant is properly defined +and replaces the old field-specific strategies. """ import pytest -from shared.bbox import ( - ScaleStrategy, - DEFAULT_STRATEGY, - MANUAL_LABEL_STRATEGY, - FIELD_SCALE_STRATEGIES, -) -from shared.fields import CLASS_NAMES +from shared.bbox.scale_strategy import UNIFORM_PAD -class TestScaleStrategyDataclass: - """Tests for ScaleStrategy dataclass behavior.""" +class TestUniformPad: + """Tests for UNIFORM_PAD constant.""" - def test_default_strategy_values(self): - """Verify default strategy has expected default values.""" - strategy = ScaleStrategy() - assert strategy.scale_x == 1.15 - assert strategy.scale_y == 1.15 - assert strategy.extra_top_ratio == 0.0 - assert strategy.extra_bottom_ratio == 0.0 - assert strategy.extra_left_ratio == 0.0 - assert strategy.extra_right_ratio == 0.0 - assert strategy.max_pad_x == 50 - assert strategy.max_pad_y == 50 + def test_uniform_pad_is_integer(self): + assert isinstance(UNIFORM_PAD, int) - def test_scale_strategy_immutability(self): - """Verify ScaleStrategy is frozen (immutable).""" - strategy = ScaleStrategy() - with pytest.raises(AttributeError): - strategy.scale_x = 2.0 # type: ignore + def test_uniform_pad_value_is_15(self): + """15px at 150 DPI provides ~2.5mm real-world padding.""" + assert UNIFORM_PAD == 15 - def test_custom_strategy_values(self): - """Verify custom values are properly set.""" - strategy = ScaleStrategy( - scale_x=1.5, - scale_y=1.8, - extra_top_ratio=0.6, - extra_left_ratio=0.8, - max_pad_x=100, - max_pad_y=150, - ) - assert strategy.scale_x == 1.5 - assert strategy.scale_y == 1.8 - assert strategy.extra_top_ratio == 0.6 - assert strategy.extra_left_ratio == 0.8 - assert strategy.max_pad_x == 100 - assert strategy.max_pad_y == 150 - - -class TestDefaultStrategy: - """Tests for DEFAULT_STRATEGY constant.""" - - def test_default_strategy_is_scale_strategy(self): - """Verify DEFAULT_STRATEGY is a ScaleStrategy instance.""" - assert isinstance(DEFAULT_STRATEGY, ScaleStrategy) - - def test_default_strategy_matches_default_values(self): - """Verify DEFAULT_STRATEGY has same values as ScaleStrategy().""" - expected = ScaleStrategy() - assert DEFAULT_STRATEGY == expected - - -class TestManualLabelStrategy: - """Tests for MANUAL_LABEL_STRATEGY constant.""" - - def test_manual_label_strategy_is_scale_strategy(self): - """Verify MANUAL_LABEL_STRATEGY is a ScaleStrategy instance.""" - assert isinstance(MANUAL_LABEL_STRATEGY, ScaleStrategy) - - def test_manual_label_strategy_has_no_scaling(self): - """Verify MANUAL_LABEL_STRATEGY has scale factors of 1.0.""" - assert MANUAL_LABEL_STRATEGY.scale_x == 1.0 - assert MANUAL_LABEL_STRATEGY.scale_y == 1.0 - - def test_manual_label_strategy_has_no_directional_expansion(self): - """Verify MANUAL_LABEL_STRATEGY has no directional expansion.""" - assert MANUAL_LABEL_STRATEGY.extra_top_ratio == 0.0 - assert MANUAL_LABEL_STRATEGY.extra_bottom_ratio == 0.0 - assert MANUAL_LABEL_STRATEGY.extra_left_ratio == 0.0 - assert MANUAL_LABEL_STRATEGY.extra_right_ratio == 0.0 - - def test_manual_label_strategy_has_small_max_pad(self): - """Verify MANUAL_LABEL_STRATEGY has small max padding.""" - assert MANUAL_LABEL_STRATEGY.max_pad_x <= 15 - assert MANUAL_LABEL_STRATEGY.max_pad_y <= 15 - - -class TestFieldScaleStrategies: - """Tests for FIELD_SCALE_STRATEGIES dictionary.""" - - def test_all_class_names_have_strategies(self): - """Verify all field class names have defined strategies.""" - for class_name in CLASS_NAMES: - assert class_name in FIELD_SCALE_STRATEGIES, ( - f"Missing strategy for field: {class_name}" - ) - - def test_strategies_are_scale_strategy_instances(self): - """Verify all strategies are ScaleStrategy instances.""" - for field_name, strategy in FIELD_SCALE_STRATEGIES.items(): - assert isinstance(strategy, ScaleStrategy), ( - f"Strategy for {field_name} is not a ScaleStrategy" - ) - - def test_scale_values_are_greater_than_one(self): - """Verify all scale values are >= 1.0 (expansion, not contraction).""" - for field_name, strategy in FIELD_SCALE_STRATEGIES.items(): - assert strategy.scale_x >= 1.0, ( - f"{field_name} scale_x should be >= 1.0" - ) - assert strategy.scale_y >= 1.0, ( - f"{field_name} scale_y should be >= 1.0" - ) - - def test_extra_ratios_are_non_negative(self): - """Verify all extra ratios are >= 0.""" - for field_name, strategy in FIELD_SCALE_STRATEGIES.items(): - assert strategy.extra_top_ratio >= 0, ( - f"{field_name} extra_top_ratio should be >= 0" - ) - assert strategy.extra_bottom_ratio >= 0, ( - f"{field_name} extra_bottom_ratio should be >= 0" - ) - assert strategy.extra_left_ratio >= 0, ( - f"{field_name} extra_left_ratio should be >= 0" - ) - assert strategy.extra_right_ratio >= 0, ( - f"{field_name} extra_right_ratio should be >= 0" - ) - - def test_max_pad_values_are_positive(self): - """Verify all max_pad values are > 0.""" - for field_name, strategy in FIELD_SCALE_STRATEGIES.items(): - assert strategy.max_pad_x > 0, ( - f"{field_name} max_pad_x should be > 0" - ) - assert strategy.max_pad_y > 0, ( - f"{field_name} max_pad_y should be > 0" - ) - - -class TestSpecificFieldStrategies: - """Tests for specific field strategy configurations.""" - - def test_ocr_number_expands_upward(self): - """Verify ocr_number strategy expands upward to capture label.""" - strategy = FIELD_SCALE_STRATEGIES["ocr_number"] - assert strategy.extra_top_ratio > 0.0 - assert strategy.extra_top_ratio >= 0.5 # Significant upward expansion - - def test_bankgiro_expands_leftward(self): - """Verify bankgiro strategy expands leftward to capture prefix.""" - strategy = FIELD_SCALE_STRATEGIES["bankgiro"] - assert strategy.extra_left_ratio > 0.0 - assert strategy.extra_left_ratio >= 0.5 # Significant leftward expansion - - def test_plusgiro_expands_leftward(self): - """Verify plusgiro strategy expands leftward to capture prefix.""" - strategy = FIELD_SCALE_STRATEGIES["plusgiro"] - assert strategy.extra_left_ratio > 0.0 - assert strategy.extra_left_ratio >= 0.5 - - def test_amount_expands_rightward(self): - """Verify amount strategy expands rightward for currency symbol.""" - strategy = FIELD_SCALE_STRATEGIES["amount"] - assert strategy.extra_right_ratio > 0.0 - - def test_invoice_date_expands_upward(self): - """Verify invoice_date strategy expands upward to capture label.""" - strategy = FIELD_SCALE_STRATEGIES["invoice_date"] - assert strategy.extra_top_ratio > 0.0 - - def test_invoice_due_date_expands_upward_and_leftward(self): - """Verify invoice_due_date strategy expands both up and left.""" - strategy = FIELD_SCALE_STRATEGIES["invoice_due_date"] - assert strategy.extra_top_ratio > 0.0 - assert strategy.extra_left_ratio > 0.0 - - def test_payment_line_has_minimal_expansion(self): - """Verify payment_line has conservative expansion (machine code).""" - strategy = FIELD_SCALE_STRATEGIES["payment_line"] - # Payment line is machine-readable, needs minimal expansion - assert strategy.scale_x <= 1.2 - assert strategy.scale_y <= 1.3 + def test_uniform_pad_is_positive(self): + assert UNIFORM_PAD > 0 diff --git a/tests/training/yolo/test_annotation_generator.py b/tests/training/yolo/test_annotation_generator.py index 69f669e..5abf8e2 100644 --- a/tests/training/yolo/test_annotation_generator.py +++ b/tests/training/yolo/test_annotation_generator.py @@ -171,8 +171,8 @@ class TestGenerateFromMatches: assert len(annotations) == 0 - def test_applies_field_specific_expansion(self): - """Verify different fields get different expansion.""" + def test_applies_uniform_expansion(self): + """Verify all fields get the same uniform expansion.""" gen = AnnotationGenerator(min_confidence=0.5) # Same bbox, different fields @@ -199,10 +199,11 @@ class TestGenerateFromMatches: dpi=150 )[0] - # Bankgiro has extra_left_ratio=0.80, invoice_number has extra_top_ratio=0.40 - # They should have different widths due to different expansion - # Bankgiro expands more to the left - assert ann_bankgiro.width != ann_invoice.width or ann_bankgiro.x_center != ann_invoice.x_center + # Uniform expansion: same bbox -> same dimensions (only class_id differs) + assert ann_bankgiro.width == ann_invoice.width + assert ann_bankgiro.height == ann_invoice.height + assert ann_bankgiro.x_center == ann_invoice.x_center + assert ann_bankgiro.y_center == ann_invoice.y_center def test_enforces_min_bbox_height(self): """Verify minimum bbox height is enforced.""" diff --git a/tests/training/yolo/test_db_dataset.py b/tests/training/yolo/test_db_dataset.py index 69b21c5..f470c71 100644 --- a/tests/training/yolo/test_db_dataset.py +++ b/tests/training/yolo/test_db_dataset.py @@ -7,27 +7,23 @@ from pathlib import Path from training.yolo.db_dataset import DBYOLODataset from training.yolo.annotation_generator import YOLOAnnotation -from shared.bbox import FIELD_SCALE_STRATEGIES, DEFAULT_STRATEGY +from shared.bbox import UNIFORM_PAD from shared.fields import CLASS_NAMES class TestConvertLabelsWithExpandBbox: - """Tests for _convert_labels using expand_bbox instead of fixed padding.""" + """Tests for _convert_labels using uniform expand_bbox.""" def test_convert_labels_uses_expand_bbox(self): - """Verify _convert_labels calls expand_bbox for field-specific expansion.""" - # Create a mock dataset without loading from DB + """Verify _convert_labels calls expand_bbox with uniform padding.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 - # Create annotation for bankgiro (has extra_left_ratio) - # bbox in PDF points: x0=100, y0=200, x1=200, y1=250 - # center: (150, 225), width: 100, height: 50 annotations = [ YOLOAnnotation( class_id=4, # bankgiro - x_center=150, # in PDF points + x_center=150, y_center=225, width=100, height=50, @@ -35,48 +31,26 @@ class TestConvertLabelsWithExpandBbox: ) ] - # Image size in pixels (at 300 DPI) - img_width = 2480 # A4 width at 300 DPI - img_height = 3508 # A4 height at 300 DPI + img_width = 2480 + img_height = 3508 - # Convert labels labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False) - # Should have one label assert labels.shape == (1, 5) - - # Check class_id assert labels[0, 0] == 4 - # The bbox should be expanded using bankgiro strategy (extra_left_ratio=0.80) - # Original bbox at 300 DPI: - # x0 = 100 * (300/72) = 416.67 - # y0 = 200 * (300/72) = 833.33 - # x1 = 200 * (300/72) = 833.33 - # y1 = 250 * (300/72) = 1041.67 - # width_px = 416.67, height_px = 208.33 - - # After expand_bbox with bankgiro strategy: - # scale_x=1.45, scale_y=1.35, extra_left_ratio=0.80 - # The x_center should shift left due to extra_left_ratio x_center = labels[0, 1] y_center = labels[0, 2] width = labels[0, 3] height = labels[0, 4] - # Verify normalized values are in valid range assert 0 <= x_center <= 1 assert 0 <= y_center <= 1 assert 0 < width <= 1 assert 0 < height <= 1 - # Width should be larger than original due to scaling and extra_left - # Original normalized width: 416.67 / 2480 = 0.168 - # After bankgiro expansion it should be wider - assert width > 0.168 - - def test_convert_labels_different_field_types(self): - """Verify different field types use their specific strategies.""" + def test_convert_labels_all_fields_get_same_expansion(self): + """Verify all field types get the same uniform expansion.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 @@ -84,7 +58,6 @@ class TestConvertLabelsWithExpandBbox: img_width = 2480 img_height = 3508 - # Same bbox for different field types base_annotation = { 'x_center': 150, 'y_center': 225, @@ -93,30 +66,20 @@ class TestConvertLabelsWithExpandBbox: 'confidence': 0.9 } - # OCR number (class_id=3) - has extra_top_ratio=0.60 + # All field types should get the same uniform expansion ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)] ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False) - # Bankgiro (class_id=4) - has extra_left_ratio=0.80 bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)] bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False) - # Amount (class_id=6) - has extra_right_ratio=0.30 - amount_annotations = [YOLOAnnotation(class_id=6, **base_annotation)] - amount_labels = dataset._convert_labels(amount_annotations, img_width, img_height, is_scanned=False) + # x_center and y_center should be the same (uniform padding is symmetric) + assert abs(ocr_labels[0, 1] - bankgiro_labels[0, 1]) < 0.001 + assert abs(ocr_labels[0, 2] - bankgiro_labels[0, 2]) < 0.001 - # Each field type should have different expansion - # OCR should expand more vertically (extra_top) - # Bankgiro should expand more to the left - # Amount should expand more to the right - - # OCR: extra_top shifts y_center up - # Bankgiro: extra_left shifts x_center left - # So bankgiro x_center < OCR x_center - assert bankgiro_labels[0, 1] < ocr_labels[0, 1] - - # OCR has higher scale_y (1.80) than amount (1.35) - assert ocr_labels[0, 4] > amount_labels[0, 4] + # width and height should also be the same + assert abs(ocr_labels[0, 3] - bankgiro_labels[0, 3]) < 0.001 + assert abs(ocr_labels[0, 4] - bankgiro_labels[0, 4]) < 0.001 def test_convert_labels_clamps_to_image_bounds(self): """Verify labels are clamped to image boundaries.""" @@ -124,11 +87,10 @@ class TestConvertLabelsWithExpandBbox: dataset.dpi = 300 dataset.min_bbox_height_px = 30 - # Annotation near edge of image (in PDF points) annotations = [ YOLOAnnotation( - class_id=4, # bankgiro - will expand left - x_center=30, # Very close to left edge + class_id=4, + x_center=30, y_center=50, width=40, height=30, @@ -141,11 +103,10 @@ class TestConvertLabelsWithExpandBbox: labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False) - # All values should be in valid range - assert 0 <= labels[0, 1] <= 1 # x_center - assert 0 <= labels[0, 2] <= 1 # y_center - assert 0 < labels[0, 3] <= 1 # width - assert 0 < labels[0, 4] <= 1 # height + assert 0 <= labels[0, 1] <= 1 + assert 0 <= labels[0, 2] <= 1 + assert 0 < labels[0, 3] <= 1 + assert 0 < labels[0, 4] <= 1 def test_convert_labels_empty_annotations(self): """Verify empty annotations return empty array.""" @@ -162,23 +123,21 @@ class TestConvertLabelsWithExpandBbox: """Verify minimum height is enforced after expansion.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 - dataset.min_bbox_height_px = 50 # Higher minimum + dataset.min_bbox_height_px = 50 - # Very small annotation annotations = [ YOLOAnnotation( - class_id=9, # payment_line - minimal expansion + class_id=9, x_center=100, y_center=100, width=200, - height=5, # Very small height + height=5, confidence=0.9 ) ] labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False) - # Height should be at least min_bbox_height_px / img_height min_normalized_height = 50 / 3508 assert labels[0, 4] >= min_normalized_height @@ -190,25 +149,23 @@ class TestCreateAnnotationWithClassName: """Verify _create_annotation stores class_name for later use.""" dataset = object.__new__(DBYOLODataset) - # Create annotation for invoice_number annotation = dataset._create_annotation( field_name="InvoiceNumber", bbox=[100, 200, 200, 250], score=0.9 ) - assert annotation.class_id == 0 # invoice_number class_id + assert annotation.class_id == 0 class TestLoadLabelsFromDbWithClassName: """Tests for _load_labels_from_db preserving field_name for expansion.""" def test_load_labels_maps_field_names_correctly(self): - """Verify field names are mapped correctly for expand_bbox.""" + """Verify field names are mapped correctly.""" dataset = object.__new__(DBYOLODataset) dataset.min_confidence = 0.7 - # Mock database mock_db = MagicMock() mock_db.get_documents_batch.return_value = { 'doc1': { @@ -240,12 +197,7 @@ class TestLoadLabelsFromDbWithClassName: assert 'doc1' in result page_labels, is_scanned, csv_split = result['doc1'] - # Should have 2 annotations on page 0 assert 0 in page_labels assert len(page_labels[0]) == 2 - - # First annotation: Bankgiro (class_id=4) assert page_labels[0][0].class_id == 4 - - # Second annotation: Plusgiro mapped from supplier_accounts(Plusgiro) (class_id=5) assert page_labels[0][1].class_id == 5 diff --git a/tests/web/test_admin_training.py b/tests/web/test_admin_training.py index f27dbd1..4c485ba 100644 --- a/tests/web/test_admin_training.py +++ b/tests/web/test_admin_training.py @@ -53,7 +53,7 @@ class TestTrainingConfigSchema: """Test default training configuration.""" config = TrainingConfig() - assert config.model_name == "yolo11n.pt" + assert config.model_name == "yolo26s.pt" assert config.epochs == 100 assert config.batch_size == 16 assert config.image_size == 640 @@ -63,7 +63,7 @@ class TestTrainingConfigSchema: def test_custom_config(self): """Test custom training configuration.""" config = TrainingConfig( - model_name="yolo11s.pt", + model_name="yolo26s.pt", epochs=50, batch_size=8, image_size=416, @@ -71,7 +71,7 @@ class TestTrainingConfigSchema: device="cpu", ) - assert config.model_name == "yolo11s.pt" + assert config.model_name == "yolo26s.pt" assert config.epochs == 50 assert config.batch_size == 8 @@ -136,7 +136,7 @@ class TestTrainingTaskModel: def test_task_with_config(self): """Test task with configuration.""" config = { - "model_name": "yolo11n.pt", + "model_name": "yolo26s.pt", "epochs": 100, } task = TrainingTask( diff --git a/tests/web/test_data_mixer.py b/tests/web/test_data_mixer.py new file mode 100644 index 0000000..2529169 --- /dev/null +++ b/tests/web/test_data_mixer.py @@ -0,0 +1,784 @@ +""" +Comprehensive unit tests for Data Mixing Service. + +Tests the data mixing service functions for YOLO fine-tuning: +- Mixing ratio calculation based on sample counts +- Dataset building with old/new sample mixing +- Image collection and path conversion +- Pool document matching +""" + +from pathlib import Path +from uuid import UUID, uuid4 + +import pytest + +from backend.web.services.data_mixer import ( + get_mixing_ratio, + build_mixed_dataset, + _collect_images, + _image_to_label_path, + _find_pool_images, + MIXING_RATIOS, + DEFAULT_MULTIPLIER, + MAX_OLD_SAMPLES, + MIN_POOL_SIZE, +) + + +class TestGetMixingRatio: + """Tests for get_mixing_ratio function.""" + + def test_mixing_ratio_at_first_threshold(self): + """Test mixing ratio at first threshold boundary (10 samples).""" + assert get_mixing_ratio(1) == 50 + assert get_mixing_ratio(5) == 50 + assert get_mixing_ratio(10) == 50 + + def test_mixing_ratio_at_second_threshold(self): + """Test mixing ratio at second threshold boundary (50 samples).""" + assert get_mixing_ratio(11) == 20 + assert get_mixing_ratio(30) == 20 + assert get_mixing_ratio(50) == 20 + + def test_mixing_ratio_at_third_threshold(self): + """Test mixing ratio at third threshold boundary (200 samples).""" + assert get_mixing_ratio(51) == 10 + assert get_mixing_ratio(100) == 10 + assert get_mixing_ratio(200) == 10 + + def test_mixing_ratio_at_fourth_threshold(self): + """Test mixing ratio at fourth threshold boundary (500 samples).""" + assert get_mixing_ratio(201) == 5 + assert get_mixing_ratio(350) == 5 + assert get_mixing_ratio(500) == 5 + + def test_mixing_ratio_above_all_thresholds(self): + """Test mixing ratio for samples above all thresholds.""" + assert get_mixing_ratio(501) == DEFAULT_MULTIPLIER + assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER + assert get_mixing_ratio(10000) == DEFAULT_MULTIPLIER + + def test_mixing_ratio_boundary_values(self): + """Test exact threshold boundaries match expected ratios.""" + # Verify threshold boundaries from MIXING_RATIOS + for threshold, expected_multiplier in MIXING_RATIOS: + assert get_mixing_ratio(threshold) == expected_multiplier + # One above threshold should give next ratio + if threshold < MIXING_RATIOS[-1][0]: + next_idx = MIXING_RATIOS.index((threshold, expected_multiplier)) + 1 + next_multiplier = MIXING_RATIOS[next_idx][1] + assert get_mixing_ratio(threshold + 1) == next_multiplier + + +class TestCollectImages: + """Tests for _collect_images function.""" + + def test_collect_images_empty_directory(self, tmp_path): + """Test collecting images from empty directory.""" + images_dir = tmp_path / "images" + images_dir.mkdir() + + result = _collect_images(images_dir) + + assert result == [] + + def test_collect_images_nonexistent_directory(self, tmp_path): + """Test collecting images from non-existent directory.""" + images_dir = tmp_path / "nonexistent" + + result = _collect_images(images_dir) + + assert result == [] + + def test_collect_png_images(self, tmp_path): + """Test collecting PNG images.""" + images_dir = tmp_path / "images" + images_dir.mkdir() + + # Create PNG files + (images_dir / "img1.png").touch() + (images_dir / "img2.png").touch() + (images_dir / "img3.png").touch() + + result = _collect_images(images_dir) + + assert len(result) == 3 + assert all(img.suffix == ".png" for img in result) + # Verify sorted order + assert result == sorted(result) + + def test_collect_jpg_images(self, tmp_path): + """Test collecting JPG images.""" + images_dir = tmp_path / "images" + images_dir.mkdir() + + # Create JPG files + (images_dir / "img1.jpg").touch() + (images_dir / "img2.jpg").touch() + + result = _collect_images(images_dir) + + assert len(result) == 2 + assert all(img.suffix == ".jpg" for img in result) + + def test_collect_mixed_image_types(self, tmp_path): + """Test collecting both PNG and JPG images.""" + images_dir = tmp_path / "images" + images_dir.mkdir() + + # Create mixed files + (images_dir / "img1.png").touch() + (images_dir / "img2.jpg").touch() + (images_dir / "img3.png").touch() + (images_dir / "img4.jpg").touch() + + result = _collect_images(images_dir) + + assert len(result) == 4 + # PNG files should come first (sorted separately) + png_files = [r for r in result if r.suffix == ".png"] + jpg_files = [r for r in result if r.suffix == ".jpg"] + assert len(png_files) == 2 + assert len(jpg_files) == 2 + + def test_collect_images_ignores_other_files(self, tmp_path): + """Test that non-image files are ignored.""" + images_dir = tmp_path / "images" + images_dir.mkdir() + + # Create various files + (images_dir / "img1.png").touch() + (images_dir / "img2.jpg").touch() + (images_dir / "doc.txt").touch() + (images_dir / "data.json").touch() + (images_dir / "notes.md").touch() + + result = _collect_images(images_dir) + + assert len(result) == 2 + assert all(img.suffix in [".png", ".jpg"] for img in result) + + +class TestImageToLabelPath: + """Tests for _image_to_label_path function.""" + + def test_image_to_label_path_train(self, tmp_path): + """Test converting train image path to label path.""" + base = tmp_path / "dataset" + image_path = base / "images" / "train" / "doc123_page1.png" + + label_path = _image_to_label_path(image_path) + + expected = base / "labels" / "train" / "doc123_page1.txt" + assert label_path == expected + + def test_image_to_label_path_val(self, tmp_path): + """Test converting val image path to label path.""" + base = tmp_path / "dataset" + image_path = base / "images" / "val" / "doc456_page2.jpg" + + label_path = _image_to_label_path(image_path) + + expected = base / "labels" / "val" / "doc456_page2.txt" + assert label_path == expected + + def test_image_to_label_path_test(self, tmp_path): + """Test converting test image path to label path.""" + base = tmp_path / "dataset" + image_path = base / "images" / "test" / "doc789_page3.png" + + label_path = _image_to_label_path(image_path) + + expected = base / "labels" / "test" / "doc789_page3.txt" + assert label_path == expected + + def test_image_to_label_path_preserves_filename(self, tmp_path): + """Test that filename (without extension) is preserved.""" + base = tmp_path / "dataset" + image_path = base / "images" / "train" / "complex_filename_123_page5.png" + + label_path = _image_to_label_path(image_path) + + assert label_path.stem == "complex_filename_123_page5" + assert label_path.suffix == ".txt" + + def test_image_to_label_path_jpg_to_txt(self, tmp_path): + """Test that JPG extension is converted to TXT.""" + base = tmp_path / "dataset" + image_path = base / "images" / "train" / "image.jpg" + + label_path = _image_to_label_path(image_path) + + assert label_path.suffix == ".txt" + + +class TestFindPoolImages: + """Tests for _find_pool_images function.""" + + def test_find_pool_images_in_train(self, tmp_path): + """Test finding pool images in train split.""" + base = tmp_path / "dataset" + train_dir = base / "images" / "train" + train_dir.mkdir(parents=True) + + doc_id = str(uuid4()) + pool_doc_ids = {doc_id} + + # Create images + (train_dir / f"{doc_id}_page1.png").touch() + (train_dir / f"{doc_id}_page2.png").touch() + (train_dir / "other_doc_page1.png").touch() + + result = _find_pool_images(base, pool_doc_ids) + + assert len(result) == 2 + assert all(doc_id in str(img) for img in result) + + def test_find_pool_images_in_val(self, tmp_path): + """Test finding pool images in val split.""" + base = tmp_path / "dataset" + val_dir = base / "images" / "val" + val_dir.mkdir(parents=True) + + doc_id = str(uuid4()) + pool_doc_ids = {doc_id} + + # Create images + (val_dir / f"{doc_id}_page1.png").touch() + + result = _find_pool_images(base, pool_doc_ids) + + assert len(result) == 1 + assert doc_id in str(result[0]) + + def test_find_pool_images_across_splits(self, tmp_path): + """Test finding pool images across train, val, and test splits.""" + base = tmp_path / "dataset" + + doc_id1 = str(uuid4()) + doc_id2 = str(uuid4()) + pool_doc_ids = {doc_id1, doc_id2} + + # Create images in different splits + train_dir = base / "images" / "train" + val_dir = base / "images" / "val" + test_dir = base / "images" / "test" + + train_dir.mkdir(parents=True) + val_dir.mkdir(parents=True) + test_dir.mkdir(parents=True) + + (train_dir / f"{doc_id1}_page1.png").touch() + (val_dir / f"{doc_id1}_page2.png").touch() + (test_dir / f"{doc_id2}_page1.png").touch() + (train_dir / "other_doc_page1.png").touch() + + result = _find_pool_images(base, pool_doc_ids) + + assert len(result) == 3 + doc1_images = [img for img in result if doc_id1 in str(img)] + doc2_images = [img for img in result if doc_id2 in str(img)] + assert len(doc1_images) == 2 + assert len(doc2_images) == 1 + + def test_find_pool_images_empty_pool(self, tmp_path): + """Test finding images with empty pool.""" + base = tmp_path / "dataset" + train_dir = base / "images" / "train" + train_dir.mkdir(parents=True) + + (train_dir / "doc123_page1.png").touch() + + result = _find_pool_images(base, set()) + + assert len(result) == 0 + + def test_find_pool_images_no_matches(self, tmp_path): + """Test finding images when no documents match pool.""" + base = tmp_path / "dataset" + train_dir = base / "images" / "train" + train_dir.mkdir(parents=True) + + pool_doc_ids = {str(uuid4())} + + (train_dir / "other_doc_page1.png").touch() + (train_dir / "another_doc_page1.png").touch() + + result = _find_pool_images(base, pool_doc_ids) + + assert len(result) == 0 + + def test_find_pool_images_multiple_pages(self, tmp_path): + """Test finding multiple pages for same document.""" + base = tmp_path / "dataset" + train_dir = base / "images" / "train" + train_dir.mkdir(parents=True) + + doc_id = str(uuid4()) + pool_doc_ids = {doc_id} + + # Create multiple pages + for i in range(1, 6): + (train_dir / f"{doc_id}_page{i}.png").touch() + + result = _find_pool_images(base, pool_doc_ids) + + assert len(result) == 5 + + def test_find_pool_images_ignores_non_files(self, tmp_path): + """Test that directories are ignored.""" + base = tmp_path / "dataset" + train_dir = base / "images" / "train" + train_dir.mkdir(parents=True) + + doc_id = str(uuid4()) + pool_doc_ids = {doc_id} + + (train_dir / f"{doc_id}_page1.png").touch() + (train_dir / "subdir").mkdir() + + result = _find_pool_images(base, pool_doc_ids) + + assert len(result) == 1 + + def test_find_pool_images_nonexistent_splits(self, tmp_path): + """Test handling non-existent split directories.""" + base = tmp_path / "dataset" + # Don't create any directories + + pool_doc_ids = {str(uuid4())} + + result = _find_pool_images(base, pool_doc_ids) + + assert len(result) == 0 + + +class TestBuildMixedDataset: + """Tests for build_mixed_dataset function.""" + + @pytest.fixture + def setup_base_dataset(self, tmp_path): + """Create a base dataset with old training data.""" + base = tmp_path / "base_dataset" + + # Create directory structure + for split in ("train", "val"): + (base / "images" / split).mkdir(parents=True) + (base / "labels" / split).mkdir(parents=True) + + # Create old training images and labels + for i in range(1, 11): + img_path = base / "images" / "train" / f"old_doc_{i}_page1.png" + label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt" + img_path.write_text(f"image {i}") + label_path.write_text(f"0 0.5 0.5 0.1 0.1") + + for i in range(1, 6): + img_path = base / "images" / "val" / f"old_doc_val_{i}_page1.png" + label_path = base / "labels" / "val" / f"old_doc_val_{i}_page1.txt" + img_path.write_text(f"val image {i}") + label_path.write_text(f"0 0.5 0.5 0.1 0.1") + + return base + + @pytest.fixture + def setup_pool_documents(self, tmp_path, setup_base_dataset): + """Create pool documents in base dataset.""" + base = setup_base_dataset + pool_ids = [uuid4() for _ in range(5)] + + # Add pool documents to train split + for doc_id in pool_ids: + img_path = base / "images" / "train" / f"{doc_id}_page1.png" + label_path = base / "labels" / "train" / f"{doc_id}_page1.txt" + img_path.write_text(f"pool image {doc_id}") + label_path.write_text(f"1 0.5 0.5 0.2 0.2") + + return base, pool_ids + + def test_build_mixed_dataset_basic(self, tmp_path, setup_pool_documents): + """Test basic mixed dataset building.""" + base, pool_ids = setup_pool_documents + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + # Verify result structure + assert "data_yaml" in result + assert "total_images" in result + assert "old_images" in result + assert "new_images" in result + assert "mixing_ratio" in result + + # Verify counts - new images should be > 0 (at least some were copied) + # Note: new images are split 80/20 and copied without overwriting + assert result["new_images"] > 0 + assert result["old_images"] > 0 + assert result["total_images"] == result["old_images"] + result["new_images"] + + # Verify output structure + assert output_dir.exists() + assert (output_dir / "images" / "train").exists() + assert (output_dir / "images" / "val").exists() + assert (output_dir / "labels" / "train").exists() + assert (output_dir / "labels" / "val").exists() + + # Verify data.yaml exists + yaml_path = Path(result["data_yaml"]) + assert yaml_path.exists() + yaml_content = yaml_path.read_text() + assert "train: images/train" in yaml_content + assert "val: images/val" in yaml_content + assert "nc:" in yaml_content + assert "names:" in yaml_content + + def test_build_mixed_dataset_respects_mixing_ratio(self, tmp_path, setup_pool_documents): + """Test that mixing ratio is correctly applied.""" + base, pool_ids = setup_pool_documents + output_dir = tmp_path / "mixed_output" + + # With 5 pool documents, get_mixing_ratio(5) returns 50 + # (because 5 <= 10, first threshold) + # So target old_samples = 5 * 50 = 250 + # But limited by available data: 10 old train + 5 old val + 5 pool = 20 total + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + # Pool images are in the base dataset, so they can be sampled as "old" + # Total available: 20 images (15 pure old + 5 pool images) + assert result["old_images"] <= 20 # Can't exceed available in base dataset + assert result["old_images"] > 0 # Should have some old data + assert result["mixing_ratio"] == 50 # Correct ratio for 5 samples + + def test_build_mixed_dataset_max_old_samples_limit(self, tmp_path): + """Test that MAX_OLD_SAMPLES limit is applied.""" + base = tmp_path / "base_dataset" + + # Create directory structure + for split in ("train", "val"): + (base / "images" / split).mkdir(parents=True) + (base / "labels" / split).mkdir(parents=True) + + # Create MORE than MAX_OLD_SAMPLES old images + for i in range(MAX_OLD_SAMPLES + 500): + img_path = base / "images" / "train" / f"old_doc_{i}_page1.png" + label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt" + img_path.write_text(f"image {i}") + label_path.write_text(f"0 0.5 0.5 0.1 0.1") + + # Create pool documents (100 samples, ratio=10, so target=1000) + # But should be capped at MAX_OLD_SAMPLES (3000) + pool_ids = [uuid4() for _ in range(100)] + for doc_id in pool_ids: + img_path = base / "images" / "train" / f"{doc_id}_page1.png" + label_path = base / "labels" / "train" / f"{doc_id}_page1.txt" + img_path.write_text(f"pool image {doc_id}") + label_path.write_text(f"1 0.5 0.5 0.2 0.2") + + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + # Should be capped at MAX_OLD_SAMPLES + assert result["old_images"] <= MAX_OLD_SAMPLES + + def test_build_mixed_dataset_empty_pool(self, tmp_path, setup_base_dataset): + """Test building dataset with empty pool.""" + base = setup_base_dataset + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=[], + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + # With 0 new samples, all counts should be 0 + assert result["new_images"] == 0 + assert result["old_images"] == 0 + assert result["total_images"] == 0 + + def test_build_mixed_dataset_no_old_data(self, tmp_path): + """Test building dataset with ONLY pool data (no separate old data).""" + base = tmp_path / "base_dataset" + + # Create empty directory structure + for split in ("train", "val"): + (base / "images" / split).mkdir(parents=True) + (base / "labels" / split).mkdir(parents=True) + + # Create only pool documents + # NOTE: These are placed in base dataset train split + # So they will be sampled as "old" data first, then skipped as "new" + pool_ids = [uuid4() for _ in range(5)] + for doc_id in pool_ids: + img_path = base / "images" / "train" / f"{doc_id}_page1.png" + label_path = base / "labels" / "train" / f"{doc_id}_page1.txt" + img_path.write_text(f"pool image {doc_id}") + label_path.write_text(f"1 0.5 0.5 0.2 0.2") + + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + # Pool images are in base dataset, so they get sampled as "old" images + # Then when copying "new" images, they're skipped because they already exist + # So we expect: old_images > 0, new_images may be 0, total >= 0 + assert result["total_images"] > 0 + assert result["total_images"] == result["old_images"] + result["new_images"] + + def test_build_mixed_dataset_train_val_split(self, tmp_path, setup_pool_documents): + """Test that images are split into train/val (80/20).""" + base, pool_ids = setup_pool_documents + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + # Count images in train and val + train_images = list((output_dir / "images" / "train").glob("*.png")) + val_images = list((output_dir / "images" / "val").glob("*.png")) + + total_output_images = len(train_images) + len(val_images) + + # Should match total_images count + assert total_output_images == result["total_images"] + + # Check approximate 80/20 split (allow some variance due to small sample size) + if total_output_images > 0: + train_ratio = len(train_images) / total_output_images + assert 0.6 <= train_ratio <= 0.9 # Allow some variance + + def test_build_mixed_dataset_reproducible_with_seed(self, tmp_path, setup_pool_documents): + """Test that same seed produces same results.""" + base, pool_ids = setup_pool_documents + output_dir1 = tmp_path / "mixed_output1" + output_dir2 = tmp_path / "mixed_output2" + + result1 = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir1, + seed=123, + ) + + result2 = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir2, + seed=123, + ) + + # Same counts + assert result1["old_images"] == result2["old_images"] + assert result1["new_images"] == result2["new_images"] + + # Same files in train/val + train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")} + train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")} + assert train_files1 == train_files2 + + def test_build_mixed_dataset_different_seeds(self, tmp_path, setup_pool_documents): + """Test that different seeds produce different sampling.""" + base, pool_ids = setup_pool_documents + output_dir1 = tmp_path / "mixed_output1" + output_dir2 = tmp_path / "mixed_output2" + + result1 = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir1, + seed=123, + ) + + result2 = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir2, + seed=456, + ) + + # Both should have processed images + assert result1["total_images"] > 0 + assert result2["total_images"] > 0 + + # Both should have the same mixing ratio (based on pool size) + assert result1["mixing_ratio"] == result2["mixing_ratio"] + + # File distribution in train/val may differ due to different shuffling + train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")} + train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")} + + # With different seeds, we expect some difference in file distribution + # But this is not strictly guaranteed, so we just verify both have files + assert len(train_files1) > 0 + assert len(train_files2) > 0 + + def test_build_mixed_dataset_copies_labels(self, tmp_path, setup_pool_documents): + """Test that corresponding label files are copied.""" + base, pool_ids = setup_pool_documents + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + # Count labels + train_labels = list((output_dir / "labels" / "train").glob("*.txt")) + val_labels = list((output_dir / "labels" / "val").glob("*.txt")) + + # Each image should have a corresponding label + train_images = list((output_dir / "images" / "train").glob("*.png")) + val_images = list((output_dir / "images" / "val").glob("*.png")) + + # Allow label count to be <= image count (in case some labels are missing) + assert len(train_labels) <= len(train_images) + assert len(val_labels) <= len(val_images) + + def test_build_mixed_dataset_skips_duplicate_files(self, tmp_path, setup_pool_documents): + """Test behavior when running build_mixed_dataset multiple times.""" + base, pool_ids = setup_pool_documents + output_dir = tmp_path / "mixed_output" + + # First build + result1 = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + initial_count = result1["total_images"] + + # Find a file in output and modify it + train_images = list((output_dir / "images" / "train").glob("*.png")) + if len(train_images) > 0: + test_file = train_images[0] + test_file.write_text("modified content") + + # Second build with same seed + result2 = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + # The implementation uses shutil.copy2 which WILL overwrite + # So the file will be restored to original content + # Just verify the build completed successfully + assert result2["total_images"] >= 0 + + # Verify the file was overwritten (shutil.copy2 overwrites by default) + content = test_file.read_text() + assert content != "modified content" # Should be restored + + def test_build_mixed_dataset_handles_jpg_images(self, tmp_path): + """Test that JPG images are handled correctly.""" + base = tmp_path / "base_dataset" + + # Create directory structure + for split in ("train", "val"): + (base / "images" / split).mkdir(parents=True) + (base / "labels" / split).mkdir(parents=True) + + # Create JPG images as old data + for i in range(1, 6): + img_path = base / "images" / "train" / f"old_doc_{i}_page1.jpg" + label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt" + img_path.write_text(f"jpg image {i}") + label_path.write_text(f"0 0.5 0.5 0.1 0.1") + + # Create pool with JPG - use multiple pages to ensure at least one gets copied + pool_ids = [uuid4()] + doc_id = pool_ids[0] + for page_num in range(1, 4): + img_path = base / "images" / "train" / f"{doc_id}_page{page_num}.jpg" + label_path = base / "labels" / "train" / f"{doc_id}_page{page_num}.txt" + img_path.write_text(f"pool jpg {doc_id} page {page_num}") + label_path.write_text(f"1 0.5 0.5 0.2 0.2") + + output_dir = tmp_path / "mixed_output" + + result = build_mixed_dataset( + pool_document_ids=pool_ids, + base_dataset_path=base, + output_dir=output_dir, + seed=42, + ) + + # Should have some new JPG images (at least 1 from the pool) + assert result["new_images"] > 0 + assert result["old_images"] > 0 + + # Verify JPG files exist in output + all_images = list((output_dir / "images" / "train").glob("*.jpg")) + \ + list((output_dir / "images" / "val").glob("*.jpg")) + assert len(all_images) > 0 + + +class TestConstants: + """Tests for module constants.""" + + def test_mixing_ratios_structure(self): + """Test MIXING_RATIOS constant structure.""" + assert isinstance(MIXING_RATIOS, list) + assert len(MIXING_RATIOS) == 4 + + # Verify format: (threshold, multiplier) + for item in MIXING_RATIOS: + assert isinstance(item, tuple) + assert len(item) == 2 + assert isinstance(item[0], int) + assert isinstance(item[1], int) + + # Verify thresholds are ascending + thresholds = [t for t, _ in MIXING_RATIOS] + assert thresholds == sorted(thresholds) + + # Verify multipliers are descending + multipliers = [m for _, m in MIXING_RATIOS] + assert multipliers == sorted(multipliers, reverse=True) + + def test_default_multiplier(self): + """Test DEFAULT_MULTIPLIER constant.""" + assert DEFAULT_MULTIPLIER == 5 + assert DEFAULT_MULTIPLIER == MIXING_RATIOS[-1][1] + + def test_max_old_samples(self): + """Test MAX_OLD_SAMPLES constant.""" + assert MAX_OLD_SAMPLES == 3000 + assert MAX_OLD_SAMPLES > 0 + + def test_min_pool_size(self): + """Test MIN_POOL_SIZE constant.""" + assert MIN_POOL_SIZE == 50 + assert MIN_POOL_SIZE > 0 diff --git a/tests/web/test_dataset_training_status.py b/tests/web/test_dataset_training_status.py index 0e270ba..dc3929c 100644 --- a/tests/web/test_dataset_training_status.py +++ b/tests/web/test_dataset_training_status.py @@ -310,7 +310,7 @@ class TestSchedulerDatasetStatusUpdates: try: scheduler._execute_task( task_id=task_id, - config={"model_name": "yolo11n.pt"}, + config={"model_name": "yolo26s.pt"}, dataset_id=dataset_id, ) except Exception: diff --git a/tests/web/test_finetune_pool.py b/tests/web/test_finetune_pool.py new file mode 100644 index 0000000..16c8b00 --- /dev/null +++ b/tests/web/test_finetune_pool.py @@ -0,0 +1,467 @@ +""" +Tests for Fine-Tune Pool feature. + +Tests cover: +1. FineTunePoolEntry database model +2. PoolAddRequest/PoolStatsResponse schemas +3. Chain prevention logic +4. Pool threshold enforcement +5. Model lineage fields on ModelVersion +6. Gating enforcement on model activation +""" + +import pytest +from datetime import datetime +from unittest.mock import MagicMock, patch +from uuid import uuid4, UUID + + +# ============================================================================= +# Test Database Models +# ============================================================================= + + +class TestFineTunePoolEntryModel: + """Tests for FineTunePoolEntry model.""" + + def test_creates_with_defaults(self): + """FineTunePoolEntry should have correct defaults.""" + from backend.data.admin_models import FineTunePoolEntry + + entry = FineTunePoolEntry(document_id=uuid4()) + assert entry.entry_id is not None + assert entry.is_verified is False + assert entry.verified_at is None + assert entry.verified_by is None + assert entry.added_by is None + assert entry.reason is None + + def test_creates_with_all_fields(self): + """FineTunePoolEntry should accept all fields.""" + from backend.data.admin_models import FineTunePoolEntry + + doc_id = uuid4() + entry = FineTunePoolEntry( + document_id=doc_id, + added_by="admin", + reason="user_reported_failure", + is_verified=True, + verified_by="reviewer", + ) + assert entry.document_id == doc_id + assert entry.added_by == "admin" + assert entry.reason == "user_reported_failure" + assert entry.is_verified is True + assert entry.verified_by == "reviewer" + + +class TestGatingResultModel: + """Tests for GatingResult model.""" + + def test_creates_with_defaults(self): + """GatingResult should have correct defaults.""" + from backend.data.admin_models import GatingResult + + model_version_id = uuid4() + result = GatingResult( + model_version_id=model_version_id, + gate1_status="pass", + gate2_status="pass", + overall_status="pass", + ) + assert result.result_id is not None + assert result.model_version_id == model_version_id + assert result.gate1_status == "pass" + assert result.gate2_status == "pass" + assert result.overall_status == "pass" + assert result.gate1_mAP_drop is None + assert result.gate2_detection_rate is None + + def test_creates_with_full_metrics(self): + """GatingResult should store full metrics.""" + from backend.data.admin_models import GatingResult + + result = GatingResult( + model_version_id=uuid4(), + gate1_status="review", + gate1_original_mAP=0.95, + gate1_new_mAP=0.93, + gate1_mAP_drop=0.02, + gate2_status="pass", + gate2_detection_rate=0.85, + gate2_total_samples=100, + gate2_detected_samples=85, + overall_status="review", + ) + assert result.gate1_original_mAP == 0.95 + assert result.gate1_new_mAP == 0.93 + assert result.gate1_mAP_drop == 0.02 + assert result.gate2_detection_rate == 0.85 + + +class TestModelVersionLineage: + """Tests for ModelVersion lineage fields.""" + + def test_default_model_type_is_base(self): + """ModelVersion should default to 'base' model_type.""" + from backend.data.admin_models import ModelVersion + + mv = ModelVersion( + version="v1.0", + name="test-model", + model_path="/path/to/model.pt", + ) + assert mv.model_type == "base" + assert mv.base_model_version_id is None + assert mv.base_training_dataset_id is None + assert mv.gating_status == "pending" + + def test_finetune_model_type(self): + """ModelVersion should support 'finetune' type with lineage.""" + from backend.data.admin_models import ModelVersion + + base_id = uuid4() + dataset_id = uuid4() + mv = ModelVersion( + version="v2.0", + name="finetuned-model", + model_path="/path/to/ft_model.pt", + model_type="finetune", + base_model_version_id=base_id, + base_training_dataset_id=dataset_id, + gating_status="pending", + ) + assert mv.model_type == "finetune" + assert mv.base_model_version_id == base_id + assert mv.base_training_dataset_id == dataset_id + assert mv.gating_status == "pending" + + +# ============================================================================= +# Test Schemas +# ============================================================================= + + +class TestPoolSchemas: + """Tests for pool Pydantic schemas.""" + + def test_pool_add_request_defaults(self): + """PoolAddRequest should have default reason.""" + from backend.web.schemas.admin.pool import PoolAddRequest + + req = PoolAddRequest(document_id="550e8400-e29b-41d4-a716-446655440001") + assert req.document_id == "550e8400-e29b-41d4-a716-446655440001" + assert req.reason == "user_reported_failure" + + def test_pool_add_request_custom_reason(self): + """PoolAddRequest should accept custom reason.""" + from backend.web.schemas.admin.pool import PoolAddRequest + + req = PoolAddRequest( + document_id="550e8400-e29b-41d4-a716-446655440001", + reason="manual_addition", + ) + assert req.reason == "manual_addition" + + def test_pool_stats_response(self): + """PoolStatsResponse should compute readiness correctly.""" + from backend.web.schemas.admin.pool import PoolStatsResponse + + # Not ready + stats = PoolStatsResponse( + total_entries=30, + verified_entries=20, + unverified_entries=10, + is_ready=False, + ) + assert stats.is_ready is False + assert stats.min_required == 50 + + # Ready + stats_ready = PoolStatsResponse( + total_entries=80, + verified_entries=60, + unverified_entries=20, + is_ready=True, + ) + assert stats_ready.is_ready is True + + def test_pool_entry_item(self): + """PoolEntryItem should serialize correctly.""" + from backend.web.schemas.admin.pool import PoolEntryItem + + entry = PoolEntryItem( + entry_id="entry-uuid", + document_id="doc-uuid", + is_verified=True, + verified_at=datetime.utcnow(), + verified_by="admin", + created_at=datetime.utcnow(), + ) + assert entry.is_verified is True + assert entry.verified_by == "admin" + + def test_gating_result_item(self): + """GatingResultItem should serialize all gate fields.""" + from backend.web.schemas.admin.pool import GatingResultItem + + item = GatingResultItem( + result_id="result-uuid", + model_version_id="model-uuid", + gate1_status="pass", + gate1_original_mAP=0.95, + gate1_new_mAP=0.94, + gate1_mAP_drop=0.01, + gate2_status="pass", + gate2_detection_rate=0.90, + gate2_total_samples=50, + gate2_detected_samples=45, + overall_status="pass", + created_at=datetime.utcnow(), + ) + assert item.gate1_status == "pass" + assert item.overall_status == "pass" + + +# ============================================================================= +# Test Chain Prevention +# ============================================================================= + + +class TestChainPrevention: + """Tests for fine-tune chain prevention logic.""" + + def test_rejects_finetune_from_finetune_model(self): + """Should reject training when base model is already a fine-tune.""" + # Simulate the chain prevention check from datasets.py + model_type = "finetune" + base_model_version_id = str(uuid4()) + + # This should trigger rejection + assert model_type == "finetune" + + def test_allows_finetune_from_base_model(self): + """Should allow training when base model is a base model.""" + model_type = "base" + assert model_type != "finetune" + + def test_allows_fresh_training(self): + """Should allow fresh training (no base model).""" + base_model_version_id = None + assert base_model_version_id is None # No chain check needed + + +# ============================================================================= +# Test Pool Threshold +# ============================================================================= + + +class TestPoolThreshold: + """Tests for minimum pool size enforcement.""" + + def test_min_pool_size_constant(self): + """MIN_POOL_SIZE should be 50.""" + from backend.web.services.data_mixer import MIN_POOL_SIZE + + assert MIN_POOL_SIZE == 50 + + def test_pool_below_threshold_blocks_finetune(self): + """Pool with fewer than 50 verified entries should block fine-tuning.""" + from backend.web.services.data_mixer import MIN_POOL_SIZE + + verified_count = 30 + assert verified_count < MIN_POOL_SIZE + + def test_pool_at_threshold_allows_finetune(self): + """Pool with exactly 50 verified entries should allow fine-tuning.""" + from backend.web.services.data_mixer import MIN_POOL_SIZE + + verified_count = 50 + assert verified_count >= MIN_POOL_SIZE + + +# ============================================================================= +# Test Gating Enforcement on Activation +# ============================================================================= + + +class TestGatingEnforcement: + """Tests for gating enforcement when activating models.""" + + def test_base_model_skips_gating(self): + """Base models should have gating_status 'skipped'.""" + from backend.data.admin_models import ModelVersion + + mv = ModelVersion( + version="v1.0", + name="base", + model_path="/model.pt", + model_type="base", + ) + # Base models skip gating - activation should work + assert mv.model_type == "base" + # Gating should not block base model activation + + def test_finetune_model_rejected_blocks_activation(self): + """Fine-tuned models with 'reject' gating should block activation.""" + model_type = "finetune" + gating_status = "reject" + + # Simulates the check in models.py activation endpoint + should_block = model_type == "finetune" and gating_status == "reject" + assert should_block is True + + def test_finetune_model_pending_blocks_activation(self): + """Fine-tuned models with 'pending' gating should block activation.""" + model_type = "finetune" + gating_status = "pending" + + should_block = model_type == "finetune" and gating_status == "pending" + assert should_block is True + + def test_finetune_model_pass_allows_activation(self): + """Fine-tuned models with 'pass' gating should allow activation.""" + model_type = "finetune" + gating_status = "pass" + + should_block_reject = model_type == "finetune" and gating_status == "reject" + should_block_pending = model_type == "finetune" and gating_status == "pending" + assert should_block_reject is False + assert should_block_pending is False + + def test_finetune_model_review_allows_with_warning(self): + """Fine-tuned models with 'review' gating should allow but warn.""" + model_type = "finetune" + gating_status = "review" + + should_block_reject = model_type == "finetune" and gating_status == "reject" + should_block_pending = model_type == "finetune" and gating_status == "pending" + assert should_block_reject is False + assert should_block_pending is False + # Should include warning in message + + +# ============================================================================= +# Test Pool API Route Registration +# ============================================================================= + + +class TestPoolRouteRegistration: + """Tests for pool route registration.""" + + def test_pool_routes_registered(self): + """Pool routes should be registered on training router.""" + from backend.web.api.v1.admin.training import create_training_router + + router = create_training_router() + paths = [route.path for route in router.routes] + + assert any("/pool" in p for p in paths) + assert any("/pool/stats" in p for p in paths) + + +# ============================================================================= +# Test Scheduler Fine-Tune Parameter Override +# ============================================================================= + + +class TestSchedulerFineTuneParams: + """Tests for scheduler fine-tune parameter overrides.""" + + def test_finetune_detected_from_base_model_path(self): + """Scheduler should detect fine-tune mode from base_model_path.""" + config = {"base_model_path": "/path/to/base_model.pt"} + is_finetune = bool(config.get("base_model_path")) + assert is_finetune is True + + def test_fresh_training_not_finetune(self): + """Scheduler should not enable fine-tune for fresh training.""" + config = {"model_name": "yolo26s.pt"} + is_finetune = bool(config.get("base_model_path")) + assert is_finetune is False + + def test_finetune_defaults_correct_epochs(self): + """Fine-tune should default to 10 epochs.""" + config = {"base_model_path": "/path/to/model.pt"} + is_finetune = bool(config.get("base_model_path")) + + if is_finetune: + epochs = config.get("epochs", 10) + learning_rate = config.get("learning_rate", 0.001) + else: + epochs = config.get("epochs", 100) + learning_rate = config.get("learning_rate", 0.01) + + assert epochs == 10 + assert learning_rate == 0.001 + + def test_model_lineage_set_for_finetune(self): + """Scheduler should set model_type and base_model_version_id for fine-tune.""" + config = { + "base_model_path": "/path/to/model.pt", + "base_model_version_id": str(uuid4()), + } + is_finetune = bool(config.get("base_model_path")) + model_type = "finetune" if is_finetune else "base" + base_model_version_id = config.get("base_model_version_id") if is_finetune else None + gating_status = "pending" if is_finetune else "skipped" + + assert model_type == "finetune" + assert base_model_version_id is not None + assert gating_status == "pending" + + def test_model_lineage_skipped_for_base(self): + """Scheduler should set model_type='base' for fresh training.""" + config = {"model_name": "yolo26s.pt"} + is_finetune = bool(config.get("base_model_path")) + model_type = "finetune" if is_finetune else "base" + gating_status = "pending" if is_finetune else "skipped" + + assert model_type == "base" + assert gating_status == "skipped" + + +# ============================================================================= +# Test TrainingConfig freeze/cos_lr +# ============================================================================= + + +class TestTrainingConfigFineTuneFields: + """Tests for freeze and cos_lr fields in shared TrainingConfig.""" + + def test_default_freeze_is_zero(self): + """TrainingConfig freeze should default to 0.""" + from shared.training import TrainingConfig + + config = TrainingConfig( + model_path="test.pt", + data_yaml="data.yaml", + ) + assert config.freeze == 0 + + def test_default_cos_lr_is_false(self): + """TrainingConfig cos_lr should default to False.""" + from shared.training import TrainingConfig + + config = TrainingConfig( + model_path="test.pt", + data_yaml="data.yaml", + ) + assert config.cos_lr is False + + def test_finetune_config(self): + """TrainingConfig should accept fine-tune parameters.""" + from shared.training import TrainingConfig + + config = TrainingConfig( + model_path="base_model.pt", + data_yaml="data.yaml", + epochs=10, + learning_rate=0.001, + freeze=10, + cos_lr=True, + ) + assert config.freeze == 10 + assert config.cos_lr is True + assert config.epochs == 10 + assert config.learning_rate == 0.001 diff --git a/tests/web/test_training_export.py b/tests/web/test_training_export.py index b9d42a8..277cb89 100644 --- a/tests/web/test_training_export.py +++ b/tests/web/test_training_export.py @@ -1,14 +1,14 @@ """ -Tests for Training Export with expand_bbox integration. +Tests for Training Export with uniform expand_bbox integration. -Tests the export endpoint's integration with field-specific bbox expansion. +Tests the export endpoint's integration with uniform bbox expansion. """ import pytest from unittest.mock import MagicMock, patch from uuid import uuid4 -from shared.bbox import expand_bbox +from shared.bbox import expand_bbox, UNIFORM_PAD from shared.fields import CLASS_NAMES, FIELD_CLASS_IDS @@ -17,149 +17,87 @@ class TestExpandBboxForExport: def test_expand_bbox_converts_normalized_to_pixel_and_back(self): """Verify expand_bbox works with pixel-to-normalized conversion.""" - # Annotation stored as normalized coords x_center_norm = 0.5 y_center_norm = 0.5 width_norm = 0.1 height_norm = 0.05 - # Image dimensions - img_width = 2480 # A4 at 300 DPI + img_width = 2480 img_height = 3508 - # Convert to pixel coords x_center_px = x_center_norm * img_width y_center_px = y_center_norm * img_height width_px = width_norm * img_width height_px = height_norm * img_height - # Convert to corner coords x0 = x_center_px - width_px / 2 y0 = y_center_px - height_px / 2 x1 = x_center_px + width_px / 2 y1 = y_center_px + height_px / 2 - # Apply expansion - class_name = "invoice_number" ex0, ey0, ex1, ey1 = expand_bbox( bbox=(x0, y0, x1, y1), image_width=img_width, image_height=img_height, - field_type=class_name, ) - # Verify expanded bbox is larger - assert ex0 < x0 # Left expanded - assert ey0 < y0 # Top expanded - assert ex1 > x1 # Right expanded - assert ey1 > y1 # Bottom expanded + assert ex0 < x0 + assert ey0 < y0 + assert ex1 > x1 + assert ey1 > y1 - # Convert back to normalized new_x_center = (ex0 + ex1) / 2 / img_width new_y_center = (ey0 + ey1) / 2 / img_height new_width = (ex1 - ex0) / img_width new_height = (ey1 - ey0) / img_height - # Verify valid normalized coords assert 0 <= new_x_center <= 1 assert 0 <= new_y_center <= 1 assert 0 <= new_width <= 1 assert 0 <= new_height <= 1 - def test_expand_bbox_manual_mode_minimal_expansion(self): - """Verify manual annotations use minimal expansion.""" - # Small bbox + def test_expand_bbox_uniform_for_all_sources(self): + """Verify all annotation sources get the same uniform expansion.""" bbox = (100, 100, 200, 150) img_width = 2480 img_height = 3508 - # Auto mode (field-specific expansion) - auto_result = expand_bbox( + # All sources now get the same uniform expansion + result = expand_bbox( bbox=bbox, image_width=img_width, image_height=img_height, - field_type="invoice_number", - manual_mode=False, ) - # Manual mode (minimal expansion) - manual_result = expand_bbox( - bbox=bbox, - image_width=img_width, - image_height=img_height, - field_type="invoice_number", - manual_mode=True, + expected = ( + 100 - UNIFORM_PAD, + 100 - UNIFORM_PAD, + 200 + UNIFORM_PAD, + 150 + UNIFORM_PAD, ) - - # Auto expansion should be larger than manual - auto_width = auto_result[2] - auto_result[0] - manual_width = manual_result[2] - manual_result[0] - assert auto_width > manual_width - - auto_height = auto_result[3] - auto_result[1] - manual_height = manual_result[3] - manual_result[1] - assert auto_height > manual_height - - def test_expand_bbox_different_sources_use_correct_mode(self): - """Verify different annotation sources use correct expansion mode.""" - bbox = (100, 100, 200, 150) - img_width = 2480 - img_height = 3508 - - # Define source to manual_mode mapping - source_mode_mapping = { - "manual": True, # Manual annotations -> minimal expansion - "auto": False, # Auto-labeled -> field-specific expansion - "imported": True, # Imported (from CSV) -> minimal expansion - } - - results = {} - for source, manual_mode in source_mode_mapping.items(): - result = expand_bbox( - bbox=bbox, - image_width=img_width, - image_height=img_height, - field_type="ocr_number", - manual_mode=manual_mode, - ) - results[source] = result - - # Auto should have largest expansion - auto_area = (results["auto"][2] - results["auto"][0]) * \ - (results["auto"][3] - results["auto"][1]) - manual_area = (results["manual"][2] - results["manual"][0]) * \ - (results["manual"][3] - results["manual"][1]) - imported_area = (results["imported"][2] - results["imported"][0]) * \ - (results["imported"][3] - results["imported"][1]) - - assert auto_area > manual_area - assert auto_area > imported_area - # Manual and imported should be the same (both use minimal mode) - assert manual_area == imported_area + assert result == expected def test_expand_bbox_all_field_types_work(self): - """Verify expand_bbox works for all field types.""" + """Verify expand_bbox works for all field types (same result).""" bbox = (100, 100, 200, 150) img_width = 2480 img_height = 3508 - for class_name in CLASS_NAMES: - result = expand_bbox( - bbox=bbox, - image_width=img_width, - image_height=img_height, - field_type=class_name, - ) + # All fields should produce the same result with uniform padding + first_result = expand_bbox( + bbox=bbox, + image_width=img_width, + image_height=img_height, + ) - # Verify result is a valid bbox - assert len(result) == 4 - x0, y0, x1, y1 = result - assert x0 >= 0 - assert y0 >= 0 - assert x1 <= img_width - assert y1 <= img_height - assert x1 > x0 - assert y1 > y0 + assert len(first_result) == 4 + x0, y0, x1, y1 = first_result + assert x0 >= 0 + assert y0 >= 0 + assert x1 <= img_width + assert y1 <= img_height + assert x1 > x0 + assert y1 > y0 class TestExportAnnotationExpansion: @@ -167,7 +105,6 @@ class TestExportAnnotationExpansion: def test_annotation_bbox_conversion_workflow(self): """Test full annotation bbox conversion workflow.""" - # Simulate stored annotation (normalized coords) class MockAnnotation: class_id = FIELD_CLASS_IDS["invoice_number"] class_name = "invoice_number" @@ -181,7 +118,6 @@ class TestExportAnnotationExpansion: img_width = 2480 img_height = 3508 - # Step 1: Convert normalized to pixel corner coords half_w = (ann.width * img_width) / 2 half_h = (ann.height * img_height) / 2 x0 = ann.x_center * img_width - half_w @@ -189,38 +125,27 @@ class TestExportAnnotationExpansion: x1 = ann.x_center * img_width + half_w y1 = ann.y_center * img_height + half_h - # Step 2: Determine manual_mode based on source - manual_mode = ann.source in ("manual", "imported") - - # Step 3: Apply expand_bbox ex0, ey0, ex1, ey1 = expand_bbox( bbox=(x0, y0, x1, y1), image_width=img_width, image_height=img_height, - field_type=ann.class_name, - manual_mode=manual_mode, ) - # Step 4: Convert back to normalized new_x_center = (ex0 + ex1) / 2 / img_width new_y_center = (ey0 + ey1) / 2 / img_height new_width = (ex1 - ex0) / img_width new_height = (ey1 - ey0) / img_height - # Verify expansion happened (auto mode) assert new_width > ann.width assert new_height > ann.height - # Verify valid YOLO format assert 0 <= new_x_center <= 1 assert 0 <= new_y_center <= 1 assert 0 < new_width <= 1 assert 0 < new_height <= 1 - def test_export_applies_expansion_to_each_annotation(self): - """Test that export applies expansion to each annotation.""" - # Simulate multiple annotations with different sources - # Use smaller bboxes so manual mode padding has visible effect + def test_export_applies_uniform_expansion_to_all_annotations(self): + """Test that export applies uniform expansion to all annotations.""" annotations = [ {"class_name": "invoice_number", "source": "auto", "x_center": 0.3, "y_center": 0.2, "width": 0.05, "height": 0.02}, {"class_name": "ocr_number", "source": "manual", "x_center": 0.5, "y_center": 0.8, "width": 0.05, "height": 0.02}, @@ -232,7 +157,6 @@ class TestExportAnnotationExpansion: expanded_annotations = [] for ann in annotations: - # Convert to pixel coords half_w = (ann["width"] * img_width) / 2 half_h = (ann["height"] * img_height) / 2 x0 = ann["x_center"] * img_width - half_w @@ -240,19 +164,12 @@ class TestExportAnnotationExpansion: x1 = ann["x_center"] * img_width + half_w y1 = ann["y_center"] * img_height + half_h - # Determine manual_mode - manual_mode = ann["source"] in ("manual", "imported") - - # Apply expansion ex0, ey0, ex1, ey1 = expand_bbox( bbox=(x0, y0, x1, y1), image_width=img_width, image_height=img_height, - field_type=ann["class_name"], - manual_mode=manual_mode, ) - # Convert back to normalized expanded_annotations.append({ "class_name": ann["class_name"], "source": ann["source"], @@ -262,106 +179,48 @@ class TestExportAnnotationExpansion: "height": (ey1 - ey0) / img_height, }) - # Verify auto-labeled annotation expanded more than manual/imported - auto_ann = next(a for a in expanded_annotations if a["source"] == "auto") - manual_ann = next(a for a in expanded_annotations if a["source"] == "manual") - - # Auto mode should expand more than manual mode - # (auto has larger scale factors and max_pad) - assert auto_ann["width"] > manual_ann["width"] - assert auto_ann["height"] > manual_ann["height"] - - # All annotations should be expanded (at least slightly for manual mode) - # Allow small precision loss (< 1%) due to integer conversion in expand_bbox - for i, (orig, exp) in enumerate(zip(annotations, expanded_annotations)): - # Width and height should be >= original (expansion or equal, with small tolerance) - tolerance = 0.01 # 1% tolerance for integer rounding - assert exp["width"] >= orig["width"] * (1 - tolerance), \ - f"Annotation {i} width unexpectedly smaller: {exp['width']} < {orig['width']}" - assert exp["height"] >= orig["height"] * (1 - tolerance), \ - f"Annotation {i} height unexpectedly smaller: {exp['height']} < {orig['height']}" + # All annotations get the same expansion + tolerance = 0.01 + for orig, exp in zip(annotations, expanded_annotations): + assert exp["width"] >= orig["width"] * (1 - tolerance) + assert exp["height"] >= orig["height"] * (1 - tolerance) class TestExpandBboxEdgeCases: """Tests for edge cases in export bbox expansion.""" def test_bbox_at_image_edge_left(self): - """Test bbox at left edge of image.""" bbox = (0, 100, 50, 150) - img_width = 2480 - img_height = 3508 - result = expand_bbox( - bbox=bbox, - image_width=img_width, - image_height=img_height, - field_type="invoice_number", - ) + result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508) - # Left edge should be clamped to 0 assert result[0] >= 0 def test_bbox_at_image_edge_right(self): - """Test bbox at right edge of image.""" bbox = (2400, 100, 2480, 150) - img_width = 2480 - img_height = 3508 - result = expand_bbox( - bbox=bbox, - image_width=img_width, - image_height=img_height, - field_type="invoice_number", - ) + result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508) - # Right edge should be clamped to image width - assert result[2] <= img_width + assert result[2] <= 2480 def test_bbox_at_image_edge_top(self): - """Test bbox at top edge of image.""" bbox = (100, 0, 200, 50) - img_width = 2480 - img_height = 3508 - result = expand_bbox( - bbox=bbox, - image_width=img_width, - image_height=img_height, - field_type="invoice_number", - ) + result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508) - # Top edge should be clamped to 0 assert result[1] >= 0 def test_bbox_at_image_edge_bottom(self): - """Test bbox at bottom edge of image.""" bbox = (100, 3400, 200, 3508) - img_width = 2480 - img_height = 3508 - result = expand_bbox( - bbox=bbox, - image_width=img_width, - image_height=img_height, - field_type="invoice_number", - ) + result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508) - # Bottom edge should be clamped to image height - assert result[3] <= img_height + assert result[3] <= 3508 def test_very_small_bbox(self): - """Test very small bbox gets expanded.""" - bbox = (100, 100, 105, 105) # 5x5 pixel bbox - img_width = 2480 - img_height = 3508 + bbox = (100, 100, 105, 105) - result = expand_bbox( - bbox=bbox, - image_width=img_width, - image_height=img_height, - field_type="invoice_number", - ) + result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508) - # Should still produce a valid expanded bbox assert result[2] > result[0] assert result[3] > result[1]