38 KiB
38 KiB
重构计划文档 (Refactoring Plan)
项目: Invoice Field Extraction System 生成日期: 2026-01-22 基于: CODE_REVIEW_REPORT.md 目标: 提升代码可维护性、可测试性和安全性
📋 目录
🎯 重构目标
主要目标
- 安全性: 消除明文密码、SQL注入等安全隐患
- 可维护性: 减少代码重复,降低函数复杂度
- 可测试性: 提升测试覆盖率至70%+,增加集成测试
- 可读性: 统一代码风格,添加必要文档
- 性能: 优化批处理和并发处理
量化指标
- 测试覆盖率: 45% → 70%+
- 平均函数长度: 80行 → 50行以下
- 代码重复率: 15% → 5%以下
- 循环复杂度: 最高15+ → 最高10
- 关键函数文档覆盖: 30% → 80%+
📐 总体策略
原则
- 增量重构: 小步快跑,每次重构保持系统可运行
- 测试先行: 重构前先补充测试,确保行为不变
- 向后兼容: API接口保持兼容,逐步废弃旧接口
- 文档同步: 代码变更同步更新文档
工作流程
1. 为待重构模块补充测试 (确保现有行为被覆盖)
↓
2. 执行重构 (Extract Method, Extract Class, etc.)
↓
3. 运行全量测试 (确保行为不变)
↓
4. 更新文档
↓
5. Code Review
↓
6. 合并主分支
🗓️ 三阶段执行计划
Phase 1: 紧急修复 (1周)
目标: 修复安全漏洞和关键bug
| 任务 | 优先级 | 预计时间 | 负责模块 |
|---|---|---|---|
| 修复明文密码问题 | P0 | 1小时 | src/db/config.py |
| 配置环境变量管理 | P0 | 2小时 | 根目录 .env |
| 修复SQL注入风险 | P0 | 3小时 | src/db/operations.py |
| 添加输入验证 | P1 | 4小时 | src/web/routes.py |
| 异常处理规范化 | P1 | 1天 | 全局 |
Phase 2: 核心重构 (2-3周)
目标: 降低代码复杂度,消除重复
| 任务 | 优先级 | 预计时间 | 负责模块 |
|---|---|---|---|
拆分 _normalize_customer_number |
P0 | 1天 | field_extractor.py |
| 统一 payment_line 解析 | P0 | 2天 | 抽取到单独模块 |
重构 process_document |
P1 | 2天 | pipeline.py |
| Extract Method: 长函数拆分 | P1 | 3天 | 全局 |
| 添加集成测试 | P0 | 3天 | tests/integration/ |
| 提升单元测试覆盖率 | P1 | 2天 | 各模块 |
Phase 3: 优化改进 (1-2周)
目标: 性能优化、文档完善
| 任务 | 优先级 | 预计时间 | 负责模块 |
|---|---|---|---|
| 批处理并发优化 | P1 | 2天 | batch_processor.py |
| API文档完善 | P2 | 1天 | docs/API.md |
| 配置提取到常量 | P2 | 1天 | src/config/constants.py |
| 日志系统优化 | P2 | 1天 | src/utils/logging.py |
| 性能分析和优化 | P2 | 2天 | 全局 |
🔧 详细重构步骤
Step 1: 修复明文密码 (P0, 1小时)
当前问题:
# src/db/config.py:29
DATABASE_CONFIG = {
"host": "localhost",
"port": 3306,
"user": "root",
"password": "your_password", # ❌ 明文密码
"database": "invoice_extraction",
}
重构步骤:
- 创建
.env.example模板:
# Database Configuration
DB_HOST=localhost
DB_PORT=3306
DB_USER=root
DB_PASSWORD=your_password_here
DB_NAME=invoice_extraction
- 创建
.env文件 (加入.gitignore):
DB_PASSWORD=actual_secure_password
- 修改
src/db/config.py:
import os
from dotenv import load_dotenv
load_dotenv()
DATABASE_CONFIG = {
"host": os.getenv("DB_HOST", "localhost"),
"port": int(os.getenv("DB_PORT", "3306")),
"user": os.getenv("DB_USER", "root"),
"password": os.getenv("DB_PASSWORD"), # ✅ 从环境变量读取
"database": os.getenv("DB_NAME", "invoice_extraction"),
}
# 启动时验证
if not DATABASE_CONFIG["password"]:
raise ValueError("DB_PASSWORD environment variable not set")
- 安装依赖:
pip install python-dotenv
- 更新
requirements.txt:
python-dotenv>=1.0.0
测试:
- 验证环境变量读取正常
- 确认缺少环境变量时抛出异常
- 测试数据库连接
Step 2: 修复SQL注入 (P0, 3小时)
当前问题:
# src/db/operations.py:156
query = f"SELECT * FROM documents WHERE id = {doc_id}" # ❌ SQL注入风险
cursor.execute(query)
重构步骤:
- 审查所有SQL查询,识别字符串拼接:
grep -n "f\".*SELECT" src/db/operations.py
grep -n "f\".*INSERT" src/db/operations.py
grep -n "f\".*UPDATE" src/db/operations.py
grep -n "f\".*DELETE" src/db/operations.py
- 替换为参数化查询:
# Before
query = f"SELECT * FROM documents WHERE id = {doc_id}"
cursor.execute(query)
# After ✅
query = "SELECT * FROM documents WHERE id = %s"
cursor.execute(query, (doc_id,))
- 常见场景替换:
# INSERT
query = "INSERT INTO documents (filename, status) VALUES (%s, %s)"
cursor.execute(query, (filename, status))
# UPDATE
query = "UPDATE documents SET status = %s WHERE id = %s"
cursor.execute(query, (new_status, doc_id))
# IN clause
placeholders = ','.join(['%s'] * len(ids))
query = f"SELECT * FROM documents WHERE id IN ({placeholders})"
cursor.execute(query, ids)
- 创建查询构建器辅助函数:
# src/db/query_builder.py
def build_select(table: str, columns: list[str] = None, where: dict = None):
"""Build safe SELECT query with parameters."""
cols = ', '.join(columns) if columns else '*'
query = f"SELECT {cols} FROM {table}"
params = []
if where:
conditions = []
for key, value in where.items():
conditions.append(f"{key} = %s")
params.append(value)
query += " WHERE " + " AND ".join(conditions)
return query, tuple(params)
测试:
- 单元测试所有修改的查询函数
- SQL注入测试: 传入
"1 OR 1=1"等恶意输入 - 集成测试验证功能正常
Step 3: 统一 payment_line 解析 (P0, 2天)
当前问题: payment_line 解析逻辑在3个地方重复实现
src/inference/field_extractor.py:632-705(normalization)src/inference/pipeline.py:217-252(parsing for cross-validation)src/inference/test_field_extractor.py:269-344(test cases)
重构步骤:
- 创建独立模块
src/inference/payment_line_parser.py:
"""
Swedish Payment Line Parser
Handles parsing and validation of Swedish machine-readable payment lines.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
"""
import re
from dataclasses import dataclass
from typing import Optional
@dataclass
class PaymentLineData:
"""Parsed payment line data."""
ocr_number: str
amount: str # Format: "KRONOR.ÖRE"
account_number: str # Bankgiro or Plusgiro
record_type: str # Usually "5" or "9"
check_digits: str
raw_text: str
is_valid: bool
error: Optional[str] = None
class PaymentLineParser:
"""Parser for Swedish payment lines with OCR error handling."""
# Pattern with OCR error tolerance
FULL_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Pattern without amount (fallback)
PARTIAL_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#.*?(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
def __init__(self):
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> PaymentLineData:
"""
Parse payment line text.
Handles common OCR errors:
- Spaces in numbers: "12 0 0" → "1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #" → "#41#"
Args:
text: Raw payment line text
Returns:
PaymentLineData with parsed fields
"""
text = text.strip()
# Try full pattern with amount
match = self.FULL_PATTERN.search(text)
if match:
return self._parse_full_match(match, text)
# Try partial pattern without amount
match = self.PARTIAL_PATTERN.search(text)
if match:
return self._parse_partial_match(match, text)
# No match
return PaymentLineData(
ocr_number="",
amount="",
account_number="",
record_type="",
check_digits="",
raw_text=text,
is_valid=False,
error="Invalid payment line format"
)
def _parse_full_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse full pattern match (with amount)."""
ocr = self._clean_digits(match.group(1))
kronor = self._clean_digits(match.group(2))
ore = match.group(3)
record_type = match.group(4)
account = self._clean_digits(match.group(5))
check_digits = match.group(6)
amount = f"{kronor}.{ore}"
return PaymentLineData(
ocr_number=ocr,
amount=amount,
account_number=account,
record_type=record_type,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True
)
def _parse_partial_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse partial pattern match (without amount)."""
ocr = self._clean_digits(match.group(1))
record_type = match.group(2)
account = self._clean_digits(match.group(3))
check_digits = match.group(4)
return PaymentLineData(
ocr_number=ocr,
amount="", # No amount in partial format
account_number=account,
record_type=record_type,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True
)
def _clean_digits(self, text: str) -> str:
"""Remove spaces from digit string."""
return text.replace(' ', '')
def format_machine_readable(self, data: PaymentLineData) -> str:
"""
Format parsed data back to machine-readable format.
Returns:
Formatted string: "# OCR # KRONOR ÖRE TYPE > ACCOUNT#CHECK#"
"""
if not data.is_valid:
return data.raw_text
if data.amount:
kronor, ore = data.amount.split('.')
return (
f"# {data.ocr_number} # {kronor} {ore} {data.record_type} > "
f"{data.account_number}#{data.check_digits}#"
)
else:
return (
f"# {data.ocr_number} # ... {data.record_type} > "
f"{data.account_number}#{data.check_digits}#"
)
- 重构
field_extractor.py使用新parser:
# src/inference/field_extractor.py
from .payment_line_parser import PaymentLineParser
class FieldExtractor:
def __init__(self):
self.payment_parser = PaymentLineParser()
# ...
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize payment line using dedicated parser."""
data = self.payment_parser.parse(text)
if not data.is_valid:
return None, False, data.error
formatted = self.payment_parser.format_machine_readable(data)
return formatted, True, None
- 重构
pipeline.py使用新parser:
# src/inference/pipeline.py
from .payment_line_parser import PaymentLineParser
class InferencePipeline:
def __init__(self):
self.payment_parser = PaymentLineParser()
# ...
def _parse_machine_readable_payment_line(
self, payment_line: str
) -> tuple[str | None, str | None, str | None]:
"""Parse payment line for cross-validation."""
data = self.payment_parser.parse(payment_line)
if not data.is_valid:
return None, None, None
return data.ocr_number, data.amount, data.account_number
- 更新测试使用新parser:
# tests/unit/test_payment_line_parser.py
from src.inference.payment_line_parser import PaymentLineParser
class TestPaymentLineParser:
def test_full_format_with_spaces(self):
"""Test parsing with OCR-induced spaces."""
parser = PaymentLineParser()
text = "# 6026726908 # 736 00 9 > 5692041 #41 #"
data = parser.parse(text)
assert data.is_valid
assert data.ocr_number == "6026726908"
assert data.amount == "736.00"
assert data.account_number == "5692041"
assert data.check_digits == "41"
def test_format_without_amount(self):
"""Test parsing without amount."""
parser = PaymentLineParser()
text = "# 11000770600242 # ... 5 > 3082963#41#"
data = parser.parse(text)
assert data.is_valid
assert data.ocr_number == "11000770600242"
assert data.amount == ""
assert data.account_number == "3082963"
def test_machine_readable_format(self):
"""Test formatting back to machine-readable."""
parser = PaymentLineParser()
text = "# 6026726908 # 736 00 9 > 5692041 #41 #"
data = parser.parse(text)
formatted = parser.format_machine_readable(data)
assert "# 6026726908 #" in formatted
assert "736 00" in formatted
assert "5692041#41#" in formatted
迁移步骤:
- 创建
payment_line_parser.py并添加测试 - 运行测试确保新实现正确
- 逐个文件迁移到新parser
- 每次迁移后运行全量测试
- 删除旧实现代码
- 更新文档
测试:
- 单元测试覆盖所有解析场景
- 集成测试验证端到端功能
- 回归测试确保行为不变
Step 4: 拆分 _normalize_customer_number (P0, 1天)
当前问题:
- 函数长度: 127行
- 循环复杂度: 15+
- 职责过多: 模式匹配、格式化、验证混在一起
重构策略: Extract Method + Strategy Pattern
重构步骤:
- 创建
src/inference/customer_number_parser.py:
"""
Customer Number Parser
Handles extraction and normalization of Swedish customer numbers.
"""
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
@dataclass
class CustomerNumberMatch:
"""Customer number match result."""
value: str
pattern_name: str
confidence: float
raw_text: str
class CustomerNumberPattern(ABC):
"""Abstract base for customer number patterns."""
@abstractmethod
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Try to match pattern in text."""
pass
@abstractmethod
def format(self, match: re.Match) -> str:
"""Format matched groups to standard format."""
pass
class DashFormatPattern(CustomerNumberPattern):
"""Pattern: ABC 123-X"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
match = self.PATTERN.search(text)
if not match:
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="DashFormat",
confidence=0.95,
raw_text=match.group(0)
)
def format(self, match: re.Match) -> str:
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
class NoDashFormatPattern(CustomerNumberPattern):
"""Pattern: ABC 123X (no dash)"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
match = self.PATTERN.search(text)
if not match:
return None
# Exclude postal codes
full_text = match.group(0)
if self._is_postal_code(full_text):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="NoDashFormat",
confidence=0.90,
raw_text=full_text
)
def format(self, match: re.Match) -> str:
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
# SE 106 43, SE 10643, etc.
return bool(re.match(r'^SE\s*\d{3}\s*\d{2}', text, re.IGNORECASE))
class CustomerNumberParser:
"""Parser for Swedish customer numbers."""
def __init__(self):
# Patterns ordered by specificity (most specific first)
self.patterns: list[CustomerNumberPattern] = [
DashFormatPattern(),
NoDashFormatPattern(),
# Add more patterns as needed
]
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> tuple[Optional[str], bool, Optional[str]]:
"""
Parse customer number from text.
Returns:
(customer_number, is_valid, error)
"""
text = text.strip()
# Try each pattern
matches: list[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
matches.append(match)
# No matches
if not matches:
return None, False, "No customer number found"
# Return highest confidence match
best_match = max(matches, key=lambda m: m.confidence)
return best_match.value, True, None
def parse_all(self, text: str) -> list[CustomerNumberMatch]:
"""
Find all customer numbers in text.
Useful for cases with multiple potential matches.
"""
matches: list[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
matches.append(match)
return sorted(matches, key=lambda m: m.confidence, reverse=True)
- 重构
field_extractor.py:
# src/inference/field_extractor.py
from .customer_number_parser import CustomerNumberParser
class FieldExtractor:
def __init__(self):
self.customer_parser = CustomerNumberParser()
# ...
def _normalize_customer_number(
self, text: str
) -> tuple[str | None, bool, str | None]:
"""Normalize customer number using dedicated parser."""
return self.customer_parser.parse(text)
- 添加测试:
# tests/unit/test_customer_number_parser.py
from src.inference.customer_number_parser import (
CustomerNumberParser,
DashFormatPattern,
NoDashFormatPattern,
)
class TestDashFormatPattern:
def test_standard_format(self):
pattern = DashFormatPattern()
match = pattern.match("Customer: JTY 576-3")
assert match is not None
assert match.value == "JTY 576-3"
assert match.confidence == 0.95
class TestNoDashFormatPattern:
def test_no_dash_format(self):
pattern = NoDashFormatPattern()
match = pattern.match("Dwq 211X")
assert match is not None
assert match.value == "DWQ 211-X"
assert match.confidence == 0.90
def test_exclude_postal_code(self):
pattern = NoDashFormatPattern()
match = pattern.match("SE 106 43")
assert match is None # Should be filtered out
class TestCustomerNumberParser:
def test_parse_with_dash(self):
parser = CustomerNumberParser()
result, is_valid, error = parser.parse("Customer: JTY 576-3")
assert is_valid
assert result == "JTY 576-3"
assert error is None
def test_parse_without_dash(self):
parser = CustomerNumberParser()
result, is_valid, error = parser.parse("Dwq 211X Billo")
assert is_valid
assert result == "DWQ 211-X"
def test_parse_all_finds_multiple(self):
parser = CustomerNumberParser()
text = "JTY 576-3 and DWQ 211X"
matches = parser.parse_all(text)
assert len(matches) >= 1 # At least one match
assert matches[0].confidence >= 0.90
迁移计划:
- Day 1 上午: 创建新parser和测试
- Day 1 下午: 迁移
field_extractor.py,运行测试 - 回归测试确保所有文档处理正常
Step 5: 重构 process_document (P1, 2天)
当前问题: pipeline.py:100-250 (150行) 职责过多
重构策略: Extract Method + 责任分离
目标结构:
def process_document(self, image_path: Path, document_id: str) -> DocumentResult:
"""Main orchestration - keep under 30 lines."""
# 1. Run detection
detections = self._run_yolo_detection(image_path)
# 2. Extract fields
fields = self._extract_fields_from_detections(detections, image_path)
# 3. Apply cross-validation
fields = self._apply_cross_validation(fields)
# 4. Multi-source fusion
fields = self._apply_multi_source_fusion(fields)
# 5. Build result
return self._build_document_result(document_id, fields, detections)
详细步骤见 docs/CODE_REVIEW_REPORT.md Section 5.3.
Step 6: 添加集成测试 (P0, 3天)
当前状况: 缺少端到端集成测试
目标: 创建完整的集成测试套件
测试场景:
- PDF → 推理 → 结果验证 (端到端)
- 批处理多文档
- API端点测试
- 数据库集成测试
- 错误场景测试
实施步骤:
- 创建测试数据集:
tests/
├── fixtures/
│ ├── sample_invoices/
│ │ ├── billo_363.pdf
│ │ ├── billo_308.pdf
│ │ └── billo_310.pdf
│ └── expected_results/
│ ├── billo_363.json
│ ├── billo_308.json
│ └── billo_310.json
- 创建
tests/integration/test_end_to_end.py:
import pytest
from pathlib import Path
from src.inference.pipeline import InferencePipeline
from src.inference.field_extractor import FieldExtractor
@pytest.fixture
def pipeline():
"""Create inference pipeline."""
extractor = FieldExtractor()
return InferencePipeline(
model_path="runs/train/invoice_fields/weights/best.pt",
confidence_threshold=0.5,
dpi=150,
field_extractor=extractor
)
@pytest.fixture
def sample_invoices():
"""Load sample invoices and expected results."""
fixtures_dir = Path(__file__).parent.parent / "fixtures"
samples = []
for pdf_path in (fixtures_dir / "sample_invoices").glob("*.pdf"):
json_path = fixtures_dir / "expected_results" / f"{pdf_path.stem}.json"
with open(json_path) as f:
expected = json.load(f)
samples.append({
"pdf_path": pdf_path,
"expected": expected
})
return samples
class TestEndToEnd:
"""End-to-end integration tests."""
def test_single_document_processing(self, pipeline, sample_invoices):
"""Test processing a single invoice from PDF to extracted fields."""
sample = sample_invoices[0]
# Process PDF
result = pipeline.process_pdf(
sample["pdf_path"],
document_id="test_001"
)
# Verify success
assert result.success
# Verify extracted fields match expected
expected = sample["expected"]
assert result.fields["amount"] == expected["amount"]
assert result.fields["ocr_number"] == expected["ocr_number"]
assert result.fields["customer_number"] == expected["customer_number"]
def test_batch_processing(self, pipeline, sample_invoices):
"""Test batch processing multiple invoices."""
pdf_paths = [s["pdf_path"] for s in sample_invoices]
# Process batch
results = pipeline.process_batch(pdf_paths)
# Verify all processed
assert len(results) == len(pdf_paths)
# Verify success rate
success_count = sum(1 for r in results if r.success)
assert success_count >= len(pdf_paths) * 0.9 # At least 90% success
def test_cross_validation_overrides(self, pipeline):
"""Test that payment_line values override detected values."""
# Use sample with known discrepancy (Billo310)
pdf_path = Path("tests/fixtures/sample_invoices/billo_310.pdf")
result = pipeline.process_pdf(pdf_path, document_id="test_cross_val")
# Verify payment_line was parsed
assert "payment_line" in result.fields
# Verify Amount was corrected from payment_line
# (Billo310: detected 20736.00, payment_line has 736.00)
assert result.fields["amount"] == "736.00"
def test_error_handling_invalid_pdf(self, pipeline):
"""Test graceful error handling for invalid PDF."""
invalid_pdf = Path("tests/fixtures/invalid.pdf")
result = pipeline.process_pdf(invalid_pdf, document_id="test_error")
# Should return result with success=False
assert not result.success
assert result.errors
assert len(result.errors) > 0
class TestAPIIntegration:
"""API endpoint integration tests."""
@pytest.fixture
def client(self):
"""Create test client."""
from fastapi.testclient import TestClient
from src.web.app import create_app
from src.web.config import AppConfig
config = AppConfig.from_defaults()
app = create_app(config)
return TestClient(app)
def test_health_endpoint(self, client):
"""Test /api/v1/health endpoint."""
response = client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "model_loaded" in data
def test_infer_endpoint_with_pdf(self, client, sample_invoices):
"""Test /api/v1/infer with PDF upload."""
sample = sample_invoices[0]
with open(sample["pdf_path"], "rb") as f:
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", f, "application/pdf")}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "result" in data
assert "fields" in data["result"]
def test_infer_endpoint_invalid_file(self, client):
"""Test /api/v1/infer rejects invalid file."""
response = client.post(
"/api/v1/infer",
files={"file": ("test.txt", b"invalid", "text/plain")}
)
assert response.status_code == 400
assert "Unsupported file type" in response.json()["detail"]
class TestDatabaseIntegration:
"""Database integration tests."""
@pytest.fixture
def db_connection(self):
"""Create test database connection."""
from src.db.connection import DatabaseConnection
# Use test database
conn = DatabaseConnection(database="invoice_extraction_test")
yield conn
conn.close()
def test_save_and_retrieve_result(self, db_connection, pipeline, sample_invoices):
"""Test saving inference result to database and retrieving it."""
sample = sample_invoices[0]
# Process document
result = pipeline.process_pdf(sample["pdf_path"], document_id="test_db_001")
# Save to database
db_connection.save_inference_result(result)
# Retrieve from database
retrieved = db_connection.get_inference_result("test_db_001")
# Verify
assert retrieved is not None
assert retrieved["document_id"] == "test_db_001"
assert retrieved["fields"]["amount"] == result.fields["amount"]
- 配置 pytest 运行集成测试:
# pytest.ini
[pytest]
markers =
unit: Unit tests (fast, no external dependencies)
integration: Integration tests (slower, may use database/files)
slow: Slow tests
# Run unit tests by default
addopts = -v -m "not integration"
# Run all tests including integration
# pytest -m ""
# Run only integration tests
# pytest -m integration
- CI/CD集成:
# .github/workflows/test.yml
name: Tests
on: [push, pull_request]
jobs:
unit-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest pytest-cov
- name: Run unit tests
run: pytest -m "not integration" --cov=src --cov-report=xml
- name: Upload coverage
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
integration-tests:
runs-on: ubuntu-latest
services:
mysql:
image: mysql:8.0
env:
MYSQL_ROOT_PASSWORD: test_password
MYSQL_DATABASE: invoice_extraction_test
ports:
- 3306:3306
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest
- name: Run integration tests
env:
DB_HOST: localhost
DB_PORT: 3306
DB_USER: root
DB_PASSWORD: test_password
DB_NAME: invoice_extraction_test
run: pytest -m integration
时间分配:
- Day 1: 准备测试数据、创建测试框架
- Day 2: 编写端到端和API测试
- Day 3: 数据库集成测试、CI/CD配置
Step 7: 异常处理规范化 (P1, 1天)
当前问题: 31处 except Exception 捕获过于宽泛
目标: 创建异常层次结构,精确捕获
实施步骤:
- 创建
src/exceptions.py:
"""
Application-specific exceptions.
"""
class InvoiceExtractionError(Exception):
"""Base exception for invoice extraction errors."""
pass
class PDFProcessingError(InvoiceExtractionError):
"""Error during PDF processing."""
pass
class OCRError(InvoiceExtractionError):
"""Error during OCR."""
pass
class ModelInferenceError(InvoiceExtractionError):
"""Error during model inference."""
pass
class FieldValidationError(InvoiceExtractionError):
"""Error during field validation."""
pass
class DatabaseError(InvoiceExtractionError):
"""Error during database operations."""
pass
class ConfigurationError(InvoiceExtractionError):
"""Error in configuration."""
pass
- 替换宽泛的异常捕获:
# Before ❌
try:
result = process_pdf(path)
except Exception as e:
logger.error(f"Error: {e}")
return None
# After ✅
try:
result = process_pdf(path)
except PDFProcessingError as e:
logger.error(f"PDF processing failed: {e}")
return None
except OCRError as e:
logger.warning(f"OCR failed, trying fallback: {e}")
result = fallback_ocr(path)
except ModelInferenceError as e:
logger.error(f"Model inference failed: {e}")
raise # Re-raise for upper layer
- 在各模块中抛出具体异常:
# src/inference/pdf_processor.py
from src.exceptions import PDFProcessingError
def convert_pdf_to_image(pdf_path: Path, dpi: int) -> list[np.ndarray]:
try:
images = pdf2image.convert_from_path(pdf_path, dpi=dpi)
except Exception as e:
raise PDFProcessingError(f"Failed to convert PDF: {e}") from e
if not images:
raise PDFProcessingError("PDF conversion returned no images")
return images
- 创建异常处理装饰器:
# src/utils/error_handling.py
import functools
from typing import Callable, Type
from src.exceptions import InvoiceExtractionError
def handle_errors(
*exception_types: Type[Exception],
default_return=None,
log_error: bool = True
):
"""Decorator for standardized error handling."""
def decorator(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except exception_types as e:
if log_error:
logger = logging.getLogger(func.__module__)
logger.error(
f"Error in {func.__name__}: {e}",
exc_info=True
)
return default_return
return wrapper
return decorator
# Usage
@handle_errors(PDFProcessingError, OCRError, default_return=None)
def safe_process_document(doc_path: Path):
return process_document(doc_path)
Step 8-12: 其他重构任务
详细步骤参见 CODE_REVIEW_REPORT.md Section 6 (Action Plan)。
🧪 测试策略
测试金字塔
/\
/ \ E2E Tests (10%)
/----\ - Full pipeline tests
/ \ - API integration tests
/--------\
/ \ Integration Tests (30%)
/------------\ - Module integration
/ \ - Database tests
----------------
Unit Tests (60%)
- Function-level tests
- High coverage
测试覆盖率目标
| 模块 | 当前覆盖率 | 目标覆盖率 |
|---|---|---|
field_extractor.py |
40% | 80% |
pipeline.py |
50% | 75% |
payment_line_parser.py |
0% (新) | 90% |
customer_number_parser.py |
0% (新) | 90% |
web/routes.py |
30% | 70% |
db/operations.py |
20% | 60% |
| Overall | 45% | 70%+ |
回归测试
每次重构后必须运行:
# 1. 单元测试
pytest tests/unit/ -v
# 2. 集成测试
pytest tests/integration/ -v
# 3. 端到端测试(使用实际PDF)
pytest tests/e2e/ -v
# 4. 性能测试(确保没有退化)
pytest tests/performance/ -v --benchmark
# 5. 测试覆盖率检查
pytest --cov=src --cov-report=html --cov-fail-under=70
⚠️ 风险管理
识别的风险
| 风险 | 影响 | 概率 | 缓解措施 |
|---|---|---|---|
| 重构破坏现有功能 | 高 | 中 | 1. 重构前补充测试 2. 小步迭代 3. 回归测试 |
| 性能退化 | 中 | 低 | 1. 性能基准测试 2. 持续监控 3. Profile优化 |
| API接口变更影响客户端 | 高 | 低 | 1. 语义化版本控制 2. 废弃通知期 3. 向后兼容 |
| 数据库迁移失败 | 高 | 低 | 1. 备份数据 2. 分阶段迁移 3. 回滚计划 |
| 时间超期 | 中 | 中 | 1. 优先级排序 2. 每周进度审查 3. 必要时调整范围 |
回滚计划
每个重构步骤都应有明确的回滚策略:
-
代码回滚: 使用Git分支隔离变更
# 每个重构任务创建特性分支 git checkout -b refactor/payment-line-parser # 如需回滚 git checkout main git branch -D refactor/payment-line-parser -
数据库回滚: 使用数据库迁移工具
# 应用迁移 alembic upgrade head # 回滚迁移 alembic downgrade -1 -
配置回滚: 保留旧配置兼容性
# 支持新旧两种配置格式 password = config.get("db_password") or config.get("password")
📊 成功指标
量化指标
| 指标 | 当前值 | 目标值 | 测量方法 |
|---|---|---|---|
| 测试覆盖率 | 45% | 70%+ | pytest --cov |
| 平均函数长度 | 80行 | <50行 | radon cc |
| 循环复杂度 | 最高15+ | <10 | radon cc |
| 代码重复率 | ~15% | <5% | pylint --duplicate |
| 安全问题 | 2个 (明文密码, SQL注入) | 0个 | 手动审查 + bandit |
| 文档覆盖率 | 30% | 80%+ | 手动审查 |
| 平均处理时间 | ~2秒/文档 | <2秒/文档 | 性能测试 |
质量门禁
所有变更必须满足:
- ✅ 测试覆盖率 ≥ 70%
- ✅ 所有测试通过 (单元 + 集成 + E2E)
- ✅ 无高危安全问题
- ✅ 代码审查通过
- ✅ 性能无退化 (±5%以内)
- ✅ 文档已更新
📅 时间表
Phase 1: 紧急修复 (Week 1)
| 日期 | 任务 | 负责人 | 状态 |
|---|---|---|---|
| Day 1 | 修复明文密码 + 环境变量配置 | ⏳ | |
| Day 2-3 | 修复SQL注入 + 添加参数化查询 | ⏳ | |
| Day 4-5 | 异常处理规范化 | ⏳ |
Phase 2: 核心重构 (Week 2-4)
| 周 | 任务 | 状态 |
|---|---|---|
| Week 2 | 统一payment_line解析 + 拆分customer_number | ⏳ |
| Week 3 | 重构pipeline + Extract Method | ⏳ |
| Week 4 | 添加集成测试 + 提升单元测试覆盖率 | ⏳ |
Phase 3: 优化改进 (Week 5-6)
| 周 | 任务 | 状态 |
|---|---|---|
| Week 5 | 批处理优化 + 配置提取 | ⏳ |
| Week 6 | 文档完善 + 日志优化 + 性能调优 | ⏳ |
🔄 持续改进
Code Review Checklist
每次提交前检查:
- 所有测试通过
- 测试覆盖率达标
- 无新增安全问题
- 代码符合风格指南
- 函数长度 < 50行
- 循环复杂度 < 10
- 文档已更新
- 变更日志已记录
自动化工具
配置pre-commit hooks:
# .pre-commit-config.yaml
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.11
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args: [--max-line-length=88, --extend-ignore=E203]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5
hooks:
- id: bandit
args: [-c, pyproject.toml]
- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: pytest
language: system
pass_filenames: false
always_run: true
args: [-m, "not integration", --tb=short]
📚 参考资料
重构书籍
- Refactoring: Improving the Design of Existing Code - Martin Fowler
- Clean Code - Robert C. Martin
- Working Effectively with Legacy Code - Michael Feathers
设计模式
- Strategy Pattern (customer_number patterns)
- Factory Pattern (parser creation)
- Template Method (field normalization)
Python最佳实践
- PEP 8: Style Guide
- PEP 257: Docstring Conventions
- Google Python Style Guide
✅ 验收标准
重构完成的定义:
- ✅ 所有P0和P1任务完成
- ✅ 测试覆盖率 ≥ 70%
- ✅ 安全问题全部修复
- ✅ 代码重复率 < 5%
- ✅ 所有长函数 (>100行) 已拆分
- ✅ API文档完整
- ✅ 性能无退化
- ✅ 生产环境部署成功
文档结束
下一步: 开始执行 Phase 1, Day 1 - 修复明文密码问题