This commit is contained in:
Yaojia Wang
2026-02-12 23:06:00 +01:00
parent ad5ed46b4c
commit 58d36c8927
26 changed files with 3903 additions and 2551 deletions

View File

@@ -1,66 +1,72 @@
# Invoice Master POC v2
Swedish Invoice Field Extraction System - YOLO26 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
Swedish Invoice Field Extraction System - YOLO + PaddleOCR extracts structured data from Swedish PDF invoices.
## Architecture
```
PDF → PyMuPDF (DPI=150) → YOLO Detection → PaddleOCR → Field Extraction → Normalization → Output
```
### Project Structure
```
packages/
├── backend/ # FastAPI web server + inference pipeline
│ └── pipeline/ # YOLO detector → OCR → field extractor → value selector → normalizers
├── shared/ # Common utilities (bbox, OCR, field mappings)
└── training/ # YOLO training data generation (annotation, dataset)
tests/ # Mirrors packages/ structure
```
### Pipeline Flow (process_pdf)
1. YOLO detects field regions on rendered PDF page
2. PaddleOCR extracts text from detected bboxes
3. Field extractor maps detections to invoice fields via CLASS_TO_FIELD
4. Value selector picks best candidate per field (confidence + validation)
5. Normalizers clean values (dates, amounts, invoice numbers)
6. Fallback regex extraction if key fields missing
## Tech Stack
| Component | Technology |
|-----------|------------|
| Object Detection | YOLO26 (Ultralytics >= 8.4.0) |
| OCR Engine | PaddleOCR v5 (PP-OCRv5) |
| PDF Processing | PyMuPDF (fitz) |
| Object Detection | YOLO (Ultralytics >= 8.4.0) |
| OCR | PaddleOCR v5 (PP-OCRv5) |
| PDF | PyMuPDF (fitz), DPI=150 |
| Database | PostgreSQL + psycopg2 |
| Web Framework | FastAPI + Uvicorn |
| Deep Learning | PyTorch + CUDA 12.x |
| Web | FastAPI + Uvicorn |
| ML | PyTorch + CUDA 12.x |
## WSL Environment (REQUIRED)
**Prefix ALL commands with:**
ALL Python commands MUST use this prefix:
```bash
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && <command>"
```
**NEVER run Python commands directly in Windows PowerShell/CMD.**
NEVER run Python directly in Windows PowerShell/CMD.
## Project-Specific Rules
## Project Rules
- Python 3.11+ with type hints
- No print() in production - use logging
- Run tests: `pytest --cov=src`
- Python 3.10, type hints on all function signatures
- No `print()` in production code - use `logging` module
- Validation with `pydantic` or `dataclasses`
- Error handling with `try/except` (not try/catch)
- Run tests: `pytest --cov=packages tests/`
## Critical Rules
## Key Files
### Code Organization
- Many small files over few large files
- High cohesion, low coupling
- 200-400 lines typical, 800 max per file
- Organize by feature/domain, not by type
### Code Style
- No emojis in code, comments, or documentation
- Immutability always - never mutate objects or arrays
- No console.log in production code
- Proper error handling with try/catch
- Input validation with Zod or similar
### Testing
- TDD: Write tests first
- 80% minimum coverage
- Unit tests for utilities
- Integration tests for APIs
- E2E tests for critical flows
### Security
- No hardcoded secrets
- Environment variables for sensitive data
- Validate all user inputs
- Parameterized queries only
- CSRF protection enabled
| File | Purpose |
|------|---------|
| `packages/backend/backend/pipeline/pipeline.py` | Main inference pipeline |
| `packages/backend/backend/pipeline/field_extractor.py` | YOLO → field mapping |
| `packages/backend/backend/pipeline/value_selector.py` | Best candidate selection |
| `packages/shared/shared/fields/mappings.py` | CLASS_TO_FIELD mapping |
| `packages/shared/shared/ocr/paddle_ocr.py` | OCRToken definition |
| `packages/shared/shared/bbox/` | Bbox expansion strategies |
## Environment Variables
@@ -78,18 +84,41 @@ CONFIDENCE_THRESHOLD=0.5
SERVER_HOST=0.0.0.0
SERVER_PORT=8000
```
## Available Commands
- `/tdd` - Test-driven development workflow
- `/plan` - Create implementation plan
- `/code-review` - Review code quality
- `/build-fix` - Fix build errors
## Auto-trigger Rules (ALWAYS FOLLOW - even after context compaction)
## Git Workflow
These rules MUST be followed regardless of conversation history:
- Conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `test:`
- Never commit to main directly
- PRs require review
- All tests must pass before merge
- New feature or bug fix → MUST use **tdd-guide** agent (write tests first)
- When writing code → MUST follow coding standards skill for the target language:
- Python → `python-patterns` (PEP 8, type hints, Pythonic idioms)
- C# → `dotnet-skills:coding-standards` (records, pattern matching, modern C#)
- TS/JS → `coding-standards` (universal best practices)
- After writing/modifying code → MUST use **code-reviewer** agent
- Before git commit → MUST use **security-reviewer** agent
- When build/test fails → MUST use **build-error-resolver** agent
- After context compaction → read MEMORY.md to restore session state
Push the code before review and fix finished.
## Plan Completion Protocol
After completing any plan or major task:
1. **Test** - Run `pytest` to confirm all tests pass
2. **Security review** - Use **security-reviewer** agent on changed files
3. **Fix loop** - If security review reports CRITICAL or HIGH issues:
- Fix the issues
- Re-run tests (back to step 1)
- Re-run security review (back to step 2)
- Repeat until no CRITICAL/HIGH issues remain
4. **Commit** - Auto-commit with conventional commit message (`feat:`, `fix:`, `refactor:`, etc.). Stage only the files changed in this task, not unrelated files
5. **Save** - Write a summary to MEMORY.md including: what was done, files changed, decisions made, remaining work
6. **Suggest clear** - Tell the user: "Plan complete. Recommend `/clear` to free context for the next task."
7. **Do NOT start a new task** in the same context - wait for user to /clear first
This keeps each plan in a fresh context window for maximum quality.
## Known Issues
- Pre-existing test failures: `test_s3.py`, `test_azure.py` (missing boto3/azure) - safe to ignore
- Always re-run dedup/validation after fallback adds new fields
- PDF DPI must be 150 (not 300) for correct bbox alignment

View File

@@ -1,37 +0,0 @@
# Python Coding Style
> This file extends [common/coding-style.md](../common/coding-style.md) with Python specific content.
## Standards
- Follow **PEP 8** conventions
- Use **type annotations** on all function signatures
## Immutability
Prefer immutable data structures:
```python
from dataclasses import dataclass
@dataclass(frozen=True)
class User:
name: str
email: str
from typing import NamedTuple
class Point(NamedTuple):
x: float
y: float
```
## Formatting
- **black** for code formatting
- **isort** for import sorting
- **ruff** for linting
## Reference
See skill: `python-patterns` for comprehensive Python idioms and patterns.

View File

@@ -1,14 +0,0 @@
# Python Hooks
> This file extends [common/hooks.md](../common/hooks.md) with Python specific content.
## PostToolUse Hooks
Configure in `~/.claude/settings.json`:
- **black/ruff**: Auto-format `.py` files after edit
- **mypy/pyright**: Run type checking after editing `.py` files
## Warnings
- Warn about `print()` statements in edited files (use `logging` module instead)

View File

@@ -1,34 +0,0 @@
# Python Patterns
> This file extends [common/patterns.md](../common/patterns.md) with Python specific content.
## Protocol (Duck Typing)
```python
from typing import Protocol
class Repository(Protocol):
def find_by_id(self, id: str) -> dict | None: ...
def save(self, entity: dict) -> dict: ...
```
## Dataclasses as DTOs
```python
from dataclasses import dataclass
@dataclass
class CreateUserRequest:
name: str
email: str
age: int | None = None
```
## Context Managers & Generators
- Use context managers (`with` statement) for resource management
- Use generators for lazy evaluation and memory-efficient iteration
## Reference
See skill: `python-patterns` for comprehensive patterns including decorators, concurrency, and package organization.

View File

@@ -1,25 +0,0 @@
# Python Security
> This file extends [common/security.md](../common/security.md) with Python specific content.
## Secret Management
```python
import os
from dotenv import load_dotenv
load_dotenv()
api_key = os.environ["OPENAI_API_KEY"] # Raises KeyError if missing
```
## Security Scanning
- Use **bandit** for static security analysis:
```bash
bandit -r src/
```
## Reference
See skill: `django-security` for Django-specific security guidelines (if applicable).

View File

@@ -1,33 +0,0 @@
# Python Testing
> This file extends [common/testing.md](../common/testing.md) with Python specific content.
## Framework
Use **pytest** as the testing framework.
## Coverage
```bash
pytest --cov=src --cov-report=term-missing
```
## Test Organization
Use `pytest.mark` for test categorization:
```python
import pytest
@pytest.mark.unit
def test_calculate_total():
...
@pytest.mark.integration
def test_database_connection():
...
```
## Reference
See skill: `python-testing` for detailed pytest patterns and fixtures.

BIN
.coverage

Binary file not shown.

View File

@@ -1,666 +0,0 @@
# Invoice Master POC v2 - 总体架构审查报告
**审查日期**: 2026-02-01
**审查人**: Claude Code
**项目路径**: `/Users/yiukai/Documents/git/invoice-master-poc-v2`
---
## 架构概述
### 整体架构图
```
┌─────────────────────────────────────────────────────────────────┐
│ Frontend (React) │
│ Vite + TypeScript + TailwindCSS │
└─────────────────────────────┬───────────────────────────────────┘
│ HTTP/REST
┌─────────────────────────────▼───────────────────────────────────┐
│ Inference Service (FastAPI) │
│ ┌──────────────┬──────────────┬──────────────┬──────────────┐ │
│ │ Public API │ Admin API │ Training API│ Batch API │ │
│ └──────────────┴──────────────┴──────────────┴──────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Service Layer │ │
│ │ InferenceService │ AsyncProcessing │ BatchUpload │ Dataset │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Data Layer │ │
│ │ AdminDB │ AsyncRequestDB │ SQLModel │ PostgreSQL │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Core Components │ │
│ │ RateLimiter │ Schedulers │ TaskQueues │ Auth │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────┬───────────────────────────────────┘
│ PostgreSQL
┌─────────────────────────────▼───────────────────────────────────┐
│ Training Service (GPU) │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ CLI: train │ autolabel │ analyze │ validate │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ YOLO: db_dataset │ annotation_generator │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Processing: CPU Pool │ GPU Pool │ Task Dispatcher │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
┌─────────┴─────────┐
▼ ▼
┌──────────────┐ ┌──────────────┐
│ Shared │ │ Storage │
│ PDF │ OCR │ │ Local/Azure/ │
│ Normalize │ │ S3 │
└──────────────┘ └──────────────┘
```
### 技术栈
| 层级 | 技术 | 评估 |
|------|------|------|
| **前端** | React + Vite + TypeScript + TailwindCSS | ✅ 现代栈 |
| **API 框架** | FastAPI | ✅ 高性能,类型安全 |
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 ORM |
| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) | ✅ 业界标准 |
| **OCR** | PaddleOCR v5 | ✅ 支持瑞典语 |
| **部署** | Docker + Azure/AWS | ✅ 云原生 |
---
## 架构优势
### 1. Monorepo 结构 ✅
```
packages/
├── shared/ # 共享库 - 无外部依赖
├── training/ # 训练服务 - 依赖 shared
└── inference/ # 推理服务 - 依赖 shared
```
**优点**:
- 清晰的包边界,无循环依赖
- 独立部署training 按需启动
- 代码复用率高
### 2. 分层架构 ✅
```
API Routes (web/api/v1/)
Service Layer (web/services/)
Data Layer (data/)
Database (PostgreSQL)
```
**优点**:
- 职责分离明确
- 便于单元测试
- 可替换底层实现
### 3. 依赖注入 ✅
```python
# FastAPI Depends 使用得当
@router.post("/infer")
async def infer(
file: UploadFile,
db: AdminDB = Depends(get_admin_db), # 注入
token: str = Depends(validate_admin_token),
):
```
### 4. 存储抽象层 ✅
```python
# 统一接口,支持多后端
class StorageBackend(ABC):
def upload(self, source: Path, destination: str) -> None: ...
def download(self, source: str, destination: Path) -> None: ...
def get_presigned_url(self, path: str) -> str: ...
# 实现: LocalStorageBackend, AzureStorageBackend, S3StorageBackend
```
### 5. 动态模型管理 ✅
```python
# 数据库驱动的模型切换
def get_active_model_path() -> Path | None:
db = AdminDB()
active_model = db.get_active_model_version()
return active_model.model_path if active_model else None
inference_service = InferenceService(
model_path_resolver=get_active_model_path,
)
```
### 6. 任务队列分离 ✅
```python
# 不同类型任务使用不同队列
- AsyncTaskQueue: 异步推理任务
- BatchQueue: 批量上传任务
- TrainingScheduler: 训练任务调度
- AutoLabelScheduler: 自动标注调度
```
---
## 架构问题与风险
### 1. 数据库层职责过重 ⚠️ **中风险**
**问题**: `AdminDB` 类过大,违反单一职责原则
```python
# packages/inference/inference/data/admin_db.py
class AdminDB:
# Token 管理 (5 个方法)
def is_valid_admin_token(self, token: str) -> bool: ...
def create_admin_token(self, token: str, name: str): ...
# 文档管理 (8 个方法)
def create_document(self, ...): ...
def get_document(self, doc_id: str): ...
# 标注管理 (6 个方法)
def create_annotation(self, ...): ...
def get_annotations(self, doc_id: str): ...
# 训练任务 (7 个方法)
def create_training_task(self, ...): ...
def update_training_task(self, ...): ...
# 数据集 (6 个方法)
def create_dataset(self, ...): ...
def get_dataset(self, dataset_id: str): ...
# 模型版本 (5 个方法)
def create_model_version(self, ...): ...
def activate_model_version(self, ...): ...
# 批处理 (4 个方法)
# 锁管理 (3 个方法)
# ... 总计 50+ 方法
```
**影响**:
- 类过大,难以维护
- 测试困难
- 不同领域变更互相影响
**建议**: 按领域拆分为 Repository 模式
```python
# 建议重构
class TokenRepository:
def validate(self, token: str) -> bool: ...
def create(self, token: Token) -> None: ...
class DocumentRepository:
def find_by_id(self, doc_id: str) -> Document | None: ...
def save(self, document: Document) -> None: ...
class TrainingRepository:
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
def update_task_status(self, task_id: str, status: TaskStatus): ...
class ModelRepository:
def get_active(self) -> ModelVersion | None: ...
def activate(self, version_id: str) -> None: ...
```
---
### 2. Service 层混合业务逻辑与技术细节 ⚠️ **中风险**
**问题**: `InferenceService` 既处理业务逻辑又处理技术实现
```python
# packages/inference/inference/web/services/inference.py
class InferenceService:
def process(self, image_bytes: bytes) -> ServiceResult:
# 1. 技术细节: 图像解码
image = Image.open(io.BytesIO(image_bytes))
# 2. 业务逻辑: 字段提取
fields = self._extract_fields(image)
# 3. 技术细节: 模型推理
detections = self._model.predict(image)
# 4. 业务逻辑: 结果验证
if not self._validate_fields(fields):
raise ValidationError()
```
**影响**:
- 难以测试业务逻辑
- 技术变更影响业务代码
- 无法切换技术实现
**建议**: 引入领域层和适配器模式
```python
# 领域层 - 纯业务逻辑
@dataclass
class InvoiceDocument:
document_id: str
pages: list[Page]
class InvoiceExtractor:
"""纯业务逻辑,不依赖技术实现"""
def extract(self, document: InvoiceDocument) -> InvoiceFields:
# 只处理业务规则
pass
# 适配器层 - 技术实现
class YoloFieldDetector:
"""YOLO 技术适配器"""
def __init__(self, model_path: Path):
self._model = YOLO(model_path)
def detect(self, image: np.ndarray) -> list[FieldRegion]:
return self._model.predict(image)
class PaddleOcrEngine:
"""PaddleOCR 技术适配器"""
def __init__(self):
self._ocr = PaddleOCR()
def recognize(self, image: np.ndarray, region: BoundingBox) -> str:
return self._ocr.ocr(image, region)
# 应用服务 - 协调领域和适配器
class InvoiceProcessingService:
def __init__(
self,
extractor: InvoiceExtractor,
detector: FieldDetector,
ocr: OcrEngine,
):
self._extractor = extractor
self._detector = detector
self._ocr = ocr
```
---
### 3. 调度器设计分散 ⚠️ **中风险**
**问题**: 多个独立调度器缺乏统一协调
```python
# 当前设计 - 4 个独立调度器
# 1. TrainingScheduler (core/scheduler.py)
# 2. AutoLabelScheduler (core/autolabel_scheduler.py)
# 3. AsyncTaskQueue (workers/async_queue.py)
# 4. BatchQueue (workers/batch_queue.py)
# app.py 中分别启动
start_scheduler() # 训练调度器
start_autolabel_scheduler() # 自动标注调度器
init_batch_queue() # 批处理队列
```
**影响**:
- 资源竞争风险
- 难以监控和追踪
- 任务优先级难以管理
- 重启时任务丢失
**建议**: 使用 Celery + Redis 统一任务队列
```python
# 建议重构
from celery import Celery
app = Celery('invoice_master')
@app.task(bind=True, max_retries=3)
def process_inference(self, document_id: str):
"""异步推理任务"""
try:
service = get_inference_service()
result = service.process(document_id)
return result
except Exception as exc:
raise self.retry(exc=exc, countdown=60)
@app.task
def train_model(dataset_id: str, config: dict):
"""训练任务"""
training_service = get_training_service()
return training_service.train(dataset_id, config)
@app.task
def auto_label_documents(document_ids: list[str]):
"""批量自动标注"""
for doc_id in document_ids:
auto_label_document.delay(doc_id)
# 优先级队列
app.conf.task_routes = {
'tasks.process_inference': {'queue': 'high_priority'},
'tasks.train_model': {'queue': 'gpu_queue'},
'tasks.auto_label_documents': {'queue': 'low_priority'},
}
```
---
### 4. 配置分散 ⚠️ **低风险**
**问题**: 配置分散在多个文件
```python
# packages/shared/shared/config.py
DATABASE = {...}
PATHS = {...}
AUTOLABEL = {...}
# packages/inference/inference/web/config.py
@dataclass
class ModelConfig: ...
@dataclass
class ServerConfig: ...
@dataclass
class FileConfig: ...
# 环境变量
# .env 文件
```
**影响**:
- 配置难以追踪
- 可能出现不一致
- 缺少配置验证
**建议**: 使用 Pydantic Settings 集中管理
```python
# config/settings.py
from pydantic_settings import BaseSettings, SettingsConfigDict
class DatabaseSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix='DB_')
host: str = 'localhost'
port: int = 5432
name: str = 'docmaster'
user: str = 'docmaster'
password: str # 无默认值,必须设置
class StorageSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix='STORAGE_')
backend: str = 'local'
base_path: str = '~/invoice-data'
azure_connection_string: str | None = None
s3_bucket: str | None = None
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
)
database: DatabaseSettings = DatabaseSettings()
storage: StorageSettings = StorageSettings()
# 验证
@field_validator('database')
def validate_database(cls, v):
if not v.password:
raise ValueError('Database password is required')
return v
# 全局配置实例
settings = Settings()
```
---
### 5. 内存队列单点故障 ⚠️ **中风险**
**问题**: AsyncTaskQueue 和 BatchQueue 基于内存
```python
# workers/async_queue.py
class AsyncTaskQueue:
def __init__(self):
self._queue = Queue() # 内存队列
self._workers = []
def enqueue(self, task: AsyncTask) -> None:
self._queue.put(task) # 仅存储在内存
```
**影响**:
- 服务重启丢失所有待处理任务
- 无法水平扩展
- 任务持久化困难
**建议**: 使用 Redis/RabbitMQ 持久化队列
---
### 6. 缺少 API 版本迁移策略 ❓ **低风险**
**问题**: 有 `/api/v1/` 版本,但缺少升级策略
```
当前: /api/v1/admin/documents
未来: /api/v2/admin/documents ?
```
**建议**:
- 制定 API 版本升级流程
- 使用 Header 版本控制
- 维护版本兼容性文档
---
## 关键架构风险矩阵
| 风险项 | 概率 | 影响 | 风险等级 | 优先级 |
|--------|------|------|----------|--------|
| 内存队列丢失任务 | 中 | 高 | **高** | 🔴 P0 |
| AdminDB 职责过重 | 高 | 中 | **中** | 🟡 P1 |
| Service 层混合 | 高 | 中 | **中** | 🟡 P1 |
| 调度器资源竞争 | 中 | 中 | **中** | 🟡 P1 |
| 配置分散 | 高 | 低 | **低** | 🟢 P2 |
| API 版本策略 | 低 | 低 | **低** | 🟢 P2 |
---
## 改进建议路线图
### Phase 1: 立即执行 (本周)
#### 1.1 拆分 AdminDB
```python
# 创建 repositories 包
inference/data/repositories/
├── __init__.py
├── base.py # Repository 基类
├── token.py # TokenRepository
├── document.py # DocumentRepository
├── annotation.py # AnnotationRepository
├── training.py # TrainingRepository
├── dataset.py # DatasetRepository
└── model.py # ModelRepository
```
#### 1.2 统一配置
```python
# 创建统一配置模块
inference/config/
├── __init__.py
├── settings.py # Pydantic Settings
└── validators.py # 配置验证
```
### Phase 2: 短期执行 (本月)
#### 2.1 引入消息队列
```yaml
# docker-compose.yml 添加
services:
redis:
image: redis:7-alpine
ports:
- "6379:6379"
celery_worker:
build: .
command: celery -A inference.tasks worker -l info
depends_on:
- redis
- postgres
```
#### 2.2 添加缓存层
```python
# 使用 Redis 缓存热点数据
from redis import Redis
redis_client = Redis(host='localhost', port=6379)
class CachedDocumentRepository(DocumentRepository):
def find_by_id(self, doc_id: str) -> Document | None:
# 先查缓存
cached = redis_client.get(f"doc:{doc_id}")
if cached:
return Document.parse_raw(cached)
# 再查数据库
doc = super().find_by_id(doc_id)
if doc:
redis_client.setex(f"doc:{doc_id}", 3600, doc.json())
return doc
```
### Phase 3: 长期执行 (本季度)
#### 3.1 数据库读写分离
```python
# 配置主从数据库
class DatabaseManager:
def __init__(self):
self._master = create_engine(MASTER_DB_URL)
self._replica = create_engine(REPLICA_DB_URL)
def get_session(self, readonly: bool = False) -> Session:
engine = self._replica if readonly else self._master
return Session(engine)
```
#### 3.2 事件驱动架构
```python
# 引入事件总线
from event_bus import EventBus
bus = EventBus()
# 发布事件
@router.post("/documents")
async def create_document(...):
doc = document_repo.save(document)
bus.publish('document.created', {'document_id': doc.id})
return doc
# 订阅事件
@bus.subscribe('document.created')
def on_document_created(event):
# 触发自动标注
auto_label_task.delay(event['document_id'])
```
---
## 架构演进建议
### 当前架构 (适合 1-10 用户)
```
Single Instance
├── FastAPI App
├── Memory Queues
└── PostgreSQL
```
### 目标架构 (适合 100+ 用户)
```
Load Balancer
├── FastAPI Instance 1
├── FastAPI Instance 2
└── FastAPI Instance N
┌───────┴───────┐
▼ ▼
Redis Cluster PostgreSQL
(Celery + Cache) (Master + Replica)
```
---
## 总结
### 总体评分
| 维度 | 评分 | 说明 |
|------|------|------|
| **模块化** | 8/10 | 包结构清晰,但部分类过大 |
| **可扩展性** | 7/10 | 水平扩展良好,垂直扩展受限 |
| **可维护性** | 8/10 | 分层合理,但职责边界需细化 |
| **可靠性** | 7/10 | 内存队列是单点故障 |
| **性能** | 8/10 | 异步处理良好 |
| **安全性** | 8/10 | 基础安全到位 |
| **总体** | **7.7/10** | 良好的架构基础,需优化细节 |
### 关键结论
1. **架构设计合理**: Monorepo + 分层架构适合当前规模
2. **主要风险**: 内存队列和数据库职责过重
3. **演进路径**: 引入消息队列和缓存层
4. **投入产出**: 当前架构可支撑到 100+ 用户,无需大规模重构
### 下一步行动
| 优先级 | 任务 | 预计工时 | 影响 |
|--------|------|----------|------|
| 🔴 P0 | 引入 Celery + Redis | 3 天 | 解决任务丢失问题 |
| 🟡 P1 | 拆分 AdminDB | 2 天 | 提升可维护性 |
| 🟡 P1 | 统一配置管理 | 1 天 | 减少配置错误 |
| 🟢 P2 | 添加缓存层 | 2 天 | 提升性能 |
| 🟢 P2 | 数据库读写分离 | 3 天 | 提升扩展性 |
---
## 附录
### 关键文件清单
| 文件 | 职责 | 问题 |
|------|------|------|
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
| `shared/shared/config.py` | 共享配置 | 分散管理 |
### 参考资源
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
- [Celery Documentation](https://docs.celeryproject.org/)
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)

View File

@@ -1,317 +0,0 @@
# Changelog
All notable changes to the Invoice Field Extraction project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added - Phase 1: Security & Infrastructure (2026-01-22)
#### Security Enhancements
- **Environment Variable Management**: Added `python-dotenv` for secure configuration management
- Created `.env.example` template file for configuration reference
- Created `.env` file for actual credentials (gitignored)
- Updated `config.py` to load database password from environment variables
- Added validation to ensure `DB_PASSWORD` is set at startup
- Files modified: `config.py`, `requirements.txt`
- New files: `.env`, `.env.example`
- Tests: `tests/test_config.py` (7 tests, all passing)
- **SQL Injection Prevention**: Fixed SQL injection vulnerabilities in database queries
- Replaced f-string formatting with parameterized queries in `LIMIT` clauses
- Updated `get_all_documents_summary()` to use `%s` placeholder for LIMIT parameter
- Updated `get_failed_matches()` to use `%s` placeholder for LIMIT parameter
- Files modified: `src/data/db.py` (lines 246, 298)
- Tests: `tests/test_db_security.py` (9 tests, all passing)
#### Code Quality
- **Exception Hierarchy**: Created comprehensive custom exception system
- Added base class `InvoiceExtractionError` with message and details support
- Added specific exception types:
- `PDFProcessingError` - PDF rendering/conversion errors
- `OCRError` - OCR processing errors
- `ModelInferenceError` - YOLO model errors
- `FieldValidationError` - Field validation errors (with field-specific attributes)
- `DatabaseError` - Database operation errors
- `ConfigurationError` - Configuration errors
- `PaymentLineParseError` - Payment line parsing errors
- `CustomerNumberParseError` - Customer number parsing errors
- `DataLoadError` - Data loading errors
- `AnnotationError` - Annotation generation errors
- New file: `src/exceptions.py`
- Tests: `tests/test_exceptions.py` (16 tests, all passing)
### Testing
- Added 32 new tests across 3 test files
- Configuration tests: 7 tests
- SQL injection prevention tests: 9 tests
- Exception hierarchy tests: 16 tests
- All tests passing (32/32)
### Documentation
- Created `docs/CODE_REVIEW_REPORT.md` - Comprehensive code quality analysis (550+ lines)
- Created `docs/REFACTORING_PLAN.md` - Detailed 3-phase refactoring plan (600+ lines)
- Created `CHANGELOG.md` - Project changelog (this file)
### Changed
- **Configuration Loading**: Database configuration now loads from environment variables instead of hardcoded values
- Breaking change: Requires `.env` file with `DB_PASSWORD` set
- Migration: Copy `.env.example` to `.env` and set your database password
### Security
- **Fixed**: Database password no longer stored in plain text in `config.py`
- **Fixed**: SQL injection vulnerabilities in LIMIT clauses (2 instances)
### Technical Debt Addressed
- Eliminated security vulnerability: plaintext password storage
- Reduced SQL injection attack surface
- Improved error handling granularity with custom exceptions
---
### Added - Phase 2: Parser Refactoring (2026-01-22)
#### Unified Parser Modules
- **Payment Line Parser**: Created dedicated payment line parsing module
- Handles Swedish payment line format: `# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#`
- Tolerates common OCR errors: spaces in numbers, missing symbols, spaces in check digits
- Supports 4 parsing patterns: full format, no amount, alternative, account-only
- Returns structured `PaymentLineData` with parsed fields
- New file: `src/inference/payment_line_parser.py` (90 lines, 92% coverage)
- Tests: `tests/test_payment_line_parser.py` (23 tests, all passing)
- Eliminates 1st code duplication (payment line parsing logic)
- **Customer Number Parser**: Created dedicated customer number parsing module
- Handles Swedish customer number formats: `JTY 576-3`, `DWQ 211-X`, `FFL 019N`, etc.
- Uses Strategy Pattern with 5 pattern classes:
- `LabeledPattern` - Explicit labels (highest priority, 0.98 confidence)
- `DashFormatPattern` - Standard format with dash (0.95 confidence)
- `NoDashFormatPattern` - Format without dash, adds dash automatically (0.90 confidence)
- `CompactFormatPattern` - Compact format without spaces (0.75 confidence)
- `GenericAlphanumericPattern` - Fallback generic pattern (variable confidence)
- Excludes Swedish postal codes (`SE XXX XX` format)
- Returns highest confidence match
- New file: `src/inference/customer_number_parser.py` (154 lines, 92% coverage)
- Tests: `tests/test_customer_number_parser.py` (32 tests, all passing)
- Reduces `_normalize_customer_number` complexity (127 lines → will use 5-10 lines after integration)
### Testing Summary
**Phase 1 Tests** (32 tests):
- Configuration tests: 7 tests ([test_config.py](tests/test_config.py))
- SQL injection prevention tests: 9 tests ([test_db_security.py](tests/test_db_security.py))
- Exception hierarchy tests: 16 tests ([test_exceptions.py](tests/test_exceptions.py))
**Phase 2 Tests** (121 tests):
- Payment line parser tests: 23 tests ([test_payment_line_parser.py](tests/test_payment_line_parser.py))
- Standard parsing, OCR error handling, real-world examples, edge cases
- Coverage: 92%
- Customer number parser tests: 32 tests ([test_customer_number_parser.py](tests/test_customer_number_parser.py))
- Pattern matching (DashFormat, NoDashFormat, Compact, Labeled)
- Real-world examples, edge cases, Swedish postal code exclusion
- Coverage: 92%
- Field extractor integration tests: 45 tests ([test_field_extractor.py](src/inference/test_field_extractor.py))
- Validates backward compatibility with existing code
- Tests for invoice numbers, bankgiro, plusgiro, amounts, OCR, dates, payment lines, customer numbers
- Pipeline integration tests: 21 tests ([test_pipeline.py](src/inference/test_pipeline.py))
- Cross-validation, payment line parsing, field overrides
**Total**: 153 tests, 100% passing, 4.50s runtime
### Code Quality
- **Eliminated Code Duplication**: Payment line parsing previously in 3 places, now unified in 1 module
- **Improved Maintainability**: Strategy Pattern makes customer number patterns easy to extend
- **Better Test Coverage**: New parsers have 92% coverage vs original 10% in field_extractor.py
#### Parser Integration into field_extractor.py (2026-01-22)
- **field_extractor.py Integration**: Successfully integrated new parsers
- Added `PaymentLineParser` and `CustomerNumberParser` instances (lines 99-101)
- Replaced `_normalize_payment_line` method: 74 lines → 3 lines (lines 640-657)
- Replaced `_normalize_customer_number` method: 127 lines → 3 lines (lines 697-707)
- All 45 existing tests pass (100% backward compatibility maintained)
- Tests run time: 4.21 seconds
- File: `src/inference/field_extractor.py`
#### Parser Integration into pipeline.py (2026-01-22)
- **pipeline.py Integration**: Successfully integrated PaymentLineParser
- Added `PaymentLineParser` import (line 15)
- Added `payment_line_parser` instance initialization (line 128)
- Replaced `_parse_machine_readable_payment_line` method: 36 lines → 6 lines (lines 219-233)
- All 21 existing tests pass (100% backward compatibility maintained)
- Tests run time: 4.00 seconds
- File: `src/inference/pipeline.py`
### Phase 2 Status: **COMPLETED** ✅
- [x] Create unified `payment_line_parser` module ✅
- [x] Create unified `customer_number_parser` module ✅
- [x] Refactor `field_extractor.py` to use new parsers ✅
- [x] Refactor `pipeline.py` to use new parsers ✅
- [x] Comprehensive test suite (153 tests, 100% passing) ✅
### Achieved Impact
- Eliminate code duplication: 3 implementations → 1 ✅ (payment_line unified across field_extractor.py, pipeline.py, tests)
- Reduce `_normalize_payment_line` complexity in field_extractor.py: 74 lines → 3 lines ✅
- Reduce `_normalize_customer_number` complexity in field_extractor.py: 127 lines → 3 lines ✅
- Reduce `_parse_machine_readable_payment_line` complexity in pipeline.py: 36 lines → 6 lines ✅
- Total lines of code eliminated: 201 lines reduced to 12 lines (94% reduction) ✅
- Improve test coverage: New parser modules have 92% coverage (vs original 10% in field_extractor.py)
- Simplify maintenance: Pattern-based approach makes extension easy
- 100% backward compatibility: All 66 existing tests pass (45 field_extractor + 21 pipeline)
---
## Phase 3: Performance & Documentation (2026-01-22)
### Added
#### Configuration Constants Extraction
- **Created `src/inference/constants.py`**: Centralized configuration constants
- Detection & model configuration (confidence thresholds, IOU)
- Image processing configuration (DPI, scaling factors)
- Customer number parser confidence scores
- Field extraction confidence multipliers
- Account type detection thresholds
- Pattern matching constants
- 90 lines of well-documented constants with usage notes
- Eliminates ~15 hardcoded magic numbers across codebase
- File: [src/inference/constants.py](src/inference/constants.py)
#### Performance Optimization Documentation
- **Created `docs/PERFORMANCE_OPTIMIZATION.md`**: Comprehensive performance guide (400+ lines)
- **Batch Processing Optimization**: Parallel processing strategies, already-implemented dual pool system
- **Database Query Optimization**: Connection pooling recommendations, index strategies
- **Caching Strategies**: Model loading cache, parser reuse (already optimal), OCR result caching
- **Memory Management**: Explicit cleanup, generator patterns, context managers
- **Profiling Guidelines**: cProfile, memory_profiler, py-spy recommendations
- **Benchmarking Scripts**: Ready-to-use performance measurement code
- **Priority Roadmap**: High/Medium/Low priority optimizations with effort estimates
- Expected impact: 2-5x throughput improvement for batch processing
- File: [docs/PERFORMANCE_OPTIMIZATION.md](docs/PERFORMANCE_OPTIMIZATION.md)
### Phase 3 Status: **COMPLETED** ✅
- [x] Configuration constants extraction ✅
- [x] Performance optimization analysis ✅
- [x] Batch processing optimization recommendations ✅
- [x] Database optimization strategies ✅
- [x] Caching and memory management guidelines ✅
- [x] Profiling and benchmarking documentation ✅
### Deliverables
**New Files** (2 files):
1. `src/inference/constants.py` (90 lines) - Centralized configuration constants
2. `docs/PERFORMANCE_OPTIMIZATION.md` (400+ lines) - Performance optimization guide
**Impact**:
- Eliminates 15+ hardcoded magic numbers
- Provides clear optimization roadmap
- Documents existing performance features
- Identifies quick wins (connection pooling, indexes)
- Long-term strategy (caching, profiling)
---
## Notes
### Breaking Changes
- **v2.x**: Requires `.env` file with database credentials
- Action required: Create `.env` file based on `.env.example`
- Affected: All deployments, CI/CD pipelines
### Migration Guide
#### From v1.x to v2.x (Environment Variables)
1. Copy `.env.example` to `.env`:
```bash
cp .env.example .env
```
2. Edit `.env` and set your database password:
```
DB_PASSWORD=your_actual_password_here
```
3. Install new dependency:
```bash
pip install python-dotenv
```
4. Verify configuration loads correctly:
```bash
python -c "import config; print('Config loaded successfully')"
```
## Summary of All Work Completed
### Files Created (13 new files)
**Phase 1** (3 files):
1. `.env` - Environment variables for database credentials
2. `.env.example` - Template for environment configuration
3. `src/exceptions.py` - Custom exception hierarchy (35 lines, 66% coverage)
**Phase 2** (7 files):
4. `src/inference/payment_line_parser.py` - Unified payment line parsing (90 lines, 92% coverage)
5. `src/inference/customer_number_parser.py` - Unified customer number parsing (154 lines, 92% coverage)
6. `tests/test_config.py` - Configuration tests (7 tests)
7. `tests/test_db_security.py` - SQL injection prevention tests (9 tests)
8. `tests/test_exceptions.py` - Exception hierarchy tests (16 tests)
9. `tests/test_payment_line_parser.py` - Payment line parser tests (23 tests)
10. `tests/test_customer_number_parser.py` - Customer number parser tests (32 tests)
**Phase 3** (2 files):
11. `src/inference/constants.py` - Centralized configuration constants (90 lines)
12. `docs/PERFORMANCE_OPTIMIZATION.md` - Performance optimization guide (400+ lines)
**Documentation** (1 file):
13. `CHANGELOG.md` - This file (260+ lines of detailed documentation)
### Files Modified (4 files)
1. `config.py` - Added environment variable loading with python-dotenv
2. `src/data/db.py` - Fixed 2 SQL injection vulnerabilities (lines 246, 298)
3. `src/inference/field_extractor.py` - Integrated new parsers (reduced 201 lines to 6 lines)
4. `src/inference/pipeline.py` - Integrated PaymentLineParser (reduced 36 lines to 6 lines)
5. `requirements.txt` - Added python-dotenv dependency
### Test Summary
- **Total tests**: 153 tests across 7 test files
- **Passing**: 153 (100%)
- **Failing**: 0
- **Runtime**: 4.50 seconds
- **Coverage**:
- New parser modules: 92%
- Config module: 100%
- Exception module: 66%
- DB security coverage: 18% (focused on parameterized queries)
### Code Metrics
- **Lines eliminated**: 237 lines of duplicated/complex code → 18 lines (92% reduction)
- field_extractor.py: 201 lines → 6 lines
- pipeline.py: 36 lines → 6 lines
- **New code added**: 279 lines of well-tested parser code
- **Net impact**: Replaced 237 lines of duplicate code with 279 lines of unified, tested code (+42 lines, but -3 implementations)
- **Test coverage improvement**: 0% → 92% for parser logic
### Performance Impact
- Configuration loading: Negligible (<1ms overhead for .env parsing)
- SQL queries: No performance change (parameterized queries are standard practice)
- Parser refactoring: No performance degradation (logic simplified, not changed)
- Exception handling: Minimal overhead (only when exceptions are raised)
### Security Improvements
- Eliminated plaintext password storage
- Fixed 2 SQL injection vulnerabilities
- Added input validation in database layer
### Maintainability Improvements
- Eliminated code duplication (3 implementations 1)
- Strategy Pattern enables easy extension of customer number formats
- Comprehensive test suite (153 tests) ensures safe refactoring
- 100% backward compatibility maintained
- Custom exception hierarchy for granular error handling

View File

@@ -1,805 +0,0 @@
# Invoice Master POC v2 - 详细代码审查报告
**审查日期**: 2026-02-01
**审查人**: Claude Code
**项目路径**: `C:\Users\yaoji\git\ColaCoder\invoice-master-poc-v2`
**代码统计**:
- Python文件: 200+ 个
- 测试文件: 97 个
- TypeScript/React文件: 39 个
- 总测试数: 1,601 个
- 测试覆盖率: 28%
---
## 目录
1. [执行摘要](#执行摘要)
2. [架构概览](#架构概览)
3. [详细模块审查](#详细模块审查)
4. [代码质量问题](#代码质量问题)
5. [安全风险分析](#安全风险分析)
6. [性能问题](#性能问题)
7. [改进建议](#改进建议)
8. [总结与评分](#总结与评分)
---
## 执行摘要
### 总体评估
| 维度 | 评分 | 状态 |
|------|------|------|
| **代码质量** | 7.5/10 | 良好,但有改进空间 |
| **安全性** | 7/10 | 基础安全到位,需加强 |
| **可维护性** | 8/10 | 模块化良好 |
| **测试覆盖** | 5/10 | 偏低,需提升 |
| **性能** | 8/10 | 异步处理良好 |
| **文档** | 8/10 | 文档详尽 |
| **总体** | **7.3/10** | 生产就绪,需小幅改进 |
### 关键发现
**优势:**
- 清晰的Monorepo架构三包分离合理
- 类型注解覆盖率高(>90%
- 存储抽象层设计优秀
- FastAPI使用规范依赖注入模式良好
- 异常处理完善,自定义异常层次清晰
**风险:**
- 测试覆盖率仅28%,远低于行业标准
- AdminDB类过大50+方法),违反单一职责原则
- 内存队列存在单点故障风险
- 部分安全细节需加强(时序攻击、文件上传验证)
- 前端状态管理简单,可能难以扩展
---
## 架构概览
### 项目结构
```
invoice-master-poc-v2/
├── packages/
│ ├── shared/ # 共享库 (74个Python文件)
│ │ ├── pdf/ # PDF处理
│ │ ├── ocr/ # OCR封装
│ │ ├── normalize/ # 字段规范化
│ │ ├── matcher/ # 字段匹配
│ │ ├── storage/ # 存储抽象层
│ │ ├── training/ # 训练组件
│ │ └── augmentation/# 数据增强
│ ├── training/ # 训练服务 (26个Python文件)
│ │ ├── cli/ # 命令行工具
│ │ ├── yolo/ # YOLO数据集
│ │ └── processing/ # 任务处理
│ └── inference/ # 推理服务 (100个Python文件)
│ ├── web/ # FastAPI应用
│ ├── pipeline/ # 推理管道
│ ├── data/ # 数据层
│ └── cli/ # 命令行工具
├── frontend/ # React前端 (39个TS/TSX文件)
│ ├── src/
│ │ ├── components/ # UI组件
│ │ ├── hooks/ # React Query hooks
│ │ └── api/ # API客户端
└── tests/ # 测试 (97个Python文件)
```
### 技术栈
| 层级 | 技术 | 评估 |
|------|------|------|
| **前端** | React 18 + TypeScript + Vite + TailwindCSS | 现代栈,类型安全 |
| **API框架** | FastAPI + Uvicorn | 高性能,异步支持 |
| **数据库** | PostgreSQL + SQLModel | 类型安全ORM |
| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) | 业界标准 |
| **OCR** | PaddleOCR v5 | 支持瑞典语 |
| **部署** | Docker + Azure/AWS | 云原生 |
---
## 详细模块审查
### 1. Shared Package
#### 1.1 配置模块 (`shared/config.py`)
**文件位置**: `packages/shared/shared/config.py`
**代码行数**: 82行
**优点:**
- 使用环境变量加载配置,无硬编码敏感信息
- DPI配置统一管理DEFAULT_DPI = 150
- 密码无默认值,强制要求设置
**问题:**
```python
# 问题1: 配置分散,缺少验证
DATABASE = {
'host': os.getenv('DB_HOST', '192.168.68.31'), # 硬编码IP
'port': int(os.getenv('DB_PORT', '5432')),
# ...
}
# 问题2: 缺少类型安全
# 建议使用 Pydantic Settings
```
**严重程度**: 中
**建议**: 使用 Pydantic Settings 集中管理配置,添加验证逻辑
---
#### 1.2 存储抽象层 (`shared/storage/`)
**文件位置**: `packages/shared/shared/storage/`
**包含文件**: 8个
**优点:**
- 设计优秀的抽象接口 `StorageBackend`
- 支持 Local/Azure/S3 多后端
- 预签名URL支持
- 异常层次清晰
**代码示例 - 优秀设计:**
```python
class StorageBackend(ABC):
@abstractmethod
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
pass
@abstractmethod
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
pass
```
**问题:**
- `upload_bytes``download_bytes` 默认实现使用临时文件,效率较低
- 缺少文件类型验证(魔术字节检查)
**严重程度**: 低
**建议**: 子类可重写bytes方法以提高效率添加文件类型验证
---
#### 1.3 异常定义 (`shared/exceptions.py`)
**文件位置**: `packages/shared/shared/exceptions.py`
**代码行数**: 103行
**优点:**
- 清晰的异常层次结构
- 所有异常继承自 `InvoiceExtractionError`
- 包含详细的错误上下文
**代码示例:**
```python
class InvoiceExtractionError(Exception):
def __init__(self, message: str, details: dict = None):
super().__init__(message)
self.message = message
self.details = details or {}
```
**评分**: 9/10 - 设计优秀
---
#### 1.4 数据增强 (`shared/augmentation/`)
**文件位置**: `packages/shared/shared/augmentation/`
**包含文件**: 10个
**功能:**
- 12种数据增强策略
- 透视变换、皱纹、边缘损坏、污渍等
- 高斯模糊、运动模糊、噪声等
**代码质量**: 良好,模块化设计
---
### 2. Inference Package
#### 2.1 认证模块 (`inference/web/core/auth.py`)
**文件位置**: `packages/inference/inference/web/core/auth.py`
**代码行数**: 61行
**优点:**
- 使用FastAPI依赖注入模式
- Token过期检查
- 记录最后使用时间
**安全问题:**
```python
# 问题: 时序攻击风险 (第46行)
if not admin_db.is_valid_admin_token(x_admin_token):
raise HTTPException(status_code=401, detail="Invalid or expired admin token.")
# 建议: 使用 constant-time 比较
import hmac
if not hmac.compare_digest(token, expected_token):
raise HTTPException(status_code=401, ...)
```
**严重程度**: 中
**建议**: 使用 `hmac.compare_digest()` 进行constant-time比较
---
#### 2.2 限流器 (`inference/web/core/rate_limiter.py`)
**文件位置**: `packages/inference/inference/web/core/rate_limiter.py`
**代码行数**: 212行
**优点:**
- 滑动窗口算法实现
- 线程安全使用Lock
- 支持并发任务限制
- 可配置的限流策略
**代码示例 - 优秀设计:**
```python
@dataclass(frozen=True)
class RateLimitConfig:
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000
```
**问题:**
- 内存存储,服务重启后限流状态丢失
- 分布式部署时无法共享限流状态
**严重程度**: 中
**建议**: 生产环境使用Redis实现分布式限流
---
#### 2.3 AdminDB (`inference/data/admin_db.py`)
**文件位置**: `packages/inference/inference/data/admin_db.py`
**代码行数**: 1300+行
**严重问题 - 类过大:**
```python
class AdminDB:
# Token管理 (5个方法)
# 文档管理 (8个方法)
# 标注管理 (6个方法)
# 训练任务 (7个方法)
# 数据集 (6个方法)
# 模型版本 (5个方法)
# 批处理 (4个方法)
# 锁管理 (3个方法)
# ... 总计50+方法
```
**影响:**
- 违反单一职责原则
- 难以维护
- 测试困难
- 不同领域变更互相影响
**严重程度**: 高
**建议**: 按领域拆分为Repository模式
```python
# 建议重构
class TokenRepository:
def validate(self, token: str) -> bool: ...
class DocumentRepository:
def find_by_id(self, doc_id: str) -> Document | None: ...
class TrainingRepository:
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
```
---
#### 2.4 文档路由 (`inference/web/api/v1/admin/documents.py`)
**文件位置**: `packages/inference/inference/web/api/v1/admin/documents.py`
**代码行数**: 692行
**优点:**
- FastAPI使用规范
- 输入验证完善
- 响应模型定义清晰
- 错误处理良好
**问题:**
```python
# 问题1: 文件上传缺少魔术字节验证 (第127-131行)
content = await file.read()
# 建议: 验证PDF魔术字节 %PDF
# 问题2: 路径遍历风险 (第494-498行)
filename = Path(document.file_path).name
# 建议: 使用 Path.name 并验证路径范围
# 问题3: 函数过长,职责过多
# _convert_pdf_to_images 函数混合了PDF处理和存储操作
```
**严重程度**: 中
**建议**: 添加文件类型验证,拆分大函数
---
#### 2.5 推理服务 (`inference/web/services/inference.py`)
**文件位置**: `packages/inference/inference/web/services/inference.py`
**代码行数**: 361行
**优点:**
- 支持动态模型加载
- 懒加载初始化
- 模型热重载支持
**问题:**
```python
# 问题1: 混合业务逻辑和技术实现
def process_image(self, image_path: Path, ...) -> ServiceResult:
# 1. 技术细节: 图像解码
# 2. 业务逻辑: 字段提取
# 3. 技术细节: 模型推理
# 4. 业务逻辑: 结果验证
# 问题2: 可视化方法重复加载模型
model = YOLO(str(self.model_config.model_path)) # 第316行
# 应该在初始化时加载避免重复IO
# 问题3: 临时文件未使用上下文管理器
temp_path = results_dir / f"{doc_id}_temp.png"
# 建议使用 tempfile 上下文管理器
```
**严重程度**: 中
**建议**: 引入领域层和适配器模式,分离业务和技术逻辑
---
#### 2.6 异步队列 (`inference/web/workers/async_queue.py`)
**文件位置**: `packages/inference/inference/web/workers/async_queue.py`
**代码行数**: 213行
**优点:**
- 线程安全实现
- 优雅关闭支持
- 任务状态跟踪
**严重问题:**
```python
# 问题: 内存队列,服务重启丢失任务 (第42行)
self._queue: Queue[AsyncTask] = Queue(maxsize=max_size)
# 问题: 无法水平扩展
# 问题: 任务持久化困难
```
**严重程度**: 高
**建议**: 使用Redis/RabbitMQ持久化队列
---
### 3. Training Package
#### 3.1 整体评估
**文件数量**: 26个Python文件
**优点:**
- CLI工具设计良好
- 双池协调器CPU + GPU设计优秀
- 数据增强策略丰富
**总体评分**: 8/10
---
### 4. Frontend
#### 4.1 API客户端 (`frontend/src/api/client.ts`)
**文件位置**: `frontend/src/api/client.ts`
**代码行数**: 42行
**优点:**
- Axios配置清晰
- 请求/响应拦截器
- 认证token自动添加
**问题:**
```typescript
// 问题1: Token存储在localStorage存在XSS风险
const token = localStorage.getItem('admin_token')
// 问题2: 401错误处理不完整
if (error.response?.status === 401) {
console.warn('Authentication required...')
// 应该触发重新登录或token刷新
}
```
**严重程度**: 中
**建议**: 考虑使用http-only cookie存储token完善错误处理
---
#### 4.2 Dashboard组件 (`frontend/src/components/Dashboard.tsx`)
**文件位置**: `frontend/src/components/Dashboard.tsx`
**代码行数**: 301行
**优点:**
- React hooks使用规范
- 类型定义清晰
- UI响应式设计
**问题:**
```typescript
// 问题1: 硬编码的进度值
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
if (doc.auto_label_status === 'running') {
return 45 // 硬编码!
}
// ...
}
// 问题2: 搜索功能未实现
// 没有onChange处理
// 问题3: 缺少错误边界处理
// 组件应该包裹在Error Boundary中
```
**严重程度**: 低
**建议**: 实现真实的进度获取,添加搜索功能
---
#### 4.3 整体评估
**优点:**
- TypeScript类型安全
- React Query状态管理
- TailwindCSS样式一致
**问题:**
- 缺少错误边界
- 部分功能硬编码
- 缺少单元测试
**总体评分**: 7.5/10
---
### 5. Tests
#### 5.1 测试统计
- **测试文件数**: 97个
- **测试总数**: 1,601个
- **测试覆盖率**: 28%
#### 5.2 覆盖率分析
| 模块 | 估计覆盖率 | 状态 |
|------|-----------|------|
| `shared/` | 35% | 偏低 |
| `inference/web/` | 25% | 偏低 |
| `inference/pipeline/` | 20% | 严重不足 |
| `training/` | 30% | 偏低 |
| `frontend/` | 15% | 严重不足 |
#### 5.3 测试质量问题
**优点:**
- 使用了pytest框架
- 有conftest.py配置
- 部分集成测试
**问题:**
- 覆盖率远低于行业标准80%
- 缺少端到端测试
- 部分测试可能过于简单
**严重程度**: 高
**建议**: 制定测试计划,优先覆盖核心业务逻辑
---
## 代码质量问题
### 高优先级问题
| 问题 | 位置 | 影响 | 建议 |
|------|------|------|------|
| AdminDB类过大 | `inference/data/admin_db.py` | 维护困难 | 拆分为Repository模式 |
| 内存队列单点故障 | `inference/web/workers/async_queue.py` | 任务丢失 | 使用Redis持久化 |
| 测试覆盖率过低 | 全项目 | 代码风险 | 提升至60%+ |
### 中优先级问题
| 问题 | 位置 | 影响 | 建议 |
|------|------|------|------|
| 时序攻击风险 | `inference/web/core/auth.py` | 安全漏洞 | 使用hmac.compare_digest |
| 限流器内存存储 | `inference/web/core/rate_limiter.py` | 分布式问题 | 使用Redis |
| 配置分散 | `shared/config.py` | 难以管理 | 使用Pydantic Settings |
| 文件上传验证不足 | `inference/web/api/v1/admin/documents.py` | 安全风险 | 添加魔术字节验证 |
| 推理服务混合职责 | `inference/web/services/inference.py` | 难以测试 | 分离业务和技术逻辑 |
### 低优先级问题
| 问题 | 位置 | 影响 | 建议 |
|------|------|------|------|
| 前端搜索未实现 | `frontend/src/components/Dashboard.tsx` | 功能缺失 | 实现搜索功能 |
| 硬编码进度值 | `frontend/src/components/Dashboard.tsx` | 用户体验 | 获取真实进度 |
| Token存储方式 | `frontend/src/api/client.ts` | XSS风险 | 考虑http-only cookie |
---
## 安全风险分析
### 已识别的安全风险
#### 1. 时序攻击 (中风险)
**位置**: `inference/web/core/auth.py:46`
```python
# 当前实现(有风险)
if not admin_db.is_valid_admin_token(x_admin_token):
raise HTTPException(status_code=401, ...)
# 安全实现
import hmac
if not hmac.compare_digest(token, expected_token):
raise HTTPException(status_code=401, ...)
```
#### 2. 文件上传验证不足 (中风险)
**位置**: `inference/web/api/v1/admin/documents.py:127-131`
```python
# 建议添加魔术字节验证
ALLOWED_EXTENSIONS = {".pdf"}
MAX_FILE_SIZE = 10 * 1024 * 1024
if not content.startswith(b"%PDF"):
raise HTTPException(400, "Invalid PDF file format")
```
#### 3. 路径遍历风险 (中风险)
**位置**: `inference/web/api/v1/admin/documents.py:494-498`
```python
# 建议实现
from pathlib import Path
def get_safe_path(filename: str, base_dir: Path) -> Path:
safe_name = Path(filename).name
full_path = (base_dir / safe_name).resolve()
if not full_path.is_relative_to(base_dir):
raise HTTPException(400, "Invalid file path")
return full_path
```
#### 4. CORS配置 (低风险)
**位置**: FastAPI中间件配置
```python
# 建议生产环境配置
ALLOWED_ORIGINS = [
"http://localhost:5173",
"https://your-domain.com",
]
```
#### 5. XSS风险 (低风险)
**位置**: `frontend/src/api/client.ts:13`
```typescript
// 当前实现
const token = localStorage.getItem('admin_token')
// 建议考虑
// 使用http-only cookie存储敏感token
```
### 安全评分
| 类别 | 评分 | 说明 |
|------|------|------|
| 认证 | 8/10 | 基础良好,需加强时序攻击防护 |
| 输入验证 | 7/10 | 基本验证到位,需加强文件验证 |
| 数据保护 | 8/10 | 无敏感信息硬编码 |
| 传输安全 | 8/10 | 使用HTTPS生产环境 |
| 总体 | 7.5/10 | 基础安全良好,需加强细节 |
---
## 性能问题
### 已识别的性能问题
#### 1. 重复模型加载
**位置**: `inference/web/services/inference.py:316`
```python
# 问题: 每次可视化都重新加载模型
model = YOLO(str(self.model_config.model_path))
# 建议: 复用已加载的模型
```
#### 2. 临时文件处理
**位置**: `shared/storage/base.py:178-203`
```python
# 问题: bytes操作使用临时文件
def upload_bytes(self, data: bytes, ...):
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(data)
temp_path = Path(f.name)
# ...
# 建议: 子类重写为直接上传
```
#### 3. 数据库查询优化
**位置**: `inference/data/admin_db.py`
```python
# 问题: N+1查询风险
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
# ...
# 建议: 使用join预加载
```
### 性能评分
| 类别 | 评分 | 说明 |
|------|------|------|
| 响应时间 | 8/10 | 异步处理良好 |
| 资源使用 | 7/10 | 有优化空间 |
| 可扩展性 | 7/10 | 内存队列限制 |
| 并发处理 | 8/10 | 线程池设计良好 |
| 总体 | 7.5/10 | 良好,有优化空间 |
---
## 改进建议
### 立即执行 (本周)
1. **拆分AdminDB**
- 创建 `repositories/` 目录
- 按领域拆分TokenRepository, DocumentRepository, TrainingRepository
- 估计工时: 2天
2. **修复安全漏洞**
- 添加 `hmac.compare_digest()` 时序攻击防护
- 添加文件魔术字节验证
- 估计工时: 0.5天
3. **提升测试覆盖率**
- 优先测试 `inference/pipeline/`
- 添加API集成测试
- 目标: 从28%提升至50%
- 估计工时: 3天
### 短期执行 (本月)
4. **引入消息队列**
- 添加Redis服务
- 使用Celery替换内存队列
- 估计工时: 3天
5. **统一配置管理**
- 使用 Pydantic Settings
- 集中验证逻辑
- 估计工时: 1天
6. **添加缓存层**
- Redis缓存热点数据
- 缓存文档、模型配置
- 估计工时: 2天
### 长期执行 (本季度)
7. **数据库读写分离**
- 配置主从数据库
- 读操作使用从库
- 估计工时: 3天
8. **事件驱动架构**
- 引入事件总线
- 解耦模块依赖
- 估计工时: 5天
9. **前端优化**
- 添加错误边界
- 实现真实搜索功能
- 添加E2E测试
- 估计工时: 3天
---
## 总结与评分
### 各维度评分
| 维度 | 评分 | 权重 | 加权得分 |
|------|------|------|----------|
| **代码质量** | 7.5/10 | 20% | 1.5 |
| **安全性** | 7.5/10 | 20% | 1.5 |
| **可维护性** | 8/10 | 15% | 1.2 |
| **测试覆盖** | 5/10 | 15% | 0.75 |
| **性能** | 7.5/10 | 15% | 1.125 |
| **文档** | 8/10 | 10% | 0.8 |
| **架构设计** | 8/10 | 5% | 0.4 |
| **总体** | **7.3/10** | 100% | **7.275** |
### 关键结论
1. **架构设计优秀**: Monorepo + 三包分离架构清晰,便于维护和扩展
2. **代码质量良好**: 类型注解完善,文档详尽,结构清晰
3. **安全基础良好**: 没有严重的安全漏洞,基础防护到位
4. **测试是短板**: 28%覆盖率是最大风险点
5. **生产就绪**: 经过小幅改进后可以投入生产使用
### 下一步行动
| 优先级 | 任务 | 预计工时 | 影响 |
|--------|------|----------|------|
| 高 | 拆分AdminDB | 2天 | 提升可维护性 |
| 高 | 引入Redis队列 | 3天 | 解决任务丢失问题 |
| 高 | 提升测试覆盖率 | 5天 | 降低代码风险 |
| 中 | 修复安全漏洞 | 0.5天 | 提升安全性 |
| 中 | 统一配置管理 | 1天 | 减少配置错误 |
| 低 | 前端优化 | 3天 | 提升用户体验 |
---
## 附录
### 关键文件清单
| 文件 | 职责 | 问题 |
|------|------|------|
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
| `shared/shared/config.py` | 共享配置 | 分散管理 |
### 参考资源
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
- [Celery Documentation](https://docs.celeryproject.org/)
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
---
**报告生成时间**: 2026-02-01
**审查工具**: Claude Code + AST-grep + LSP

View File

@@ -124,7 +124,7 @@ class AmountNormalizer(BaseNormalizer):
if not match:
continue
amount = self._parse_amount_str(match)
if amount is not None and amount > 0:
if amount is not None and 0 < amount < 10_000_000:
all_amounts.append(amount)
# Return the last amount found (usually the total)
@@ -134,7 +134,7 @@ class AmountNormalizer(BaseNormalizer):
# Fallback: try shared validator on cleaned text
cleaned = TextCleaner.normalize_amount_text(text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and amount > 0:
if amount is not None and 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
# Try to find any decimal number
@@ -144,7 +144,7 @@ class AmountNormalizer(BaseNormalizer):
amount_str = matches[-1].replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
if 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
@@ -156,7 +156,7 @@ class AmountNormalizer(BaseNormalizer):
if match:
try:
amount = float(match.group(1))
if amount > 0:
if 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
@@ -168,7 +168,7 @@ class AmountNormalizer(BaseNormalizer):
# Take the last/largest number
try:
amount = float(matches[-1])
if amount > 0:
if 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass

View File

@@ -62,14 +62,25 @@ class InvoiceNumberNormalizer(BaseNormalizer):
# Skip if it looks like a date (YYYYMMDD)
if len(seq) == 8 and seq.startswith("20"):
continue
# Skip year-only values (2024, 2025, 2026, etc.)
if len(seq) == 4 and seq.startswith("20"):
continue
# Skip if too long (likely OCR number)
if len(seq) > 10:
continue
valid_sequences.append(seq)
if valid_sequences:
# Return shortest valid sequence
return NormalizationResult.success(min(valid_sequences, key=len))
# Prefer 4-8 digit sequences (typical invoice numbers),
# then closest to 6 digits within that range.
# This avoids picking short fragments like "775" from amounts.
def _score(seq: str) -> tuple[int, int]:
length = len(seq)
if 4 <= length <= 8:
return (1, -abs(length - 6))
return (0, -length)
return NormalizationResult.success(max(valid_sequences, key=_score))
# Fallback: extract all digits if nothing else works
digits = re.sub(r"\D", "", text)

View File

@@ -14,7 +14,7 @@ class OcrNumberNormalizer(BaseNormalizer):
Normalizes OCR (Optical Character Recognition) reference numbers.
OCR numbers in Swedish payment systems:
- Minimum 5 digits
- Minimum 2 digits
- Used for automated payment matching
"""
@@ -29,7 +29,7 @@ class OcrNumberNormalizer(BaseNormalizer):
digits = re.sub(r"\D", "", text)
if len(digits) < 5:
if len(digits) < 2:
return NormalizationResult.failure(
f"Too few digits for OCR: {len(digits)}"
)

View File

@@ -234,7 +234,7 @@ class InferencePipeline:
confidence_threshold=confidence_threshold,
device='cuda' if use_gpu else 'cpu'
)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu, dpi=dpi)
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi
self.enable_fallback = enable_fallback
@@ -361,6 +361,7 @@ class InferencePipeline:
# Fallback if key fields are missing
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
self._dedup_invoice_number(result)
# Extract business invoice features if enabled
if use_business_features:
@@ -477,9 +478,48 @@ class InferencePipeline:
# Store bbox for each field (useful for payment_line and other fields)
result.bboxes[field_name] = best.bbox
# Validate date consistency
self._validate_dates(result)
# Perform cross-validation if payment_line is detected
self._cross_validate_payment_line(result)
# Remove InvoiceNumber if it duplicates OCR or Bankgiro
self._dedup_invoice_number(result)
def _validate_dates(self, result: InferenceResult) -> None:
"""Remove InvoiceDueDate if it is earlier than InvoiceDate."""
invoice_date = result.fields.get('InvoiceDate')
due_date = result.fields.get('InvoiceDueDate')
if invoice_date and due_date and due_date < invoice_date:
del result.fields['InvoiceDueDate']
result.confidence.pop('InvoiceDueDate', None)
result.bboxes.pop('InvoiceDueDate', None)
def _dedup_invoice_number(self, result: InferenceResult) -> None:
"""Remove InvoiceNumber if it duplicates OCR or Bankgiro digits."""
inv_num = result.fields.get('InvoiceNumber')
if not inv_num:
return
inv_digits = re.sub(r'\D', '', str(inv_num))
# Check against OCR
ocr = result.fields.get('OCR')
if ocr and inv_digits == re.sub(r'\D', '', str(ocr)):
del result.fields['InvoiceNumber']
result.confidence.pop('InvoiceNumber', None)
result.bboxes.pop('InvoiceNumber', None)
return
# Check against Bankgiro (exact or substring match)
bg = result.fields.get('Bankgiro')
if bg:
bg_digits = re.sub(r'\D', '', str(bg))
if inv_digits == bg_digits or inv_digits in bg_digits:
del result.fields['InvoiceNumber']
result.confidence.pop('InvoiceNumber', None)
result.bboxes.pop('InvoiceNumber', None)
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
"""
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
@@ -638,10 +678,14 @@ class InferencePipeline:
def _needs_fallback(self, result: InferenceResult) -> bool:
"""Check if fallback OCR is needed."""
# Check for key fields
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
missing = sum(1 for f in key_fields if f not in result.fields)
return missing >= 2 # Fallback if 2+ key fields missing
important_fields = ['InvoiceDate', 'InvoiceDueDate', 'supplier_organisation_number']
key_missing = sum(1 for f in key_fields if f not in result.fields)
important_missing = sum(1 for f in important_fields if f not in result.fields)
# Fallback if any key field missing OR 2+ important fields missing
return key_missing >= 1 or important_missing >= 2
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback."""
@@ -673,12 +717,13 @@ class InferencePipeline:
"""Extract fields using regex patterns (fallback)."""
patterns = {
'Amount': [
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
r'(?:att\s+betala)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
r'(?:summa|total|belopp)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
r'([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)\s*$',
],
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
r'(?<!\d)(\d{3,4}[-\s]\d{4})(?!\d)',
],
'OCR': [
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
@@ -686,6 +731,20 @@ class InferencePipeline:
'InvoiceNumber': [
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
],
'InvoiceDate': [
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
],
'InvoiceDueDate': [
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
],
'supplier_organisation_number': [
r'(?:org\.?\s*n[ru]|organisationsnummer)\s*[:.]?\s*(\d{6}[-\s]?\d{4})',
],
'Plusgiro': [
r'(?:plusgiro|pg)\s*[:.]?\s*(\d[\d\s-]{4,12}\d)',
],
}
for field_name, field_patterns in patterns.items():
@@ -708,6 +767,22 @@ class InferencePipeline:
digits = re.sub(r'\D', '', value)
if len(digits) == 8:
value = f"{digits[:4]}-{digits[4:]}"
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Normalize DD/MM/YYYY to YYYY-MM-DD
date_match = re.match(r'(\d{2})[-/](\d{2})[-/](\d{4})', value)
if date_match:
value = f"{date_match.group(3)}-{date_match.group(2)}-{date_match.group(1)}"
# Replace / with -
value = value.replace('/', '-')
elif field_name == 'InvoiceNumber':
# Skip year-like values (2024, 2025, 2026, etc.)
if re.match(r'^20\d{2}$', value):
continue
elif field_name == 'supplier_organisation_number':
# Ensure NNNNNN-NNNN format
digits = re.sub(r'\D', '', value)
if len(digits) == 10:
value = f"{digits[:6]}-{digits[6:]}"
result.fields[field_name] = value
result.confidence[field_name] = 0.5 # Lower confidence for regex

View File

@@ -123,12 +123,12 @@ class ValueSelector:
@staticmethod
def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]:
"""Select token with the longest digit sequence (min 5 digits)."""
"""Select token with the longest digit sequence (min 2 digits)."""
best: OCRToken | None = None
best_count = 0
for token in tokens:
digit_count = _count_digits(token.text)
if digit_count >= 5 and digit_count > best_count:
if digit_count >= 2 and digit_count > best_count:
best = token
best_count = digit_count
return [best] if best else []

103
scripts/analyze_v3.py Normal file
View File

@@ -0,0 +1,103 @@
#!/usr/bin/env python3
"""Analyze batch inference v3 results (Round 2 fixes)."""
import json
from collections import Counter
with open("scripts/inference_results_v3.json") as f:
results = json.load(f)
total = len(results)
success = sum(1 for r in results if r["status"] == 200)
print(f"Total PDFs: {total}, Successful: {success}")
print()
# Summary table
header = f"{'PDF':<40} {'Det':<4} {'Fld':<4} {'Time':<7} Fields"
print(header)
print("-" * 140)
for r in results:
fn = r["filename"][:39]
data = r.get("data", {})
result_data = data.get("result", {})
fields = result_data.get("fields", {})
dets = len(result_data.get("detections", []))
nfields = len(fields)
t = r["time_seconds"]
parts = []
for k, v in fields.items():
sv = str(v)
if len(sv) > 30:
sv = sv[:27] + "..."
parts.append(f"{k}={sv}")
field_str = ", ".join(parts)
print(f"{fn:<40} {dets:<4} {nfields:<4} {t:<7} {field_str}")
print()
# Field coverage
field_counts: Counter = Counter()
conf_sums: Counter = Counter()
ok_count = 0
for r in results:
if r["status"] != 200:
continue
ok_count += 1
result_data = r["data"]["result"]
for k in result_data.get("fields", {}):
field_counts[k] += 1
for k, v in (result_data.get("confidence") or {}).items():
conf_sums[k] += v
print(f"Field Coverage ({ok_count} successful PDFs):")
hdr = f"{'Field':<35} {'Present':<10} {'Rate':<10} {'Avg Conf':<10}"
print(hdr)
print("-" * 65)
for field in [
"InvoiceNumber", "InvoiceDate", "InvoiceDueDate", "OCR",
"Amount", "Bankgiro", "Plusgiro",
"supplier_organisation_number", "customer_number", "payment_line",
]:
cnt = field_counts.get(field, 0)
rate = cnt / ok_count * 100 if ok_count else 0
avg_conf = conf_sums.get(field, 0) / cnt if cnt else 0
flag = ""
if rate < 30:
flag = " <<<"
elif rate < 60:
flag = " !!"
print(f"{field:<35} {cnt:<10} {rate:<10.1f} {avg_conf:<10.3f}{flag}")
# Fallback count
fb_count = 0
for r in results:
if r["status"] == 200:
result_data = r["data"]["result"]
if result_data.get("fallback_used"):
fb_count += 1
print(f"\nFallback used: {fb_count}/{ok_count}")
# Low-confidence fields
print("\nLow-confidence extractions (< 0.7):")
for r in results:
if r["status"] != 200:
continue
result_data = r["data"]["result"]
for k, v in (result_data.get("confidence") or {}).items():
if v < 0.7:
fv = result_data.get("fields", {}).get(k, "?")
print(f" [{v:.3f}] {k:<25} = {str(fv)[:40]:<40} ({r['filename'][:36]})")
# PDFs with very few fields (possible issues)
print("\nPDFs with <= 2 fields extracted:")
for r in results:
if r["status"] != 200:
continue
result_data = r["data"]["result"]
fields = result_data.get("fields", {})
if len(fields) <= 2:
print(f" {r['filename']}: {len(fields)} fields - {list(fields.keys())}")
# Avg time
avg_time = sum(r["time_seconds"] for r in results) / len(results)
print(f"\nAverage processing time: {avg_time:.2f}s")

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""Batch inference v3 - 30 random PDFs for Round 2 validation."""
import json
import os
import random
import time
import requests
PDF_DIR = "/mnt/c/Users/yaoji/git/Billo/Billo.Platform.Document/Billo.Platform.Document.AdminAPI/downloads/to_check"
API_URL = "http://localhost:8000/api/v1/infer"
OUTPUT_FILE = "/mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/scripts/inference_results_v3.json"
SAMPLE_SIZE = 30
def main():
random.seed(99_2026) # New seed for Round 3
all_pdfs = [f for f in os.listdir(PDF_DIR) if f.lower().endswith(".pdf")]
selected = random.sample(all_pdfs, min(SAMPLE_SIZE, len(all_pdfs)))
print(f"Selected {len(selected)} random PDFs for inference")
results = []
for i, filename in enumerate(selected, 1):
filepath = os.path.join(PDF_DIR, filename)
filesize = os.path.getsize(filepath)
print(f"[{i}/{len(selected)}] Processing {filename}...", end=" ", flush=True)
start = time.time()
try:
with open(filepath, "rb") as f:
resp = requests.post(
API_URL,
files={"file": (filename, f, "application/pdf")},
timeout=120,
)
elapsed = round(time.time() - start, 2)
if resp.status_code == 200:
data = resp.json()
field_count = sum(
1 for k, v in data.items()
if k not in (
"DocumentId", "confidence", "success", "fallback_used",
"bboxes", "cross_validation", "processing_time_ms",
"line_items", "vat_summary", "vat_validation",
"raw_detections", "detection_classes", "detection_count",
)
and v is not None
)
det_count = data.get("detection_count", "?")
print(f"OK ({elapsed}s) - {field_count} fields, {det_count} detections")
results.append({
"filename": filename,
"status": resp.status_code,
"time_seconds": elapsed,
"filesize": filesize,
"data": data,
})
else:
print(f"HTTP {resp.status_code} ({elapsed}s)")
results.append({
"filename": filename,
"status": resp.status_code,
"time_seconds": elapsed,
"filesize": filesize,
"error": resp.text[:200],
})
except Exception as e:
elapsed = round(time.time() - start, 2)
print(f"FAILED ({elapsed}s) - {e}")
results.append({
"filename": filename,
"status": -1,
"time_seconds": elapsed,
"filesize": filesize,
"error": str(e),
})
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\nResults saved to {OUTPUT_FILE}")
success = sum(1 for r in results if r["status"] == 200)
failed = len(results) - success
print(f"Total: {len(results)}, Success: {success}, Failed: {failed}")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -1,387 +0,0 @@
#!/usr/bin/env python3
"""
PP-StructureV3 Line Items Extraction POC
Tests line items extraction from Swedish invoices using PP-StructureV3.
Parses HTML table structure to extract structured line item data.
Run with invoice-sm120 conda environment.
"""
import sys
import re
from pathlib import Path
from html.parser import HTMLParser
from dataclasses import dataclass
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "packages" / "backend"))
from paddleocr import PPStructureV3
import fitz # PyMuPDF
@dataclass
class LineItem:
"""Single line item from invoice."""
row_index: int
article_number: str | None
description: str | None
quantity: str | None
unit: str | None
unit_price: str | None
amount: str | None
vat_rate: str | None
confidence: float = 0.9
class TableHTMLParser(HTMLParser):
"""Parse HTML table into rows and cells."""
def __init__(self):
super().__init__()
self.rows: list[list[str]] = []
self.current_row: list[str] = []
self.current_cell: str = ""
self.in_td = False
self.in_thead = False
self.header_row: list[str] = []
def handle_starttag(self, tag, attrs):
if tag == "tr":
self.current_row = []
elif tag in ("td", "th"):
self.in_td = True
self.current_cell = ""
elif tag == "thead":
self.in_thead = True
def handle_endtag(self, tag):
if tag in ("td", "th"):
self.in_td = False
self.current_row.append(self.current_cell.strip())
elif tag == "tr":
if self.current_row:
if self.in_thead:
self.header_row = self.current_row
else:
self.rows.append(self.current_row)
elif tag == "thead":
self.in_thead = False
def handle_data(self, data):
if self.in_td:
self.current_cell += data
# Swedish column name mappings
# Note: Some headers may contain multiple column names merged together
COLUMN_MAPPINGS = {
'article_number': ['art nummer', 'artikelnummer', 'artikel', 'artnr', 'art.nr', 'art nr'],
'description': ['beskrivning', 'produktbeskrivning', 'produkt', 'tjänst', 'text', 'benämning', 'vara/tjänst', 'vara'],
'quantity': ['antal', 'qty', 'st', 'pcs', 'kvantitet'],
'unit': ['enhet', 'unit'],
'unit_price': ['á-pris', 'a-pris', 'pris', 'styckpris', 'enhetspris', 'à pris'],
'amount': ['belopp', 'summa', 'total', 'netto', 'rad summa'],
'vat_rate': ['moms', 'moms%', 'vat', 'skatt', 'moms %'],
}
def normalize_header(header: str) -> str:
"""Normalize header text for matching."""
return header.lower().strip().replace(".", "").replace("-", " ")
def map_columns(headers: list[str]) -> dict[int, str]:
"""Map column indices to field names."""
mapping = {}
for idx, header in enumerate(headers):
normalized = normalize_header(header)
# Skip empty headers
if not normalized.strip():
continue
best_match = None
best_match_len = 0
for field, patterns in COLUMN_MAPPINGS.items():
for pattern in patterns:
# Require exact match or pattern must be a significant portion
if pattern == normalized:
# Exact match - use immediately
best_match = field
best_match_len = len(pattern) + 100 # Prioritize exact
break
elif pattern in normalized and len(pattern) > best_match_len:
# Pattern found in header - use longer matches
if len(pattern) >= 3: # Minimum pattern length
best_match = field
best_match_len = len(pattern)
if best_match_len > 100: # Was exact match
break
if best_match:
mapping[idx] = best_match
return mapping
def parse_table_html(html: str) -> tuple[list[str], list[list[str]]]:
"""Parse HTML table and return header and rows."""
parser = TableHTMLParser()
parser.feed(html)
return parser.header_row, parser.rows
def detect_header_row(rows: list[list[str]]) -> tuple[int, list[str], bool]:
"""
Detect which row is the header based on content patterns.
Returns (header_row_index, header_row, is_at_end).
is_at_end indicates if header is at the end (table is reversed).
Returns (-1, [], False) if no header detected.
"""
header_keywords = set()
for patterns in COLUMN_MAPPINGS.values():
for p in patterns:
header_keywords.add(p.lower())
best_match = (-1, [], 0)
for i, row in enumerate(rows):
# Skip empty rows
if all(not cell.strip() for cell in row):
continue
# Check if row contains header keywords
row_text = " ".join(cell.lower() for cell in row)
matches = sum(1 for kw in header_keywords if kw in row_text)
# Track the best match
if matches > best_match[2]:
best_match = (i, row, matches)
if best_match[2] >= 2:
header_idx = best_match[0]
is_at_end = header_idx == len(rows) - 1 or header_idx > len(rows) // 2
return header_idx, best_match[1], is_at_end
return -1, [], False
def extract_line_items(html: str) -> list[LineItem]:
"""Extract line items from HTML table."""
header, rows = parse_table_html(html)
is_reversed = False
if not header:
# Try to detect header row from content
header_idx, detected_header, is_at_end = detect_header_row(rows)
if header_idx >= 0:
header = detected_header
if is_at_end:
# Header is at the end - table is reversed
is_reversed = True
rows = rows[:header_idx] # Data rows are before header
else:
rows = rows[header_idx + 1:] # Data rows start after header
elif rows:
# Fall back to first non-empty row
for i, row in enumerate(rows):
if any(cell.strip() for cell in row):
header = row
rows = rows[i + 1:]
break
column_map = map_columns(header)
items = []
for row_idx, row in enumerate(rows):
item_data = {
'row_index': row_idx,
'article_number': None,
'description': None,
'quantity': None,
'unit': None,
'unit_price': None,
'amount': None,
'vat_rate': None,
}
for col_idx, cell in enumerate(row):
if col_idx in column_map:
field = column_map[col_idx]
item_data[field] = cell if cell else None
# Only add if we have at least description or amount
if item_data['description'] or item_data['amount']:
items.append(LineItem(**item_data))
return items
def render_pdf_to_image(pdf_path: str, dpi: int = 200) -> bytes:
"""Render first page of PDF to image bytes."""
doc = fitz.open(pdf_path)
page = doc[0]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
img_bytes = pix.tobytes("png")
doc.close()
return img_bytes
def test_line_items_extraction(pdf_path: str) -> dict:
"""Test line items extraction on a PDF."""
print(f"\n{'='*70}")
print(f"Testing Line Items Extraction: {Path(pdf_path).name}")
print(f"{'='*70}")
# Render PDF to image
print("Rendering PDF to image...")
img_bytes = render_pdf_to_image(pdf_path)
# Save temp image
temp_img_path = "/tmp/test_invoice.png"
with open(temp_img_path, "wb") as f:
f.write(img_bytes)
# Initialize PP-StructureV3
print("Initializing PP-StructureV3...")
pipeline = PPStructureV3(
device="gpu:0",
use_doc_orientation_classify=False,
use_doc_unwarping=False,
)
# Run detection
print("Running table detection...")
results = pipeline.predict(temp_img_path)
all_line_items = []
table_details = []
for result in results if results else []:
table_res_list = result.get("table_res_list") if hasattr(result, "get") else None
if table_res_list:
print(f"\nFound {len(table_res_list)} tables")
for i, table_res in enumerate(table_res_list):
html = table_res.get("pred_html", "")
ocr_pred = table_res.get("table_ocr_pred", {})
print(f"\n--- Table {i+1} ---")
# Debug: show full HTML for first table
if i == 0:
print(f" Full HTML:\n{html}")
# Debug: inspect table_ocr_pred structure
if isinstance(ocr_pred, dict):
print(f" table_ocr_pred keys: {list(ocr_pred.keys())}")
# Check if rec_texts exists (actual OCR text)
if "rec_texts" in ocr_pred:
texts = ocr_pred["rec_texts"]
print(f" OCR texts count: {len(texts)}")
print(f" Sample OCR texts: {texts[:5]}")
elif isinstance(ocr_pred, list):
print(f" table_ocr_pred is list with {len(ocr_pred)} items")
if ocr_pred:
print(f" First item type: {type(ocr_pred[0])}")
print(f" First few items: {ocr_pred[:3]}")
# Parse HTML
header, rows = parse_table_html(html)
print(f" HTML Header (from thead): {header}")
print(f" HTML Rows: {len(rows)}")
# Try to detect header if not in thead
detected_header = None
is_reversed = False
if not header and rows:
header_idx, detected_header, is_at_end = detect_header_row(rows)
if header_idx >= 0:
is_reversed = is_at_end
print(f" Detected header at row {header_idx}: {detected_header}")
print(f" Table is {'REVERSED (header at bottom)' if is_reversed else 'normal'}")
header = detected_header
if rows:
print(f" First row: {rows[0]}")
if len(rows) > 1:
print(f" Second row: {rows[1]}")
# Check if this looks like a line items table
column_map = map_columns(header) if header else {}
print(f" Column mapping: {column_map}")
is_line_items_table = (
'description' in column_map.values() or
'amount' in column_map.values() or
'article_number' in column_map.values()
)
if is_line_items_table:
print(f" >>> This appears to be a LINE ITEMS table!")
items = extract_line_items(html)
print(f" Extracted {len(items)} line items:")
for item in items:
print(f" - {item.description}: {item.quantity} x {item.unit_price} = {item.amount}")
all_line_items.extend(items)
else:
print(f" >>> This is NOT a line items table (summary/payment)")
table_details.append({
"index": i,
"header": header,
"row_count": len(rows),
"is_line_items": is_line_items_table,
"column_map": column_map,
})
print(f"\n{'='*70}")
print(f"EXTRACTION SUMMARY")
print(f"{'='*70}")
print(f"Total tables: {len(table_details)}")
print(f"Line items tables: {sum(1 for t in table_details if t['is_line_items'])}")
print(f"Total line items: {len(all_line_items)}")
return {
"pdf": pdf_path,
"tables": table_details,
"line_items": all_line_items,
}
def main():
import argparse
parser = argparse.ArgumentParser(description="Test line items extraction")
parser.add_argument("--pdf", type=str, help="Path to PDF file")
args = parser.parse_args()
if args.pdf:
# Test specific PDF
pdf_path = Path(args.pdf)
if not pdf_path.exists():
# Try relative to project root
pdf_path = project_root / args.pdf
if not pdf_path.exists():
print(f"PDF not found: {args.pdf}")
return
test_line_items_extraction(str(pdf_path))
else:
# Test default invoice
default_pdf = project_root / "exampl" / "Faktura54011.pdf"
if default_pdf.exists():
test_line_items_extraction(str(default_pdf))
else:
print(f"Default PDF not found: {default_pdf}")
print("Usage: python ppstructure_line_items_poc.py --pdf <path>")
if __name__ == "__main__":
main()

View File

@@ -1,154 +0,0 @@
#!/usr/bin/env python3
"""
PP-StructureV3 POC Script
Tests table detection on real Swedish invoices using PP-StructureV3.
Run with invoice-sm120 conda environment.
"""
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "packages" / "backend"))
from paddleocr import PPStructureV3
import fitz # PyMuPDF
def render_pdf_to_image(pdf_path: str, dpi: int = 200) -> bytes:
"""Render first page of PDF to image bytes."""
doc = fitz.open(pdf_path)
page = doc[0]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
img_bytes = pix.tobytes("png")
doc.close()
return img_bytes
def test_table_detection(pdf_path: str) -> dict:
"""Test PP-StructureV3 table detection on a PDF."""
print(f"\n{'='*60}")
print(f"Testing: {Path(pdf_path).name}")
print(f"{'='*60}")
# Render PDF to image
print("Rendering PDF to image...")
img_bytes = render_pdf_to_image(pdf_path)
# Save temp image
temp_img_path = "/tmp/test_invoice.png"
with open(temp_img_path, "wb") as f:
f.write(img_bytes)
print(f"Saved temp image: {temp_img_path}")
# Initialize PP-StructureV3
print("Initializing PP-StructureV3...")
pipeline = PPStructureV3(
device="gpu:0",
use_doc_orientation_classify=False,
use_doc_unwarping=False,
)
# Run detection
print("Running table detection...")
results = pipeline.predict(temp_img_path)
# Parse results - PaddleX 3.x returns dict-like LayoutParsingResultV2
tables_found = []
all_elements = []
for result in results if results else []:
# Get table results from the new API
table_res_list = result.get("table_res_list") if hasattr(result, "get") else None
if table_res_list:
print(f" Found {len(table_res_list)} tables in table_res_list")
for i, table_res in enumerate(table_res_list):
# Debug: show all keys in table_res
if isinstance(table_res, dict):
print(f" Table {i+1} keys: {list(table_res.keys())}")
else:
print(f" Table {i+1} attrs: {[a for a in dir(table_res) if not a.startswith('_')]}")
# Extract table info - use correct key names from PaddleX 3.x
cell_boxes = table_res.get("cell_box_list", [])
html = table_res.get("pred_html", "") # HTML is in pred_html
ocr_text = table_res.get("table_ocr_pred", [])
region_id = table_res.get("table_region_id", -1)
bbox = [] # bbox is stored elsewhere in parsing_res_list
print(f" Table {i+1}:")
print(f" - Cells: {len(cell_boxes) if cell_boxes is not None else 0}")
print(f" - Region ID: {region_id}")
print(f" - HTML length: {len(html) if html else 0}")
print(f" - OCR texts: {len(ocr_text) if ocr_text else 0}")
if html:
print(f" - HTML preview: {html[:300]}...")
if ocr_text and len(ocr_text) > 0:
print(f" - First few OCR texts: {ocr_text[:3]}")
tables_found.append({
"index": i,
"cell_count": len(cell_boxes) if cell_boxes is not None else 0,
"region_id": region_id,
"html": html[:1000] if html else "",
"ocr_count": len(ocr_text) if ocr_text else 0,
})
# Get parsing results for all layout elements
parsing_res_list = result.get("parsing_res_list") if hasattr(result, "get") else None
if parsing_res_list:
print(f"\n Layout elements from parsing_res_list:")
for elem in parsing_res_list[:10]: # Show first 10
label = elem.get("label", "unknown") if isinstance(elem, dict) else getattr(elem, "label", "unknown")
bbox = elem.get("bbox", []) if isinstance(elem, dict) else getattr(elem, "bbox", [])
print(f" - {label}: {bbox}")
all_elements.append({"label": label, "bbox": bbox})
print(f"\nSummary:")
print(f" Tables detected: {len(tables_found)}")
print(f" Layout elements: {len(all_elements)}")
return {"pdf": pdf_path, "tables": tables_found, "elements": all_elements}
def main():
# Find test PDFs
pdf_dir = Path("/mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/admin_uploads")
pdf_files = list(pdf_dir.glob("*.pdf"))[:5] # Test first 5
if not pdf_files:
print("No PDF files found in admin_uploads directory")
return
print(f"Found {len(pdf_files)} PDF files")
all_results = []
for pdf_file in pdf_files:
result = test_table_detection(str(pdf_file))
all_results.append(result)
# Summary
print(f"\n{'='*60}")
print("FINAL SUMMARY")
print(f"{'='*60}")
total_tables = sum(len(r["tables"]) for r in all_results)
print(f"Total PDFs tested: {len(all_results)}")
print(f"Total tables detected: {total_tables}")
for r in all_results:
pdf_name = Path(r["pdf"]).name
table_count = len(r["tables"])
print(f" {pdf_name}: {table_count} tables")
for t in r["tables"]:
print(f" - Table {t['index']+1}: {t['cell_count']} cells")
if __name__ == "__main__":
main()

54
scripts/render_pdfs_v3.py Normal file
View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python3
"""Render selected PDFs from v3 batch for visual comparison."""
import os
import fitz # PyMuPDF
PDF_DIR = "/mnt/c/Users/yaoji/git/Billo/Billo.Platform.Document/Billo.Platform.Document.AdminAPI/downloads/to_check"
OUTPUT_DIR = "/mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/scripts/pdf_renders_v3"
# Select 10 PDFs covering different scenarios:
SELECTED = [
# Potentially wrong Amount (81648164.00 - too high?)
"b84c7d70-821d-4a1a-9be7-d7bb2392bd91.pdf",
# Only 2 fields extracted
"072571e2-da5f-4268-b1a8-f0e5a85a3ec4.pdf",
# InvoiceNumber=5085 (suspiciously short, same as BG prefix?)
"6a83ba35-afdf-4c13-ade1-25513e213637.pdf",
# InvoiceNumber=450 (very short, might be wrong)
"8551b540-d93d-459d-b7eb-e9ee086f9f16.pdf",
# InvoiceNumber=134 (very short, same as BG prefix)
"cb1bd3b1-63d0-4140-930f-e4a7ae2b6cd5.pdf",
# Large Amount=172904.52, InvoiceNumber=89902
"d121a5ee-7382-41d8-8010-63880def1f96.pdf",
# Good 9-field PDF for positive check
"6cb90895-e52b-4831-b57b-7cb968bcdd54.pdf",
# Amount=2026.00 (same as year - could be confused?)
"d376c5b5-0dc5-4ccf-b787-0d481eef8577.pdf",
# 8 fields, good coverage
"f3f5da6f-7552-4ec6-8625-3629042fbfd0.pdf",
# Low confidence Amount=596.49
"5783e4af-eef3-411c-84b1-3a8f4694fed8.pdf",
]
os.makedirs(OUTPUT_DIR, exist_ok=True)
for pdf_name in SELECTED:
pdf_path = os.path.join(PDF_DIR, pdf_name)
if not os.path.exists(pdf_path):
print(f"SKIP {pdf_name} - not found")
continue
doc = fitz.open(pdf_path)
page = doc[0]
mat = fitz.Matrix(150 / 72, 150 / 72)
pix = page.get_pixmap(matrix=mat)
out_name = pdf_name.replace(".pdf", ".png")
out_path = os.path.join(OUTPUT_DIR, out_name)
pix.save(out_path)
print(f"Rendered {pdf_name} -> {out_name} ({pix.width}x{pix.height})")
doc.close()
print(f"\nAll renders saved to {OUTPUT_DIR}")

View File

@@ -213,8 +213,8 @@ class TestNormalizeOCR:
assert ' ' not in result.value # Spaces should be removed
def test_short_ocr_invalid(self, normalizer):
"""Test that too short OCR is invalid."""
result = normalizer.normalize("123")
"""Test that single-digit OCR is invalid (min 2 digits)."""
result = normalizer.normalize("5")
assert result.is_valid is False

View File

@@ -100,6 +100,22 @@ class TestInvoiceNumberNormalizer:
result = normalizer.normalize("Invoice 54321 OCR 12345678901234")
assert result.value == "54321"
def test_year_not_extracted_when_real_number_exists(self, normalizer):
"""4-digit year should be skipped when a real invoice number is present."""
result = normalizer.normalize("Faktura 12345 Datum 2025")
assert result.value == "12345"
def test_year_2026_not_extracted(self, normalizer):
"""Year '2026' should not be preferred over a real invoice number."""
result = normalizer.normalize("Invoice 54321 Date 2026")
assert result.value == "54321"
def test_non_year_4_digit_still_matches(self, normalizer):
"""4-digit numbers that are NOT years should still match."""
result = normalizer.normalize("Invoice 3456")
assert result.value == "3456"
assert result.is_valid is True
def test_fallback_extraction(self, normalizer):
"""Test fallback to digit extraction."""
# This matches Pattern 3 (short digit sequence 3-10 digits)
@@ -107,6 +123,16 @@ class TestInvoiceNumberNormalizer:
assert result.value == "123"
assert result.is_valid is True
def test_amount_fragment_not_selected(self, normalizer):
"""Amount fragment '775' from '9 775,96' should lose to real invoice number."""
result = normalizer.normalize("9 775,96 Belopp Kontoutdragsnr 04862823")
assert result.value == "04862823"
def test_prefer_medium_length_over_shortest(self, normalizer):
"""Prefer 4-8 digit sequences over very short 3-digit ones."""
result = normalizer.normalize("Ref 999 Fakturanr 12345")
assert result.value == "12345"
def test_no_valid_sequence(self, normalizer):
"""Test failure when no valid sequence found."""
result = normalizer.normalize("no numbers here")
@@ -134,8 +160,21 @@ class TestOcrNumberNormalizer:
assert result.value == "310196187399952"
assert " " not in result.value
def test_4_digit_ocr_valid(self, normalizer):
"""4-digit OCR numbers like '3046' should be accepted."""
result = normalizer.normalize("3046")
assert result.is_valid is True
assert result.value == "3046"
def test_2_digit_ocr_valid(self, normalizer):
"""2-digit OCR numbers should be accepted."""
result = normalizer.normalize("42")
assert result.is_valid is True
assert result.value == "42"
def test_too_short(self, normalizer):
result = normalizer.normalize("1234")
"""Single-digit OCR should be rejected."""
result = normalizer.normalize("5")
assert result.is_valid is False
def test_empty_string(self, normalizer):
@@ -477,6 +516,38 @@ class TestAmountNormalizer:
assert result.value == "100.00"
assert result.is_valid is True
def test_astronomical_amount_rejected(self, normalizer):
"""IBAN digits should NOT produce astronomical amounts (>10M)."""
# IBAN "SE14120000001201138650" contains long digit sequences
# The standalone fallback pattern should not extract these as amounts
result = normalizer.normalize("SE14120000001201138650")
if result.is_valid:
assert float(result.value) < 10_000_000
def test_large_valid_amount_accepted(self, normalizer):
"""Valid large amount like 108000,00 should be accepted."""
result = normalizer.normalize("108000,00")
assert result.value == "108000.00"
assert result.is_valid is True
def test_standalone_iban_digits_rejected(self, normalizer):
"""Very long digit sequence (IBAN fragment) should not produce >10M."""
result = normalizer.normalize("1036149234823114")
if result.is_valid:
assert float(result.value) < 10_000_000
def test_main_pattern_rejects_over_10m(self, normalizer):
"""Main regex path should reject amounts over 10M (e.g. IBAN-like digits)."""
result = normalizer.normalize("Belopp 81648164,00 kr")
# 81648164.00 > 10M, should be rejected
assert not result.is_valid or float(result.value) < 10_000_000
def test_main_pattern_accepts_under_10m(self, normalizer):
"""Main regex path should accept valid amounts under 10M."""
result = normalizer.normalize("Summa 999999,99 kr")
assert result.value == "999999.99"
assert result.is_valid is True
class TestEnhancedAmountNormalizer:
"""Tests for EnhancedAmountNormalizer."""

View File

@@ -670,5 +670,387 @@ class TestProcessPdfTokenPath:
assert call_args[3] == 100 # image height
class TestDpiPassthrough:
"""Tests for DPI being passed from pipeline to FieldExtractor (Bug 1)."""
def test_field_extractor_receives_pipeline_dpi(self):
"""FieldExtractor should receive the pipeline's DPI, not default to 300."""
with patch('backend.pipeline.pipeline.YOLODetector'):
with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls:
InferencePipeline(
model_path='/fake/model.pt',
dpi=150,
use_gpu=False,
)
mock_fe_cls.assert_called_once_with(
ocr_lang='en', use_gpu=False, dpi=150
)
def test_field_extractor_receives_default_dpi(self):
"""When dpi=300 (default), FieldExtractor should also get 300."""
with patch('backend.pipeline.pipeline.YOLODetector'):
with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls:
InferencePipeline(
model_path='/fake/model.pt',
dpi=300,
use_gpu=False,
)
mock_fe_cls.assert_called_once_with(
ocr_lang='en', use_gpu=False, dpi=300
)
class TestFallbackPatternExtraction:
"""Tests for _extract_with_patterns fallback regex (Bugs 2, 3)."""
def _make_pipeline_with_patterns(self):
"""Create pipeline with mocked internals for pattern testing."""
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
p = InferencePipeline()
p.dpi = 150
p.enable_fallback = True
return p
def test_bankgiro_no_match_in_org_number(self):
"""Bankgiro regex must NOT match digits embedded in an org number."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Org.nr 802546-1610", result)
assert 'Bankgiro' not in result.fields
def test_bankgiro_matches_labeled(self):
"""Bankgiro regex should match when preceded by 'Bankgiro' label."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Bankgiro 5393-9484", result)
assert result.fields.get('Bankgiro') == '5393-9484'
def test_bankgiro_matches_standalone(self):
"""Bankgiro regex should match a standalone 4-4 digit pattern."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Betala till 5393-9484 senast", result)
assert result.fields.get('Bankgiro') == '5393-9484'
def test_amount_rejects_bare_integer(self):
"""Amount regex must NOT match bare integers like 'Summa 1'."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Summa 1 Medlemsavgift", result)
assert 'Amount' not in result.fields
def test_amount_requires_decimal(self):
"""Amount regex should require a decimal separator."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Total 5 items", result)
assert 'Amount' not in result.fields
def test_amount_with_decimal_works(self):
"""Amount regex should match Swedish decimal amounts."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Att betala 1 234,56 SEK", result)
assert 'Amount' in result.fields
assert float(result.fields['Amount']) == pytest.approx(1234.56, abs=0.01)
def test_amount_with_sek_suffix(self):
"""Amount regex should match amounts ending with SEK."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("7 500,00 SEK", result)
assert 'Amount' in result.fields
assert float(result.fields['Amount']) == pytest.approx(7500.00, abs=0.01)
def test_fallback_extracts_invoice_date(self):
"""Fallback should extract InvoiceDate from Swedish text."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Fakturadatum 2025-01-15 Referens ABC", result)
assert result.fields.get('InvoiceDate') == '2025-01-15'
def test_fallback_extracts_due_date(self):
"""Fallback should extract InvoiceDueDate from Swedish text."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Forfallodag 2025-02-15 Belopp", result)
assert result.fields.get('InvoiceDueDate') == '2025-02-15'
def test_fallback_extracts_supplier_org(self):
"""Fallback should extract supplier_organisation_number."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Org.nr 556123-4567 Stockholm", result)
assert result.fields.get('supplier_organisation_number') == '556123-4567'
def test_fallback_extracts_plusgiro(self):
"""Fallback should extract Plusgiro number."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Plusgiro 12 34 56-7 betalning", result)
assert 'Plusgiro' in result.fields
def test_fallback_skips_year_as_invoice_number(self):
"""Fallback should NOT extract year-like value as InvoiceNumber."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Fakturanr 2025 Datum 2025-01-15", result)
assert 'InvoiceNumber' not in result.fields
def test_fallback_accepts_valid_invoice_number(self):
"""Fallback should extract valid non-year InvoiceNumber."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Fakturanr 12345 Summa", result)
assert result.fields.get('InvoiceNumber') == '12345'
class TestDateValidation:
"""Tests for InvoiceDueDate < InvoiceDate validation (Bug 6)."""
def _make_pipeline_for_merge(self):
"""Create pipeline with mocked internals for merge testing."""
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
p = InferencePipeline()
p.payment_line_parser = MagicMock()
p.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
return p
def test_due_date_before_invoice_date_dropped(self):
"""DueDate earlier than InvoiceDate should be removed."""
from backend.pipeline.field_extractor import ExtractedField
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name='InvoiceDate', raw_text='2026-01-16',
normalized_value='2026-01-16', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 0, 100, 50), page_no=0,
),
ExtractedField(
field_name='InvoiceDueDate', raw_text='2025-12-01',
normalized_value='2025-12-01', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 60, 100, 110), page_no=0,
),
]
p._merge_fields(result)
assert 'InvoiceDate' in result.fields
assert 'InvoiceDueDate' not in result.fields
def test_valid_dates_preserved(self):
"""Both dates kept when DueDate >= InvoiceDate."""
from backend.pipeline.field_extractor import ExtractedField
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name='InvoiceDate', raw_text='2026-01-16',
normalized_value='2026-01-16', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 0, 100, 50), page_no=0,
),
ExtractedField(
field_name='InvoiceDueDate', raw_text='2026-02-15',
normalized_value='2026-02-15', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 60, 100, 110), page_no=0,
),
]
p._merge_fields(result)
assert result.fields['InvoiceDate'] == '2026-01-16'
assert result.fields['InvoiceDueDate'] == '2026-02-15'
def test_same_dates_preserved(self):
"""Same InvoiceDate and DueDate should both be kept."""
from backend.pipeline.field_extractor import ExtractedField
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name='InvoiceDate', raw_text='2026-01-16',
normalized_value='2026-01-16', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 0, 100, 50), page_no=0,
),
ExtractedField(
field_name='InvoiceDueDate', raw_text='2026-01-16',
normalized_value='2026-01-16', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 60, 100, 110), page_no=0,
),
]
p._merge_fields(result)
assert result.fields['InvoiceDate'] == '2026-01-16'
assert result.fields['InvoiceDueDate'] == '2026-01-16'
class TestCrossFieldDedup:
"""Tests for cross-field deduplication of InvoiceNumber vs OCR/Bankgiro."""
def _make_pipeline_for_merge(self):
"""Create pipeline with mocked internals for merge testing."""
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
p = InferencePipeline()
p.payment_line_parser = MagicMock()
p.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
return p
def _make_extracted_field(self, field_name, raw_text, normalized, confidence=0.9):
from backend.pipeline.field_extractor import ExtractedField
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized,
confidence=confidence,
detection_confidence=confidence,
ocr_confidence=1.0,
bbox=(0, 0, 100, 50),
page_no=0,
)
def test_invoice_number_not_same_as_ocr(self):
"""When InvoiceNumber == OCR, InvoiceNumber should be dropped."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '9179845608', '9179845608'),
self._make_extracted_field('OCR', '9179845608', '9179845608'),
self._make_extracted_field('Amount', '1234,56', '1234.56'),
]
p._merge_fields(result)
assert 'OCR' in result.fields
assert result.fields['OCR'] == '9179845608'
assert 'InvoiceNumber' not in result.fields
def test_invoice_number_not_same_as_bankgiro_digits(self):
"""When InvoiceNumber digits == Bankgiro digits, InvoiceNumber should be dropped."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '53939484', '53939484'),
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
self._make_extracted_field('Amount', '500,00', '500.00'),
]
p._merge_fields(result)
assert 'Bankgiro' in result.fields
assert result.fields['Bankgiro'] == '5393-9484'
assert 'InvoiceNumber' not in result.fields
def test_unrelated_values_kept(self):
"""When InvoiceNumber, OCR, and Bankgiro are all different, keep all."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '19061', '19061'),
self._make_extracted_field('OCR', '9179845608', '9179845608'),
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
]
p._merge_fields(result)
assert result.fields['InvoiceNumber'] == '19061'
assert result.fields['OCR'] == '9179845608'
assert result.fields['Bankgiro'] == '5393-9484'
def test_dedup_after_fallback_re_add(self):
"""Dedup should remove InvoiceNumber re-added by fallback if it matches OCR."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
# Simulate state after fallback re-adds InvoiceNumber = OCR
result.fields = {
'OCR': '758200602426',
'Amount': '164.00',
'InvoiceNumber': '758200602426', # re-added by fallback
}
result.confidence = {
'OCR': 0.9,
'Amount': 0.9,
'InvoiceNumber': 0.5, # fallback confidence
}
result.bboxes = {}
p._dedup_invoice_number(result)
assert 'InvoiceNumber' not in result.fields
assert 'OCR' in result.fields
def test_invoice_number_substring_of_bankgiro(self):
"""When InvoiceNumber digits are a substring of Bankgiro digits, drop InvoiceNumber."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '4639', '4639'),
self._make_extracted_field('Bankgiro', '134-4639', '134-4639'),
self._make_extracted_field('Amount', '500,00', '500.00'),
]
p._merge_fields(result)
assert 'Bankgiro' in result.fields
assert result.fields['Bankgiro'] == '134-4639'
assert 'InvoiceNumber' not in result.fields
def test_invoice_number_not_substring_of_unrelated_bankgiro(self):
"""When InvoiceNumber is NOT a substring of Bankgiro, keep both."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '19061', '19061'),
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
self._make_extracted_field('Amount', '500,00', '500.00'),
]
p._merge_fields(result)
assert result.fields['InvoiceNumber'] == '19061'
assert result.fields['Bankgiro'] == '5393-9484'
class TestFallbackTrigger:
"""Tests for _needs_fallback trigger threshold."""
def _make_pipeline(self):
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
p = InferencePipeline()
return p
def test_fallback_triggers_when_1_key_field_missing(self):
"""Should trigger when only 1 key field (e.g. InvoiceNumber) is missing."""
p = self._make_pipeline()
result = InferenceResult()
result.fields = {
'Amount': '1234.56',
'OCR': '12345678901',
'InvoiceDate': '2025-01-15',
'InvoiceDueDate': '2025-02-15',
'supplier_organisation_number': '556123-4567',
}
# InvoiceNumber missing -> should trigger
assert p._needs_fallback(result) is True
def test_fallback_triggers_when_dates_missing(self):
"""Should trigger when all key fields present but 2+ important fields missing."""
p = self._make_pipeline()
result = InferenceResult()
result.fields = {
'Amount': '1234.56',
'InvoiceNumber': '12345',
'OCR': '12345678901',
}
# InvoiceDate, InvoiceDueDate, supplier_org all missing -> should trigger
assert p._needs_fallback(result) is True
def test_no_fallback_when_all_fields_present(self):
"""Should NOT trigger when all key and important fields present."""
p = self._make_pipeline()
result = InferenceResult()
result.fields = {
'Amount': '1234.56',
'InvoiceNumber': '12345',
'OCR': '12345678901',
'InvoiceDate': '2025-01-15',
'InvoiceDueDate': '2025-02-15',
'supplier_organisation_number': '556123-4567',
}
assert p._needs_fallback(result) is False
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -335,12 +335,15 @@ class TestFallbackLogic:
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# All key fields present
# All key and important fields present
result = InferenceResult(
fields={
"Amount": "1500.00",
"InvoiceNumber": "INV-001",
"OCR": "12345678901234",
"InvoiceDate": "2025-01-15",
"InvoiceDueDate": "2025-02-15",
"supplier_organisation_number": "556123-4567",
}
)

View File

@@ -203,15 +203,33 @@ class TestValueSelectorOcrField:
assert len(result) == 1
assert result[0].text == "94228110015950070"
def test_ignores_short_digit_tokens(self):
"""Tokens with fewer than 5 digits are not OCR references."""
tokens = _tokens("OCR", "123")
def test_ignores_single_digit_tokens(self):
"""Tokens with fewer than 2 digits are not OCR references."""
tokens = _tokens("OCR", "5")
result = ValueSelector.select_value_tokens(tokens, "OCR")
# Fallback: return all tokens since no valid OCR found
assert len(result) == 2
def test_ocr_4_digit_token_selected(self):
"""4-digit OCR token should be selected."""
tokens = _tokens("OCR", "3046")
result = ValueSelector.select_value_tokens(tokens, "OCR")
assert len(result) == 1
assert result[0].text == "3046"
def test_ocr_2_digit_token_selected(self):
"""2-digit OCR token should be selected."""
tokens = _tokens("OCR", "42")
result = ValueSelector.select_value_tokens(tokens, "OCR")
assert len(result) == 1
assert result[0].text == "42"
class TestValueSelectorInvoiceNumberField:
"""Tests for InvoiceNumber field value selection."""