Files
invoice-master-poc-v2/docs/REFACTORING_PLAN.md
2026-01-25 15:21:11 +01:00

1448 lines
38 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 重构计划文档 (Refactoring Plan)
**项目**: Invoice Field Extraction System
**生成日期**: 2026-01-22
**基于**: CODE_REVIEW_REPORT.md
**目标**: 提升代码可维护性、可测试性和安全性
---
## 📋 目录
1. [重构目标](#重构目标)
2. [总体策略](#总体策略)
3. [三阶段执行计划](#三阶段执行计划)
4. [详细重构步骤](#详细重构步骤)
5. [测试策略](#测试策略)
6. [风险管理](#风险管理)
7. [成功指标](#成功指标)
---
## 🎯 重构目标
### 主要目标
1. **安全性**: 消除明文密码、SQL注入等安全隐患
2. **可维护性**: 减少代码重复,降低函数复杂度
3. **可测试性**: 提升测试覆盖率至70%+,增加集成测试
4. **可读性**: 统一代码风格,添加必要文档
5. **性能**: 优化批处理和并发处理
### 量化指标
- 测试覆盖率: 45% → 70%+
- 平均函数长度: 80行 → 50行以下
- 代码重复率: 15% → 5%以下
- 循环复杂度: 最高15+ → 最高10
- 关键函数文档覆盖: 30% → 80%+
---
## 📐 总体策略
### 原则
1. **增量重构**: 小步快跑,每次重构保持系统可运行
2. **测试先行**: 重构前先补充测试,确保行为不变
3. **向后兼容**: API接口保持兼容逐步废弃旧接口
4. **文档同步**: 代码变更同步更新文档
### 工作流程
```
1. 为待重构模块补充测试 (确保现有行为被覆盖)
2. 执行重构 (Extract Method, Extract Class, etc.)
3. 运行全量测试 (确保行为不变)
4. 更新文档
5. Code Review
6. 合并主分支
```
---
## 🗓️ 三阶段执行计划
### Phase 1: 紧急修复 (1周)
**目标**: 修复安全漏洞和关键bug
| 任务 | 优先级 | 预计时间 | 负责模块 |
|------|--------|----------|----------|
| 修复明文密码问题 | P0 | 1小时 | `src/db/config.py` |
| 配置环境变量管理 | P0 | 2小时 | 根目录 `.env` |
| 修复SQL注入风险 | P0 | 3小时 | `src/db/operations.py` |
| 添加输入验证 | P1 | 4小时 | `src/web/routes.py` |
| 异常处理规范化 | P1 | 1天 | 全局 |
### Phase 2: 核心重构 (2-3周)
**目标**: 降低代码复杂度,消除重复
| 任务 | 优先级 | 预计时间 | 负责模块 |
|------|--------|----------|----------|
| 拆分 `_normalize_customer_number` | P0 | 1天 | `field_extractor.py` |
| 统一 payment_line 解析 | P0 | 2天 | 抽取到单独模块 |
| 重构 `process_document` | P1 | 2天 | `pipeline.py` |
| Extract Method: 长函数拆分 | P1 | 3天 | 全局 |
| 添加集成测试 | P0 | 3天 | `tests/integration/` |
| 提升单元测试覆盖率 | P1 | 2天 | 各模块 |
### Phase 3: 优化改进 (1-2周)
**目标**: 性能优化、文档完善
| 任务 | 优先级 | 预计时间 | 负责模块 |
|------|--------|----------|----------|
| 批处理并发优化 | P1 | 2天 | `batch_processor.py` |
| API文档完善 | P2 | 1天 | `docs/API.md` |
| 配置提取到常量 | P2 | 1天 | `src/config/constants.py` |
| 日志系统优化 | P2 | 1天 | `src/utils/logging.py` |
| 性能分析和优化 | P2 | 2天 | 全局 |
---
## 🔧 详细重构步骤
### Step 1: 修复明文密码 (P0, 1小时)
**当前问题**:
```python
# src/db/config.py:29
DATABASE_CONFIG = {
"host": "localhost",
"port": 3306,
"user": "root",
"password": "your_password", # ❌ 明文密码
"database": "invoice_extraction",
}
```
**重构步骤**:
1. 创建 `.env.example` 模板:
```bash
# Database Configuration
DB_HOST=localhost
DB_PORT=3306
DB_USER=root
DB_PASSWORD=your_password_here
DB_NAME=invoice_extraction
```
2. 创建 `.env` 文件 (加入 `.gitignore`):
```bash
DB_PASSWORD=actual_secure_password
```
3. 修改 `src/db/config.py`:
```python
import os
from dotenv import load_dotenv
load_dotenv()
DATABASE_CONFIG = {
"host": os.getenv("DB_HOST", "localhost"),
"port": int(os.getenv("DB_PORT", "3306")),
"user": os.getenv("DB_USER", "root"),
"password": os.getenv("DB_PASSWORD"), # ✅ 从环境变量读取
"database": os.getenv("DB_NAME", "invoice_extraction"),
}
# 启动时验证
if not DATABASE_CONFIG["password"]:
raise ValueError("DB_PASSWORD environment variable not set")
```
4. 安装依赖:
```bash
pip install python-dotenv
```
5. 更新 `requirements.txt`:
```
python-dotenv>=1.0.0
```
**测试**:
- 验证环境变量读取正常
- 确认缺少环境变量时抛出异常
- 测试数据库连接
---
### Step 2: 修复SQL注入 (P0, 3小时)
**当前问题**:
```python
# src/db/operations.py:156
query = f"SELECT * FROM documents WHERE id = {doc_id}" # ❌ SQL注入风险
cursor.execute(query)
```
**重构步骤**:
1. 审查所有SQL查询识别字符串拼接:
```bash
grep -n "f\".*SELECT" src/db/operations.py
grep -n "f\".*INSERT" src/db/operations.py
grep -n "f\".*UPDATE" src/db/operations.py
grep -n "f\".*DELETE" src/db/operations.py
```
2. 替换为参数化查询:
```python
# Before
query = f"SELECT * FROM documents WHERE id = {doc_id}"
cursor.execute(query)
# After ✅
query = "SELECT * FROM documents WHERE id = %s"
cursor.execute(query, (doc_id,))
```
3. 常见场景替换:
```python
# INSERT
query = "INSERT INTO documents (filename, status) VALUES (%s, %s)"
cursor.execute(query, (filename, status))
# UPDATE
query = "UPDATE documents SET status = %s WHERE id = %s"
cursor.execute(query, (new_status, doc_id))
# IN clause
placeholders = ','.join(['%s'] * len(ids))
query = f"SELECT * FROM documents WHERE id IN ({placeholders})"
cursor.execute(query, ids)
```
4. 创建查询构建器辅助函数:
```python
# src/db/query_builder.py
def build_select(table: str, columns: list[str] = None, where: dict = None):
"""Build safe SELECT query with parameters."""
cols = ', '.join(columns) if columns else '*'
query = f"SELECT {cols} FROM {table}"
params = []
if where:
conditions = []
for key, value in where.items():
conditions.append(f"{key} = %s")
params.append(value)
query += " WHERE " + " AND ".join(conditions)
return query, tuple(params)
```
**测试**:
- 单元测试所有修改的查询函数
- SQL注入测试: 传入 `"1 OR 1=1"` 等恶意输入
- 集成测试验证功能正常
---
### Step 3: 统一 payment_line 解析 (P0, 2天)
**当前问题**: payment_line 解析逻辑在3个地方重复实现
- `src/inference/field_extractor.py:632-705` (normalization)
- `src/inference/pipeline.py:217-252` (parsing for cross-validation)
- `src/inference/test_field_extractor.py:269-344` (test cases)
**重构步骤**:
1. 创建独立模块 `src/inference/payment_line_parser.py`:
```python
"""
Swedish Payment Line Parser
Handles parsing and validation of Swedish machine-readable payment lines.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
"""
import re
from dataclasses import dataclass
from typing import Optional
@dataclass
class PaymentLineData:
"""Parsed payment line data."""
ocr_number: str
amount: str # Format: "KRONOR.ÖRE"
account_number: str # Bankgiro or Plusgiro
record_type: str # Usually "5" or "9"
check_digits: str
raw_text: str
is_valid: bool
error: Optional[str] = None
class PaymentLineParser:
"""Parser for Swedish payment lines with OCR error handling."""
# Pattern with OCR error tolerance
FULL_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Pattern without amount (fallback)
PARTIAL_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#.*?(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
def __init__(self):
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> PaymentLineData:
"""
Parse payment line text.
Handles common OCR errors:
- Spaces in numbers: "12 0 0" → "1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #" → "#41#"
Args:
text: Raw payment line text
Returns:
PaymentLineData with parsed fields
"""
text = text.strip()
# Try full pattern with amount
match = self.FULL_PATTERN.search(text)
if match:
return self._parse_full_match(match, text)
# Try partial pattern without amount
match = self.PARTIAL_PATTERN.search(text)
if match:
return self._parse_partial_match(match, text)
# No match
return PaymentLineData(
ocr_number="",
amount="",
account_number="",
record_type="",
check_digits="",
raw_text=text,
is_valid=False,
error="Invalid payment line format"
)
def _parse_full_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse full pattern match (with amount)."""
ocr = self._clean_digits(match.group(1))
kronor = self._clean_digits(match.group(2))
ore = match.group(3)
record_type = match.group(4)
account = self._clean_digits(match.group(5))
check_digits = match.group(6)
amount = f"{kronor}.{ore}"
return PaymentLineData(
ocr_number=ocr,
amount=amount,
account_number=account,
record_type=record_type,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True
)
def _parse_partial_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse partial pattern match (without amount)."""
ocr = self._clean_digits(match.group(1))
record_type = match.group(2)
account = self._clean_digits(match.group(3))
check_digits = match.group(4)
return PaymentLineData(
ocr_number=ocr,
amount="", # No amount in partial format
account_number=account,
record_type=record_type,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True
)
def _clean_digits(self, text: str) -> str:
"""Remove spaces from digit string."""
return text.replace(' ', '')
def format_machine_readable(self, data: PaymentLineData) -> str:
"""
Format parsed data back to machine-readable format.
Returns:
Formatted string: "# OCR # KRONOR ÖRE TYPE > ACCOUNT#CHECK#"
"""
if not data.is_valid:
return data.raw_text
if data.amount:
kronor, ore = data.amount.split('.')
return (
f"# {data.ocr_number} # {kronor} {ore} {data.record_type} > "
f"{data.account_number}#{data.check_digits}#"
)
else:
return (
f"# {data.ocr_number} # ... {data.record_type} > "
f"{data.account_number}#{data.check_digits}#"
)
```
2. 重构 `field_extractor.py` 使用新parser:
```python
# src/inference/field_extractor.py
from .payment_line_parser import PaymentLineParser
class FieldExtractor:
def __init__(self):
self.payment_parser = PaymentLineParser()
# ...
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize payment line using dedicated parser."""
data = self.payment_parser.parse(text)
if not data.is_valid:
return None, False, data.error
formatted = self.payment_parser.format_machine_readable(data)
return formatted, True, None
```
3. 重构 `pipeline.py` 使用新parser:
```python
# src/inference/pipeline.py
from .payment_line_parser import PaymentLineParser
class InferencePipeline:
def __init__(self):
self.payment_parser = PaymentLineParser()
# ...
def _parse_machine_readable_payment_line(
self, payment_line: str
) -> tuple[str | None, str | None, str | None]:
"""Parse payment line for cross-validation."""
data = self.payment_parser.parse(payment_line)
if not data.is_valid:
return None, None, None
return data.ocr_number, data.amount, data.account_number
```
4. 更新测试使用新parser:
```python
# tests/unit/test_payment_line_parser.py
from src.inference.payment_line_parser import PaymentLineParser
class TestPaymentLineParser:
def test_full_format_with_spaces(self):
"""Test parsing with OCR-induced spaces."""
parser = PaymentLineParser()
text = "# 6026726908 # 736 00 9 > 5692041 #41 #"
data = parser.parse(text)
assert data.is_valid
assert data.ocr_number == "6026726908"
assert data.amount == "736.00"
assert data.account_number == "5692041"
assert data.check_digits == "41"
def test_format_without_amount(self):
"""Test parsing without amount."""
parser = PaymentLineParser()
text = "# 11000770600242 # ... 5 > 3082963#41#"
data = parser.parse(text)
assert data.is_valid
assert data.ocr_number == "11000770600242"
assert data.amount == ""
assert data.account_number == "3082963"
def test_machine_readable_format(self):
"""Test formatting back to machine-readable."""
parser = PaymentLineParser()
text = "# 6026726908 # 736 00 9 > 5692041 #41 #"
data = parser.parse(text)
formatted = parser.format_machine_readable(data)
assert "# 6026726908 #" in formatted
assert "736 00" in formatted
assert "5692041#41#" in formatted
```
**迁移步骤**:
1. 创建 `payment_line_parser.py` 并添加测试
2. 运行测试确保新实现正确
3. 逐个文件迁移到新parser
4. 每次迁移后运行全量测试
5. 删除旧实现代码
6. 更新文档
**测试**:
- 单元测试覆盖所有解析场景
- 集成测试验证端到端功能
- 回归测试确保行为不变
---
### Step 4: 拆分 `_normalize_customer_number` (P0, 1天)
**当前问题**:
- 函数长度: 127行
- 循环复杂度: 15+
- 职责过多: 模式匹配、格式化、验证混在一起
**重构策略**: Extract Method + Strategy Pattern
**重构步骤**:
1. 创建 `src/inference/customer_number_parser.py`:
```python
"""
Customer Number Parser
Handles extraction and normalization of Swedish customer numbers.
"""
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
@dataclass
class CustomerNumberMatch:
"""Customer number match result."""
value: str
pattern_name: str
confidence: float
raw_text: str
class CustomerNumberPattern(ABC):
"""Abstract base for customer number patterns."""
@abstractmethod
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Try to match pattern in text."""
pass
@abstractmethod
def format(self, match: re.Match) -> str:
"""Format matched groups to standard format."""
pass
class DashFormatPattern(CustomerNumberPattern):
"""Pattern: ABC 123-X"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
match = self.PATTERN.search(text)
if not match:
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="DashFormat",
confidence=0.95,
raw_text=match.group(0)
)
def format(self, match: re.Match) -> str:
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
class NoDashFormatPattern(CustomerNumberPattern):
"""Pattern: ABC 123X (no dash)"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
match = self.PATTERN.search(text)
if not match:
return None
# Exclude postal codes
full_text = match.group(0)
if self._is_postal_code(full_text):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="NoDashFormat",
confidence=0.90,
raw_text=full_text
)
def format(self, match: re.Match) -> str:
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
# SE 106 43, SE 10643, etc.
return bool(re.match(r'^SE\s*\d{3}\s*\d{2}', text, re.IGNORECASE))
class CustomerNumberParser:
"""Parser for Swedish customer numbers."""
def __init__(self):
# Patterns ordered by specificity (most specific first)
self.patterns: list[CustomerNumberPattern] = [
DashFormatPattern(),
NoDashFormatPattern(),
# Add more patterns as needed
]
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> tuple[Optional[str], bool, Optional[str]]:
"""
Parse customer number from text.
Returns:
(customer_number, is_valid, error)
"""
text = text.strip()
# Try each pattern
matches: list[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
matches.append(match)
# No matches
if not matches:
return None, False, "No customer number found"
# Return highest confidence match
best_match = max(matches, key=lambda m: m.confidence)
return best_match.value, True, None
def parse_all(self, text: str) -> list[CustomerNumberMatch]:
"""
Find all customer numbers in text.
Useful for cases with multiple potential matches.
"""
matches: list[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
matches.append(match)
return sorted(matches, key=lambda m: m.confidence, reverse=True)
```
2. 重构 `field_extractor.py`:
```python
# src/inference/field_extractor.py
from .customer_number_parser import CustomerNumberParser
class FieldExtractor:
def __init__(self):
self.customer_parser = CustomerNumberParser()
# ...
def _normalize_customer_number(
self, text: str
) -> tuple[str | None, bool, str | None]:
"""Normalize customer number using dedicated parser."""
return self.customer_parser.parse(text)
```
3. 添加测试:
```python
# tests/unit/test_customer_number_parser.py
from src.inference.customer_number_parser import (
CustomerNumberParser,
DashFormatPattern,
NoDashFormatPattern,
)
class TestDashFormatPattern:
def test_standard_format(self):
pattern = DashFormatPattern()
match = pattern.match("Customer: JTY 576-3")
assert match is not None
assert match.value == "JTY 576-3"
assert match.confidence == 0.95
class TestNoDashFormatPattern:
def test_no_dash_format(self):
pattern = NoDashFormatPattern()
match = pattern.match("Dwq 211X")
assert match is not None
assert match.value == "DWQ 211-X"
assert match.confidence == 0.90
def test_exclude_postal_code(self):
pattern = NoDashFormatPattern()
match = pattern.match("SE 106 43")
assert match is None # Should be filtered out
class TestCustomerNumberParser:
def test_parse_with_dash(self):
parser = CustomerNumberParser()
result, is_valid, error = parser.parse("Customer: JTY 576-3")
assert is_valid
assert result == "JTY 576-3"
assert error is None
def test_parse_without_dash(self):
parser = CustomerNumberParser()
result, is_valid, error = parser.parse("Dwq 211X Billo")
assert is_valid
assert result == "DWQ 211-X"
def test_parse_all_finds_multiple(self):
parser = CustomerNumberParser()
text = "JTY 576-3 and DWQ 211X"
matches = parser.parse_all(text)
assert len(matches) >= 1 # At least one match
assert matches[0].confidence >= 0.90
```
**迁移计划**:
1. Day 1 上午: 创建新parser和测试
2. Day 1 下午: 迁移 `field_extractor.py`,运行测试
3. 回归测试确保所有文档处理正常
---
### Step 5: 重构 `process_document` (P1, 2天)
**当前问题**: `pipeline.py:100-250` (150行) 职责过多
**重构策略**: Extract Method + 责任分离
**目标结构**:
```python
def process_document(self, image_path: Path, document_id: str) -> DocumentResult:
"""Main orchestration - keep under 30 lines."""
# 1. Run detection
detections = self._run_yolo_detection(image_path)
# 2. Extract fields
fields = self._extract_fields_from_detections(detections, image_path)
# 3. Apply cross-validation
fields = self._apply_cross_validation(fields)
# 4. Multi-source fusion
fields = self._apply_multi_source_fusion(fields)
# 5. Build result
return self._build_document_result(document_id, fields, detections)
```
详细步骤见 `docs/CODE_REVIEW_REPORT.md` Section 5.3.
---
### Step 6: 添加集成测试 (P0, 3天)
**当前状况**: 缺少端到端集成测试
**目标**: 创建完整的集成测试套件
**测试场景**:
1. PDF → 推理 → 结果验证 (端到端)
2. 批处理多文档
3. API端点测试
4. 数据库集成测试
5. 错误场景测试
**实施步骤**:
1. 创建测试数据集:
```
tests/
├── fixtures/
│ ├── sample_invoices/
│ │ ├── billo_363.pdf
│ │ ├── billo_308.pdf
│ │ └── billo_310.pdf
│ └── expected_results/
│ ├── billo_363.json
│ ├── billo_308.json
│ └── billo_310.json
```
2. 创建 `tests/integration/test_end_to_end.py`:
```python
import pytest
from pathlib import Path
from src.inference.pipeline import InferencePipeline
from src.inference.field_extractor import FieldExtractor
@pytest.fixture
def pipeline():
"""Create inference pipeline."""
extractor = FieldExtractor()
return InferencePipeline(
model_path="runs/train/invoice_fields/weights/best.pt",
confidence_threshold=0.5,
dpi=150,
field_extractor=extractor
)
@pytest.fixture
def sample_invoices():
"""Load sample invoices and expected results."""
fixtures_dir = Path(__file__).parent.parent / "fixtures"
samples = []
for pdf_path in (fixtures_dir / "sample_invoices").glob("*.pdf"):
json_path = fixtures_dir / "expected_results" / f"{pdf_path.stem}.json"
with open(json_path) as f:
expected = json.load(f)
samples.append({
"pdf_path": pdf_path,
"expected": expected
})
return samples
class TestEndToEnd:
"""End-to-end integration tests."""
def test_single_document_processing(self, pipeline, sample_invoices):
"""Test processing a single invoice from PDF to extracted fields."""
sample = sample_invoices[0]
# Process PDF
result = pipeline.process_pdf(
sample["pdf_path"],
document_id="test_001"
)
# Verify success
assert result.success
# Verify extracted fields match expected
expected = sample["expected"]
assert result.fields["amount"] == expected["amount"]
assert result.fields["ocr_number"] == expected["ocr_number"]
assert result.fields["customer_number"] == expected["customer_number"]
def test_batch_processing(self, pipeline, sample_invoices):
"""Test batch processing multiple invoices."""
pdf_paths = [s["pdf_path"] for s in sample_invoices]
# Process batch
results = pipeline.process_batch(pdf_paths)
# Verify all processed
assert len(results) == len(pdf_paths)
# Verify success rate
success_count = sum(1 for r in results if r.success)
assert success_count >= len(pdf_paths) * 0.9 # At least 90% success
def test_cross_validation_overrides(self, pipeline):
"""Test that payment_line values override detected values."""
# Use sample with known discrepancy (Billo310)
pdf_path = Path("tests/fixtures/sample_invoices/billo_310.pdf")
result = pipeline.process_pdf(pdf_path, document_id="test_cross_val")
# Verify payment_line was parsed
assert "payment_line" in result.fields
# Verify Amount was corrected from payment_line
# (Billo310: detected 20736.00, payment_line has 736.00)
assert result.fields["amount"] == "736.00"
def test_error_handling_invalid_pdf(self, pipeline):
"""Test graceful error handling for invalid PDF."""
invalid_pdf = Path("tests/fixtures/invalid.pdf")
result = pipeline.process_pdf(invalid_pdf, document_id="test_error")
# Should return result with success=False
assert not result.success
assert result.errors
assert len(result.errors) > 0
class TestAPIIntegration:
"""API endpoint integration tests."""
@pytest.fixture
def client(self):
"""Create test client."""
from fastapi.testclient import TestClient
from src.web.app import create_app
from src.web.config import AppConfig
config = AppConfig.from_defaults()
app = create_app(config)
return TestClient(app)
def test_health_endpoint(self, client):
"""Test /api/v1/health endpoint."""
response = client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "model_loaded" in data
def test_infer_endpoint_with_pdf(self, client, sample_invoices):
"""Test /api/v1/infer with PDF upload."""
sample = sample_invoices[0]
with open(sample["pdf_path"], "rb") as f:
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", f, "application/pdf")}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "result" in data
assert "fields" in data["result"]
def test_infer_endpoint_invalid_file(self, client):
"""Test /api/v1/infer rejects invalid file."""
response = client.post(
"/api/v1/infer",
files={"file": ("test.txt", b"invalid", "text/plain")}
)
assert response.status_code == 400
assert "Unsupported file type" in response.json()["detail"]
class TestDatabaseIntegration:
"""Database integration tests."""
@pytest.fixture
def db_connection(self):
"""Create test database connection."""
from src.db.connection import DatabaseConnection
# Use test database
conn = DatabaseConnection(database="invoice_extraction_test")
yield conn
conn.close()
def test_save_and_retrieve_result(self, db_connection, pipeline, sample_invoices):
"""Test saving inference result to database and retrieving it."""
sample = sample_invoices[0]
# Process document
result = pipeline.process_pdf(sample["pdf_path"], document_id="test_db_001")
# Save to database
db_connection.save_inference_result(result)
# Retrieve from database
retrieved = db_connection.get_inference_result("test_db_001")
# Verify
assert retrieved is not None
assert retrieved["document_id"] == "test_db_001"
assert retrieved["fields"]["amount"] == result.fields["amount"]
```
3. 配置 pytest 运行集成测试:
```ini
# pytest.ini
[pytest]
markers =
unit: Unit tests (fast, no external dependencies)
integration: Integration tests (slower, may use database/files)
slow: Slow tests
# Run unit tests by default
addopts = -v -m "not integration"
# Run all tests including integration
# pytest -m ""
# Run only integration tests
# pytest -m integration
```
4. CI/CD集成:
```yaml
# .github/workflows/test.yml
name: Tests
on: [push, pull_request]
jobs:
unit-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest pytest-cov
- name: Run unit tests
run: pytest -m "not integration" --cov=src --cov-report=xml
- name: Upload coverage
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
integration-tests:
runs-on: ubuntu-latest
services:
mysql:
image: mysql:8.0
env:
MYSQL_ROOT_PASSWORD: test_password
MYSQL_DATABASE: invoice_extraction_test
ports:
- 3306:3306
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest
- name: Run integration tests
env:
DB_HOST: localhost
DB_PORT: 3306
DB_USER: root
DB_PASSWORD: test_password
DB_NAME: invoice_extraction_test
run: pytest -m integration
```
**时间分配**:
- Day 1: 准备测试数据、创建测试框架
- Day 2: 编写端到端和API测试
- Day 3: 数据库集成测试、CI/CD配置
---
### Step 7: 异常处理规范化 (P1, 1天)
**当前问题**: 31处 `except Exception` 捕获过于宽泛
**目标**: 创建异常层次结构,精确捕获
**实施步骤**:
1. 创建 `src/exceptions.py`:
```python
"""
Application-specific exceptions.
"""
class InvoiceExtractionError(Exception):
"""Base exception for invoice extraction errors."""
pass
class PDFProcessingError(InvoiceExtractionError):
"""Error during PDF processing."""
pass
class OCRError(InvoiceExtractionError):
"""Error during OCR."""
pass
class ModelInferenceError(InvoiceExtractionError):
"""Error during model inference."""
pass
class FieldValidationError(InvoiceExtractionError):
"""Error during field validation."""
pass
class DatabaseError(InvoiceExtractionError):
"""Error during database operations."""
pass
class ConfigurationError(InvoiceExtractionError):
"""Error in configuration."""
pass
```
2. 替换宽泛的异常捕获:
```python
# Before ❌
try:
result = process_pdf(path)
except Exception as e:
logger.error(f"Error: {e}")
return None
# After ✅
try:
result = process_pdf(path)
except PDFProcessingError as e:
logger.error(f"PDF processing failed: {e}")
return None
except OCRError as e:
logger.warning(f"OCR failed, trying fallback: {e}")
result = fallback_ocr(path)
except ModelInferenceError as e:
logger.error(f"Model inference failed: {e}")
raise # Re-raise for upper layer
```
3. 在各模块中抛出具体异常:
```python
# src/inference/pdf_processor.py
from src.exceptions import PDFProcessingError
def convert_pdf_to_image(pdf_path: Path, dpi: int) -> list[np.ndarray]:
try:
images = pdf2image.convert_from_path(pdf_path, dpi=dpi)
except Exception as e:
raise PDFProcessingError(f"Failed to convert PDF: {e}") from e
if not images:
raise PDFProcessingError("PDF conversion returned no images")
return images
```
4. 创建异常处理装饰器:
```python
# src/utils/error_handling.py
import functools
from typing import Callable, Type
from src.exceptions import InvoiceExtractionError
def handle_errors(
*exception_types: Type[Exception],
default_return=None,
log_error: bool = True
):
"""Decorator for standardized error handling."""
def decorator(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except exception_types as e:
if log_error:
logger = logging.getLogger(func.__module__)
logger.error(
f"Error in {func.__name__}: {e}",
exc_info=True
)
return default_return
return wrapper
return decorator
# Usage
@handle_errors(PDFProcessingError, OCRError, default_return=None)
def safe_process_document(doc_path: Path):
return process_document(doc_path)
```
---
### Step 8-12: 其他重构任务
详细步骤参见 `CODE_REVIEW_REPORT.md` Section 6 (Action Plan)。
---
## 🧪 测试策略
### 测试金字塔
```
/\
/ \ E2E Tests (10%)
/----\ - Full pipeline tests
/ \ - API integration tests
/--------\
/ \ Integration Tests (30%)
/------------\ - Module integration
/ \ - Database tests
----------------
Unit Tests (60%)
- Function-level tests
- High coverage
```
### 测试覆盖率目标
| 模块 | 当前覆盖率 | 目标覆盖率 |
|------|-----------|-----------|
| `field_extractor.py` | 40% | 80% |
| `pipeline.py` | 50% | 75% |
| `payment_line_parser.py` | 0% (新) | 90% |
| `customer_number_parser.py` | 0% (新) | 90% |
| `web/routes.py` | 30% | 70% |
| `db/operations.py` | 20% | 60% |
| **Overall** | **45%** | **70%+** |
### 回归测试
每次重构后必须运行:
```bash
# 1. 单元测试
pytest tests/unit/ -v
# 2. 集成测试
pytest tests/integration/ -v
# 3. 端到端测试使用实际PDF
pytest tests/e2e/ -v
# 4. 性能测试(确保没有退化)
pytest tests/performance/ -v --benchmark
# 5. 测试覆盖率检查
pytest --cov=src --cov-report=html --cov-fail-under=70
```
---
## ⚠️ 风险管理
### 识别的风险
| 风险 | 影响 | 概率 | 缓解措施 |
|------|------|------|---------|
| 重构破坏现有功能 | 高 | 中 | 1. 重构前补充测试<br>2. 小步迭代<br>3. 回归测试 |
| 性能退化 | 中 | 低 | 1. 性能基准测试<br>2. 持续监控<br>3. Profile优化 |
| API接口变更影响客户端 | 高 | 低 | 1. 语义化版本控制<br>2. 废弃通知期<br>3. 向后兼容 |
| 数据库迁移失败 | 高 | 低 | 1. 备份数据<br>2. 分阶段迁移<br>3. 回滚计划 |
| 时间超期 | 中 | 中 | 1. 优先级排序<br>2. 每周进度审查<br>3. 必要时调整范围 |
### 回滚计划
每个重构步骤都应有明确的回滚策略:
1. **代码回滚**: 使用Git分支隔离变更
```bash
# 每个重构任务创建特性分支
git checkout -b refactor/payment-line-parser
# 如需回滚
git checkout main
git branch -D refactor/payment-line-parser
```
2. **数据库回滚**: 使用数据库迁移工具
```bash
# 应用迁移
alembic upgrade head
# 回滚迁移
alembic downgrade -1
```
3. **配置回滚**: 保留旧配置兼容性
```python
# 支持新旧两种配置格式
password = config.get("db_password") or config.get("password")
```
---
## 📊 成功指标
### 量化指标
| 指标 | 当前值 | 目标值 | 测量方法 |
|------|--------|--------|---------|
| 测试覆盖率 | 45% | 70%+ | `pytest --cov` |
| 平均函数长度 | 80行 | <50行 | `radon cc` |
| 循环复杂度 | 最高15+ | <10 | `radon cc` |
| 代码重复率 | ~15% | <5% | `pylint --duplicate` |
| 安全问题 | 2个 (明文密码, SQL注入) | 0个 | 手动审查 + `bandit` |
| 文档覆盖率 | 30% | 80%+ | 手动审查 |
| 平均处理时间 | ~2秒/文档 | <2秒/文档 | 性能测试 |
### 质量门禁
所有变更必须满足:
- ✅ 测试覆盖率 ≥ 70%
- ✅ 所有测试通过 (单元 + 集成 + E2E)
- ✅ 无高危安全问题
- ✅ 代码审查通过
- ✅ 性能无退化 (±5%以内)
- ✅ 文档已更新
---
## 📅 时间表
### Phase 1: 紧急修复 (Week 1)
| 日期 | 任务 | 负责人 | 状态 |
|------|------|--------|------|
| Day 1 | 修复明文密码 + 环境变量配置 | | ⏳ |
| Day 2-3 | 修复SQL注入 + 添加参数化查询 | | ⏳ |
| Day 4-5 | 异常处理规范化 | | ⏳ |
### Phase 2: 核心重构 (Week 2-4)
| 周 | 任务 | 状态 |
|----|------|------|
| Week 2 | 统一payment_line解析 + 拆分customer_number | ⏳ |
| Week 3 | 重构pipeline + Extract Method | ⏳ |
| Week 4 | 添加集成测试 + 提升单元测试覆盖率 | ⏳ |
### Phase 3: 优化改进 (Week 5-6)
| 周 | 任务 | 状态 |
|----|------|------|
| Week 5 | 批处理优化 + 配置提取 | ⏳ |
| Week 6 | 文档完善 + 日志优化 + 性能调优 | ⏳ |
---
## 🔄 持续改进
### Code Review Checklist
每次提交前检查:
- [ ] 所有测试通过
- [ ] 测试覆盖率达标
- [ ] 无新增安全问题
- [ ] 代码符合风格指南
- [ ] 函数长度 < 50行
- [ ] 循环复杂度 < 10
- [ ] 文档已更新
- [ ] 变更日志已记录
### 自动化工具
配置pre-commit hooks:
```yaml
# .pre-commit-config.yaml
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.11
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args: [--max-line-length=88, --extend-ignore=E203]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5
hooks:
- id: bandit
args: [-c, pyproject.toml]
- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: pytest
language: system
pass_filenames: false
always_run: true
args: [-m, "not integration", --tb=short]
```
---
## 📚 参考资料
### 重构书籍
- *Refactoring: Improving the Design of Existing Code* - Martin Fowler
- *Clean Code* - Robert C. Martin
- *Working Effectively with Legacy Code* - Michael Feathers
### 设计模式
- Strategy Pattern (customer_number patterns)
- Factory Pattern (parser creation)
- Template Method (field normalization)
### Python最佳实践
- PEP 8: Style Guide
- PEP 257: Docstring Conventions
- Google Python Style Guide
---
## ✅ 验收标准
重构完成的定义:
1. ✅ 所有P0和P1任务完成
2. ✅ 测试覆盖率 ≥ 70%
3. ✅ 安全问题全部修复
4. ✅ 代码重复率 < 5%
5. 所有长函数 (>100行) 已拆分
6. ✅ API文档完整
7. ✅ 性能无退化
8. ✅ 生产环境部署成功
---
**文档结束**
下一步: 开始执行 Phase 1, Day 1 - 修复明文密码问题