# 重构计划文档 (Refactoring Plan) **项目**: Invoice Field Extraction System **生成日期**: 2026-01-22 **基于**: CODE_REVIEW_REPORT.md **目标**: 提升代码可维护性、可测试性和安全性 --- ## 📋 目录 1. [重构目标](#重构目标) 2. [总体策略](#总体策略) 3. [三阶段执行计划](#三阶段执行计划) 4. [详细重构步骤](#详细重构步骤) 5. [测试策略](#测试策略) 6. [风险管理](#风险管理) 7. [成功指标](#成功指标) --- ## 🎯 重构目标 ### 主要目标 1. **安全性**: 消除明文密码、SQL注入等安全隐患 2. **可维护性**: 减少代码重复,降低函数复杂度 3. **可测试性**: 提升测试覆盖率至70%+,增加集成测试 4. **可读性**: 统一代码风格,添加必要文档 5. **性能**: 优化批处理和并发处理 ### 量化指标 - 测试覆盖率: 45% → 70%+ - 平均函数长度: 80行 → 50行以下 - 代码重复率: 15% → 5%以下 - 循环复杂度: 最高15+ → 最高10 - 关键函数文档覆盖: 30% → 80%+ --- ## 📐 总体策略 ### 原则 1. **增量重构**: 小步快跑,每次重构保持系统可运行 2. **测试先行**: 重构前先补充测试,确保行为不变 3. **向后兼容**: API接口保持兼容,逐步废弃旧接口 4. **文档同步**: 代码变更同步更新文档 ### 工作流程 ``` 1. 为待重构模块补充测试 (确保现有行为被覆盖) ↓ 2. 执行重构 (Extract Method, Extract Class, etc.) ↓ 3. 运行全量测试 (确保行为不变) ↓ 4. 更新文档 ↓ 5. Code Review ↓ 6. 合并主分支 ``` --- ## 🗓️ 三阶段执行计划 ### Phase 1: 紧急修复 (1周) **目标**: 修复安全漏洞和关键bug | 任务 | 优先级 | 预计时间 | 负责模块 | |------|--------|----------|----------| | 修复明文密码问题 | P0 | 1小时 | `src/db/config.py` | | 配置环境变量管理 | P0 | 2小时 | 根目录 `.env` | | 修复SQL注入风险 | P0 | 3小时 | `src/db/operations.py` | | 添加输入验证 | P1 | 4小时 | `src/web/routes.py` | | 异常处理规范化 | P1 | 1天 | 全局 | ### Phase 2: 核心重构 (2-3周) **目标**: 降低代码复杂度,消除重复 | 任务 | 优先级 | 预计时间 | 负责模块 | |------|--------|----------|----------| | 拆分 `_normalize_customer_number` | P0 | 1天 | `field_extractor.py` | | 统一 payment_line 解析 | P0 | 2天 | 抽取到单独模块 | | 重构 `process_document` | P1 | 2天 | `pipeline.py` | | Extract Method: 长函数拆分 | P1 | 3天 | 全局 | | 添加集成测试 | P0 | 3天 | `tests/integration/` | | 提升单元测试覆盖率 | P1 | 2天 | 各模块 | ### Phase 3: 优化改进 (1-2周) **目标**: 性能优化、文档完善 | 任务 | 优先级 | 预计时间 | 负责模块 | |------|--------|----------|----------| | 批处理并发优化 | P1 | 2天 | `batch_processor.py` | | API文档完善 | P2 | 1天 | `docs/API.md` | | 配置提取到常量 | P2 | 1天 | `src/config/constants.py` | | 日志系统优化 | P2 | 1天 | `src/utils/logging.py` | | 性能分析和优化 | P2 | 2天 | 全局 | --- ## 🔧 详细重构步骤 ### Step 1: 修复明文密码 (P0, 1小时) **当前问题**: ```python # src/db/config.py:29 DATABASE_CONFIG = { "host": "localhost", "port": 3306, "user": "root", "password": "your_password", # ❌ 明文密码 "database": "invoice_extraction", } ``` **重构步骤**: 1. 创建 `.env.example` 模板: ```bash # Database Configuration DB_HOST=localhost DB_PORT=3306 DB_USER=root DB_PASSWORD=your_password_here DB_NAME=invoice_extraction ``` 2. 创建 `.env` 文件 (加入 `.gitignore`): ```bash DB_PASSWORD=actual_secure_password ``` 3. 修改 `src/db/config.py`: ```python import os from dotenv import load_dotenv load_dotenv() DATABASE_CONFIG = { "host": os.getenv("DB_HOST", "localhost"), "port": int(os.getenv("DB_PORT", "3306")), "user": os.getenv("DB_USER", "root"), "password": os.getenv("DB_PASSWORD"), # ✅ 从环境变量读取 "database": os.getenv("DB_NAME", "invoice_extraction"), } # 启动时验证 if not DATABASE_CONFIG["password"]: raise ValueError("DB_PASSWORD environment variable not set") ``` 4. 安装依赖: ```bash pip install python-dotenv ``` 5. 更新 `requirements.txt`: ``` python-dotenv>=1.0.0 ``` **测试**: - 验证环境变量读取正常 - 确认缺少环境变量时抛出异常 - 测试数据库连接 --- ### Step 2: 修复SQL注入 (P0, 3小时) **当前问题**: ```python # src/db/operations.py:156 query = f"SELECT * FROM documents WHERE id = {doc_id}" # ❌ SQL注入风险 cursor.execute(query) ``` **重构步骤**: 1. 审查所有SQL查询,识别字符串拼接: ```bash grep -n "f\".*SELECT" src/db/operations.py grep -n "f\".*INSERT" src/db/operations.py grep -n "f\".*UPDATE" src/db/operations.py grep -n "f\".*DELETE" src/db/operations.py ``` 2. 替换为参数化查询: ```python # Before query = f"SELECT * FROM documents WHERE id = {doc_id}" cursor.execute(query) # After ✅ query = "SELECT * FROM documents WHERE id = %s" cursor.execute(query, (doc_id,)) ``` 3. 常见场景替换: ```python # INSERT query = "INSERT INTO documents (filename, status) VALUES (%s, %s)" cursor.execute(query, (filename, status)) # UPDATE query = "UPDATE documents SET status = %s WHERE id = %s" cursor.execute(query, (new_status, doc_id)) # IN clause placeholders = ','.join(['%s'] * len(ids)) query = f"SELECT * FROM documents WHERE id IN ({placeholders})" cursor.execute(query, ids) ``` 4. 创建查询构建器辅助函数: ```python # src/db/query_builder.py def build_select(table: str, columns: list[str] = None, where: dict = None): """Build safe SELECT query with parameters.""" cols = ', '.join(columns) if columns else '*' query = f"SELECT {cols} FROM {table}" params = [] if where: conditions = [] for key, value in where.items(): conditions.append(f"{key} = %s") params.append(value) query += " WHERE " + " AND ".join(conditions) return query, tuple(params) ``` **测试**: - 单元测试所有修改的查询函数 - SQL注入测试: 传入 `"1 OR 1=1"` 等恶意输入 - 集成测试验证功能正常 --- ### Step 3: 统一 payment_line 解析 (P0, 2天) **当前问题**: payment_line 解析逻辑在3个地方重复实现 - `src/inference/field_extractor.py:632-705` (normalization) - `src/inference/pipeline.py:217-252` (parsing for cross-validation) - `src/inference/test_field_extractor.py:269-344` (test cases) **重构步骤**: 1. 创建独立模块 `src/inference/payment_line_parser.py`: ```python """ Swedish Payment Line Parser Handles parsing and validation of Swedish machine-readable payment lines. Format: # # <Öre> > ## """ import re from dataclasses import dataclass from typing import Optional @dataclass class PaymentLineData: """Parsed payment line data.""" ocr_number: str amount: str # Format: "KRONOR.ÖRE" account_number: str # Bankgiro or Plusgiro record_type: str # Usually "5" or "9" check_digits: str raw_text: str is_valid: bool error: Optional[str] = None class PaymentLineParser: """Parser for Swedish payment lines with OCR error handling.""" # Pattern with OCR error tolerance FULL_PATTERN = re.compile( r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#' ) # Pattern without amount (fallback) PARTIAL_PATTERN = re.compile( r'#\s*(\d[\d\s]*)\s*#.*?(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#' ) def __init__(self): self.logger = logging.getLogger(__name__) def parse(self, text: str) -> PaymentLineData: """ Parse payment line text. Handles common OCR errors: - Spaces in numbers: "12 0 0" → "1200" - Missing symbols: Missing ">" - Spaces in check digits: "#41 #" → "#41#" Args: text: Raw payment line text Returns: PaymentLineData with parsed fields """ text = text.strip() # Try full pattern with amount match = self.FULL_PATTERN.search(text) if match: return self._parse_full_match(match, text) # Try partial pattern without amount match = self.PARTIAL_PATTERN.search(text) if match: return self._parse_partial_match(match, text) # No match return PaymentLineData( ocr_number="", amount="", account_number="", record_type="", check_digits="", raw_text=text, is_valid=False, error="Invalid payment line format" ) def _parse_full_match(self, match: re.Match, raw_text: str) -> PaymentLineData: """Parse full pattern match (with amount).""" ocr = self._clean_digits(match.group(1)) kronor = self._clean_digits(match.group(2)) ore = match.group(3) record_type = match.group(4) account = self._clean_digits(match.group(5)) check_digits = match.group(6) amount = f"{kronor}.{ore}" return PaymentLineData( ocr_number=ocr, amount=amount, account_number=account, record_type=record_type, check_digits=check_digits, raw_text=raw_text, is_valid=True ) def _parse_partial_match(self, match: re.Match, raw_text: str) -> PaymentLineData: """Parse partial pattern match (without amount).""" ocr = self._clean_digits(match.group(1)) record_type = match.group(2) account = self._clean_digits(match.group(3)) check_digits = match.group(4) return PaymentLineData( ocr_number=ocr, amount="", # No amount in partial format account_number=account, record_type=record_type, check_digits=check_digits, raw_text=raw_text, is_valid=True ) def _clean_digits(self, text: str) -> str: """Remove spaces from digit string.""" return text.replace(' ', '') def format_machine_readable(self, data: PaymentLineData) -> str: """ Format parsed data back to machine-readable format. Returns: Formatted string: "# OCR # KRONOR ÖRE TYPE > ACCOUNT#CHECK#" """ if not data.is_valid: return data.raw_text if data.amount: kronor, ore = data.amount.split('.') return ( f"# {data.ocr_number} # {kronor} {ore} {data.record_type} > " f"{data.account_number}#{data.check_digits}#" ) else: return ( f"# {data.ocr_number} # ... {data.record_type} > " f"{data.account_number}#{data.check_digits}#" ) ``` 2. 重构 `field_extractor.py` 使用新parser: ```python # src/inference/field_extractor.py from .payment_line_parser import PaymentLineParser class FieldExtractor: def __init__(self): self.payment_parser = PaymentLineParser() # ... def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]: """Normalize payment line using dedicated parser.""" data = self.payment_parser.parse(text) if not data.is_valid: return None, False, data.error formatted = self.payment_parser.format_machine_readable(data) return formatted, True, None ``` 3. 重构 `pipeline.py` 使用新parser: ```python # src/inference/pipeline.py from .payment_line_parser import PaymentLineParser class InferencePipeline: def __init__(self): self.payment_parser = PaymentLineParser() # ... def _parse_machine_readable_payment_line( self, payment_line: str ) -> tuple[str | None, str | None, str | None]: """Parse payment line for cross-validation.""" data = self.payment_parser.parse(payment_line) if not data.is_valid: return None, None, None return data.ocr_number, data.amount, data.account_number ``` 4. 更新测试使用新parser: ```python # tests/unit/test_payment_line_parser.py from src.inference.payment_line_parser import PaymentLineParser class TestPaymentLineParser: def test_full_format_with_spaces(self): """Test parsing with OCR-induced spaces.""" parser = PaymentLineParser() text = "# 6026726908 # 736 00 9 > 5692041 #41 #" data = parser.parse(text) assert data.is_valid assert data.ocr_number == "6026726908" assert data.amount == "736.00" assert data.account_number == "5692041" assert data.check_digits == "41" def test_format_without_amount(self): """Test parsing without amount.""" parser = PaymentLineParser() text = "# 11000770600242 # ... 5 > 3082963#41#" data = parser.parse(text) assert data.is_valid assert data.ocr_number == "11000770600242" assert data.amount == "" assert data.account_number == "3082963" def test_machine_readable_format(self): """Test formatting back to machine-readable.""" parser = PaymentLineParser() text = "# 6026726908 # 736 00 9 > 5692041 #41 #" data = parser.parse(text) formatted = parser.format_machine_readable(data) assert "# 6026726908 #" in formatted assert "736 00" in formatted assert "5692041#41#" in formatted ``` **迁移步骤**: 1. 创建 `payment_line_parser.py` 并添加测试 2. 运行测试确保新实现正确 3. 逐个文件迁移到新parser 4. 每次迁移后运行全量测试 5. 删除旧实现代码 6. 更新文档 **测试**: - 单元测试覆盖所有解析场景 - 集成测试验证端到端功能 - 回归测试确保行为不变 --- ### Step 4: 拆分 `_normalize_customer_number` (P0, 1天) **当前问题**: - 函数长度: 127行 - 循环复杂度: 15+ - 职责过多: 模式匹配、格式化、验证混在一起 **重构策略**: Extract Method + Strategy Pattern **重构步骤**: 1. 创建 `src/inference/customer_number_parser.py`: ```python """ Customer Number Parser Handles extraction and normalization of Swedish customer numbers. """ import re from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional @dataclass class CustomerNumberMatch: """Customer number match result.""" value: str pattern_name: str confidence: float raw_text: str class CustomerNumberPattern(ABC): """Abstract base for customer number patterns.""" @abstractmethod def match(self, text: str) -> Optional[CustomerNumberMatch]: """Try to match pattern in text.""" pass @abstractmethod def format(self, match: re.Match) -> str: """Format matched groups to standard format.""" pass class DashFormatPattern(CustomerNumberPattern): """Pattern: ABC 123-X""" PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b') def match(self, text: str) -> Optional[CustomerNumberMatch]: match = self.PATTERN.search(text) if not match: return None formatted = self.format(match) return CustomerNumberMatch( value=formatted, pattern_name="DashFormat", confidence=0.95, raw_text=match.group(0) ) def format(self, match: re.Match) -> str: prefix = match.group(1).upper() number = match.group(2) suffix = match.group(3).upper() return f"{prefix} {number}-{suffix}" class NoDashFormatPattern(CustomerNumberPattern): """Pattern: ABC 123X (no dash)""" PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b') def match(self, text: str) -> Optional[CustomerNumberMatch]: match = self.PATTERN.search(text) if not match: return None # Exclude postal codes full_text = match.group(0) if self._is_postal_code(full_text): return None formatted = self.format(match) return CustomerNumberMatch( value=formatted, pattern_name="NoDashFormat", confidence=0.90, raw_text=full_text ) def format(self, match: re.Match) -> str: prefix = match.group(1).upper() number = match.group(2) suffix = match.group(3).upper() return f"{prefix} {number}-{suffix}" def _is_postal_code(self, text: str) -> bool: """Check if text looks like Swedish postal code.""" # SE 106 43, SE 10643, etc. return bool(re.match(r'^SE\s*\d{3}\s*\d{2}', text, re.IGNORECASE)) class CustomerNumberParser: """Parser for Swedish customer numbers.""" def __init__(self): # Patterns ordered by specificity (most specific first) self.patterns: list[CustomerNumberPattern] = [ DashFormatPattern(), NoDashFormatPattern(), # Add more patterns as needed ] self.logger = logging.getLogger(__name__) def parse(self, text: str) -> tuple[Optional[str], bool, Optional[str]]: """ Parse customer number from text. Returns: (customer_number, is_valid, error) """ text = text.strip() # Try each pattern matches: list[CustomerNumberMatch] = [] for pattern in self.patterns: match = pattern.match(text) if match: matches.append(match) # No matches if not matches: return None, False, "No customer number found" # Return highest confidence match best_match = max(matches, key=lambda m: m.confidence) return best_match.value, True, None def parse_all(self, text: str) -> list[CustomerNumberMatch]: """ Find all customer numbers in text. Useful for cases with multiple potential matches. """ matches: list[CustomerNumberMatch] = [] for pattern in self.patterns: match = pattern.match(text) if match: matches.append(match) return sorted(matches, key=lambda m: m.confidence, reverse=True) ``` 2. 重构 `field_extractor.py`: ```python # src/inference/field_extractor.py from .customer_number_parser import CustomerNumberParser class FieldExtractor: def __init__(self): self.customer_parser = CustomerNumberParser() # ... def _normalize_customer_number( self, text: str ) -> tuple[str | None, bool, str | None]: """Normalize customer number using dedicated parser.""" return self.customer_parser.parse(text) ``` 3. 添加测试: ```python # tests/unit/test_customer_number_parser.py from src.inference.customer_number_parser import ( CustomerNumberParser, DashFormatPattern, NoDashFormatPattern, ) class TestDashFormatPattern: def test_standard_format(self): pattern = DashFormatPattern() match = pattern.match("Customer: JTY 576-3") assert match is not None assert match.value == "JTY 576-3" assert match.confidence == 0.95 class TestNoDashFormatPattern: def test_no_dash_format(self): pattern = NoDashFormatPattern() match = pattern.match("Dwq 211X") assert match is not None assert match.value == "DWQ 211-X" assert match.confidence == 0.90 def test_exclude_postal_code(self): pattern = NoDashFormatPattern() match = pattern.match("SE 106 43") assert match is None # Should be filtered out class TestCustomerNumberParser: def test_parse_with_dash(self): parser = CustomerNumberParser() result, is_valid, error = parser.parse("Customer: JTY 576-3") assert is_valid assert result == "JTY 576-3" assert error is None def test_parse_without_dash(self): parser = CustomerNumberParser() result, is_valid, error = parser.parse("Dwq 211X Billo") assert is_valid assert result == "DWQ 211-X" def test_parse_all_finds_multiple(self): parser = CustomerNumberParser() text = "JTY 576-3 and DWQ 211X" matches = parser.parse_all(text) assert len(matches) >= 1 # At least one match assert matches[0].confidence >= 0.90 ``` **迁移计划**: 1. Day 1 上午: 创建新parser和测试 2. Day 1 下午: 迁移 `field_extractor.py`,运行测试 3. 回归测试确保所有文档处理正常 --- ### Step 5: 重构 `process_document` (P1, 2天) **当前问题**: `pipeline.py:100-250` (150行) 职责过多 **重构策略**: Extract Method + 责任分离 **目标结构**: ```python def process_document(self, image_path: Path, document_id: str) -> DocumentResult: """Main orchestration - keep under 30 lines.""" # 1. Run detection detections = self._run_yolo_detection(image_path) # 2. Extract fields fields = self._extract_fields_from_detections(detections, image_path) # 3. Apply cross-validation fields = self._apply_cross_validation(fields) # 4. Multi-source fusion fields = self._apply_multi_source_fusion(fields) # 5. Build result return self._build_document_result(document_id, fields, detections) ``` 详细步骤见 `docs/CODE_REVIEW_REPORT.md` Section 5.3. --- ### Step 6: 添加集成测试 (P0, 3天) **当前状况**: 缺少端到端集成测试 **目标**: 创建完整的集成测试套件 **测试场景**: 1. PDF → 推理 → 结果验证 (端到端) 2. 批处理多文档 3. API端点测试 4. 数据库集成测试 5. 错误场景测试 **实施步骤**: 1. 创建测试数据集: ``` tests/ ├── fixtures/ │ ├── sample_invoices/ │ │ ├── billo_363.pdf │ │ ├── billo_308.pdf │ │ └── billo_310.pdf │ └── expected_results/ │ ├── billo_363.json │ ├── billo_308.json │ └── billo_310.json ``` 2. 创建 `tests/integration/test_end_to_end.py`: ```python import pytest from pathlib import Path from src.inference.pipeline import InferencePipeline from src.inference.field_extractor import FieldExtractor @pytest.fixture def pipeline(): """Create inference pipeline.""" extractor = FieldExtractor() return InferencePipeline( model_path="runs/train/invoice_fields/weights/best.pt", confidence_threshold=0.5, dpi=150, field_extractor=extractor ) @pytest.fixture def sample_invoices(): """Load sample invoices and expected results.""" fixtures_dir = Path(__file__).parent.parent / "fixtures" samples = [] for pdf_path in (fixtures_dir / "sample_invoices").glob("*.pdf"): json_path = fixtures_dir / "expected_results" / f"{pdf_path.stem}.json" with open(json_path) as f: expected = json.load(f) samples.append({ "pdf_path": pdf_path, "expected": expected }) return samples class TestEndToEnd: """End-to-end integration tests.""" def test_single_document_processing(self, pipeline, sample_invoices): """Test processing a single invoice from PDF to extracted fields.""" sample = sample_invoices[0] # Process PDF result = pipeline.process_pdf( sample["pdf_path"], document_id="test_001" ) # Verify success assert result.success # Verify extracted fields match expected expected = sample["expected"] assert result.fields["amount"] == expected["amount"] assert result.fields["ocr_number"] == expected["ocr_number"] assert result.fields["customer_number"] == expected["customer_number"] def test_batch_processing(self, pipeline, sample_invoices): """Test batch processing multiple invoices.""" pdf_paths = [s["pdf_path"] for s in sample_invoices] # Process batch results = pipeline.process_batch(pdf_paths) # Verify all processed assert len(results) == len(pdf_paths) # Verify success rate success_count = sum(1 for r in results if r.success) assert success_count >= len(pdf_paths) * 0.9 # At least 90% success def test_cross_validation_overrides(self, pipeline): """Test that payment_line values override detected values.""" # Use sample with known discrepancy (Billo310) pdf_path = Path("tests/fixtures/sample_invoices/billo_310.pdf") result = pipeline.process_pdf(pdf_path, document_id="test_cross_val") # Verify payment_line was parsed assert "payment_line" in result.fields # Verify Amount was corrected from payment_line # (Billo310: detected 20736.00, payment_line has 736.00) assert result.fields["amount"] == "736.00" def test_error_handling_invalid_pdf(self, pipeline): """Test graceful error handling for invalid PDF.""" invalid_pdf = Path("tests/fixtures/invalid.pdf") result = pipeline.process_pdf(invalid_pdf, document_id="test_error") # Should return result with success=False assert not result.success assert result.errors assert len(result.errors) > 0 class TestAPIIntegration: """API endpoint integration tests.""" @pytest.fixture def client(self): """Create test client.""" from fastapi.testclient import TestClient from src.web.app import create_app from src.web.config import AppConfig config = AppConfig.from_defaults() app = create_app(config) return TestClient(app) def test_health_endpoint(self, client): """Test /api/v1/health endpoint.""" response = client.get("/api/v1/health") assert response.status_code == 200 data = response.json() assert data["status"] == "healthy" assert "model_loaded" in data def test_infer_endpoint_with_pdf(self, client, sample_invoices): """Test /api/v1/infer with PDF upload.""" sample = sample_invoices[0] with open(sample["pdf_path"], "rb") as f: response = client.post( "/api/v1/infer", files={"file": ("test.pdf", f, "application/pdf")} ) assert response.status_code == 200 data = response.json() assert data["status"] == "success" assert "result" in data assert "fields" in data["result"] def test_infer_endpoint_invalid_file(self, client): """Test /api/v1/infer rejects invalid file.""" response = client.post( "/api/v1/infer", files={"file": ("test.txt", b"invalid", "text/plain")} ) assert response.status_code == 400 assert "Unsupported file type" in response.json()["detail"] class TestDatabaseIntegration: """Database integration tests.""" @pytest.fixture def db_connection(self): """Create test database connection.""" from src.db.connection import DatabaseConnection # Use test database conn = DatabaseConnection(database="invoice_extraction_test") yield conn conn.close() def test_save_and_retrieve_result(self, db_connection, pipeline, sample_invoices): """Test saving inference result to database and retrieving it.""" sample = sample_invoices[0] # Process document result = pipeline.process_pdf(sample["pdf_path"], document_id="test_db_001") # Save to database db_connection.save_inference_result(result) # Retrieve from database retrieved = db_connection.get_inference_result("test_db_001") # Verify assert retrieved is not None assert retrieved["document_id"] == "test_db_001" assert retrieved["fields"]["amount"] == result.fields["amount"] ``` 3. 配置 pytest 运行集成测试: ```ini # pytest.ini [pytest] markers = unit: Unit tests (fast, no external dependencies) integration: Integration tests (slower, may use database/files) slow: Slow tests # Run unit tests by default addopts = -v -m "not integration" # Run all tests including integration # pytest -m "" # Run only integration tests # pytest -m integration ``` 4. CI/CD集成: ```yaml # .github/workflows/test.yml name: Tests on: [push, pull_request] jobs: unit-tests: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.11' - name: Install dependencies run: | pip install -r requirements.txt pip install pytest pytest-cov - name: Run unit tests run: pytest -m "not integration" --cov=src --cov-report=xml - name: Upload coverage uses: codecov/codecov-action@v3 with: file: ./coverage.xml integration-tests: runs-on: ubuntu-latest services: mysql: image: mysql:8.0 env: MYSQL_ROOT_PASSWORD: test_password MYSQL_DATABASE: invoice_extraction_test ports: - 3306:3306 steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.11' - name: Install dependencies run: | pip install -r requirements.txt pip install pytest - name: Run integration tests env: DB_HOST: localhost DB_PORT: 3306 DB_USER: root DB_PASSWORD: test_password DB_NAME: invoice_extraction_test run: pytest -m integration ``` **时间分配**: - Day 1: 准备测试数据、创建测试框架 - Day 2: 编写端到端和API测试 - Day 3: 数据库集成测试、CI/CD配置 --- ### Step 7: 异常处理规范化 (P1, 1天) **当前问题**: 31处 `except Exception` 捕获过于宽泛 **目标**: 创建异常层次结构,精确捕获 **实施步骤**: 1. 创建 `src/exceptions.py`: ```python """ Application-specific exceptions. """ class InvoiceExtractionError(Exception): """Base exception for invoice extraction errors.""" pass class PDFProcessingError(InvoiceExtractionError): """Error during PDF processing.""" pass class OCRError(InvoiceExtractionError): """Error during OCR.""" pass class ModelInferenceError(InvoiceExtractionError): """Error during model inference.""" pass class FieldValidationError(InvoiceExtractionError): """Error during field validation.""" pass class DatabaseError(InvoiceExtractionError): """Error during database operations.""" pass class ConfigurationError(InvoiceExtractionError): """Error in configuration.""" pass ``` 2. 替换宽泛的异常捕获: ```python # Before ❌ try: result = process_pdf(path) except Exception as e: logger.error(f"Error: {e}") return None # After ✅ try: result = process_pdf(path) except PDFProcessingError as e: logger.error(f"PDF processing failed: {e}") return None except OCRError as e: logger.warning(f"OCR failed, trying fallback: {e}") result = fallback_ocr(path) except ModelInferenceError as e: logger.error(f"Model inference failed: {e}") raise # Re-raise for upper layer ``` 3. 在各模块中抛出具体异常: ```python # src/inference/pdf_processor.py from src.exceptions import PDFProcessingError def convert_pdf_to_image(pdf_path: Path, dpi: int) -> list[np.ndarray]: try: images = pdf2image.convert_from_path(pdf_path, dpi=dpi) except Exception as e: raise PDFProcessingError(f"Failed to convert PDF: {e}") from e if not images: raise PDFProcessingError("PDF conversion returned no images") return images ``` 4. 创建异常处理装饰器: ```python # src/utils/error_handling.py import functools from typing import Callable, Type from src.exceptions import InvoiceExtractionError def handle_errors( *exception_types: Type[Exception], default_return=None, log_error: bool = True ): """Decorator for standardized error handling.""" def decorator(func: Callable): @functools.wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except exception_types as e: if log_error: logger = logging.getLogger(func.__module__) logger.error( f"Error in {func.__name__}: {e}", exc_info=True ) return default_return return wrapper return decorator # Usage @handle_errors(PDFProcessingError, OCRError, default_return=None) def safe_process_document(doc_path: Path): return process_document(doc_path) ``` --- ### Step 8-12: 其他重构任务 详细步骤参见 `CODE_REVIEW_REPORT.md` Section 6 (Action Plan)。 --- ## 🧪 测试策略 ### 测试金字塔 ``` /\ / \ E2E Tests (10%) /----\ - Full pipeline tests / \ - API integration tests /--------\ / \ Integration Tests (30%) /------------\ - Module integration / \ - Database tests ---------------- Unit Tests (60%) - Function-level tests - High coverage ``` ### 测试覆盖率目标 | 模块 | 当前覆盖率 | 目标覆盖率 | |------|-----------|-----------| | `field_extractor.py` | 40% | 80% | | `pipeline.py` | 50% | 75% | | `payment_line_parser.py` | 0% (新) | 90% | | `customer_number_parser.py` | 0% (新) | 90% | | `web/routes.py` | 30% | 70% | | `db/operations.py` | 20% | 60% | | **Overall** | **45%** | **70%+** | ### 回归测试 每次重构后必须运行: ```bash # 1. 单元测试 pytest tests/unit/ -v # 2. 集成测试 pytest tests/integration/ -v # 3. 端到端测试(使用实际PDF) pytest tests/e2e/ -v # 4. 性能测试(确保没有退化) pytest tests/performance/ -v --benchmark # 5. 测试覆盖率检查 pytest --cov=src --cov-report=html --cov-fail-under=70 ``` --- ## ⚠️ 风险管理 ### 识别的风险 | 风险 | 影响 | 概率 | 缓解措施 | |------|------|------|---------| | 重构破坏现有功能 | 高 | 中 | 1. 重构前补充测试
2. 小步迭代
3. 回归测试 | | 性能退化 | 中 | 低 | 1. 性能基准测试
2. 持续监控
3. Profile优化 | | API接口变更影响客户端 | 高 | 低 | 1. 语义化版本控制
2. 废弃通知期
3. 向后兼容 | | 数据库迁移失败 | 高 | 低 | 1. 备份数据
2. 分阶段迁移
3. 回滚计划 | | 时间超期 | 中 | 中 | 1. 优先级排序
2. 每周进度审查
3. 必要时调整范围 | ### 回滚计划 每个重构步骤都应有明确的回滚策略: 1. **代码回滚**: 使用Git分支隔离变更 ```bash # 每个重构任务创建特性分支 git checkout -b refactor/payment-line-parser # 如需回滚 git checkout main git branch -D refactor/payment-line-parser ``` 2. **数据库回滚**: 使用数据库迁移工具 ```bash # 应用迁移 alembic upgrade head # 回滚迁移 alembic downgrade -1 ``` 3. **配置回滚**: 保留旧配置兼容性 ```python # 支持新旧两种配置格式 password = config.get("db_password") or config.get("password") ``` --- ## 📊 成功指标 ### 量化指标 | 指标 | 当前值 | 目标值 | 测量方法 | |------|--------|--------|---------| | 测试覆盖率 | 45% | 70%+ | `pytest --cov` | | 平均函数长度 | 80行 | <50行 | `radon cc` | | 循环复杂度 | 最高15+ | <10 | `radon cc` | | 代码重复率 | ~15% | <5% | `pylint --duplicate` | | 安全问题 | 2个 (明文密码, SQL注入) | 0个 | 手动审查 + `bandit` | | 文档覆盖率 | 30% | 80%+ | 手动审查 | | 平均处理时间 | ~2秒/文档 | <2秒/文档 | 性能测试 | ### 质量门禁 所有变更必须满足: - ✅ 测试覆盖率 ≥ 70% - ✅ 所有测试通过 (单元 + 集成 + E2E) - ✅ 无高危安全问题 - ✅ 代码审查通过 - ✅ 性能无退化 (±5%以内) - ✅ 文档已更新 --- ## 📅 时间表 ### Phase 1: 紧急修复 (Week 1) | 日期 | 任务 | 负责人 | 状态 | |------|------|--------|------| | Day 1 | 修复明文密码 + 环境变量配置 | | ⏳ | | Day 2-3 | 修复SQL注入 + 添加参数化查询 | | ⏳ | | Day 4-5 | 异常处理规范化 | | ⏳ | ### Phase 2: 核心重构 (Week 2-4) | 周 | 任务 | 状态 | |----|------|------| | Week 2 | 统一payment_line解析 + 拆分customer_number | ⏳ | | Week 3 | 重构pipeline + Extract Method | ⏳ | | Week 4 | 添加集成测试 + 提升单元测试覆盖率 | ⏳ | ### Phase 3: 优化改进 (Week 5-6) | 周 | 任务 | 状态 | |----|------|------| | Week 5 | 批处理优化 + 配置提取 | ⏳ | | Week 6 | 文档完善 + 日志优化 + 性能调优 | ⏳ | --- ## 🔄 持续改进 ### Code Review Checklist 每次提交前检查: - [ ] 所有测试通过 - [ ] 测试覆盖率达标 - [ ] 无新增安全问题 - [ ] 代码符合风格指南 - [ ] 函数长度 < 50行 - [ ] 循环复杂度 < 10 - [ ] 文档已更新 - [ ] 变更日志已记录 ### 自动化工具 配置pre-commit hooks: ```yaml # .pre-commit-config.yaml repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files - repo: https://github.com/psf/black rev: 23.3.0 hooks: - id: black language_version: python3.11 - repo: https://github.com/PyCQA/flake8 rev: 6.0.0 hooks: - id: flake8 args: [--max-line-length=88, --extend-ignore=E203] - repo: https://github.com/PyCQA/bandit rev: 1.7.5 hooks: - id: bandit args: [-c, pyproject.toml] - repo: local hooks: - id: pytest-check name: pytest-check entry: pytest language: system pass_filenames: false always_run: true args: [-m, "not integration", --tb=short] ``` --- ## 📚 参考资料 ### 重构书籍 - *Refactoring: Improving the Design of Existing Code* - Martin Fowler - *Clean Code* - Robert C. Martin - *Working Effectively with Legacy Code* - Michael Feathers ### 设计模式 - Strategy Pattern (customer_number patterns) - Factory Pattern (parser creation) - Template Method (field normalization) ### Python最佳实践 - PEP 8: Style Guide - PEP 257: Docstring Conventions - Google Python Style Guide --- ## ✅ 验收标准 重构完成的定义: 1. ✅ 所有P0和P1任务完成 2. ✅ 测试覆盖率 ≥ 70% 3. ✅ 安全问题全部修复 4. ✅ 代码重复率 < 5% 5. ✅ 所有长函数 (>100行) 已拆分 6. ✅ API文档完整 7. ✅ 性能无退化 8. ✅ 生产环境部署成功 --- **文档结束** 下一步: 开始执行 Phase 1, Day 1 - 修复明文密码问题