WIP
This commit is contained in:
@@ -1,66 +1,72 @@
|
||||
# Invoice Master POC v2
|
||||
|
||||
Swedish Invoice Field Extraction System - YOLO26 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
|
||||
Swedish Invoice Field Extraction System - YOLO + PaddleOCR extracts structured data from Swedish PDF invoices.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
PDF → PyMuPDF (DPI=150) → YOLO Detection → PaddleOCR → Field Extraction → Normalization → Output
|
||||
```
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
packages/
|
||||
├── backend/ # FastAPI web server + inference pipeline
|
||||
│ └── pipeline/ # YOLO detector → OCR → field extractor → value selector → normalizers
|
||||
├── shared/ # Common utilities (bbox, OCR, field mappings)
|
||||
└── training/ # YOLO training data generation (annotation, dataset)
|
||||
tests/ # Mirrors packages/ structure
|
||||
```
|
||||
|
||||
### Pipeline Flow (process_pdf)
|
||||
|
||||
1. YOLO detects field regions on rendered PDF page
|
||||
2. PaddleOCR extracts text from detected bboxes
|
||||
3. Field extractor maps detections to invoice fields via CLASS_TO_FIELD
|
||||
4. Value selector picks best candidate per field (confidence + validation)
|
||||
5. Normalizers clean values (dates, amounts, invoice numbers)
|
||||
6. Fallback regex extraction if key fields missing
|
||||
|
||||
## Tech Stack
|
||||
|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Object Detection | YOLO26 (Ultralytics >= 8.4.0) |
|
||||
| OCR Engine | PaddleOCR v5 (PP-OCRv5) |
|
||||
| PDF Processing | PyMuPDF (fitz) |
|
||||
| Object Detection | YOLO (Ultralytics >= 8.4.0) |
|
||||
| OCR | PaddleOCR v5 (PP-OCRv5) |
|
||||
| PDF | PyMuPDF (fitz), DPI=150 |
|
||||
| Database | PostgreSQL + psycopg2 |
|
||||
| Web Framework | FastAPI + Uvicorn |
|
||||
| Deep Learning | PyTorch + CUDA 12.x |
|
||||
| Web | FastAPI + Uvicorn |
|
||||
| ML | PyTorch + CUDA 12.x |
|
||||
|
||||
## WSL Environment (REQUIRED)
|
||||
|
||||
**Prefix ALL commands with:**
|
||||
ALL Python commands MUST use this prefix:
|
||||
|
||||
```bash
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && <command>"
|
||||
```
|
||||
|
||||
**NEVER run Python commands directly in Windows PowerShell/CMD.**
|
||||
NEVER run Python directly in Windows PowerShell/CMD.
|
||||
|
||||
## Project-Specific Rules
|
||||
## Project Rules
|
||||
|
||||
- Python 3.11+ with type hints
|
||||
- No print() in production - use logging
|
||||
- Run tests: `pytest --cov=src`
|
||||
- Python 3.10, type hints on all function signatures
|
||||
- No `print()` in production code - use `logging` module
|
||||
- Validation with `pydantic` or `dataclasses`
|
||||
- Error handling with `try/except` (not try/catch)
|
||||
- Run tests: `pytest --cov=packages tests/`
|
||||
|
||||
## Critical Rules
|
||||
## Key Files
|
||||
|
||||
### Code Organization
|
||||
|
||||
- Many small files over few large files
|
||||
- High cohesion, low coupling
|
||||
- 200-400 lines typical, 800 max per file
|
||||
- Organize by feature/domain, not by type
|
||||
|
||||
### Code Style
|
||||
|
||||
- No emojis in code, comments, or documentation
|
||||
- Immutability always - never mutate objects or arrays
|
||||
- No console.log in production code
|
||||
- Proper error handling with try/catch
|
||||
- Input validation with Zod or similar
|
||||
|
||||
### Testing
|
||||
|
||||
- TDD: Write tests first
|
||||
- 80% minimum coverage
|
||||
- Unit tests for utilities
|
||||
- Integration tests for APIs
|
||||
- E2E tests for critical flows
|
||||
|
||||
### Security
|
||||
|
||||
- No hardcoded secrets
|
||||
- Environment variables for sensitive data
|
||||
- Validate all user inputs
|
||||
- Parameterized queries only
|
||||
- CSRF protection enabled
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `packages/backend/backend/pipeline/pipeline.py` | Main inference pipeline |
|
||||
| `packages/backend/backend/pipeline/field_extractor.py` | YOLO → field mapping |
|
||||
| `packages/backend/backend/pipeline/value_selector.py` | Best candidate selection |
|
||||
| `packages/shared/shared/fields/mappings.py` | CLASS_TO_FIELD mapping |
|
||||
| `packages/shared/shared/ocr/paddle_ocr.py` | OCRToken definition |
|
||||
| `packages/shared/shared/bbox/` | Bbox expansion strategies |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
@@ -78,18 +84,41 @@ CONFIDENCE_THRESHOLD=0.5
|
||||
SERVER_HOST=0.0.0.0
|
||||
SERVER_PORT=8000
|
||||
```
|
||||
## Available Commands
|
||||
|
||||
- `/tdd` - Test-driven development workflow
|
||||
- `/plan` - Create implementation plan
|
||||
- `/code-review` - Review code quality
|
||||
- `/build-fix` - Fix build errors
|
||||
## Auto-trigger Rules (ALWAYS FOLLOW - even after context compaction)
|
||||
|
||||
## Git Workflow
|
||||
These rules MUST be followed regardless of conversation history:
|
||||
|
||||
- Conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `test:`
|
||||
- Never commit to main directly
|
||||
- PRs require review
|
||||
- All tests must pass before merge
|
||||
- New feature or bug fix → MUST use **tdd-guide** agent (write tests first)
|
||||
- When writing code → MUST follow coding standards skill for the target language:
|
||||
- Python → `python-patterns` (PEP 8, type hints, Pythonic idioms)
|
||||
- C# → `dotnet-skills:coding-standards` (records, pattern matching, modern C#)
|
||||
- TS/JS → `coding-standards` (universal best practices)
|
||||
- After writing/modifying code → MUST use **code-reviewer** agent
|
||||
- Before git commit → MUST use **security-reviewer** agent
|
||||
- When build/test fails → MUST use **build-error-resolver** agent
|
||||
- After context compaction → read MEMORY.md to restore session state
|
||||
|
||||
Push the code before review and fix finished.
|
||||
## Plan Completion Protocol
|
||||
|
||||
After completing any plan or major task:
|
||||
|
||||
1. **Test** - Run `pytest` to confirm all tests pass
|
||||
2. **Security review** - Use **security-reviewer** agent on changed files
|
||||
3. **Fix loop** - If security review reports CRITICAL or HIGH issues:
|
||||
- Fix the issues
|
||||
- Re-run tests (back to step 1)
|
||||
- Re-run security review (back to step 2)
|
||||
- Repeat until no CRITICAL/HIGH issues remain
|
||||
4. **Commit** - Auto-commit with conventional commit message (`feat:`, `fix:`, `refactor:`, etc.). Stage only the files changed in this task, not unrelated files
|
||||
5. **Save** - Write a summary to MEMORY.md including: what was done, files changed, decisions made, remaining work
|
||||
6. **Suggest clear** - Tell the user: "Plan complete. Recommend `/clear` to free context for the next task."
|
||||
7. **Do NOT start a new task** in the same context - wait for user to /clear first
|
||||
|
||||
This keeps each plan in a fresh context window for maximum quality.
|
||||
|
||||
## Known Issues
|
||||
|
||||
- Pre-existing test failures: `test_s3.py`, `test_azure.py` (missing boto3/azure) - safe to ignore
|
||||
- Always re-run dedup/validation after fallback adds new fields
|
||||
- PDF DPI must be 150 (not 300) for correct bbox alignment
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
# Python Coding Style
|
||||
|
||||
> This file extends [common/coding-style.md](../common/coding-style.md) with Python specific content.
|
||||
|
||||
## Standards
|
||||
|
||||
- Follow **PEP 8** conventions
|
||||
- Use **type annotations** on all function signatures
|
||||
|
||||
## Immutability
|
||||
|
||||
Prefer immutable data structures:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class User:
|
||||
name: str
|
||||
email: str
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
class Point(NamedTuple):
|
||||
x: float
|
||||
y: float
|
||||
```
|
||||
|
||||
## Formatting
|
||||
|
||||
- **black** for code formatting
|
||||
- **isort** for import sorting
|
||||
- **ruff** for linting
|
||||
|
||||
## Reference
|
||||
|
||||
See skill: `python-patterns` for comprehensive Python idioms and patterns.
|
||||
@@ -1,14 +0,0 @@
|
||||
# Python Hooks
|
||||
|
||||
> This file extends [common/hooks.md](../common/hooks.md) with Python specific content.
|
||||
|
||||
## PostToolUse Hooks
|
||||
|
||||
Configure in `~/.claude/settings.json`:
|
||||
|
||||
- **black/ruff**: Auto-format `.py` files after edit
|
||||
- **mypy/pyright**: Run type checking after editing `.py` files
|
||||
|
||||
## Warnings
|
||||
|
||||
- Warn about `print()` statements in edited files (use `logging` module instead)
|
||||
@@ -1,34 +0,0 @@
|
||||
# Python Patterns
|
||||
|
||||
> This file extends [common/patterns.md](../common/patterns.md) with Python specific content.
|
||||
|
||||
## Protocol (Duck Typing)
|
||||
|
||||
```python
|
||||
from typing import Protocol
|
||||
|
||||
class Repository(Protocol):
|
||||
def find_by_id(self, id: str) -> dict | None: ...
|
||||
def save(self, entity: dict) -> dict: ...
|
||||
```
|
||||
|
||||
## Dataclasses as DTOs
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CreateUserRequest:
|
||||
name: str
|
||||
email: str
|
||||
age: int | None = None
|
||||
```
|
||||
|
||||
## Context Managers & Generators
|
||||
|
||||
- Use context managers (`with` statement) for resource management
|
||||
- Use generators for lazy evaluation and memory-efficient iteration
|
||||
|
||||
## Reference
|
||||
|
||||
See skill: `python-patterns` for comprehensive patterns including decorators, concurrency, and package organization.
|
||||
@@ -1,25 +0,0 @@
|
||||
# Python Security
|
||||
|
||||
> This file extends [common/security.md](../common/security.md) with Python specific content.
|
||||
|
||||
## Secret Management
|
||||
|
||||
```python
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
api_key = os.environ["OPENAI_API_KEY"] # Raises KeyError if missing
|
||||
```
|
||||
|
||||
## Security Scanning
|
||||
|
||||
- Use **bandit** for static security analysis:
|
||||
```bash
|
||||
bandit -r src/
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
See skill: `django-security` for Django-specific security guidelines (if applicable).
|
||||
@@ -1,33 +0,0 @@
|
||||
# Python Testing
|
||||
|
||||
> This file extends [common/testing.md](../common/testing.md) with Python specific content.
|
||||
|
||||
## Framework
|
||||
|
||||
Use **pytest** as the testing framework.
|
||||
|
||||
## Coverage
|
||||
|
||||
```bash
|
||||
pytest --cov=src --cov-report=term-missing
|
||||
```
|
||||
|
||||
## Test Organization
|
||||
|
||||
Use `pytest.mark` for test categorization:
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_total():
|
||||
...
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_database_connection():
|
||||
...
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
See skill: `python-testing` for detailed pytest patterns and fixtures.
|
||||
@@ -1,666 +0,0 @@
|
||||
# Invoice Master POC v2 - 总体架构审查报告
|
||||
|
||||
**审查日期**: 2026-02-01
|
||||
**审查人**: Claude Code
|
||||
**项目路径**: `/Users/yiukai/Documents/git/invoice-master-poc-v2`
|
||||
|
||||
---
|
||||
|
||||
## 架构概述
|
||||
|
||||
### 整体架构图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Frontend (React) │
|
||||
│ Vite + TypeScript + TailwindCSS │
|
||||
└─────────────────────────────┬───────────────────────────────────┘
|
||||
│ HTTP/REST
|
||||
┌─────────────────────────────▼───────────────────────────────────┐
|
||||
│ Inference Service (FastAPI) │
|
||||
│ ┌──────────────┬──────────────┬──────────────┬──────────────┐ │
|
||||
│ │ Public API │ Admin API │ Training API│ Batch API │ │
|
||||
│ └──────────────┴──────────────┴──────────────┴──────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Service Layer │ │
|
||||
│ │ InferenceService │ AsyncProcessing │ BatchUpload │ Dataset │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Data Layer │ │
|
||||
│ │ AdminDB │ AsyncRequestDB │ SQLModel │ PostgreSQL │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Core Components │ │
|
||||
│ │ RateLimiter │ Schedulers │ TaskQueues │ Auth │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────┬───────────────────────────────────┘
|
||||
│ PostgreSQL
|
||||
┌─────────────────────────────▼───────────────────────────────────┐
|
||||
│ Training Service (GPU) │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ CLI: train │ autolabel │ analyze │ validate │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ YOLO: db_dataset │ annotation_generator │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Processing: CPU Pool │ GPU Pool │ Task Dispatcher │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌─────────┴─────────┐
|
||||
▼ ▼
|
||||
┌──────────────┐ ┌──────────────┐
|
||||
│ Shared │ │ Storage │
|
||||
│ PDF │ OCR │ │ Local/Azure/ │
|
||||
│ Normalize │ │ S3 │
|
||||
└──────────────┘ └──────────────┘
|
||||
```
|
||||
|
||||
### 技术栈
|
||||
|
||||
| 层级 | 技术 | 评估 |
|
||||
|------|------|------|
|
||||
| **前端** | React + Vite + TypeScript + TailwindCSS | ✅ 现代栈 |
|
||||
| **API 框架** | FastAPI | ✅ 高性能,类型安全 |
|
||||
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 ORM |
|
||||
| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) | ✅ 业界标准 |
|
||||
| **OCR** | PaddleOCR v5 | ✅ 支持瑞典语 |
|
||||
| **部署** | Docker + Azure/AWS | ✅ 云原生 |
|
||||
|
||||
---
|
||||
|
||||
## 架构优势
|
||||
|
||||
### 1. Monorepo 结构 ✅
|
||||
|
||||
```
|
||||
packages/
|
||||
├── shared/ # 共享库 - 无外部依赖
|
||||
├── training/ # 训练服务 - 依赖 shared
|
||||
└── inference/ # 推理服务 - 依赖 shared
|
||||
```
|
||||
|
||||
**优点**:
|
||||
- 清晰的包边界,无循环依赖
|
||||
- 独立部署,training 按需启动
|
||||
- 代码复用率高
|
||||
|
||||
### 2. 分层架构 ✅
|
||||
|
||||
```
|
||||
API Routes (web/api/v1/)
|
||||
↓
|
||||
Service Layer (web/services/)
|
||||
↓
|
||||
Data Layer (data/)
|
||||
↓
|
||||
Database (PostgreSQL)
|
||||
```
|
||||
|
||||
**优点**:
|
||||
- 职责分离明确
|
||||
- 便于单元测试
|
||||
- 可替换底层实现
|
||||
|
||||
### 3. 依赖注入 ✅
|
||||
|
||||
```python
|
||||
# FastAPI Depends 使用得当
|
||||
@router.post("/infer")
|
||||
async def infer(
|
||||
file: UploadFile,
|
||||
db: AdminDB = Depends(get_admin_db), # 注入
|
||||
token: str = Depends(validate_admin_token),
|
||||
):
|
||||
```
|
||||
|
||||
### 4. 存储抽象层 ✅
|
||||
|
||||
```python
|
||||
# 统一接口,支持多后端
|
||||
class StorageBackend(ABC):
|
||||
def upload(self, source: Path, destination: str) -> None: ...
|
||||
def download(self, source: str, destination: Path) -> None: ...
|
||||
def get_presigned_url(self, path: str) -> str: ...
|
||||
|
||||
# 实现: LocalStorageBackend, AzureStorageBackend, S3StorageBackend
|
||||
```
|
||||
|
||||
### 5. 动态模型管理 ✅
|
||||
|
||||
```python
|
||||
# 数据库驱动的模型切换
|
||||
def get_active_model_path() -> Path | None:
|
||||
db = AdminDB()
|
||||
active_model = db.get_active_model_version()
|
||||
return active_model.model_path if active_model else None
|
||||
|
||||
inference_service = InferenceService(
|
||||
model_path_resolver=get_active_model_path,
|
||||
)
|
||||
```
|
||||
|
||||
### 6. 任务队列分离 ✅
|
||||
|
||||
```python
|
||||
# 不同类型任务使用不同队列
|
||||
- AsyncTaskQueue: 异步推理任务
|
||||
- BatchQueue: 批量上传任务
|
||||
- TrainingScheduler: 训练任务调度
|
||||
- AutoLabelScheduler: 自动标注调度
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 架构问题与风险
|
||||
|
||||
### 1. 数据库层职责过重 ⚠️ **中风险**
|
||||
|
||||
**问题**: `AdminDB` 类过大,违反单一职责原则
|
||||
|
||||
```python
|
||||
# packages/inference/inference/data/admin_db.py
|
||||
class AdminDB:
|
||||
# Token 管理 (5 个方法)
|
||||
def is_valid_admin_token(self, token: str) -> bool: ...
|
||||
def create_admin_token(self, token: str, name: str): ...
|
||||
|
||||
# 文档管理 (8 个方法)
|
||||
def create_document(self, ...): ...
|
||||
def get_document(self, doc_id: str): ...
|
||||
|
||||
# 标注管理 (6 个方法)
|
||||
def create_annotation(self, ...): ...
|
||||
def get_annotations(self, doc_id: str): ...
|
||||
|
||||
# 训练任务 (7 个方法)
|
||||
def create_training_task(self, ...): ...
|
||||
def update_training_task(self, ...): ...
|
||||
|
||||
# 数据集 (6 个方法)
|
||||
def create_dataset(self, ...): ...
|
||||
def get_dataset(self, dataset_id: str): ...
|
||||
|
||||
# 模型版本 (5 个方法)
|
||||
def create_model_version(self, ...): ...
|
||||
def activate_model_version(self, ...): ...
|
||||
|
||||
# 批处理 (4 个方法)
|
||||
# 锁管理 (3 个方法)
|
||||
# ... 总计 50+ 方法
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 类过大,难以维护
|
||||
- 测试困难
|
||||
- 不同领域变更互相影响
|
||||
|
||||
**建议**: 按领域拆分为 Repository 模式
|
||||
|
||||
```python
|
||||
# 建议重构
|
||||
class TokenRepository:
|
||||
def validate(self, token: str) -> bool: ...
|
||||
def create(self, token: Token) -> None: ...
|
||||
|
||||
class DocumentRepository:
|
||||
def find_by_id(self, doc_id: str) -> Document | None: ...
|
||||
def save(self, document: Document) -> None: ...
|
||||
|
||||
class TrainingRepository:
|
||||
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
|
||||
def update_task_status(self, task_id: str, status: TaskStatus): ...
|
||||
|
||||
class ModelRepository:
|
||||
def get_active(self) -> ModelVersion | None: ...
|
||||
def activate(self, version_id: str) -> None: ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. Service 层混合业务逻辑与技术细节 ⚠️ **中风险**
|
||||
|
||||
**问题**: `InferenceService` 既处理业务逻辑又处理技术实现
|
||||
|
||||
```python
|
||||
# packages/inference/inference/web/services/inference.py
|
||||
class InferenceService:
|
||||
def process(self, image_bytes: bytes) -> ServiceResult:
|
||||
# 1. 技术细节: 图像解码
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
# 2. 业务逻辑: 字段提取
|
||||
fields = self._extract_fields(image)
|
||||
|
||||
# 3. 技术细节: 模型推理
|
||||
detections = self._model.predict(image)
|
||||
|
||||
# 4. 业务逻辑: 结果验证
|
||||
if not self._validate_fields(fields):
|
||||
raise ValidationError()
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 难以测试业务逻辑
|
||||
- 技术变更影响业务代码
|
||||
- 无法切换技术实现
|
||||
|
||||
**建议**: 引入领域层和适配器模式
|
||||
|
||||
```python
|
||||
# 领域层 - 纯业务逻辑
|
||||
@dataclass
|
||||
class InvoiceDocument:
|
||||
document_id: str
|
||||
pages: list[Page]
|
||||
|
||||
class InvoiceExtractor:
|
||||
"""纯业务逻辑,不依赖技术实现"""
|
||||
def extract(self, document: InvoiceDocument) -> InvoiceFields:
|
||||
# 只处理业务规则
|
||||
pass
|
||||
|
||||
# 适配器层 - 技术实现
|
||||
class YoloFieldDetector:
|
||||
"""YOLO 技术适配器"""
|
||||
def __init__(self, model_path: Path):
|
||||
self._model = YOLO(model_path)
|
||||
|
||||
def detect(self, image: np.ndarray) -> list[FieldRegion]:
|
||||
return self._model.predict(image)
|
||||
|
||||
class PaddleOcrEngine:
|
||||
"""PaddleOCR 技术适配器"""
|
||||
def __init__(self):
|
||||
self._ocr = PaddleOCR()
|
||||
|
||||
def recognize(self, image: np.ndarray, region: BoundingBox) -> str:
|
||||
return self._ocr.ocr(image, region)
|
||||
|
||||
# 应用服务 - 协调领域和适配器
|
||||
class InvoiceProcessingService:
|
||||
def __init__(
|
||||
self,
|
||||
extractor: InvoiceExtractor,
|
||||
detector: FieldDetector,
|
||||
ocr: OcrEngine,
|
||||
):
|
||||
self._extractor = extractor
|
||||
self._detector = detector
|
||||
self._ocr = ocr
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. 调度器设计分散 ⚠️ **中风险**
|
||||
|
||||
**问题**: 多个独立调度器缺乏统一协调
|
||||
|
||||
```python
|
||||
# 当前设计 - 4 个独立调度器
|
||||
# 1. TrainingScheduler (core/scheduler.py)
|
||||
# 2. AutoLabelScheduler (core/autolabel_scheduler.py)
|
||||
# 3. AsyncTaskQueue (workers/async_queue.py)
|
||||
# 4. BatchQueue (workers/batch_queue.py)
|
||||
|
||||
# app.py 中分别启动
|
||||
start_scheduler() # 训练调度器
|
||||
start_autolabel_scheduler() # 自动标注调度器
|
||||
init_batch_queue() # 批处理队列
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 资源竞争风险
|
||||
- 难以监控和追踪
|
||||
- 任务优先级难以管理
|
||||
- 重启时任务丢失
|
||||
|
||||
**建议**: 使用 Celery + Redis 统一任务队列
|
||||
|
||||
```python
|
||||
# 建议重构
|
||||
from celery import Celery
|
||||
|
||||
app = Celery('invoice_master')
|
||||
|
||||
@app.task(bind=True, max_retries=3)
|
||||
def process_inference(self, document_id: str):
|
||||
"""异步推理任务"""
|
||||
try:
|
||||
service = get_inference_service()
|
||||
result = service.process(document_id)
|
||||
return result
|
||||
except Exception as exc:
|
||||
raise self.retry(exc=exc, countdown=60)
|
||||
|
||||
@app.task
|
||||
def train_model(dataset_id: str, config: dict):
|
||||
"""训练任务"""
|
||||
training_service = get_training_service()
|
||||
return training_service.train(dataset_id, config)
|
||||
|
||||
@app.task
|
||||
def auto_label_documents(document_ids: list[str]):
|
||||
"""批量自动标注"""
|
||||
for doc_id in document_ids:
|
||||
auto_label_document.delay(doc_id)
|
||||
|
||||
# 优先级队列
|
||||
app.conf.task_routes = {
|
||||
'tasks.process_inference': {'queue': 'high_priority'},
|
||||
'tasks.train_model': {'queue': 'gpu_queue'},
|
||||
'tasks.auto_label_documents': {'queue': 'low_priority'},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. 配置分散 ⚠️ **低风险**
|
||||
|
||||
**问题**: 配置分散在多个文件
|
||||
|
||||
```python
|
||||
# packages/shared/shared/config.py
|
||||
DATABASE = {...}
|
||||
PATHS = {...}
|
||||
AUTOLABEL = {...}
|
||||
|
||||
# packages/inference/inference/web/config.py
|
||||
@dataclass
|
||||
class ModelConfig: ...
|
||||
@dataclass
|
||||
class ServerConfig: ...
|
||||
@dataclass
|
||||
class FileConfig: ...
|
||||
|
||||
# 环境变量
|
||||
# .env 文件
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 配置难以追踪
|
||||
- 可能出现不一致
|
||||
- 缺少配置验证
|
||||
|
||||
**建议**: 使用 Pydantic Settings 集中管理
|
||||
|
||||
```python
|
||||
# config/settings.py
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class DatabaseSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix='DB_')
|
||||
|
||||
host: str = 'localhost'
|
||||
port: int = 5432
|
||||
name: str = 'docmaster'
|
||||
user: str = 'docmaster'
|
||||
password: str # 无默认值,必须设置
|
||||
|
||||
class StorageSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix='STORAGE_')
|
||||
|
||||
backend: str = 'local'
|
||||
base_path: str = '~/invoice-data'
|
||||
azure_connection_string: str | None = None
|
||||
s3_bucket: str | None = None
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
)
|
||||
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
storage: StorageSettings = StorageSettings()
|
||||
|
||||
# 验证
|
||||
@field_validator('database')
|
||||
def validate_database(cls, v):
|
||||
if not v.password:
|
||||
raise ValueError('Database password is required')
|
||||
return v
|
||||
|
||||
# 全局配置实例
|
||||
settings = Settings()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 5. 内存队列单点故障 ⚠️ **中风险**
|
||||
|
||||
**问题**: AsyncTaskQueue 和 BatchQueue 基于内存
|
||||
|
||||
```python
|
||||
# workers/async_queue.py
|
||||
class AsyncTaskQueue:
|
||||
def __init__(self):
|
||||
self._queue = Queue() # 内存队列
|
||||
self._workers = []
|
||||
|
||||
def enqueue(self, task: AsyncTask) -> None:
|
||||
self._queue.put(task) # 仅存储在内存
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 服务重启丢失所有待处理任务
|
||||
- 无法水平扩展
|
||||
- 任务持久化困难
|
||||
|
||||
**建议**: 使用 Redis/RabbitMQ 持久化队列
|
||||
|
||||
---
|
||||
|
||||
### 6. 缺少 API 版本迁移策略 ❓ **低风险**
|
||||
|
||||
**问题**: 有 `/api/v1/` 版本,但缺少升级策略
|
||||
|
||||
```
|
||||
当前: /api/v1/admin/documents
|
||||
未来: /api/v2/admin/documents ?
|
||||
```
|
||||
|
||||
**建议**:
|
||||
- 制定 API 版本升级流程
|
||||
- 使用 Header 版本控制
|
||||
- 维护版本兼容性文档
|
||||
|
||||
---
|
||||
|
||||
## 关键架构风险矩阵
|
||||
|
||||
| 风险项 | 概率 | 影响 | 风险等级 | 优先级 |
|
||||
|--------|------|------|----------|--------|
|
||||
| 内存队列丢失任务 | 中 | 高 | **高** | 🔴 P0 |
|
||||
| AdminDB 职责过重 | 高 | 中 | **中** | 🟡 P1 |
|
||||
| Service 层混合 | 高 | 中 | **中** | 🟡 P1 |
|
||||
| 调度器资源竞争 | 中 | 中 | **中** | 🟡 P1 |
|
||||
| 配置分散 | 高 | 低 | **低** | 🟢 P2 |
|
||||
| API 版本策略 | 低 | 低 | **低** | 🟢 P2 |
|
||||
|
||||
---
|
||||
|
||||
## 改进建议路线图
|
||||
|
||||
### Phase 1: 立即执行 (本周)
|
||||
|
||||
#### 1.1 拆分 AdminDB
|
||||
```python
|
||||
# 创建 repositories 包
|
||||
inference/data/repositories/
|
||||
├── __init__.py
|
||||
├── base.py # Repository 基类
|
||||
├── token.py # TokenRepository
|
||||
├── document.py # DocumentRepository
|
||||
├── annotation.py # AnnotationRepository
|
||||
├── training.py # TrainingRepository
|
||||
├── dataset.py # DatasetRepository
|
||||
└── model.py # ModelRepository
|
||||
```
|
||||
|
||||
#### 1.2 统一配置
|
||||
```python
|
||||
# 创建统一配置模块
|
||||
inference/config/
|
||||
├── __init__.py
|
||||
├── settings.py # Pydantic Settings
|
||||
└── validators.py # 配置验证
|
||||
```
|
||||
|
||||
### Phase 2: 短期执行 (本月)
|
||||
|
||||
#### 2.1 引入消息队列
|
||||
```yaml
|
||||
# docker-compose.yml 添加
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
celery_worker:
|
||||
build: .
|
||||
command: celery -A inference.tasks worker -l info
|
||||
depends_on:
|
||||
- redis
|
||||
- postgres
|
||||
```
|
||||
|
||||
#### 2.2 添加缓存层
|
||||
```python
|
||||
# 使用 Redis 缓存热点数据
|
||||
from redis import Redis
|
||||
|
||||
redis_client = Redis(host='localhost', port=6379)
|
||||
|
||||
class CachedDocumentRepository(DocumentRepository):
|
||||
def find_by_id(self, doc_id: str) -> Document | None:
|
||||
# 先查缓存
|
||||
cached = redis_client.get(f"doc:{doc_id}")
|
||||
if cached:
|
||||
return Document.parse_raw(cached)
|
||||
|
||||
# 再查数据库
|
||||
doc = super().find_by_id(doc_id)
|
||||
if doc:
|
||||
redis_client.setex(f"doc:{doc_id}", 3600, doc.json())
|
||||
return doc
|
||||
```
|
||||
|
||||
### Phase 3: 长期执行 (本季度)
|
||||
|
||||
#### 3.1 数据库读写分离
|
||||
```python
|
||||
# 配置主从数据库
|
||||
class DatabaseManager:
|
||||
def __init__(self):
|
||||
self._master = create_engine(MASTER_DB_URL)
|
||||
self._replica = create_engine(REPLICA_DB_URL)
|
||||
|
||||
def get_session(self, readonly: bool = False) -> Session:
|
||||
engine = self._replica if readonly else self._master
|
||||
return Session(engine)
|
||||
```
|
||||
|
||||
#### 3.2 事件驱动架构
|
||||
```python
|
||||
# 引入事件总线
|
||||
from event_bus import EventBus
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
# 发布事件
|
||||
@router.post("/documents")
|
||||
async def create_document(...):
|
||||
doc = document_repo.save(document)
|
||||
bus.publish('document.created', {'document_id': doc.id})
|
||||
return doc
|
||||
|
||||
# 订阅事件
|
||||
@bus.subscribe('document.created')
|
||||
def on_document_created(event):
|
||||
# 触发自动标注
|
||||
auto_label_task.delay(event['document_id'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 架构演进建议
|
||||
|
||||
### 当前架构 (适合 1-10 用户)
|
||||
|
||||
```
|
||||
Single Instance
|
||||
├── FastAPI App
|
||||
├── Memory Queues
|
||||
└── PostgreSQL
|
||||
```
|
||||
|
||||
### 目标架构 (适合 100+ 用户)
|
||||
|
||||
```
|
||||
Load Balancer
|
||||
├── FastAPI Instance 1
|
||||
├── FastAPI Instance 2
|
||||
└── FastAPI Instance N
|
||||
│
|
||||
┌───────┴───────┐
|
||||
▼ ▼
|
||||
Redis Cluster PostgreSQL
|
||||
(Celery + Cache) (Master + Replica)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### 总体评分
|
||||
|
||||
| 维度 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| **模块化** | 8/10 | 包结构清晰,但部分类过大 |
|
||||
| **可扩展性** | 7/10 | 水平扩展良好,垂直扩展受限 |
|
||||
| **可维护性** | 8/10 | 分层合理,但职责边界需细化 |
|
||||
| **可靠性** | 7/10 | 内存队列是单点故障 |
|
||||
| **性能** | 8/10 | 异步处理良好 |
|
||||
| **安全性** | 8/10 | 基础安全到位 |
|
||||
| **总体** | **7.7/10** | 良好的架构基础,需优化细节 |
|
||||
|
||||
### 关键结论
|
||||
|
||||
1. **架构设计合理**: Monorepo + 分层架构适合当前规模
|
||||
2. **主要风险**: 内存队列和数据库职责过重
|
||||
3. **演进路径**: 引入消息队列和缓存层
|
||||
4. **投入产出**: 当前架构可支撑到 100+ 用户,无需大规模重构
|
||||
|
||||
### 下一步行动
|
||||
|
||||
| 优先级 | 任务 | 预计工时 | 影响 |
|
||||
|--------|------|----------|------|
|
||||
| 🔴 P0 | 引入 Celery + Redis | 3 天 | 解决任务丢失问题 |
|
||||
| 🟡 P1 | 拆分 AdminDB | 2 天 | 提升可维护性 |
|
||||
| 🟡 P1 | 统一配置管理 | 1 天 | 减少配置错误 |
|
||||
| 🟢 P2 | 添加缓存层 | 2 天 | 提升性能 |
|
||||
| 🟢 P2 | 数据库读写分离 | 3 天 | 提升扩展性 |
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### 关键文件清单
|
||||
|
||||
| 文件 | 职责 | 问题 |
|
||||
|------|------|------|
|
||||
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
|
||||
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
|
||||
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
|
||||
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
|
||||
| `shared/shared/config.py` | 共享配置 | 分散管理 |
|
||||
|
||||
### 参考资源
|
||||
|
||||
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
|
||||
- [Celery Documentation](https://docs.celeryproject.org/)
|
||||
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
|
||||
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)
|
||||
317
CHANGELOG.md
317
CHANGELOG.md
@@ -1,317 +0,0 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to the Invoice Field Extraction project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added - Phase 1: Security & Infrastructure (2026-01-22)
|
||||
|
||||
#### Security Enhancements
|
||||
- **Environment Variable Management**: Added `python-dotenv` for secure configuration management
|
||||
- Created `.env.example` template file for configuration reference
|
||||
- Created `.env` file for actual credentials (gitignored)
|
||||
- Updated `config.py` to load database password from environment variables
|
||||
- Added validation to ensure `DB_PASSWORD` is set at startup
|
||||
- Files modified: `config.py`, `requirements.txt`
|
||||
- New files: `.env`, `.env.example`
|
||||
- Tests: `tests/test_config.py` (7 tests, all passing)
|
||||
|
||||
- **SQL Injection Prevention**: Fixed SQL injection vulnerabilities in database queries
|
||||
- Replaced f-string formatting with parameterized queries in `LIMIT` clauses
|
||||
- Updated `get_all_documents_summary()` to use `%s` placeholder for LIMIT parameter
|
||||
- Updated `get_failed_matches()` to use `%s` placeholder for LIMIT parameter
|
||||
- Files modified: `src/data/db.py` (lines 246, 298)
|
||||
- Tests: `tests/test_db_security.py` (9 tests, all passing)
|
||||
|
||||
#### Code Quality
|
||||
- **Exception Hierarchy**: Created comprehensive custom exception system
|
||||
- Added base class `InvoiceExtractionError` with message and details support
|
||||
- Added specific exception types:
|
||||
- `PDFProcessingError` - PDF rendering/conversion errors
|
||||
- `OCRError` - OCR processing errors
|
||||
- `ModelInferenceError` - YOLO model errors
|
||||
- `FieldValidationError` - Field validation errors (with field-specific attributes)
|
||||
- `DatabaseError` - Database operation errors
|
||||
- `ConfigurationError` - Configuration errors
|
||||
- `PaymentLineParseError` - Payment line parsing errors
|
||||
- `CustomerNumberParseError` - Customer number parsing errors
|
||||
- `DataLoadError` - Data loading errors
|
||||
- `AnnotationError` - Annotation generation errors
|
||||
- New file: `src/exceptions.py`
|
||||
- Tests: `tests/test_exceptions.py` (16 tests, all passing)
|
||||
|
||||
### Testing
|
||||
- Added 32 new tests across 3 test files
|
||||
- Configuration tests: 7 tests
|
||||
- SQL injection prevention tests: 9 tests
|
||||
- Exception hierarchy tests: 16 tests
|
||||
- All tests passing (32/32)
|
||||
|
||||
### Documentation
|
||||
- Created `docs/CODE_REVIEW_REPORT.md` - Comprehensive code quality analysis (550+ lines)
|
||||
- Created `docs/REFACTORING_PLAN.md` - Detailed 3-phase refactoring plan (600+ lines)
|
||||
- Created `CHANGELOG.md` - Project changelog (this file)
|
||||
|
||||
### Changed
|
||||
- **Configuration Loading**: Database configuration now loads from environment variables instead of hardcoded values
|
||||
- Breaking change: Requires `.env` file with `DB_PASSWORD` set
|
||||
- Migration: Copy `.env.example` to `.env` and set your database password
|
||||
|
||||
### Security
|
||||
- **Fixed**: Database password no longer stored in plain text in `config.py`
|
||||
- **Fixed**: SQL injection vulnerabilities in LIMIT clauses (2 instances)
|
||||
|
||||
### Technical Debt Addressed
|
||||
- Eliminated security vulnerability: plaintext password storage
|
||||
- Reduced SQL injection attack surface
|
||||
- Improved error handling granularity with custom exceptions
|
||||
|
||||
---
|
||||
|
||||
### Added - Phase 2: Parser Refactoring (2026-01-22)
|
||||
|
||||
#### Unified Parser Modules
|
||||
- **Payment Line Parser**: Created dedicated payment line parsing module
|
||||
- Handles Swedish payment line format: `# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#`
|
||||
- Tolerates common OCR errors: spaces in numbers, missing symbols, spaces in check digits
|
||||
- Supports 4 parsing patterns: full format, no amount, alternative, account-only
|
||||
- Returns structured `PaymentLineData` with parsed fields
|
||||
- New file: `src/inference/payment_line_parser.py` (90 lines, 92% coverage)
|
||||
- Tests: `tests/test_payment_line_parser.py` (23 tests, all passing)
|
||||
- Eliminates 1st code duplication (payment line parsing logic)
|
||||
|
||||
- **Customer Number Parser**: Created dedicated customer number parsing module
|
||||
- Handles Swedish customer number formats: `JTY 576-3`, `DWQ 211-X`, `FFL 019N`, etc.
|
||||
- Uses Strategy Pattern with 5 pattern classes:
|
||||
- `LabeledPattern` - Explicit labels (highest priority, 0.98 confidence)
|
||||
- `DashFormatPattern` - Standard format with dash (0.95 confidence)
|
||||
- `NoDashFormatPattern` - Format without dash, adds dash automatically (0.90 confidence)
|
||||
- `CompactFormatPattern` - Compact format without spaces (0.75 confidence)
|
||||
- `GenericAlphanumericPattern` - Fallback generic pattern (variable confidence)
|
||||
- Excludes Swedish postal codes (`SE XXX XX` format)
|
||||
- Returns highest confidence match
|
||||
- New file: `src/inference/customer_number_parser.py` (154 lines, 92% coverage)
|
||||
- Tests: `tests/test_customer_number_parser.py` (32 tests, all passing)
|
||||
- Reduces `_normalize_customer_number` complexity (127 lines → will use 5-10 lines after integration)
|
||||
|
||||
### Testing Summary
|
||||
|
||||
**Phase 1 Tests** (32 tests):
|
||||
- Configuration tests: 7 tests ([test_config.py](tests/test_config.py))
|
||||
- SQL injection prevention tests: 9 tests ([test_db_security.py](tests/test_db_security.py))
|
||||
- Exception hierarchy tests: 16 tests ([test_exceptions.py](tests/test_exceptions.py))
|
||||
|
||||
**Phase 2 Tests** (121 tests):
|
||||
- Payment line parser tests: 23 tests ([test_payment_line_parser.py](tests/test_payment_line_parser.py))
|
||||
- Standard parsing, OCR error handling, real-world examples, edge cases
|
||||
- Coverage: 92%
|
||||
- Customer number parser tests: 32 tests ([test_customer_number_parser.py](tests/test_customer_number_parser.py))
|
||||
- Pattern matching (DashFormat, NoDashFormat, Compact, Labeled)
|
||||
- Real-world examples, edge cases, Swedish postal code exclusion
|
||||
- Coverage: 92%
|
||||
- Field extractor integration tests: 45 tests ([test_field_extractor.py](src/inference/test_field_extractor.py))
|
||||
- Validates backward compatibility with existing code
|
||||
- Tests for invoice numbers, bankgiro, plusgiro, amounts, OCR, dates, payment lines, customer numbers
|
||||
- Pipeline integration tests: 21 tests ([test_pipeline.py](src/inference/test_pipeline.py))
|
||||
- Cross-validation, payment line parsing, field overrides
|
||||
|
||||
**Total**: 153 tests, 100% passing, 4.50s runtime
|
||||
|
||||
### Code Quality
|
||||
- **Eliminated Code Duplication**: Payment line parsing previously in 3 places, now unified in 1 module
|
||||
- **Improved Maintainability**: Strategy Pattern makes customer number patterns easy to extend
|
||||
- **Better Test Coverage**: New parsers have 92% coverage vs original 10% in field_extractor.py
|
||||
|
||||
#### Parser Integration into field_extractor.py (2026-01-22)
|
||||
|
||||
- **field_extractor.py Integration**: Successfully integrated new parsers
|
||||
- Added `PaymentLineParser` and `CustomerNumberParser` instances (lines 99-101)
|
||||
- Replaced `_normalize_payment_line` method: 74 lines → 3 lines (lines 640-657)
|
||||
- Replaced `_normalize_customer_number` method: 127 lines → 3 lines (lines 697-707)
|
||||
- All 45 existing tests pass (100% backward compatibility maintained)
|
||||
- Tests run time: 4.21 seconds
|
||||
- File: `src/inference/field_extractor.py`
|
||||
|
||||
#### Parser Integration into pipeline.py (2026-01-22)
|
||||
|
||||
- **pipeline.py Integration**: Successfully integrated PaymentLineParser
|
||||
- Added `PaymentLineParser` import (line 15)
|
||||
- Added `payment_line_parser` instance initialization (line 128)
|
||||
- Replaced `_parse_machine_readable_payment_line` method: 36 lines → 6 lines (lines 219-233)
|
||||
- All 21 existing tests pass (100% backward compatibility maintained)
|
||||
- Tests run time: 4.00 seconds
|
||||
- File: `src/inference/pipeline.py`
|
||||
|
||||
### Phase 2 Status: **COMPLETED** ✅
|
||||
|
||||
- [x] Create unified `payment_line_parser` module ✅
|
||||
- [x] Create unified `customer_number_parser` module ✅
|
||||
- [x] Refactor `field_extractor.py` to use new parsers ✅
|
||||
- [x] Refactor `pipeline.py` to use new parsers ✅
|
||||
- [x] Comprehensive test suite (153 tests, 100% passing) ✅
|
||||
|
||||
### Achieved Impact
|
||||
- Eliminate code duplication: 3 implementations → 1 ✅ (payment_line unified across field_extractor.py, pipeline.py, tests)
|
||||
- Reduce `_normalize_payment_line` complexity in field_extractor.py: 74 lines → 3 lines ✅
|
||||
- Reduce `_normalize_customer_number` complexity in field_extractor.py: 127 lines → 3 lines ✅
|
||||
- Reduce `_parse_machine_readable_payment_line` complexity in pipeline.py: 36 lines → 6 lines ✅
|
||||
- Total lines of code eliminated: 201 lines reduced to 12 lines (94% reduction) ✅
|
||||
- Improve test coverage: New parser modules have 92% coverage (vs original 10% in field_extractor.py)
|
||||
- Simplify maintenance: Pattern-based approach makes extension easy
|
||||
- 100% backward compatibility: All 66 existing tests pass (45 field_extractor + 21 pipeline)
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: Performance & Documentation (2026-01-22)
|
||||
|
||||
### Added
|
||||
|
||||
#### Configuration Constants Extraction
|
||||
- **Created `src/inference/constants.py`**: Centralized configuration constants
|
||||
- Detection & model configuration (confidence thresholds, IOU)
|
||||
- Image processing configuration (DPI, scaling factors)
|
||||
- Customer number parser confidence scores
|
||||
- Field extraction confidence multipliers
|
||||
- Account type detection thresholds
|
||||
- Pattern matching constants
|
||||
- 90 lines of well-documented constants with usage notes
|
||||
- Eliminates ~15 hardcoded magic numbers across codebase
|
||||
- File: [src/inference/constants.py](src/inference/constants.py)
|
||||
|
||||
#### Performance Optimization Documentation
|
||||
- **Created `docs/PERFORMANCE_OPTIMIZATION.md`**: Comprehensive performance guide (400+ lines)
|
||||
- **Batch Processing Optimization**: Parallel processing strategies, already-implemented dual pool system
|
||||
- **Database Query Optimization**: Connection pooling recommendations, index strategies
|
||||
- **Caching Strategies**: Model loading cache, parser reuse (already optimal), OCR result caching
|
||||
- **Memory Management**: Explicit cleanup, generator patterns, context managers
|
||||
- **Profiling Guidelines**: cProfile, memory_profiler, py-spy recommendations
|
||||
- **Benchmarking Scripts**: Ready-to-use performance measurement code
|
||||
- **Priority Roadmap**: High/Medium/Low priority optimizations with effort estimates
|
||||
- Expected impact: 2-5x throughput improvement for batch processing
|
||||
- File: [docs/PERFORMANCE_OPTIMIZATION.md](docs/PERFORMANCE_OPTIMIZATION.md)
|
||||
|
||||
### Phase 3 Status: **COMPLETED** ✅
|
||||
|
||||
- [x] Configuration constants extraction ✅
|
||||
- [x] Performance optimization analysis ✅
|
||||
- [x] Batch processing optimization recommendations ✅
|
||||
- [x] Database optimization strategies ✅
|
||||
- [x] Caching and memory management guidelines ✅
|
||||
- [x] Profiling and benchmarking documentation ✅
|
||||
|
||||
### Deliverables
|
||||
|
||||
**New Files** (2 files):
|
||||
1. `src/inference/constants.py` (90 lines) - Centralized configuration constants
|
||||
2. `docs/PERFORMANCE_OPTIMIZATION.md` (400+ lines) - Performance optimization guide
|
||||
|
||||
**Impact**:
|
||||
- Eliminates 15+ hardcoded magic numbers
|
||||
- Provides clear optimization roadmap
|
||||
- Documents existing performance features
|
||||
- Identifies quick wins (connection pooling, indexes)
|
||||
- Long-term strategy (caching, profiling)
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
|
||||
### Breaking Changes
|
||||
- **v2.x**: Requires `.env` file with database credentials
|
||||
- Action required: Create `.env` file based on `.env.example`
|
||||
- Affected: All deployments, CI/CD pipelines
|
||||
|
||||
### Migration Guide
|
||||
|
||||
#### From v1.x to v2.x (Environment Variables)
|
||||
1. Copy `.env.example` to `.env`:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
2. Edit `.env` and set your database password:
|
||||
```
|
||||
DB_PASSWORD=your_actual_password_here
|
||||
```
|
||||
|
||||
3. Install new dependency:
|
||||
```bash
|
||||
pip install python-dotenv
|
||||
```
|
||||
|
||||
4. Verify configuration loads correctly:
|
||||
```bash
|
||||
python -c "import config; print('Config loaded successfully')"
|
||||
```
|
||||
|
||||
## Summary of All Work Completed
|
||||
|
||||
### Files Created (13 new files)
|
||||
|
||||
**Phase 1** (3 files):
|
||||
1. `.env` - Environment variables for database credentials
|
||||
2. `.env.example` - Template for environment configuration
|
||||
3. `src/exceptions.py` - Custom exception hierarchy (35 lines, 66% coverage)
|
||||
|
||||
**Phase 2** (7 files):
|
||||
4. `src/inference/payment_line_parser.py` - Unified payment line parsing (90 lines, 92% coverage)
|
||||
5. `src/inference/customer_number_parser.py` - Unified customer number parsing (154 lines, 92% coverage)
|
||||
6. `tests/test_config.py` - Configuration tests (7 tests)
|
||||
7. `tests/test_db_security.py` - SQL injection prevention tests (9 tests)
|
||||
8. `tests/test_exceptions.py` - Exception hierarchy tests (16 tests)
|
||||
9. `tests/test_payment_line_parser.py` - Payment line parser tests (23 tests)
|
||||
10. `tests/test_customer_number_parser.py` - Customer number parser tests (32 tests)
|
||||
|
||||
**Phase 3** (2 files):
|
||||
11. `src/inference/constants.py` - Centralized configuration constants (90 lines)
|
||||
12. `docs/PERFORMANCE_OPTIMIZATION.md` - Performance optimization guide (400+ lines)
|
||||
|
||||
**Documentation** (1 file):
|
||||
13. `CHANGELOG.md` - This file (260+ lines of detailed documentation)
|
||||
|
||||
### Files Modified (4 files)
|
||||
1. `config.py` - Added environment variable loading with python-dotenv
|
||||
2. `src/data/db.py` - Fixed 2 SQL injection vulnerabilities (lines 246, 298)
|
||||
3. `src/inference/field_extractor.py` - Integrated new parsers (reduced 201 lines to 6 lines)
|
||||
4. `src/inference/pipeline.py` - Integrated PaymentLineParser (reduced 36 lines to 6 lines)
|
||||
5. `requirements.txt` - Added python-dotenv dependency
|
||||
|
||||
### Test Summary
|
||||
- **Total tests**: 153 tests across 7 test files
|
||||
- **Passing**: 153 (100%)
|
||||
- **Failing**: 0
|
||||
- **Runtime**: 4.50 seconds
|
||||
- **Coverage**:
|
||||
- New parser modules: 92%
|
||||
- Config module: 100%
|
||||
- Exception module: 66%
|
||||
- DB security coverage: 18% (focused on parameterized queries)
|
||||
|
||||
### Code Metrics
|
||||
- **Lines eliminated**: 237 lines of duplicated/complex code → 18 lines (92% reduction)
|
||||
- field_extractor.py: 201 lines → 6 lines
|
||||
- pipeline.py: 36 lines → 6 lines
|
||||
- **New code added**: 279 lines of well-tested parser code
|
||||
- **Net impact**: Replaced 237 lines of duplicate code with 279 lines of unified, tested code (+42 lines, but -3 implementations)
|
||||
- **Test coverage improvement**: 0% → 92% for parser logic
|
||||
|
||||
### Performance Impact
|
||||
- Configuration loading: Negligible (<1ms overhead for .env parsing)
|
||||
- SQL queries: No performance change (parameterized queries are standard practice)
|
||||
- Parser refactoring: No performance degradation (logic simplified, not changed)
|
||||
- Exception handling: Minimal overhead (only when exceptions are raised)
|
||||
|
||||
### Security Improvements
|
||||
- ✅ Eliminated plaintext password storage
|
||||
- ✅ Fixed 2 SQL injection vulnerabilities
|
||||
- ✅ Added input validation in database layer
|
||||
|
||||
### Maintainability Improvements
|
||||
- ✅ Eliminated code duplication (3 implementations → 1)
|
||||
- ✅ Strategy Pattern enables easy extension of customer number formats
|
||||
- ✅ Comprehensive test suite (153 tests) ensures safe refactoring
|
||||
- ✅ 100% backward compatibility maintained
|
||||
- ✅ Custom exception hierarchy for granular error handling
|
||||
@@ -1,805 +0,0 @@
|
||||
# Invoice Master POC v2 - 详细代码审查报告
|
||||
|
||||
**审查日期**: 2026-02-01
|
||||
**审查人**: Claude Code
|
||||
**项目路径**: `C:\Users\yaoji\git\ColaCoder\invoice-master-poc-v2`
|
||||
**代码统计**:
|
||||
- Python文件: 200+ 个
|
||||
- 测试文件: 97 个
|
||||
- TypeScript/React文件: 39 个
|
||||
- 总测试数: 1,601 个
|
||||
- 测试覆盖率: 28%
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
1. [执行摘要](#执行摘要)
|
||||
2. [架构概览](#架构概览)
|
||||
3. [详细模块审查](#详细模块审查)
|
||||
4. [代码质量问题](#代码质量问题)
|
||||
5. [安全风险分析](#安全风险分析)
|
||||
6. [性能问题](#性能问题)
|
||||
7. [改进建议](#改进建议)
|
||||
8. [总结与评分](#总结与评分)
|
||||
|
||||
---
|
||||
|
||||
## 执行摘要
|
||||
|
||||
### 总体评估
|
||||
|
||||
| 维度 | 评分 | 状态 |
|
||||
|------|------|------|
|
||||
| **代码质量** | 7.5/10 | 良好,但有改进空间 |
|
||||
| **安全性** | 7/10 | 基础安全到位,需加强 |
|
||||
| **可维护性** | 8/10 | 模块化良好 |
|
||||
| **测试覆盖** | 5/10 | 偏低,需提升 |
|
||||
| **性能** | 8/10 | 异步处理良好 |
|
||||
| **文档** | 8/10 | 文档详尽 |
|
||||
| **总体** | **7.3/10** | 生产就绪,需小幅改进 |
|
||||
|
||||
### 关键发现
|
||||
|
||||
**优势:**
|
||||
- 清晰的Monorepo架构,三包分离合理
|
||||
- 类型注解覆盖率高(>90%)
|
||||
- 存储抽象层设计优秀
|
||||
- FastAPI使用规范,依赖注入模式良好
|
||||
- 异常处理完善,自定义异常层次清晰
|
||||
|
||||
**风险:**
|
||||
- 测试覆盖率仅28%,远低于行业标准
|
||||
- AdminDB类过大(50+方法),违反单一职责原则
|
||||
- 内存队列存在单点故障风险
|
||||
- 部分安全细节需加强(时序攻击、文件上传验证)
|
||||
- 前端状态管理简单,可能难以扩展
|
||||
|
||||
---
|
||||
|
||||
## 架构概览
|
||||
|
||||
### 项目结构
|
||||
|
||||
```
|
||||
invoice-master-poc-v2/
|
||||
├── packages/
|
||||
│ ├── shared/ # 共享库 (74个Python文件)
|
||||
│ │ ├── pdf/ # PDF处理
|
||||
│ │ ├── ocr/ # OCR封装
|
||||
│ │ ├── normalize/ # 字段规范化
|
||||
│ │ ├── matcher/ # 字段匹配
|
||||
│ │ ├── storage/ # 存储抽象层
|
||||
│ │ ├── training/ # 训练组件
|
||||
│ │ └── augmentation/# 数据增强
|
||||
│ ├── training/ # 训练服务 (26个Python文件)
|
||||
│ │ ├── cli/ # 命令行工具
|
||||
│ │ ├── yolo/ # YOLO数据集
|
||||
│ │ └── processing/ # 任务处理
|
||||
│ └── inference/ # 推理服务 (100个Python文件)
|
||||
│ ├── web/ # FastAPI应用
|
||||
│ ├── pipeline/ # 推理管道
|
||||
│ ├── data/ # 数据层
|
||||
│ └── cli/ # 命令行工具
|
||||
├── frontend/ # React前端 (39个TS/TSX文件)
|
||||
│ ├── src/
|
||||
│ │ ├── components/ # UI组件
|
||||
│ │ ├── hooks/ # React Query hooks
|
||||
│ │ └── api/ # API客户端
|
||||
└── tests/ # 测试 (97个Python文件)
|
||||
```
|
||||
|
||||
### 技术栈
|
||||
|
||||
| 层级 | 技术 | 评估 |
|
||||
|------|------|------|
|
||||
| **前端** | React 18 + TypeScript + Vite + TailwindCSS | 现代栈,类型安全 |
|
||||
| **API框架** | FastAPI + Uvicorn | 高性能,异步支持 |
|
||||
| **数据库** | PostgreSQL + SQLModel | 类型安全ORM |
|
||||
| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) | 业界标准 |
|
||||
| **OCR** | PaddleOCR v5 | 支持瑞典语 |
|
||||
| **部署** | Docker + Azure/AWS | 云原生 |
|
||||
|
||||
---
|
||||
|
||||
## 详细模块审查
|
||||
|
||||
### 1. Shared Package
|
||||
|
||||
#### 1.1 配置模块 (`shared/config.py`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/config.py`
|
||||
**代码行数**: 82行
|
||||
|
||||
**优点:**
|
||||
- 使用环境变量加载配置,无硬编码敏感信息
|
||||
- DPI配置统一管理(DEFAULT_DPI = 150)
|
||||
- 密码无默认值,强制要求设置
|
||||
|
||||
**问题:**
|
||||
```python
|
||||
# 问题1: 配置分散,缺少验证
|
||||
DATABASE = {
|
||||
'host': os.getenv('DB_HOST', '192.168.68.31'), # 硬编码IP
|
||||
'port': int(os.getenv('DB_PORT', '5432')),
|
||||
# ...
|
||||
}
|
||||
|
||||
# 问题2: 缺少类型安全
|
||||
# 建议使用 Pydantic Settings
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 使用 Pydantic Settings 集中管理配置,添加验证逻辑
|
||||
|
||||
---
|
||||
|
||||
#### 1.2 存储抽象层 (`shared/storage/`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/storage/`
|
||||
**包含文件**: 8个
|
||||
|
||||
**优点:**
|
||||
- 设计优秀的抽象接口 `StorageBackend`
|
||||
- 支持 Local/Azure/S3 多后端
|
||||
- 预签名URL支持
|
||||
- 异常层次清晰
|
||||
|
||||
**代码示例 - 优秀设计:**
|
||||
```python
|
||||
class StorageBackend(ABC):
|
||||
@abstractmethod
|
||||
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
|
||||
pass
|
||||
```
|
||||
|
||||
**问题:**
|
||||
- `upload_bytes` 和 `download_bytes` 默认实现使用临时文件,效率较低
|
||||
- 缺少文件类型验证(魔术字节检查)
|
||||
|
||||
**严重程度**: 低
|
||||
**建议**: 子类可重写bytes方法以提高效率,添加文件类型验证
|
||||
|
||||
---
|
||||
|
||||
#### 1.3 异常定义 (`shared/exceptions.py`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/exceptions.py`
|
||||
**代码行数**: 103行
|
||||
|
||||
**优点:**
|
||||
- 清晰的异常层次结构
|
||||
- 所有异常继承自 `InvoiceExtractionError`
|
||||
- 包含详细的错误上下文
|
||||
|
||||
**代码示例:**
|
||||
```python
|
||||
class InvoiceExtractionError(Exception):
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
```
|
||||
|
||||
**评分**: 9/10 - 设计优秀
|
||||
|
||||
---
|
||||
|
||||
#### 1.4 数据增强 (`shared/augmentation/`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/augmentation/`
|
||||
**包含文件**: 10个
|
||||
|
||||
**功能:**
|
||||
- 12种数据增强策略
|
||||
- 透视变换、皱纹、边缘损坏、污渍等
|
||||
- 高斯模糊、运动模糊、噪声等
|
||||
|
||||
**代码质量**: 良好,模块化设计
|
||||
|
||||
---
|
||||
|
||||
### 2. Inference Package
|
||||
|
||||
#### 2.1 认证模块 (`inference/web/core/auth.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/core/auth.py`
|
||||
**代码行数**: 61行
|
||||
|
||||
**优点:**
|
||||
- 使用FastAPI依赖注入模式
|
||||
- Token过期检查
|
||||
- 记录最后使用时间
|
||||
|
||||
**安全问题:**
|
||||
```python
|
||||
# 问题: 时序攻击风险 (第46行)
|
||||
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired admin token.")
|
||||
|
||||
# 建议: 使用 constant-time 比较
|
||||
import hmac
|
||||
if not hmac.compare_digest(token, expected_token):
|
||||
raise HTTPException(status_code=401, ...)
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 使用 `hmac.compare_digest()` 进行constant-time比较
|
||||
|
||||
---
|
||||
|
||||
#### 2.2 限流器 (`inference/web/core/rate_limiter.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/core/rate_limiter.py`
|
||||
**代码行数**: 212行
|
||||
|
||||
**优点:**
|
||||
- 滑动窗口算法实现
|
||||
- 线程安全(使用Lock)
|
||||
- 支持并发任务限制
|
||||
- 可配置的限流策略
|
||||
|
||||
**代码示例 - 优秀设计:**
|
||||
```python
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
requests_per_minute: int = 10
|
||||
max_concurrent_jobs: int = 3
|
||||
min_poll_interval_ms: int = 1000
|
||||
```
|
||||
|
||||
**问题:**
|
||||
- 内存存储,服务重启后限流状态丢失
|
||||
- 分布式部署时无法共享限流状态
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 生产环境使用Redis实现分布式限流
|
||||
|
||||
---
|
||||
|
||||
#### 2.3 AdminDB (`inference/data/admin_db.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/data/admin_db.py`
|
||||
**代码行数**: 1300+行
|
||||
|
||||
**严重问题 - 类过大:**
|
||||
```python
|
||||
class AdminDB:
|
||||
# Token管理 (5个方法)
|
||||
# 文档管理 (8个方法)
|
||||
# 标注管理 (6个方法)
|
||||
# 训练任务 (7个方法)
|
||||
# 数据集 (6个方法)
|
||||
# 模型版本 (5个方法)
|
||||
# 批处理 (4个方法)
|
||||
# 锁管理 (3个方法)
|
||||
# ... 总计50+方法
|
||||
```
|
||||
|
||||
**影响:**
|
||||
- 违反单一职责原则
|
||||
- 难以维护
|
||||
- 测试困难
|
||||
- 不同领域变更互相影响
|
||||
|
||||
**严重程度**: 高
|
||||
**建议**: 按领域拆分为Repository模式
|
||||
|
||||
```python
|
||||
# 建议重构
|
||||
class TokenRepository:
|
||||
def validate(self, token: str) -> bool: ...
|
||||
|
||||
class DocumentRepository:
|
||||
def find_by_id(self, doc_id: str) -> Document | None: ...
|
||||
|
||||
class TrainingRepository:
|
||||
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### 2.4 文档路由 (`inference/web/api/v1/admin/documents.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/api/v1/admin/documents.py`
|
||||
**代码行数**: 692行
|
||||
|
||||
**优点:**
|
||||
- FastAPI使用规范
|
||||
- 输入验证完善
|
||||
- 响应模型定义清晰
|
||||
- 错误处理良好
|
||||
|
||||
**问题:**
|
||||
```python
|
||||
# 问题1: 文件上传缺少魔术字节验证 (第127-131行)
|
||||
content = await file.read()
|
||||
# 建议: 验证PDF魔术字节 %PDF
|
||||
|
||||
# 问题2: 路径遍历风险 (第494-498行)
|
||||
filename = Path(document.file_path).name
|
||||
# 建议: 使用 Path.name 并验证路径范围
|
||||
|
||||
# 问题3: 函数过长,职责过多
|
||||
# _convert_pdf_to_images 函数混合了PDF处理和存储操作
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 添加文件类型验证,拆分大函数
|
||||
|
||||
---
|
||||
|
||||
#### 2.5 推理服务 (`inference/web/services/inference.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/services/inference.py`
|
||||
**代码行数**: 361行
|
||||
|
||||
**优点:**
|
||||
- 支持动态模型加载
|
||||
- 懒加载初始化
|
||||
- 模型热重载支持
|
||||
|
||||
**问题:**
|
||||
```python
|
||||
# 问题1: 混合业务逻辑和技术实现
|
||||
def process_image(self, image_path: Path, ...) -> ServiceResult:
|
||||
# 1. 技术细节: 图像解码
|
||||
# 2. 业务逻辑: 字段提取
|
||||
# 3. 技术细节: 模型推理
|
||||
# 4. 业务逻辑: 结果验证
|
||||
|
||||
# 问题2: 可视化方法重复加载模型
|
||||
model = YOLO(str(self.model_config.model_path)) # 第316行
|
||||
# 应该在初始化时加载,避免重复IO
|
||||
|
||||
# 问题3: 临时文件未使用上下文管理器
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
# 建议使用 tempfile 上下文管理器
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 引入领域层和适配器模式,分离业务和技术逻辑
|
||||
|
||||
---
|
||||
|
||||
#### 2.6 异步队列 (`inference/web/workers/async_queue.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/workers/async_queue.py`
|
||||
**代码行数**: 213行
|
||||
|
||||
**优点:**
|
||||
- 线程安全实现
|
||||
- 优雅关闭支持
|
||||
- 任务状态跟踪
|
||||
|
||||
**严重问题:**
|
||||
```python
|
||||
# 问题: 内存队列,服务重启丢失任务 (第42行)
|
||||
self._queue: Queue[AsyncTask] = Queue(maxsize=max_size)
|
||||
|
||||
# 问题: 无法水平扩展
|
||||
# 问题: 任务持久化困难
|
||||
```
|
||||
|
||||
**严重程度**: 高
|
||||
**建议**: 使用Redis/RabbitMQ持久化队列
|
||||
|
||||
---
|
||||
|
||||
### 3. Training Package
|
||||
|
||||
#### 3.1 整体评估
|
||||
|
||||
**文件数量**: 26个Python文件
|
||||
|
||||
**优点:**
|
||||
- CLI工具设计良好
|
||||
- 双池协调器(CPU + GPU)设计优秀
|
||||
- 数据增强策略丰富
|
||||
|
||||
**总体评分**: 8/10
|
||||
|
||||
---
|
||||
|
||||
### 4. Frontend
|
||||
|
||||
#### 4.1 API客户端 (`frontend/src/api/client.ts`)
|
||||
|
||||
**文件位置**: `frontend/src/api/client.ts`
|
||||
**代码行数**: 42行
|
||||
|
||||
**优点:**
|
||||
- Axios配置清晰
|
||||
- 请求/响应拦截器
|
||||
- 认证token自动添加
|
||||
|
||||
**问题:**
|
||||
```typescript
|
||||
// 问题1: Token存储在localStorage,存在XSS风险
|
||||
const token = localStorage.getItem('admin_token')
|
||||
|
||||
// 问题2: 401错误处理不完整
|
||||
if (error.response?.status === 401) {
|
||||
console.warn('Authentication required...')
|
||||
// 应该触发重新登录或token刷新
|
||||
}
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 考虑使用http-only cookie存储token,完善错误处理
|
||||
|
||||
---
|
||||
|
||||
#### 4.2 Dashboard组件 (`frontend/src/components/Dashboard.tsx`)
|
||||
|
||||
**文件位置**: `frontend/src/components/Dashboard.tsx`
|
||||
**代码行数**: 301行
|
||||
|
||||
**优点:**
|
||||
- React hooks使用规范
|
||||
- 类型定义清晰
|
||||
- UI响应式设计
|
||||
|
||||
**问题:**
|
||||
```typescript
|
||||
// 问题1: 硬编码的进度值
|
||||
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
|
||||
if (doc.auto_label_status === 'running') {
|
||||
return 45 // 硬编码!
|
||||
}
|
||||
// ...
|
||||
}
|
||||
|
||||
// 问题2: 搜索功能未实现
|
||||
// 没有onChange处理
|
||||
|
||||
// 问题3: 缺少错误边界处理
|
||||
// 组件应该包裹在Error Boundary中
|
||||
```
|
||||
|
||||
**严重程度**: 低
|
||||
**建议**: 实现真实的进度获取,添加搜索功能
|
||||
|
||||
---
|
||||
|
||||
#### 4.3 整体评估
|
||||
|
||||
**优点:**
|
||||
- TypeScript类型安全
|
||||
- React Query状态管理
|
||||
- TailwindCSS样式一致
|
||||
|
||||
**问题:**
|
||||
- 缺少错误边界
|
||||
- 部分功能硬编码
|
||||
- 缺少单元测试
|
||||
|
||||
**总体评分**: 7.5/10
|
||||
|
||||
---
|
||||
|
||||
### 5. Tests
|
||||
|
||||
#### 5.1 测试统计
|
||||
|
||||
- **测试文件数**: 97个
|
||||
- **测试总数**: 1,601个
|
||||
- **测试覆盖率**: 28%
|
||||
|
||||
#### 5.2 覆盖率分析
|
||||
|
||||
| 模块 | 估计覆盖率 | 状态 |
|
||||
|------|-----------|------|
|
||||
| `shared/` | 35% | 偏低 |
|
||||
| `inference/web/` | 25% | 偏低 |
|
||||
| `inference/pipeline/` | 20% | 严重不足 |
|
||||
| `training/` | 30% | 偏低 |
|
||||
| `frontend/` | 15% | 严重不足 |
|
||||
|
||||
#### 5.3 测试质量问题
|
||||
|
||||
**优点:**
|
||||
- 使用了pytest框架
|
||||
- 有conftest.py配置
|
||||
- 部分集成测试
|
||||
|
||||
**问题:**
|
||||
- 覆盖率远低于行业标准(80%)
|
||||
- 缺少端到端测试
|
||||
- 部分测试可能过于简单
|
||||
|
||||
**严重程度**: 高
|
||||
**建议**: 制定测试计划,优先覆盖核心业务逻辑
|
||||
|
||||
---
|
||||
|
||||
## 代码质量问题
|
||||
|
||||
### 高优先级问题
|
||||
|
||||
| 问题 | 位置 | 影响 | 建议 |
|
||||
|------|------|------|------|
|
||||
| AdminDB类过大 | `inference/data/admin_db.py` | 维护困难 | 拆分为Repository模式 |
|
||||
| 内存队列单点故障 | `inference/web/workers/async_queue.py` | 任务丢失 | 使用Redis持久化 |
|
||||
| 测试覆盖率过低 | 全项目 | 代码风险 | 提升至60%+ |
|
||||
|
||||
### 中优先级问题
|
||||
|
||||
| 问题 | 位置 | 影响 | 建议 |
|
||||
|------|------|------|------|
|
||||
| 时序攻击风险 | `inference/web/core/auth.py` | 安全漏洞 | 使用hmac.compare_digest |
|
||||
| 限流器内存存储 | `inference/web/core/rate_limiter.py` | 分布式问题 | 使用Redis |
|
||||
| 配置分散 | `shared/config.py` | 难以管理 | 使用Pydantic Settings |
|
||||
| 文件上传验证不足 | `inference/web/api/v1/admin/documents.py` | 安全风险 | 添加魔术字节验证 |
|
||||
| 推理服务混合职责 | `inference/web/services/inference.py` | 难以测试 | 分离业务和技术逻辑 |
|
||||
|
||||
### 低优先级问题
|
||||
|
||||
| 问题 | 位置 | 影响 | 建议 |
|
||||
|------|------|------|------|
|
||||
| 前端搜索未实现 | `frontend/src/components/Dashboard.tsx` | 功能缺失 | 实现搜索功能 |
|
||||
| 硬编码进度值 | `frontend/src/components/Dashboard.tsx` | 用户体验 | 获取真实进度 |
|
||||
| Token存储方式 | `frontend/src/api/client.ts` | XSS风险 | 考虑http-only cookie |
|
||||
|
||||
---
|
||||
|
||||
## 安全风险分析
|
||||
|
||||
### 已识别的安全风险
|
||||
|
||||
#### 1. 时序攻击 (中风险)
|
||||
|
||||
**位置**: `inference/web/core/auth.py:46`
|
||||
|
||||
```python
|
||||
# 当前实现(有风险)
|
||||
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||
raise HTTPException(status_code=401, ...)
|
||||
|
||||
# 安全实现
|
||||
import hmac
|
||||
if not hmac.compare_digest(token, expected_token):
|
||||
raise HTTPException(status_code=401, ...)
|
||||
```
|
||||
|
||||
#### 2. 文件上传验证不足 (中风险)
|
||||
|
||||
**位置**: `inference/web/api/v1/admin/documents.py:127-131`
|
||||
|
||||
```python
|
||||
# 建议添加魔术字节验证
|
||||
ALLOWED_EXTENSIONS = {".pdf"}
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
if not content.startswith(b"%PDF"):
|
||||
raise HTTPException(400, "Invalid PDF file format")
|
||||
```
|
||||
|
||||
#### 3. 路径遍历风险 (中风险)
|
||||
|
||||
**位置**: `inference/web/api/v1/admin/documents.py:494-498`
|
||||
|
||||
```python
|
||||
# 建议实现
|
||||
from pathlib import Path
|
||||
|
||||
def get_safe_path(filename: str, base_dir: Path) -> Path:
|
||||
safe_name = Path(filename).name
|
||||
full_path = (base_dir / safe_name).resolve()
|
||||
if not full_path.is_relative_to(base_dir):
|
||||
raise HTTPException(400, "Invalid file path")
|
||||
return full_path
|
||||
```
|
||||
|
||||
#### 4. CORS配置 (低风险)
|
||||
|
||||
**位置**: FastAPI中间件配置
|
||||
|
||||
```python
|
||||
# 建议生产环境配置
|
||||
ALLOWED_ORIGINS = [
|
||||
"http://localhost:5173",
|
||||
"https://your-domain.com",
|
||||
]
|
||||
```
|
||||
|
||||
#### 5. XSS风险 (低风险)
|
||||
|
||||
**位置**: `frontend/src/api/client.ts:13`
|
||||
|
||||
```typescript
|
||||
// 当前实现
|
||||
const token = localStorage.getItem('admin_token')
|
||||
|
||||
// 建议考虑
|
||||
// 使用http-only cookie存储敏感token
|
||||
```
|
||||
|
||||
### 安全评分
|
||||
|
||||
| 类别 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| 认证 | 8/10 | 基础良好,需加强时序攻击防护 |
|
||||
| 输入验证 | 7/10 | 基本验证到位,需加强文件验证 |
|
||||
| 数据保护 | 8/10 | 无敏感信息硬编码 |
|
||||
| 传输安全 | 8/10 | 使用HTTPS(生产环境) |
|
||||
| 总体 | 7.5/10 | 基础安全良好,需加强细节 |
|
||||
|
||||
---
|
||||
|
||||
## 性能问题
|
||||
|
||||
### 已识别的性能问题
|
||||
|
||||
#### 1. 重复模型加载
|
||||
|
||||
**位置**: `inference/web/services/inference.py:316`
|
||||
|
||||
```python
|
||||
# 问题: 每次可视化都重新加载模型
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
|
||||
# 建议: 复用已加载的模型
|
||||
```
|
||||
|
||||
#### 2. 临时文件处理
|
||||
|
||||
**位置**: `shared/storage/base.py:178-203`
|
||||
|
||||
```python
|
||||
# 问题: bytes操作使用临时文件
|
||||
def upload_bytes(self, data: bytes, ...):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(data)
|
||||
temp_path = Path(f.name)
|
||||
# ...
|
||||
|
||||
# 建议: 子类重写为直接上传
|
||||
```
|
||||
|
||||
#### 3. 数据库查询优化
|
||||
|
||||
**位置**: `inference/data/admin_db.py`
|
||||
|
||||
```python
|
||||
# 问题: N+1查询风险
|
||||
for doc in documents:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
# ...
|
||||
|
||||
# 建议: 使用join预加载
|
||||
```
|
||||
|
||||
### 性能评分
|
||||
|
||||
| 类别 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| 响应时间 | 8/10 | 异步处理良好 |
|
||||
| 资源使用 | 7/10 | 有优化空间 |
|
||||
| 可扩展性 | 7/10 | 内存队列限制 |
|
||||
| 并发处理 | 8/10 | 线程池设计良好 |
|
||||
| 总体 | 7.5/10 | 良好,有优化空间 |
|
||||
|
||||
---
|
||||
|
||||
## 改进建议
|
||||
|
||||
### 立即执行 (本周)
|
||||
|
||||
1. **拆分AdminDB**
|
||||
- 创建 `repositories/` 目录
|
||||
- 按领域拆分:TokenRepository, DocumentRepository, TrainingRepository
|
||||
- 估计工时: 2天
|
||||
|
||||
2. **修复安全漏洞**
|
||||
- 添加 `hmac.compare_digest()` 时序攻击防护
|
||||
- 添加文件魔术字节验证
|
||||
- 估计工时: 0.5天
|
||||
|
||||
3. **提升测试覆盖率**
|
||||
- 优先测试 `inference/pipeline/`
|
||||
- 添加API集成测试
|
||||
- 目标: 从28%提升至50%
|
||||
- 估计工时: 3天
|
||||
|
||||
### 短期执行 (本月)
|
||||
|
||||
4. **引入消息队列**
|
||||
- 添加Redis服务
|
||||
- 使用Celery替换内存队列
|
||||
- 估计工时: 3天
|
||||
|
||||
5. **统一配置管理**
|
||||
- 使用 Pydantic Settings
|
||||
- 集中验证逻辑
|
||||
- 估计工时: 1天
|
||||
|
||||
6. **添加缓存层**
|
||||
- Redis缓存热点数据
|
||||
- 缓存文档、模型配置
|
||||
- 估计工时: 2天
|
||||
|
||||
### 长期执行 (本季度)
|
||||
|
||||
7. **数据库读写分离**
|
||||
- 配置主从数据库
|
||||
- 读操作使用从库
|
||||
- 估计工时: 3天
|
||||
|
||||
8. **事件驱动架构**
|
||||
- 引入事件总线
|
||||
- 解耦模块依赖
|
||||
- 估计工时: 5天
|
||||
|
||||
9. **前端优化**
|
||||
- 添加错误边界
|
||||
- 实现真实搜索功能
|
||||
- 添加E2E测试
|
||||
- 估计工时: 3天
|
||||
|
||||
---
|
||||
|
||||
## 总结与评分
|
||||
|
||||
### 各维度评分
|
||||
|
||||
| 维度 | 评分 | 权重 | 加权得分 |
|
||||
|------|------|------|----------|
|
||||
| **代码质量** | 7.5/10 | 20% | 1.5 |
|
||||
| **安全性** | 7.5/10 | 20% | 1.5 |
|
||||
| **可维护性** | 8/10 | 15% | 1.2 |
|
||||
| **测试覆盖** | 5/10 | 15% | 0.75 |
|
||||
| **性能** | 7.5/10 | 15% | 1.125 |
|
||||
| **文档** | 8/10 | 10% | 0.8 |
|
||||
| **架构设计** | 8/10 | 5% | 0.4 |
|
||||
| **总体** | **7.3/10** | 100% | **7.275** |
|
||||
|
||||
### 关键结论
|
||||
|
||||
1. **架构设计优秀**: Monorepo + 三包分离架构清晰,便于维护和扩展
|
||||
2. **代码质量良好**: 类型注解完善,文档详尽,结构清晰
|
||||
3. **安全基础良好**: 没有严重的安全漏洞,基础防护到位
|
||||
4. **测试是短板**: 28%覆盖率是最大风险点
|
||||
5. **生产就绪**: 经过小幅改进后可以投入生产使用
|
||||
|
||||
### 下一步行动
|
||||
|
||||
| 优先级 | 任务 | 预计工时 | 影响 |
|
||||
|--------|------|----------|------|
|
||||
| 高 | 拆分AdminDB | 2天 | 提升可维护性 |
|
||||
| 高 | 引入Redis队列 | 3天 | 解决任务丢失问题 |
|
||||
| 高 | 提升测试覆盖率 | 5天 | 降低代码风险 |
|
||||
| 中 | 修复安全漏洞 | 0.5天 | 提升安全性 |
|
||||
| 中 | 统一配置管理 | 1天 | 减少配置错误 |
|
||||
| 低 | 前端优化 | 3天 | 提升用户体验 |
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### 关键文件清单
|
||||
|
||||
| 文件 | 职责 | 问题 |
|
||||
|------|------|------|
|
||||
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
|
||||
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
|
||||
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
|
||||
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
|
||||
| `shared/shared/config.py` | 共享配置 | 分散管理 |
|
||||
|
||||
### 参考资源
|
||||
|
||||
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
|
||||
- [Celery Documentation](https://docs.celeryproject.org/)
|
||||
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
|
||||
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)
|
||||
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
|
||||
|
||||
---
|
||||
|
||||
**报告生成时间**: 2026-02-01
|
||||
**审查工具**: Claude Code + AST-grep + LSP
|
||||
@@ -124,7 +124,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
if not match:
|
||||
continue
|
||||
amount = self._parse_amount_str(match)
|
||||
if amount is not None and amount > 0:
|
||||
if amount is not None and 0 < amount < 10_000_000:
|
||||
all_amounts.append(amount)
|
||||
|
||||
# Return the last amount found (usually the total)
|
||||
@@ -134,7 +134,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
# Fallback: try shared validator on cleaned text
|
||||
cleaned = TextCleaner.normalize_amount_text(text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and amount > 0:
|
||||
if amount is not None and 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
|
||||
# Try to find any decimal number
|
||||
@@ -144,7 +144,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
amount_str = matches[-1].replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
if 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -156,7 +156,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
if match:
|
||||
try:
|
||||
amount = float(match.group(1))
|
||||
if amount > 0:
|
||||
if 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -168,7 +168,7 @@ class AmountNormalizer(BaseNormalizer):
|
||||
# Take the last/largest number
|
||||
try:
|
||||
amount = float(matches[-1])
|
||||
if amount > 0:
|
||||
if 0 < amount < 10_000_000:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@@ -62,14 +62,25 @@ class InvoiceNumberNormalizer(BaseNormalizer):
|
||||
# Skip if it looks like a date (YYYYMMDD)
|
||||
if len(seq) == 8 and seq.startswith("20"):
|
||||
continue
|
||||
# Skip year-only values (2024, 2025, 2026, etc.)
|
||||
if len(seq) == 4 and seq.startswith("20"):
|
||||
continue
|
||||
# Skip if too long (likely OCR number)
|
||||
if len(seq) > 10:
|
||||
continue
|
||||
valid_sequences.append(seq)
|
||||
|
||||
if valid_sequences:
|
||||
# Return shortest valid sequence
|
||||
return NormalizationResult.success(min(valid_sequences, key=len))
|
||||
# Prefer 4-8 digit sequences (typical invoice numbers),
|
||||
# then closest to 6 digits within that range.
|
||||
# This avoids picking short fragments like "775" from amounts.
|
||||
def _score(seq: str) -> tuple[int, int]:
|
||||
length = len(seq)
|
||||
if 4 <= length <= 8:
|
||||
return (1, -abs(length - 6))
|
||||
return (0, -length)
|
||||
|
||||
return NormalizationResult.success(max(valid_sequences, key=_score))
|
||||
|
||||
# Fallback: extract all digits if nothing else works
|
||||
digits = re.sub(r"\D", "", text)
|
||||
|
||||
@@ -14,7 +14,7 @@ class OcrNumberNormalizer(BaseNormalizer):
|
||||
Normalizes OCR (Optical Character Recognition) reference numbers.
|
||||
|
||||
OCR numbers in Swedish payment systems:
|
||||
- Minimum 5 digits
|
||||
- Minimum 2 digits
|
||||
- Used for automated payment matching
|
||||
"""
|
||||
|
||||
@@ -29,7 +29,7 @@ class OcrNumberNormalizer(BaseNormalizer):
|
||||
|
||||
digits = re.sub(r"\D", "", text)
|
||||
|
||||
if len(digits) < 5:
|
||||
if len(digits) < 2:
|
||||
return NormalizationResult.failure(
|
||||
f"Too few digits for OCR: {len(digits)}"
|
||||
)
|
||||
|
||||
@@ -234,7 +234,7 @@ class InferencePipeline:
|
||||
confidence_threshold=confidence_threshold,
|
||||
device='cuda' if use_gpu else 'cpu'
|
||||
)
|
||||
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
|
||||
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu, dpi=dpi)
|
||||
self.payment_line_parser = PaymentLineParser()
|
||||
self.dpi = dpi
|
||||
self.enable_fallback = enable_fallback
|
||||
@@ -361,6 +361,7 @@ class InferencePipeline:
|
||||
# Fallback if key fields are missing
|
||||
if self.enable_fallback and self._needs_fallback(result):
|
||||
self._run_fallback(pdf_path, result)
|
||||
self._dedup_invoice_number(result)
|
||||
|
||||
# Extract business invoice features if enabled
|
||||
if use_business_features:
|
||||
@@ -477,9 +478,48 @@ class InferencePipeline:
|
||||
# Store bbox for each field (useful for payment_line and other fields)
|
||||
result.bboxes[field_name] = best.bbox
|
||||
|
||||
# Validate date consistency
|
||||
self._validate_dates(result)
|
||||
|
||||
# Perform cross-validation if payment_line is detected
|
||||
self._cross_validate_payment_line(result)
|
||||
|
||||
# Remove InvoiceNumber if it duplicates OCR or Bankgiro
|
||||
self._dedup_invoice_number(result)
|
||||
|
||||
def _validate_dates(self, result: InferenceResult) -> None:
|
||||
"""Remove InvoiceDueDate if it is earlier than InvoiceDate."""
|
||||
invoice_date = result.fields.get('InvoiceDate')
|
||||
due_date = result.fields.get('InvoiceDueDate')
|
||||
if invoice_date and due_date and due_date < invoice_date:
|
||||
del result.fields['InvoiceDueDate']
|
||||
result.confidence.pop('InvoiceDueDate', None)
|
||||
result.bboxes.pop('InvoiceDueDate', None)
|
||||
|
||||
def _dedup_invoice_number(self, result: InferenceResult) -> None:
|
||||
"""Remove InvoiceNumber if it duplicates OCR or Bankgiro digits."""
|
||||
inv_num = result.fields.get('InvoiceNumber')
|
||||
if not inv_num:
|
||||
return
|
||||
inv_digits = re.sub(r'\D', '', str(inv_num))
|
||||
|
||||
# Check against OCR
|
||||
ocr = result.fields.get('OCR')
|
||||
if ocr and inv_digits == re.sub(r'\D', '', str(ocr)):
|
||||
del result.fields['InvoiceNumber']
|
||||
result.confidence.pop('InvoiceNumber', None)
|
||||
result.bboxes.pop('InvoiceNumber', None)
|
||||
return
|
||||
|
||||
# Check against Bankgiro (exact or substring match)
|
||||
bg = result.fields.get('Bankgiro')
|
||||
if bg:
|
||||
bg_digits = re.sub(r'\D', '', str(bg))
|
||||
if inv_digits == bg_digits or inv_digits in bg_digits:
|
||||
del result.fields['InvoiceNumber']
|
||||
result.confidence.pop('InvoiceNumber', None)
|
||||
result.bboxes.pop('InvoiceNumber', None)
|
||||
|
||||
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
|
||||
"""
|
||||
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
|
||||
@@ -638,10 +678,14 @@ class InferencePipeline:
|
||||
|
||||
def _needs_fallback(self, result: InferenceResult) -> bool:
|
||||
"""Check if fallback OCR is needed."""
|
||||
# Check for key fields
|
||||
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
|
||||
missing = sum(1 for f in key_fields if f not in result.fields)
|
||||
return missing >= 2 # Fallback if 2+ key fields missing
|
||||
important_fields = ['InvoiceDate', 'InvoiceDueDate', 'supplier_organisation_number']
|
||||
|
||||
key_missing = sum(1 for f in key_fields if f not in result.fields)
|
||||
important_missing = sum(1 for f in important_fields if f not in result.fields)
|
||||
|
||||
# Fallback if any key field missing OR 2+ important fields missing
|
||||
return key_missing >= 1 or important_missing >= 2
|
||||
|
||||
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
|
||||
"""Run full-page OCR fallback."""
|
||||
@@ -673,12 +717,13 @@ class InferencePipeline:
|
||||
"""Extract fields using regex patterns (fallback)."""
|
||||
patterns = {
|
||||
'Amount': [
|
||||
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
|
||||
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
|
||||
r'(?:att\s+betala)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
|
||||
r'(?:summa|total|belopp)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
|
||||
r'([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)\s*$',
|
||||
],
|
||||
'Bankgiro': [
|
||||
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
|
||||
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
|
||||
r'(?<!\d)(\d{3,4}[-\s]\d{4})(?!\d)',
|
||||
],
|
||||
'OCR': [
|
||||
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
|
||||
@@ -686,6 +731,20 @@ class InferencePipeline:
|
||||
'InvoiceNumber': [
|
||||
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
|
||||
],
|
||||
'InvoiceDate': [
|
||||
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
|
||||
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
|
||||
],
|
||||
'InvoiceDueDate': [
|
||||
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
|
||||
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
|
||||
],
|
||||
'supplier_organisation_number': [
|
||||
r'(?:org\.?\s*n[ru]|organisationsnummer)\s*[:.]?\s*(\d{6}[-\s]?\d{4})',
|
||||
],
|
||||
'Plusgiro': [
|
||||
r'(?:plusgiro|pg)\s*[:.]?\s*(\d[\d\s-]{4,12}\d)',
|
||||
],
|
||||
}
|
||||
|
||||
for field_name, field_patterns in patterns.items():
|
||||
@@ -708,6 +767,22 @@ class InferencePipeline:
|
||||
digits = re.sub(r'\D', '', value)
|
||||
if len(digits) == 8:
|
||||
value = f"{digits[:4]}-{digits[4:]}"
|
||||
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
# Normalize DD/MM/YYYY to YYYY-MM-DD
|
||||
date_match = re.match(r'(\d{2})[-/](\d{2})[-/](\d{4})', value)
|
||||
if date_match:
|
||||
value = f"{date_match.group(3)}-{date_match.group(2)}-{date_match.group(1)}"
|
||||
# Replace / with -
|
||||
value = value.replace('/', '-')
|
||||
elif field_name == 'InvoiceNumber':
|
||||
# Skip year-like values (2024, 2025, 2026, etc.)
|
||||
if re.match(r'^20\d{2}$', value):
|
||||
continue
|
||||
elif field_name == 'supplier_organisation_number':
|
||||
# Ensure NNNNNN-NNNN format
|
||||
digits = re.sub(r'\D', '', value)
|
||||
if len(digits) == 10:
|
||||
value = f"{digits[:6]}-{digits[6:]}"
|
||||
|
||||
result.fields[field_name] = value
|
||||
result.confidence[field_name] = 0.5 # Lower confidence for regex
|
||||
|
||||
@@ -123,12 +123,12 @@ class ValueSelector:
|
||||
|
||||
@staticmethod
|
||||
def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
"""Select token with the longest digit sequence (min 5 digits)."""
|
||||
"""Select token with the longest digit sequence (min 2 digits)."""
|
||||
best: OCRToken | None = None
|
||||
best_count = 0
|
||||
for token in tokens:
|
||||
digit_count = _count_digits(token.text)
|
||||
if digit_count >= 5 and digit_count > best_count:
|
||||
if digit_count >= 2 and digit_count > best_count:
|
||||
best = token
|
||||
best_count = digit_count
|
||||
return [best] if best else []
|
||||
|
||||
103
scripts/analyze_v3.py
Normal file
103
scripts/analyze_v3.py
Normal file
@@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Analyze batch inference v3 results (Round 2 fixes)."""
|
||||
|
||||
import json
|
||||
from collections import Counter
|
||||
|
||||
with open("scripts/inference_results_v3.json") as f:
|
||||
results = json.load(f)
|
||||
|
||||
total = len(results)
|
||||
success = sum(1 for r in results if r["status"] == 200)
|
||||
print(f"Total PDFs: {total}, Successful: {success}")
|
||||
print()
|
||||
|
||||
# Summary table
|
||||
header = f"{'PDF':<40} {'Det':<4} {'Fld':<4} {'Time':<7} Fields"
|
||||
print(header)
|
||||
print("-" * 140)
|
||||
for r in results:
|
||||
fn = r["filename"][:39]
|
||||
data = r.get("data", {})
|
||||
result_data = data.get("result", {})
|
||||
fields = result_data.get("fields", {})
|
||||
dets = len(result_data.get("detections", []))
|
||||
nfields = len(fields)
|
||||
t = r["time_seconds"]
|
||||
parts = []
|
||||
for k, v in fields.items():
|
||||
sv = str(v)
|
||||
if len(sv) > 30:
|
||||
sv = sv[:27] + "..."
|
||||
parts.append(f"{k}={sv}")
|
||||
field_str = ", ".join(parts)
|
||||
print(f"{fn:<40} {dets:<4} {nfields:<4} {t:<7} {field_str}")
|
||||
|
||||
print()
|
||||
|
||||
# Field coverage
|
||||
field_counts: Counter = Counter()
|
||||
conf_sums: Counter = Counter()
|
||||
ok_count = 0
|
||||
for r in results:
|
||||
if r["status"] != 200:
|
||||
continue
|
||||
ok_count += 1
|
||||
result_data = r["data"]["result"]
|
||||
for k in result_data.get("fields", {}):
|
||||
field_counts[k] += 1
|
||||
for k, v in (result_data.get("confidence") or {}).items():
|
||||
conf_sums[k] += v
|
||||
|
||||
print(f"Field Coverage ({ok_count} successful PDFs):")
|
||||
hdr = f"{'Field':<35} {'Present':<10} {'Rate':<10} {'Avg Conf':<10}"
|
||||
print(hdr)
|
||||
print("-" * 65)
|
||||
for field in [
|
||||
"InvoiceNumber", "InvoiceDate", "InvoiceDueDate", "OCR",
|
||||
"Amount", "Bankgiro", "Plusgiro",
|
||||
"supplier_organisation_number", "customer_number", "payment_line",
|
||||
]:
|
||||
cnt = field_counts.get(field, 0)
|
||||
rate = cnt / ok_count * 100 if ok_count else 0
|
||||
avg_conf = conf_sums.get(field, 0) / cnt if cnt else 0
|
||||
flag = ""
|
||||
if rate < 30:
|
||||
flag = " <<<"
|
||||
elif rate < 60:
|
||||
flag = " !!"
|
||||
print(f"{field:<35} {cnt:<10} {rate:<10.1f} {avg_conf:<10.3f}{flag}")
|
||||
|
||||
# Fallback count
|
||||
fb_count = 0
|
||||
for r in results:
|
||||
if r["status"] == 200:
|
||||
result_data = r["data"]["result"]
|
||||
if result_data.get("fallback_used"):
|
||||
fb_count += 1
|
||||
print(f"\nFallback used: {fb_count}/{ok_count}")
|
||||
|
||||
# Low-confidence fields
|
||||
print("\nLow-confidence extractions (< 0.7):")
|
||||
for r in results:
|
||||
if r["status"] != 200:
|
||||
continue
|
||||
result_data = r["data"]["result"]
|
||||
for k, v in (result_data.get("confidence") or {}).items():
|
||||
if v < 0.7:
|
||||
fv = result_data.get("fields", {}).get(k, "?")
|
||||
print(f" [{v:.3f}] {k:<25} = {str(fv)[:40]:<40} ({r['filename'][:36]})")
|
||||
|
||||
# PDFs with very few fields (possible issues)
|
||||
print("\nPDFs with <= 2 fields extracted:")
|
||||
for r in results:
|
||||
if r["status"] != 200:
|
||||
continue
|
||||
result_data = r["data"]["result"]
|
||||
fields = result_data.get("fields", {})
|
||||
if len(fields) <= 2:
|
||||
print(f" {r['filename']}: {len(fields)} fields - {list(fields.keys())}")
|
||||
|
||||
# Avg time
|
||||
avg_time = sum(r["time_seconds"] for r in results) / len(results)
|
||||
print(f"\nAverage processing time: {avg_time:.2f}s")
|
||||
92
scripts/batch_inference_v3.py
Normal file
92
scripts/batch_inference_v3.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Batch inference v3 - 30 random PDFs for Round 2 validation."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
PDF_DIR = "/mnt/c/Users/yaoji/git/Billo/Billo.Platform.Document/Billo.Platform.Document.AdminAPI/downloads/to_check"
|
||||
API_URL = "http://localhost:8000/api/v1/infer"
|
||||
OUTPUT_FILE = "/mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/scripts/inference_results_v3.json"
|
||||
SAMPLE_SIZE = 30
|
||||
|
||||
|
||||
def main():
|
||||
random.seed(99_2026) # New seed for Round 3
|
||||
|
||||
all_pdfs = [f for f in os.listdir(PDF_DIR) if f.lower().endswith(".pdf")]
|
||||
selected = random.sample(all_pdfs, min(SAMPLE_SIZE, len(all_pdfs)))
|
||||
|
||||
print(f"Selected {len(selected)} random PDFs for inference")
|
||||
|
||||
results = []
|
||||
for i, filename in enumerate(selected, 1):
|
||||
filepath = os.path.join(PDF_DIR, filename)
|
||||
filesize = os.path.getsize(filepath)
|
||||
print(f"[{i}/{len(selected)}] Processing {filename}...", end=" ", flush=True)
|
||||
|
||||
start = time.time()
|
||||
try:
|
||||
with open(filepath, "rb") as f:
|
||||
resp = requests.post(
|
||||
API_URL,
|
||||
files={"file": (filename, f, "application/pdf")},
|
||||
timeout=120,
|
||||
)
|
||||
elapsed = round(time.time() - start, 2)
|
||||
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
field_count = sum(
|
||||
1 for k, v in data.items()
|
||||
if k not in (
|
||||
"DocumentId", "confidence", "success", "fallback_used",
|
||||
"bboxes", "cross_validation", "processing_time_ms",
|
||||
"line_items", "vat_summary", "vat_validation",
|
||||
"raw_detections", "detection_classes", "detection_count",
|
||||
)
|
||||
and v is not None
|
||||
)
|
||||
det_count = data.get("detection_count", "?")
|
||||
print(f"OK ({elapsed}s) - {field_count} fields, {det_count} detections")
|
||||
results.append({
|
||||
"filename": filename,
|
||||
"status": resp.status_code,
|
||||
"time_seconds": elapsed,
|
||||
"filesize": filesize,
|
||||
"data": data,
|
||||
})
|
||||
else:
|
||||
print(f"HTTP {resp.status_code} ({elapsed}s)")
|
||||
results.append({
|
||||
"filename": filename,
|
||||
"status": resp.status_code,
|
||||
"time_seconds": elapsed,
|
||||
"filesize": filesize,
|
||||
"error": resp.text[:200],
|
||||
})
|
||||
except Exception as e:
|
||||
elapsed = round(time.time() - start, 2)
|
||||
print(f"FAILED ({elapsed}s) - {e}")
|
||||
results.append({
|
||||
"filename": filename,
|
||||
"status": -1,
|
||||
"time_seconds": elapsed,
|
||||
"filesize": filesize,
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||
print(f"\nResults saved to {OUTPUT_FILE}")
|
||||
|
||||
success = sum(1 for r in results if r["status"] == 200)
|
||||
failed = len(results) - success
|
||||
print(f"Total: {len(results)}, Success: {success}, Failed: {failed}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2986
scripts/inference_results_v3.json
Normal file
2986
scripts/inference_results_v3.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,387 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
PP-StructureV3 Line Items Extraction POC
|
||||
|
||||
Tests line items extraction from Swedish invoices using PP-StructureV3.
|
||||
Parses HTML table structure to extract structured line item data.
|
||||
|
||||
Run with invoice-sm120 conda environment.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
from html.parser import HTMLParser
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root / "packages" / "backend"))
|
||||
|
||||
from paddleocr import PPStructureV3
|
||||
import fitz # PyMuPDF
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineItem:
|
||||
"""Single line item from invoice."""
|
||||
row_index: int
|
||||
article_number: str | None
|
||||
description: str | None
|
||||
quantity: str | None
|
||||
unit: str | None
|
||||
unit_price: str | None
|
||||
amount: str | None
|
||||
vat_rate: str | None
|
||||
confidence: float = 0.9
|
||||
|
||||
|
||||
class TableHTMLParser(HTMLParser):
|
||||
"""Parse HTML table into rows and cells."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rows: list[list[str]] = []
|
||||
self.current_row: list[str] = []
|
||||
self.current_cell: str = ""
|
||||
self.in_td = False
|
||||
self.in_thead = False
|
||||
self.header_row: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
if tag == "tr":
|
||||
self.current_row = []
|
||||
elif tag in ("td", "th"):
|
||||
self.in_td = True
|
||||
self.current_cell = ""
|
||||
elif tag == "thead":
|
||||
self.in_thead = True
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if tag in ("td", "th"):
|
||||
self.in_td = False
|
||||
self.current_row.append(self.current_cell.strip())
|
||||
elif tag == "tr":
|
||||
if self.current_row:
|
||||
if self.in_thead:
|
||||
self.header_row = self.current_row
|
||||
else:
|
||||
self.rows.append(self.current_row)
|
||||
elif tag == "thead":
|
||||
self.in_thead = False
|
||||
|
||||
def handle_data(self, data):
|
||||
if self.in_td:
|
||||
self.current_cell += data
|
||||
|
||||
|
||||
# Swedish column name mappings
|
||||
# Note: Some headers may contain multiple column names merged together
|
||||
COLUMN_MAPPINGS = {
|
||||
'article_number': ['art nummer', 'artikelnummer', 'artikel', 'artnr', 'art.nr', 'art nr'],
|
||||
'description': ['beskrivning', 'produktbeskrivning', 'produkt', 'tjänst', 'text', 'benämning', 'vara/tjänst', 'vara'],
|
||||
'quantity': ['antal', 'qty', 'st', 'pcs', 'kvantitet'],
|
||||
'unit': ['enhet', 'unit'],
|
||||
'unit_price': ['á-pris', 'a-pris', 'pris', 'styckpris', 'enhetspris', 'à pris'],
|
||||
'amount': ['belopp', 'summa', 'total', 'netto', 'rad summa'],
|
||||
'vat_rate': ['moms', 'moms%', 'vat', 'skatt', 'moms %'],
|
||||
}
|
||||
|
||||
|
||||
def normalize_header(header: str) -> str:
|
||||
"""Normalize header text for matching."""
|
||||
return header.lower().strip().replace(".", "").replace("-", " ")
|
||||
|
||||
|
||||
def map_columns(headers: list[str]) -> dict[int, str]:
|
||||
"""Map column indices to field names."""
|
||||
mapping = {}
|
||||
for idx, header in enumerate(headers):
|
||||
normalized = normalize_header(header)
|
||||
|
||||
# Skip empty headers
|
||||
if not normalized.strip():
|
||||
continue
|
||||
|
||||
best_match = None
|
||||
best_match_len = 0
|
||||
|
||||
for field, patterns in COLUMN_MAPPINGS.items():
|
||||
for pattern in patterns:
|
||||
# Require exact match or pattern must be a significant portion
|
||||
if pattern == normalized:
|
||||
# Exact match - use immediately
|
||||
best_match = field
|
||||
best_match_len = len(pattern) + 100 # Prioritize exact
|
||||
break
|
||||
elif pattern in normalized and len(pattern) > best_match_len:
|
||||
# Pattern found in header - use longer matches
|
||||
if len(pattern) >= 3: # Minimum pattern length
|
||||
best_match = field
|
||||
best_match_len = len(pattern)
|
||||
|
||||
if best_match_len > 100: # Was exact match
|
||||
break
|
||||
|
||||
if best_match:
|
||||
mapping[idx] = best_match
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def parse_table_html(html: str) -> tuple[list[str], list[list[str]]]:
|
||||
"""Parse HTML table and return header and rows."""
|
||||
parser = TableHTMLParser()
|
||||
parser.feed(html)
|
||||
return parser.header_row, parser.rows
|
||||
|
||||
|
||||
def detect_header_row(rows: list[list[str]]) -> tuple[int, list[str], bool]:
|
||||
"""
|
||||
Detect which row is the header based on content patterns.
|
||||
|
||||
Returns (header_row_index, header_row, is_at_end).
|
||||
is_at_end indicates if header is at the end (table is reversed).
|
||||
Returns (-1, [], False) if no header detected.
|
||||
"""
|
||||
header_keywords = set()
|
||||
for patterns in COLUMN_MAPPINGS.values():
|
||||
for p in patterns:
|
||||
header_keywords.add(p.lower())
|
||||
|
||||
best_match = (-1, [], 0)
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
# Skip empty rows
|
||||
if all(not cell.strip() for cell in row):
|
||||
continue
|
||||
|
||||
# Check if row contains header keywords
|
||||
row_text = " ".join(cell.lower() for cell in row)
|
||||
matches = sum(1 for kw in header_keywords if kw in row_text)
|
||||
|
||||
# Track the best match
|
||||
if matches > best_match[2]:
|
||||
best_match = (i, row, matches)
|
||||
|
||||
if best_match[2] >= 2:
|
||||
header_idx = best_match[0]
|
||||
is_at_end = header_idx == len(rows) - 1 or header_idx > len(rows) // 2
|
||||
return header_idx, best_match[1], is_at_end
|
||||
|
||||
return -1, [], False
|
||||
|
||||
|
||||
def extract_line_items(html: str) -> list[LineItem]:
|
||||
"""Extract line items from HTML table."""
|
||||
header, rows = parse_table_html(html)
|
||||
|
||||
is_reversed = False
|
||||
if not header:
|
||||
# Try to detect header row from content
|
||||
header_idx, detected_header, is_at_end = detect_header_row(rows)
|
||||
if header_idx >= 0:
|
||||
header = detected_header
|
||||
if is_at_end:
|
||||
# Header is at the end - table is reversed
|
||||
is_reversed = True
|
||||
rows = rows[:header_idx] # Data rows are before header
|
||||
else:
|
||||
rows = rows[header_idx + 1:] # Data rows start after header
|
||||
elif rows:
|
||||
# Fall back to first non-empty row
|
||||
for i, row in enumerate(rows):
|
||||
if any(cell.strip() for cell in row):
|
||||
header = row
|
||||
rows = rows[i + 1:]
|
||||
break
|
||||
|
||||
column_map = map_columns(header)
|
||||
|
||||
items = []
|
||||
for row_idx, row in enumerate(rows):
|
||||
item_data = {
|
||||
'row_index': row_idx,
|
||||
'article_number': None,
|
||||
'description': None,
|
||||
'quantity': None,
|
||||
'unit': None,
|
||||
'unit_price': None,
|
||||
'amount': None,
|
||||
'vat_rate': None,
|
||||
}
|
||||
|
||||
for col_idx, cell in enumerate(row):
|
||||
if col_idx in column_map:
|
||||
field = column_map[col_idx]
|
||||
item_data[field] = cell if cell else None
|
||||
|
||||
# Only add if we have at least description or amount
|
||||
if item_data['description'] or item_data['amount']:
|
||||
items.append(LineItem(**item_data))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def render_pdf_to_image(pdf_path: str, dpi: int = 200) -> bytes:
|
||||
"""Render first page of PDF to image bytes."""
|
||||
doc = fitz.open(pdf_path)
|
||||
page = doc[0]
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
img_bytes = pix.tobytes("png")
|
||||
doc.close()
|
||||
return img_bytes
|
||||
|
||||
|
||||
def test_line_items_extraction(pdf_path: str) -> dict:
|
||||
"""Test line items extraction on a PDF."""
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Testing Line Items Extraction: {Path(pdf_path).name}")
|
||||
print(f"{'='*70}")
|
||||
|
||||
# Render PDF to image
|
||||
print("Rendering PDF to image...")
|
||||
img_bytes = render_pdf_to_image(pdf_path)
|
||||
|
||||
# Save temp image
|
||||
temp_img_path = "/tmp/test_invoice.png"
|
||||
with open(temp_img_path, "wb") as f:
|
||||
f.write(img_bytes)
|
||||
|
||||
# Initialize PP-StructureV3
|
||||
print("Initializing PP-StructureV3...")
|
||||
pipeline = PPStructureV3(
|
||||
device="gpu:0",
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
)
|
||||
|
||||
# Run detection
|
||||
print("Running table detection...")
|
||||
results = pipeline.predict(temp_img_path)
|
||||
|
||||
all_line_items = []
|
||||
table_details = []
|
||||
|
||||
for result in results if results else []:
|
||||
table_res_list = result.get("table_res_list") if hasattr(result, "get") else None
|
||||
|
||||
if table_res_list:
|
||||
print(f"\nFound {len(table_res_list)} tables")
|
||||
|
||||
for i, table_res in enumerate(table_res_list):
|
||||
html = table_res.get("pred_html", "")
|
||||
ocr_pred = table_res.get("table_ocr_pred", {})
|
||||
|
||||
print(f"\n--- Table {i+1} ---")
|
||||
|
||||
# Debug: show full HTML for first table
|
||||
if i == 0:
|
||||
print(f" Full HTML:\n{html}")
|
||||
|
||||
# Debug: inspect table_ocr_pred structure
|
||||
if isinstance(ocr_pred, dict):
|
||||
print(f" table_ocr_pred keys: {list(ocr_pred.keys())}")
|
||||
# Check if rec_texts exists (actual OCR text)
|
||||
if "rec_texts" in ocr_pred:
|
||||
texts = ocr_pred["rec_texts"]
|
||||
print(f" OCR texts count: {len(texts)}")
|
||||
print(f" Sample OCR texts: {texts[:5]}")
|
||||
elif isinstance(ocr_pred, list):
|
||||
print(f" table_ocr_pred is list with {len(ocr_pred)} items")
|
||||
if ocr_pred:
|
||||
print(f" First item type: {type(ocr_pred[0])}")
|
||||
print(f" First few items: {ocr_pred[:3]}")
|
||||
|
||||
# Parse HTML
|
||||
header, rows = parse_table_html(html)
|
||||
print(f" HTML Header (from thead): {header}")
|
||||
print(f" HTML Rows: {len(rows)}")
|
||||
|
||||
# Try to detect header if not in thead
|
||||
detected_header = None
|
||||
is_reversed = False
|
||||
if not header and rows:
|
||||
header_idx, detected_header, is_at_end = detect_header_row(rows)
|
||||
if header_idx >= 0:
|
||||
is_reversed = is_at_end
|
||||
print(f" Detected header at row {header_idx}: {detected_header}")
|
||||
print(f" Table is {'REVERSED (header at bottom)' if is_reversed else 'normal'}")
|
||||
header = detected_header
|
||||
|
||||
if rows:
|
||||
print(f" First row: {rows[0]}")
|
||||
if len(rows) > 1:
|
||||
print(f" Second row: {rows[1]}")
|
||||
|
||||
# Check if this looks like a line items table
|
||||
column_map = map_columns(header) if header else {}
|
||||
print(f" Column mapping: {column_map}")
|
||||
|
||||
is_line_items_table = (
|
||||
'description' in column_map.values() or
|
||||
'amount' in column_map.values() or
|
||||
'article_number' in column_map.values()
|
||||
)
|
||||
|
||||
if is_line_items_table:
|
||||
print(f" >>> This appears to be a LINE ITEMS table!")
|
||||
items = extract_line_items(html)
|
||||
print(f" Extracted {len(items)} line items:")
|
||||
for item in items:
|
||||
print(f" - {item.description}: {item.quantity} x {item.unit_price} = {item.amount}")
|
||||
all_line_items.extend(items)
|
||||
else:
|
||||
print(f" >>> This is NOT a line items table (summary/payment)")
|
||||
|
||||
table_details.append({
|
||||
"index": i,
|
||||
"header": header,
|
||||
"row_count": len(rows),
|
||||
"is_line_items": is_line_items_table,
|
||||
"column_map": column_map,
|
||||
})
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f"EXTRACTION SUMMARY")
|
||||
print(f"{'='*70}")
|
||||
print(f"Total tables: {len(table_details)}")
|
||||
print(f"Line items tables: {sum(1 for t in table_details if t['is_line_items'])}")
|
||||
print(f"Total line items: {len(all_line_items)}")
|
||||
|
||||
return {
|
||||
"pdf": pdf_path,
|
||||
"tables": table_details,
|
||||
"line_items": all_line_items,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Test line items extraction")
|
||||
parser.add_argument("--pdf", type=str, help="Path to PDF file")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pdf:
|
||||
# Test specific PDF
|
||||
pdf_path = Path(args.pdf)
|
||||
if not pdf_path.exists():
|
||||
# Try relative to project root
|
||||
pdf_path = project_root / args.pdf
|
||||
if not pdf_path.exists():
|
||||
print(f"PDF not found: {args.pdf}")
|
||||
return
|
||||
test_line_items_extraction(str(pdf_path))
|
||||
else:
|
||||
# Test default invoice
|
||||
default_pdf = project_root / "exampl" / "Faktura54011.pdf"
|
||||
if default_pdf.exists():
|
||||
test_line_items_extraction(str(default_pdf))
|
||||
else:
|
||||
print(f"Default PDF not found: {default_pdf}")
|
||||
print("Usage: python ppstructure_line_items_poc.py --pdf <path>")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,154 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
PP-StructureV3 POC Script
|
||||
|
||||
Tests table detection on real Swedish invoices using PP-StructureV3.
|
||||
Run with invoice-sm120 conda environment.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root / "packages" / "backend"))
|
||||
|
||||
from paddleocr import PPStructureV3
|
||||
import fitz # PyMuPDF
|
||||
|
||||
|
||||
def render_pdf_to_image(pdf_path: str, dpi: int = 200) -> bytes:
|
||||
"""Render first page of PDF to image bytes."""
|
||||
doc = fitz.open(pdf_path)
|
||||
page = doc[0]
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
img_bytes = pix.tobytes("png")
|
||||
doc.close()
|
||||
return img_bytes
|
||||
|
||||
|
||||
def test_table_detection(pdf_path: str) -> dict:
|
||||
"""Test PP-StructureV3 table detection on a PDF."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing: {Path(pdf_path).name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Render PDF to image
|
||||
print("Rendering PDF to image...")
|
||||
img_bytes = render_pdf_to_image(pdf_path)
|
||||
|
||||
# Save temp image
|
||||
temp_img_path = "/tmp/test_invoice.png"
|
||||
with open(temp_img_path, "wb") as f:
|
||||
f.write(img_bytes)
|
||||
print(f"Saved temp image: {temp_img_path}")
|
||||
|
||||
# Initialize PP-StructureV3
|
||||
print("Initializing PP-StructureV3...")
|
||||
pipeline = PPStructureV3(
|
||||
device="gpu:0",
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
)
|
||||
|
||||
# Run detection
|
||||
print("Running table detection...")
|
||||
results = pipeline.predict(temp_img_path)
|
||||
|
||||
# Parse results - PaddleX 3.x returns dict-like LayoutParsingResultV2
|
||||
tables_found = []
|
||||
all_elements = []
|
||||
|
||||
for result in results if results else []:
|
||||
# Get table results from the new API
|
||||
table_res_list = result.get("table_res_list") if hasattr(result, "get") else None
|
||||
|
||||
if table_res_list:
|
||||
print(f" Found {len(table_res_list)} tables in table_res_list")
|
||||
for i, table_res in enumerate(table_res_list):
|
||||
# Debug: show all keys in table_res
|
||||
if isinstance(table_res, dict):
|
||||
print(f" Table {i+1} keys: {list(table_res.keys())}")
|
||||
else:
|
||||
print(f" Table {i+1} attrs: {[a for a in dir(table_res) if not a.startswith('_')]}")
|
||||
|
||||
# Extract table info - use correct key names from PaddleX 3.x
|
||||
cell_boxes = table_res.get("cell_box_list", [])
|
||||
html = table_res.get("pred_html", "") # HTML is in pred_html
|
||||
ocr_text = table_res.get("table_ocr_pred", [])
|
||||
region_id = table_res.get("table_region_id", -1)
|
||||
bbox = [] # bbox is stored elsewhere in parsing_res_list
|
||||
|
||||
print(f" Table {i+1}:")
|
||||
print(f" - Cells: {len(cell_boxes) if cell_boxes is not None else 0}")
|
||||
print(f" - Region ID: {region_id}")
|
||||
print(f" - HTML length: {len(html) if html else 0}")
|
||||
print(f" - OCR texts: {len(ocr_text) if ocr_text else 0}")
|
||||
|
||||
if html:
|
||||
print(f" - HTML preview: {html[:300]}...")
|
||||
|
||||
if ocr_text and len(ocr_text) > 0:
|
||||
print(f" - First few OCR texts: {ocr_text[:3]}")
|
||||
|
||||
tables_found.append({
|
||||
"index": i,
|
||||
"cell_count": len(cell_boxes) if cell_boxes is not None else 0,
|
||||
"region_id": region_id,
|
||||
"html": html[:1000] if html else "",
|
||||
"ocr_count": len(ocr_text) if ocr_text else 0,
|
||||
})
|
||||
|
||||
# Get parsing results for all layout elements
|
||||
parsing_res_list = result.get("parsing_res_list") if hasattr(result, "get") else None
|
||||
|
||||
if parsing_res_list:
|
||||
print(f"\n Layout elements from parsing_res_list:")
|
||||
for elem in parsing_res_list[:10]: # Show first 10
|
||||
label = elem.get("label", "unknown") if isinstance(elem, dict) else getattr(elem, "label", "unknown")
|
||||
bbox = elem.get("bbox", []) if isinstance(elem, dict) else getattr(elem, "bbox", [])
|
||||
print(f" - {label}: {bbox}")
|
||||
all_elements.append({"label": label, "bbox": bbox})
|
||||
|
||||
print(f"\nSummary:")
|
||||
print(f" Tables detected: {len(tables_found)}")
|
||||
print(f" Layout elements: {len(all_elements)}")
|
||||
|
||||
return {"pdf": pdf_path, "tables": tables_found, "elements": all_elements}
|
||||
|
||||
|
||||
def main():
|
||||
# Find test PDFs
|
||||
pdf_dir = Path("/mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/admin_uploads")
|
||||
pdf_files = list(pdf_dir.glob("*.pdf"))[:5] # Test first 5
|
||||
|
||||
if not pdf_files:
|
||||
print("No PDF files found in admin_uploads directory")
|
||||
return
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files")
|
||||
|
||||
all_results = []
|
||||
for pdf_file in pdf_files:
|
||||
result = test_table_detection(str(pdf_file))
|
||||
all_results.append(result)
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
print("FINAL SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
total_tables = sum(len(r["tables"]) for r in all_results)
|
||||
print(f"Total PDFs tested: {len(all_results)}")
|
||||
print(f"Total tables detected: {total_tables}")
|
||||
|
||||
for r in all_results:
|
||||
pdf_name = Path(r["pdf"]).name
|
||||
table_count = len(r["tables"])
|
||||
print(f" {pdf_name}: {table_count} tables")
|
||||
for t in r["tables"]:
|
||||
print(f" - Table {t['index']+1}: {t['cell_count']} cells")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
54
scripts/render_pdfs_v3.py
Normal file
54
scripts/render_pdfs_v3.py
Normal file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Render selected PDFs from v3 batch for visual comparison."""
|
||||
|
||||
import os
|
||||
|
||||
import fitz # PyMuPDF
|
||||
|
||||
PDF_DIR = "/mnt/c/Users/yaoji/git/Billo/Billo.Platform.Document/Billo.Platform.Document.AdminAPI/downloads/to_check"
|
||||
OUTPUT_DIR = "/mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/scripts/pdf_renders_v3"
|
||||
|
||||
# Select 10 PDFs covering different scenarios:
|
||||
SELECTED = [
|
||||
# Potentially wrong Amount (81648164.00 - too high?)
|
||||
"b84c7d70-821d-4a1a-9be7-d7bb2392bd91.pdf",
|
||||
# Only 2 fields extracted
|
||||
"072571e2-da5f-4268-b1a8-f0e5a85a3ec4.pdf",
|
||||
# InvoiceNumber=5085 (suspiciously short, same as BG prefix?)
|
||||
"6a83ba35-afdf-4c13-ade1-25513e213637.pdf",
|
||||
# InvoiceNumber=450 (very short, might be wrong)
|
||||
"8551b540-d93d-459d-b7eb-e9ee086f9f16.pdf",
|
||||
# InvoiceNumber=134 (very short, same as BG prefix)
|
||||
"cb1bd3b1-63d0-4140-930f-e4a7ae2b6cd5.pdf",
|
||||
# Large Amount=172904.52, InvoiceNumber=89902
|
||||
"d121a5ee-7382-41d8-8010-63880def1f96.pdf",
|
||||
# Good 9-field PDF for positive check
|
||||
"6cb90895-e52b-4831-b57b-7cb968bcdd54.pdf",
|
||||
# Amount=2026.00 (same as year - could be confused?)
|
||||
"d376c5b5-0dc5-4ccf-b787-0d481eef8577.pdf",
|
||||
# 8 fields, good coverage
|
||||
"f3f5da6f-7552-4ec6-8625-3629042fbfd0.pdf",
|
||||
# Low confidence Amount=596.49
|
||||
"5783e4af-eef3-411c-84b1-3a8f4694fed8.pdf",
|
||||
]
|
||||
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
for pdf_name in SELECTED:
|
||||
pdf_path = os.path.join(PDF_DIR, pdf_name)
|
||||
if not os.path.exists(pdf_path):
|
||||
print(f"SKIP {pdf_name} - not found")
|
||||
continue
|
||||
|
||||
doc = fitz.open(pdf_path)
|
||||
page = doc[0]
|
||||
mat = fitz.Matrix(150 / 72, 150 / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
out_name = pdf_name.replace(".pdf", ".png")
|
||||
out_path = os.path.join(OUTPUT_DIR, out_name)
|
||||
pix.save(out_path)
|
||||
print(f"Rendered {pdf_name} -> {out_name} ({pix.width}x{pix.height})")
|
||||
doc.close()
|
||||
|
||||
print(f"\nAll renders saved to {OUTPUT_DIR}")
|
||||
@@ -213,8 +213,8 @@ class TestNormalizeOCR:
|
||||
assert ' ' not in result.value # Spaces should be removed
|
||||
|
||||
def test_short_ocr_invalid(self, normalizer):
|
||||
"""Test that too short OCR is invalid."""
|
||||
result = normalizer.normalize("123")
|
||||
"""Test that single-digit OCR is invalid (min 2 digits)."""
|
||||
result = normalizer.normalize("5")
|
||||
assert result.is_valid is False
|
||||
|
||||
|
||||
|
||||
@@ -100,6 +100,22 @@ class TestInvoiceNumberNormalizer:
|
||||
result = normalizer.normalize("Invoice 54321 OCR 12345678901234")
|
||||
assert result.value == "54321"
|
||||
|
||||
def test_year_not_extracted_when_real_number_exists(self, normalizer):
|
||||
"""4-digit year should be skipped when a real invoice number is present."""
|
||||
result = normalizer.normalize("Faktura 12345 Datum 2025")
|
||||
assert result.value == "12345"
|
||||
|
||||
def test_year_2026_not_extracted(self, normalizer):
|
||||
"""Year '2026' should not be preferred over a real invoice number."""
|
||||
result = normalizer.normalize("Invoice 54321 Date 2026")
|
||||
assert result.value == "54321"
|
||||
|
||||
def test_non_year_4_digit_still_matches(self, normalizer):
|
||||
"""4-digit numbers that are NOT years should still match."""
|
||||
result = normalizer.normalize("Invoice 3456")
|
||||
assert result.value == "3456"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_fallback_extraction(self, normalizer):
|
||||
"""Test fallback to digit extraction."""
|
||||
# This matches Pattern 3 (short digit sequence 3-10 digits)
|
||||
@@ -107,6 +123,16 @@ class TestInvoiceNumberNormalizer:
|
||||
assert result.value == "123"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_amount_fragment_not_selected(self, normalizer):
|
||||
"""Amount fragment '775' from '9 775,96' should lose to real invoice number."""
|
||||
result = normalizer.normalize("9 775,96 Belopp Kontoutdragsnr 04862823")
|
||||
assert result.value == "04862823"
|
||||
|
||||
def test_prefer_medium_length_over_shortest(self, normalizer):
|
||||
"""Prefer 4-8 digit sequences over very short 3-digit ones."""
|
||||
result = normalizer.normalize("Ref 999 Fakturanr 12345")
|
||||
assert result.value == "12345"
|
||||
|
||||
def test_no_valid_sequence(self, normalizer):
|
||||
"""Test failure when no valid sequence found."""
|
||||
result = normalizer.normalize("no numbers here")
|
||||
@@ -134,8 +160,21 @@ class TestOcrNumberNormalizer:
|
||||
assert result.value == "310196187399952"
|
||||
assert " " not in result.value
|
||||
|
||||
def test_4_digit_ocr_valid(self, normalizer):
|
||||
"""4-digit OCR numbers like '3046' should be accepted."""
|
||||
result = normalizer.normalize("3046")
|
||||
assert result.is_valid is True
|
||||
assert result.value == "3046"
|
||||
|
||||
def test_2_digit_ocr_valid(self, normalizer):
|
||||
"""2-digit OCR numbers should be accepted."""
|
||||
result = normalizer.normalize("42")
|
||||
assert result.is_valid is True
|
||||
assert result.value == "42"
|
||||
|
||||
def test_too_short(self, normalizer):
|
||||
result = normalizer.normalize("1234")
|
||||
"""Single-digit OCR should be rejected."""
|
||||
result = normalizer.normalize("5")
|
||||
assert result.is_valid is False
|
||||
|
||||
def test_empty_string(self, normalizer):
|
||||
@@ -477,6 +516,38 @@ class TestAmountNormalizer:
|
||||
assert result.value == "100.00"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_astronomical_amount_rejected(self, normalizer):
|
||||
"""IBAN digits should NOT produce astronomical amounts (>10M)."""
|
||||
# IBAN "SE14120000001201138650" contains long digit sequences
|
||||
# The standalone fallback pattern should not extract these as amounts
|
||||
result = normalizer.normalize("SE14120000001201138650")
|
||||
if result.is_valid:
|
||||
assert float(result.value) < 10_000_000
|
||||
|
||||
def test_large_valid_amount_accepted(self, normalizer):
|
||||
"""Valid large amount like 108000,00 should be accepted."""
|
||||
result = normalizer.normalize("108000,00")
|
||||
assert result.value == "108000.00"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_standalone_iban_digits_rejected(self, normalizer):
|
||||
"""Very long digit sequence (IBAN fragment) should not produce >10M."""
|
||||
result = normalizer.normalize("1036149234823114")
|
||||
if result.is_valid:
|
||||
assert float(result.value) < 10_000_000
|
||||
|
||||
def test_main_pattern_rejects_over_10m(self, normalizer):
|
||||
"""Main regex path should reject amounts over 10M (e.g. IBAN-like digits)."""
|
||||
result = normalizer.normalize("Belopp 81648164,00 kr")
|
||||
# 81648164.00 > 10M, should be rejected
|
||||
assert not result.is_valid or float(result.value) < 10_000_000
|
||||
|
||||
def test_main_pattern_accepts_under_10m(self, normalizer):
|
||||
"""Main regex path should accept valid amounts under 10M."""
|
||||
result = normalizer.normalize("Summa 999999,99 kr")
|
||||
assert result.value == "999999.99"
|
||||
assert result.is_valid is True
|
||||
|
||||
|
||||
class TestEnhancedAmountNormalizer:
|
||||
"""Tests for EnhancedAmountNormalizer."""
|
||||
|
||||
@@ -670,5 +670,387 @@ class TestProcessPdfTokenPath:
|
||||
assert call_args[3] == 100 # image height
|
||||
|
||||
|
||||
class TestDpiPassthrough:
|
||||
"""Tests for DPI being passed from pipeline to FieldExtractor (Bug 1)."""
|
||||
|
||||
def test_field_extractor_receives_pipeline_dpi(self):
|
||||
"""FieldExtractor should receive the pipeline's DPI, not default to 300."""
|
||||
with patch('backend.pipeline.pipeline.YOLODetector'):
|
||||
with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls:
|
||||
InferencePipeline(
|
||||
model_path='/fake/model.pt',
|
||||
dpi=150,
|
||||
use_gpu=False,
|
||||
)
|
||||
mock_fe_cls.assert_called_once_with(
|
||||
ocr_lang='en', use_gpu=False, dpi=150
|
||||
)
|
||||
|
||||
def test_field_extractor_receives_default_dpi(self):
|
||||
"""When dpi=300 (default), FieldExtractor should also get 300."""
|
||||
with patch('backend.pipeline.pipeline.YOLODetector'):
|
||||
with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls:
|
||||
InferencePipeline(
|
||||
model_path='/fake/model.pt',
|
||||
dpi=300,
|
||||
use_gpu=False,
|
||||
)
|
||||
mock_fe_cls.assert_called_once_with(
|
||||
ocr_lang='en', use_gpu=False, dpi=300
|
||||
)
|
||||
|
||||
|
||||
class TestFallbackPatternExtraction:
|
||||
"""Tests for _extract_with_patterns fallback regex (Bugs 2, 3)."""
|
||||
|
||||
def _make_pipeline_with_patterns(self):
|
||||
"""Create pipeline with mocked internals for pattern testing."""
|
||||
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
||||
p = InferencePipeline()
|
||||
p.dpi = 150
|
||||
p.enable_fallback = True
|
||||
return p
|
||||
|
||||
def test_bankgiro_no_match_in_org_number(self):
|
||||
"""Bankgiro regex must NOT match digits embedded in an org number."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Org.nr 802546-1610", result)
|
||||
assert 'Bankgiro' not in result.fields
|
||||
|
||||
def test_bankgiro_matches_labeled(self):
|
||||
"""Bankgiro regex should match when preceded by 'Bankgiro' label."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Bankgiro 5393-9484", result)
|
||||
assert result.fields.get('Bankgiro') == '5393-9484'
|
||||
|
||||
def test_bankgiro_matches_standalone(self):
|
||||
"""Bankgiro regex should match a standalone 4-4 digit pattern."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Betala till 5393-9484 senast", result)
|
||||
assert result.fields.get('Bankgiro') == '5393-9484'
|
||||
|
||||
def test_amount_rejects_bare_integer(self):
|
||||
"""Amount regex must NOT match bare integers like 'Summa 1'."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Summa 1 Medlemsavgift", result)
|
||||
assert 'Amount' not in result.fields
|
||||
|
||||
def test_amount_requires_decimal(self):
|
||||
"""Amount regex should require a decimal separator."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Total 5 items", result)
|
||||
assert 'Amount' not in result.fields
|
||||
|
||||
def test_amount_with_decimal_works(self):
|
||||
"""Amount regex should match Swedish decimal amounts."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Att betala 1 234,56 SEK", result)
|
||||
assert 'Amount' in result.fields
|
||||
assert float(result.fields['Amount']) == pytest.approx(1234.56, abs=0.01)
|
||||
|
||||
def test_amount_with_sek_suffix(self):
|
||||
"""Amount regex should match amounts ending with SEK."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("7 500,00 SEK", result)
|
||||
assert 'Amount' in result.fields
|
||||
assert float(result.fields['Amount']) == pytest.approx(7500.00, abs=0.01)
|
||||
|
||||
def test_fallback_extracts_invoice_date(self):
|
||||
"""Fallback should extract InvoiceDate from Swedish text."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Fakturadatum 2025-01-15 Referens ABC", result)
|
||||
assert result.fields.get('InvoiceDate') == '2025-01-15'
|
||||
|
||||
def test_fallback_extracts_due_date(self):
|
||||
"""Fallback should extract InvoiceDueDate from Swedish text."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Forfallodag 2025-02-15 Belopp", result)
|
||||
assert result.fields.get('InvoiceDueDate') == '2025-02-15'
|
||||
|
||||
def test_fallback_extracts_supplier_org(self):
|
||||
"""Fallback should extract supplier_organisation_number."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Org.nr 556123-4567 Stockholm", result)
|
||||
assert result.fields.get('supplier_organisation_number') == '556123-4567'
|
||||
|
||||
def test_fallback_extracts_plusgiro(self):
|
||||
"""Fallback should extract Plusgiro number."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Plusgiro 12 34 56-7 betalning", result)
|
||||
assert 'Plusgiro' in result.fields
|
||||
|
||||
def test_fallback_skips_year_as_invoice_number(self):
|
||||
"""Fallback should NOT extract year-like value as InvoiceNumber."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Fakturanr 2025 Datum 2025-01-15", result)
|
||||
assert 'InvoiceNumber' not in result.fields
|
||||
|
||||
def test_fallback_accepts_valid_invoice_number(self):
|
||||
"""Fallback should extract valid non-year InvoiceNumber."""
|
||||
p = self._make_pipeline_with_patterns()
|
||||
result = InferenceResult()
|
||||
p._extract_with_patterns("Fakturanr 12345 Summa", result)
|
||||
assert result.fields.get('InvoiceNumber') == '12345'
|
||||
|
||||
|
||||
class TestDateValidation:
|
||||
"""Tests for InvoiceDueDate < InvoiceDate validation (Bug 6)."""
|
||||
|
||||
def _make_pipeline_for_merge(self):
|
||||
"""Create pipeline with mocked internals for merge testing."""
|
||||
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
||||
p = InferencePipeline()
|
||||
p.payment_line_parser = MagicMock()
|
||||
p.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
||||
return p
|
||||
|
||||
def test_due_date_before_invoice_date_dropped(self):
|
||||
"""DueDate earlier than InvoiceDate should be removed."""
|
||||
from backend.pipeline.field_extractor import ExtractedField
|
||||
|
||||
p = self._make_pipeline_for_merge()
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
ExtractedField(
|
||||
field_name='InvoiceDate', raw_text='2026-01-16',
|
||||
normalized_value='2026-01-16', confidence=0.9,
|
||||
detection_confidence=0.9, ocr_confidence=1.0,
|
||||
bbox=(0, 0, 100, 50), page_no=0,
|
||||
),
|
||||
ExtractedField(
|
||||
field_name='InvoiceDueDate', raw_text='2025-12-01',
|
||||
normalized_value='2025-12-01', confidence=0.9,
|
||||
detection_confidence=0.9, ocr_confidence=1.0,
|
||||
bbox=(0, 60, 100, 110), page_no=0,
|
||||
),
|
||||
]
|
||||
p._merge_fields(result)
|
||||
assert 'InvoiceDate' in result.fields
|
||||
assert 'InvoiceDueDate' not in result.fields
|
||||
|
||||
def test_valid_dates_preserved(self):
|
||||
"""Both dates kept when DueDate >= InvoiceDate."""
|
||||
from backend.pipeline.field_extractor import ExtractedField
|
||||
|
||||
p = self._make_pipeline_for_merge()
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
ExtractedField(
|
||||
field_name='InvoiceDate', raw_text='2026-01-16',
|
||||
normalized_value='2026-01-16', confidence=0.9,
|
||||
detection_confidence=0.9, ocr_confidence=1.0,
|
||||
bbox=(0, 0, 100, 50), page_no=0,
|
||||
),
|
||||
ExtractedField(
|
||||
field_name='InvoiceDueDate', raw_text='2026-02-15',
|
||||
normalized_value='2026-02-15', confidence=0.9,
|
||||
detection_confidence=0.9, ocr_confidence=1.0,
|
||||
bbox=(0, 60, 100, 110), page_no=0,
|
||||
),
|
||||
]
|
||||
p._merge_fields(result)
|
||||
assert result.fields['InvoiceDate'] == '2026-01-16'
|
||||
assert result.fields['InvoiceDueDate'] == '2026-02-15'
|
||||
|
||||
def test_same_dates_preserved(self):
|
||||
"""Same InvoiceDate and DueDate should both be kept."""
|
||||
from backend.pipeline.field_extractor import ExtractedField
|
||||
|
||||
p = self._make_pipeline_for_merge()
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
ExtractedField(
|
||||
field_name='InvoiceDate', raw_text='2026-01-16',
|
||||
normalized_value='2026-01-16', confidence=0.9,
|
||||
detection_confidence=0.9, ocr_confidence=1.0,
|
||||
bbox=(0, 0, 100, 50), page_no=0,
|
||||
),
|
||||
ExtractedField(
|
||||
field_name='InvoiceDueDate', raw_text='2026-01-16',
|
||||
normalized_value='2026-01-16', confidence=0.9,
|
||||
detection_confidence=0.9, ocr_confidence=1.0,
|
||||
bbox=(0, 60, 100, 110), page_no=0,
|
||||
),
|
||||
]
|
||||
p._merge_fields(result)
|
||||
assert result.fields['InvoiceDate'] == '2026-01-16'
|
||||
assert result.fields['InvoiceDueDate'] == '2026-01-16'
|
||||
|
||||
|
||||
class TestCrossFieldDedup:
|
||||
"""Tests for cross-field deduplication of InvoiceNumber vs OCR/Bankgiro."""
|
||||
|
||||
def _make_pipeline_for_merge(self):
|
||||
"""Create pipeline with mocked internals for merge testing."""
|
||||
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
||||
p = InferencePipeline()
|
||||
p.payment_line_parser = MagicMock()
|
||||
p.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
||||
return p
|
||||
|
||||
def _make_extracted_field(self, field_name, raw_text, normalized, confidence=0.9):
|
||||
from backend.pipeline.field_extractor import ExtractedField
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
raw_text=raw_text,
|
||||
normalized_value=normalized,
|
||||
confidence=confidence,
|
||||
detection_confidence=confidence,
|
||||
ocr_confidence=1.0,
|
||||
bbox=(0, 0, 100, 50),
|
||||
page_no=0,
|
||||
)
|
||||
|
||||
def test_invoice_number_not_same_as_ocr(self):
|
||||
"""When InvoiceNumber == OCR, InvoiceNumber should be dropped."""
|
||||
p = self._make_pipeline_for_merge()
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
self._make_extracted_field('InvoiceNumber', '9179845608', '9179845608'),
|
||||
self._make_extracted_field('OCR', '9179845608', '9179845608'),
|
||||
self._make_extracted_field('Amount', '1234,56', '1234.56'),
|
||||
]
|
||||
p._merge_fields(result)
|
||||
assert 'OCR' in result.fields
|
||||
assert result.fields['OCR'] == '9179845608'
|
||||
assert 'InvoiceNumber' not in result.fields
|
||||
|
||||
def test_invoice_number_not_same_as_bankgiro_digits(self):
|
||||
"""When InvoiceNumber digits == Bankgiro digits, InvoiceNumber should be dropped."""
|
||||
p = self._make_pipeline_for_merge()
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
self._make_extracted_field('InvoiceNumber', '53939484', '53939484'),
|
||||
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
|
||||
self._make_extracted_field('Amount', '500,00', '500.00'),
|
||||
]
|
||||
p._merge_fields(result)
|
||||
assert 'Bankgiro' in result.fields
|
||||
assert result.fields['Bankgiro'] == '5393-9484'
|
||||
assert 'InvoiceNumber' not in result.fields
|
||||
|
||||
def test_unrelated_values_kept(self):
|
||||
"""When InvoiceNumber, OCR, and Bankgiro are all different, keep all."""
|
||||
p = self._make_pipeline_for_merge()
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
self._make_extracted_field('InvoiceNumber', '19061', '19061'),
|
||||
self._make_extracted_field('OCR', '9179845608', '9179845608'),
|
||||
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
|
||||
]
|
||||
p._merge_fields(result)
|
||||
assert result.fields['InvoiceNumber'] == '19061'
|
||||
assert result.fields['OCR'] == '9179845608'
|
||||
assert result.fields['Bankgiro'] == '5393-9484'
|
||||
|
||||
def test_dedup_after_fallback_re_add(self):
|
||||
"""Dedup should remove InvoiceNumber re-added by fallback if it matches OCR."""
|
||||
p = self._make_pipeline_for_merge()
|
||||
result = InferenceResult()
|
||||
# Simulate state after fallback re-adds InvoiceNumber = OCR
|
||||
result.fields = {
|
||||
'OCR': '758200602426',
|
||||
'Amount': '164.00',
|
||||
'InvoiceNumber': '758200602426', # re-added by fallback
|
||||
}
|
||||
result.confidence = {
|
||||
'OCR': 0.9,
|
||||
'Amount': 0.9,
|
||||
'InvoiceNumber': 0.5, # fallback confidence
|
||||
}
|
||||
result.bboxes = {}
|
||||
p._dedup_invoice_number(result)
|
||||
assert 'InvoiceNumber' not in result.fields
|
||||
assert 'OCR' in result.fields
|
||||
|
||||
def test_invoice_number_substring_of_bankgiro(self):
|
||||
"""When InvoiceNumber digits are a substring of Bankgiro digits, drop InvoiceNumber."""
|
||||
p = self._make_pipeline_for_merge()
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
self._make_extracted_field('InvoiceNumber', '4639', '4639'),
|
||||
self._make_extracted_field('Bankgiro', '134-4639', '134-4639'),
|
||||
self._make_extracted_field('Amount', '500,00', '500.00'),
|
||||
]
|
||||
p._merge_fields(result)
|
||||
assert 'Bankgiro' in result.fields
|
||||
assert result.fields['Bankgiro'] == '134-4639'
|
||||
assert 'InvoiceNumber' not in result.fields
|
||||
|
||||
def test_invoice_number_not_substring_of_unrelated_bankgiro(self):
|
||||
"""When InvoiceNumber is NOT a substring of Bankgiro, keep both."""
|
||||
p = self._make_pipeline_for_merge()
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
self._make_extracted_field('InvoiceNumber', '19061', '19061'),
|
||||
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
|
||||
self._make_extracted_field('Amount', '500,00', '500.00'),
|
||||
]
|
||||
p._merge_fields(result)
|
||||
assert result.fields['InvoiceNumber'] == '19061'
|
||||
assert result.fields['Bankgiro'] == '5393-9484'
|
||||
|
||||
|
||||
class TestFallbackTrigger:
|
||||
"""Tests for _needs_fallback trigger threshold."""
|
||||
|
||||
def _make_pipeline(self):
|
||||
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
||||
p = InferencePipeline()
|
||||
return p
|
||||
|
||||
def test_fallback_triggers_when_1_key_field_missing(self):
|
||||
"""Should trigger when only 1 key field (e.g. InvoiceNumber) is missing."""
|
||||
p = self._make_pipeline()
|
||||
result = InferenceResult()
|
||||
result.fields = {
|
||||
'Amount': '1234.56',
|
||||
'OCR': '12345678901',
|
||||
'InvoiceDate': '2025-01-15',
|
||||
'InvoiceDueDate': '2025-02-15',
|
||||
'supplier_organisation_number': '556123-4567',
|
||||
}
|
||||
# InvoiceNumber missing -> should trigger
|
||||
assert p._needs_fallback(result) is True
|
||||
|
||||
def test_fallback_triggers_when_dates_missing(self):
|
||||
"""Should trigger when all key fields present but 2+ important fields missing."""
|
||||
p = self._make_pipeline()
|
||||
result = InferenceResult()
|
||||
result.fields = {
|
||||
'Amount': '1234.56',
|
||||
'InvoiceNumber': '12345',
|
||||
'OCR': '12345678901',
|
||||
}
|
||||
# InvoiceDate, InvoiceDueDate, supplier_org all missing -> should trigger
|
||||
assert p._needs_fallback(result) is True
|
||||
|
||||
def test_no_fallback_when_all_fields_present(self):
|
||||
"""Should NOT trigger when all key and important fields present."""
|
||||
p = self._make_pipeline()
|
||||
result = InferenceResult()
|
||||
result.fields = {
|
||||
'Amount': '1234.56',
|
||||
'InvoiceNumber': '12345',
|
||||
'OCR': '12345678901',
|
||||
'InvoiceDate': '2025-01-15',
|
||||
'InvoiceDueDate': '2025-02-15',
|
||||
'supplier_organisation_number': '556123-4567',
|
||||
}
|
||||
assert p._needs_fallback(result) is False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
|
||||
@@ -335,12 +335,15 @@ class TestFallbackLogic:
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# All key fields present
|
||||
# All key and important fields present
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"Amount": "1500.00",
|
||||
"InvoiceNumber": "INV-001",
|
||||
"OCR": "12345678901234",
|
||||
"InvoiceDate": "2025-01-15",
|
||||
"InvoiceDueDate": "2025-02-15",
|
||||
"supplier_organisation_number": "556123-4567",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -203,15 +203,33 @@ class TestValueSelectorOcrField:
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "94228110015950070"
|
||||
|
||||
def test_ignores_short_digit_tokens(self):
|
||||
"""Tokens with fewer than 5 digits are not OCR references."""
|
||||
tokens = _tokens("OCR", "123")
|
||||
def test_ignores_single_digit_tokens(self):
|
||||
"""Tokens with fewer than 2 digits are not OCR references."""
|
||||
tokens = _tokens("OCR", "5")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "OCR")
|
||||
|
||||
# Fallback: return all tokens since no valid OCR found
|
||||
assert len(result) == 2
|
||||
|
||||
def test_ocr_4_digit_token_selected(self):
|
||||
"""4-digit OCR token should be selected."""
|
||||
tokens = _tokens("OCR", "3046")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "OCR")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "3046"
|
||||
|
||||
def test_ocr_2_digit_token_selected(self):
|
||||
"""2-digit OCR token should be selected."""
|
||||
tokens = _tokens("OCR", "42")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "OCR")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "42"
|
||||
|
||||
|
||||
class TestValueSelectorInvoiceNumberField:
|
||||
"""Tests for InvoiceNumber field value selection."""
|
||||
|
||||
Reference in New Issue
Block a user