Init
This commit is contained in:
28
.claude/settings.local.json
Normal file
28
.claude/settings.local.json
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
{
|
||||||
|
"permissions": {
|
||||||
|
"allow": [
|
||||||
|
"Bash(mkdir:*)",
|
||||||
|
"Bash(python:*)",
|
||||||
|
"Bash(pip install:*)",
|
||||||
|
"Bash(dir:*)",
|
||||||
|
"Bash(tesseract:*)",
|
||||||
|
"Bash(nvidia-smi:*)",
|
||||||
|
"Bash(pip uninstall:*)",
|
||||||
|
"Bash(for img in ../../images/train/*.jpg)",
|
||||||
|
"Bash(do basename=\"$img##*/\")",
|
||||||
|
"Bash(labelname=\"$basename%.jpg.txt\")",
|
||||||
|
"Bash(if [ -f \"../../temp_visual_labels/$labelname\" ])",
|
||||||
|
"Bash(then cp \"../../temp_visual_labels/$labelname\" .)",
|
||||||
|
"Bash(fi)",
|
||||||
|
"Bash(done)",
|
||||||
|
"Bash(awk:*)",
|
||||||
|
"Bash(chcp 65001)",
|
||||||
|
"Bash(PYTHONIOENCODING=utf-8 python:*)",
|
||||||
|
"Bash(timeout 600 tail:*)",
|
||||||
|
"Bash(cat:*)",
|
||||||
|
"Bash(powershell -Command \"$response = Invoke-WebRequest -Uri ''http://127.0.0.1:8000/extract_invoice/'' -Method POST -Form @{file=Get-Item ''data\\processed_images\\4BC5E5B3-E561-4A73-BC9C-46D4F08F89C3.png''} -UseBasicParsing; $response.Content\")"
|
||||||
|
],
|
||||||
|
"deny": [],
|
||||||
|
"ask": []
|
||||||
|
}
|
||||||
|
}
|
||||||
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
|
/data
|
||||||
|
/.idea
|
||||||
|
/__pycache__/
|
||||||
137
config.py
Normal file
137
config.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""
|
||||||
|
Configuration file for system-dependent paths and settings.
|
||||||
|
|
||||||
|
This file contains paths that may vary between different systems.
|
||||||
|
Copy this file and modify the paths according to your local installation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# System Paths - Modify these according to your installation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Poppler path (required for PDF to image conversion)
|
||||||
|
# Download from: https://github.com/oschwartz10612/poppler-windows/releases
|
||||||
|
# Example: r"C:\poppler-23.11.0\bin"
|
||||||
|
POPPLER_PATH = os.getenv("POPPLER_PATH", r"C:\Program Files\poppler-25.07.0\Library\bin")
|
||||||
|
|
||||||
|
# Tesseract path (optional - only needed if not in system PATH)
|
||||||
|
# Download from: https://github.com/UB-Mannheim/tesseract/wiki
|
||||||
|
# Example: r"C:\Program Files\Tesseract-OCR\tesseract.exe"
|
||||||
|
TESSERACT_CMD = os.getenv("TESSERACT_CMD", None)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Project Paths - Generally don't need to modify these
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Project root directory
|
||||||
|
PROJECT_ROOT = Path(__file__).parent.absolute()
|
||||||
|
|
||||||
|
# Data directories
|
||||||
|
DATA_DIR = PROJECT_ROOT / "data"
|
||||||
|
RAW_INVOICES_DIR = DATA_DIR / "raw_invoices"
|
||||||
|
PROCESSED_IMAGES_DIR = DATA_DIR / "processed_images"
|
||||||
|
OCR_RESULTS_DIR = DATA_DIR / "ocr_results"
|
||||||
|
|
||||||
|
# YOLO dataset directories
|
||||||
|
YOLO_DATASET_DIR = DATA_DIR / "yolo_dataset"
|
||||||
|
YOLO_TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images"
|
||||||
|
YOLO_TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels"
|
||||||
|
YOLO_TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train"
|
||||||
|
YOLO_TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train"
|
||||||
|
YOLO_VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val"
|
||||||
|
YOLO_VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val"
|
||||||
|
|
||||||
|
# Model directories
|
||||||
|
MODELS_DIR = PROJECT_ROOT / "models"
|
||||||
|
DEFAULT_MODEL_PATH = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# OCR Settings
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Tesseract language (Swedish + English)
|
||||||
|
TESSERACT_LANG = "swe" # Ensure Swedish language pack is installed
|
||||||
|
|
||||||
|
# OCR confidence threshold (0-100)
|
||||||
|
OCR_CONFIDENCE_THRESHOLD = 0
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Training Settings
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# YOLO model size: n (nano), s (small), m (medium), l (large), x (xlarge)
|
||||||
|
YOLO_MODEL_SIZE = "n"
|
||||||
|
|
||||||
|
# Training epochs
|
||||||
|
TRAINING_EPOCHS = 100
|
||||||
|
|
||||||
|
# Batch size
|
||||||
|
BATCH_SIZE = 16
|
||||||
|
|
||||||
|
# Image size for training
|
||||||
|
IMAGE_SIZE = 640
|
||||||
|
|
||||||
|
# Validation split ratio (0.0 to 1.0)
|
||||||
|
VALIDATION_SPLIT = 0.2
|
||||||
|
|
||||||
|
# Random seed for reproducibility
|
||||||
|
RANDOM_SEED = 42
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# API Settings (for main.py FastAPI server)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# API host
|
||||||
|
API_HOST = "127.0.0.1"
|
||||||
|
|
||||||
|
# API port
|
||||||
|
API_PORT = 8000
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def apply_tesseract_path():
|
||||||
|
"""Apply Tesseract path if configured."""
|
||||||
|
if TESSERACT_CMD:
|
||||||
|
import pytesseract
|
||||||
|
pytesseract.pytesseract.tesseract_cmd = TESSERACT_CMD
|
||||||
|
|
||||||
|
def validate_paths():
|
||||||
|
"""Validate that required system paths exist."""
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# Check Poppler
|
||||||
|
if not os.path.exists(POPPLER_PATH):
|
||||||
|
issues.append(f"Poppler not found at: {POPPLER_PATH}")
|
||||||
|
issues.append(" Download from: https://github.com/oschwartz10612/poppler-windows/releases")
|
||||||
|
|
||||||
|
# Check Tesseract (if specified)
|
||||||
|
if TESSERACT_CMD and not os.path.exists(TESSERACT_CMD):
|
||||||
|
issues.append(f"Tesseract not found at: {TESSERACT_CMD}")
|
||||||
|
issues.append(" Download from: https://github.com/UB-Mannheim/tesseract/wiki")
|
||||||
|
|
||||||
|
if issues:
|
||||||
|
print("Configuration Issues Found:")
|
||||||
|
for issue in issues:
|
||||||
|
print(f" {issue}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Example: Environment Variable Override
|
||||||
|
# ============================================================================
|
||||||
|
# You can set these in your environment instead of modifying this file:
|
||||||
|
#
|
||||||
|
# Windows:
|
||||||
|
# set POPPLER_PATH=C:\poppler\bin
|
||||||
|
# set TESSERACT_CMD=C:\Program Files\Tesseract-OCR\tesseract.exe
|
||||||
|
#
|
||||||
|
# Linux/Mac:
|
||||||
|
# export POPPLER_PATH=/usr/bin
|
||||||
|
# export TESSERACT_CMD=/usr/bin/tesseract
|
||||||
|
# ============================================================================
|
||||||
17
extraction_results.json
Normal file
17
extraction_results.json
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"image": "data\\processed_images\\20250917.03.1.011299_c328f5a8-06f9-4093-85b5-e3a40f24bd30_page_1.jpg",
|
||||||
|
"fields": {},
|
||||||
|
"all_detections": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image": "data\\processed_images\\64A80892-8A9E-454C-9AEB-B740E8C3ACB3_page_1.jpg",
|
||||||
|
"fields": {},
|
||||||
|
"all_detections": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image": "data\\processed_images\\9fb6129f-671d-4aa1-9bad-096e84e6ded3_page_1.jpg",
|
||||||
|
"fields": {},
|
||||||
|
"all_detections": []
|
||||||
|
}
|
||||||
|
]
|
||||||
45
requirements.txt
Normal file
45
requirements.txt
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Core dependencies
|
||||||
|
ultralytics>=8.0.0 # YOLOv8
|
||||||
|
pytesseract>=0.3.10 # Tesseract OCR Python wrapper
|
||||||
|
|
||||||
|
# Image processing
|
||||||
|
pdf2image>=1.16.0 # PDF to image conversion
|
||||||
|
Pillow>=10.0.0 # Image manipulation
|
||||||
|
opencv-python>=4.8.0 # Computer vision
|
||||||
|
|
||||||
|
# Data handling
|
||||||
|
numpy>=1.24.0
|
||||||
|
pandas>=2.0.0
|
||||||
|
scikit-learn>=1.3.0 # For DBSCAN clustering in 02_create_labels.py
|
||||||
|
|
||||||
|
# API dependencies (for main.py)
|
||||||
|
fastapi>=0.104.0 # FastAPI web framework
|
||||||
|
uvicorn>=0.24.0 # ASGI server
|
||||||
|
python-multipart>=0.0.6 # For file upload support
|
||||||
|
|
||||||
|
# System utilities
|
||||||
|
# IMPORTANT: Requires system-level installation of:
|
||||||
|
#
|
||||||
|
# 1. Tesseract OCR:
|
||||||
|
# - Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki
|
||||||
|
# After installation, add to PATH or set: pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
|
||||||
|
# - Linux: sudo apt-get install tesseract-ocr tesseract-ocr-swe tesseract-ocr-eng
|
||||||
|
# - macOS: brew install tesseract tesseract-lang
|
||||||
|
#
|
||||||
|
# 2. Poppler (for pdf2image):
|
||||||
|
# - Windows: Download from https://github.com/oschwartz10612/poppler-windows/releases
|
||||||
|
# - Linux: sudo apt-get install poppler-utils
|
||||||
|
# - macOS: brew install poppler
|
||||||
|
#
|
||||||
|
# 3. Swedish language data for Tesseract:
|
||||||
|
# After installing Tesseract, you may need to download Swedish language files (swe.traineddata)
|
||||||
|
# from https://github.com/tesseract-ocr/tessdata
|
||||||
|
|
||||||
|
# Optional: GPU support
|
||||||
|
# torch>=2.0.0 # PyTorch with CUDA support
|
||||||
|
# torchvision>=0.15.0
|
||||||
|
|
||||||
|
# Development tools (optional)
|
||||||
|
# jupyter>=1.0.0
|
||||||
|
# matplotlib>=3.7.0
|
||||||
|
# seaborn>=0.12.0
|
||||||
106
scripts/01_process_invoices.py
Normal file
106
scripts/01_process_invoices.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# --- scripts/01_process_invoices.py ---
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import pytesseract
|
||||||
|
import pandas as pd
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Add parent directory to path to import config
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from config import POPPLER_PATH, apply_tesseract_path
|
||||||
|
|
||||||
|
# Apply Tesseract path from config
|
||||||
|
apply_tesseract_path()
|
||||||
|
|
||||||
|
# 项目路径设置
|
||||||
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
RAW_DIR = os.path.join(BASE_DIR, "data", "raw_invoices")
|
||||||
|
IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images")
|
||||||
|
OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results")
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
os.makedirs(IMG_DIR, exist_ok=True)
|
||||||
|
os.makedirs(OCR_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
def process_invoices():
|
||||||
|
print(f"开始处理 {RAW_DIR} 中的发票...")
|
||||||
|
|
||||||
|
for filename in os.listdir(RAW_DIR):
|
||||||
|
filepath = os.path.join(RAW_DIR, filename)
|
||||||
|
base_name = os.path.splitext(filename)[0]
|
||||||
|
img_path = os.path.join(IMG_DIR, f"{base_name}.png")
|
||||||
|
json_path = os.path.join(OCR_DIR, f"{base_name}.json")
|
||||||
|
|
||||||
|
# 防止重复处理
|
||||||
|
if os.path.exists(img_path) and os.path.exists(json_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 加载图像 (PDF 或 图片)
|
||||||
|
if filename.lower().endswith(".pdf"):
|
||||||
|
images = convert_from_path(filepath, poppler_path=POPPLER_PATH)
|
||||||
|
if not images:
|
||||||
|
print(f"警告: 无法从 {filename} 提取图像。")
|
||||||
|
continue
|
||||||
|
img_pil = images[0] # 取第一页
|
||||||
|
img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
||||||
|
else:
|
||||||
|
img = cv2.imread(filepath)
|
||||||
|
if img is None:
|
||||||
|
print(f"警告: 无法读取图像文件 {filename}。")
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_h, img_w = img.shape[:2]
|
||||||
|
|
||||||
|
# 2. 保存统一格式的图像
|
||||||
|
cv2.imwrite(img_path, img)
|
||||||
|
|
||||||
|
# 3. 运行 Tesseract OCR
|
||||||
|
ocr_data_df = pytesseract.image_to_data(
|
||||||
|
img,
|
||||||
|
lang='swe', # 确保已安装瑞典语包
|
||||||
|
output_type=pytesseract.Output.DATAFRAME
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 清理 OCR 结果
|
||||||
|
ocr_data_df = ocr_data_df[ocr_data_df.conf > 0]
|
||||||
|
ocr_data_df.dropna(subset=['text'], inplace=True)
|
||||||
|
ocr_data_df['text'] = ocr_data_df['text'].astype(str).str.strip()
|
||||||
|
ocr_data_df = ocr_data_df[ocr_data_df['text'] != ""]
|
||||||
|
|
||||||
|
# 5. 转换为您在另一脚本中使用的 JSON 格式 (包含 text_boxes)
|
||||||
|
text_boxes = []
|
||||||
|
for i, row in ocr_data_df.iterrows():
|
||||||
|
text_boxes.append({
|
||||||
|
"text": row["text"],
|
||||||
|
"bbox": {
|
||||||
|
"x_min": row["left"],
|
||||||
|
"y_min": row["top"],
|
||||||
|
"x_max": row["left"] + row["width"],
|
||||||
|
"y_max": row["top"] + row["height"]
|
||||||
|
},
|
||||||
|
"confidence": row["conf"] / 100.0
|
||||||
|
})
|
||||||
|
|
||||||
|
output_json = {
|
||||||
|
"image_name": f"{base_name}.png",
|
||||||
|
"width": img_w,
|
||||||
|
"height": img_h,
|
||||||
|
"text_boxes": text_boxes
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(json_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(output_json, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"已处理: {filename}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理 {filename} 时出错: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
process_invoices()
|
||||||
223
scripts/02_create_labels.py
Normal file
223
scripts/02_create_labels.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# --- scripts/02_create_labels.py ---
|
||||||
|
#
|
||||||
|
# **重要**: 此脚本用于训练 "阶段一" 的区域检测器.
|
||||||
|
# 它只生成 1 个类别 (class_id = 0), 即 "payment_slip" 的 *整个* 区域.
|
||||||
|
# 它不使用您的高级 classify_text 逻辑, 而是使用启发式规则 (关键词, 线条) 来找到大区域.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import json
|
||||||
|
from sklearn.cluster import DBSCAN
|
||||||
|
|
||||||
|
# --- 锚点词典 (技术二) ---
|
||||||
|
# 这些是用来 *定位* 凭证区域的词, 不是用来提取的
|
||||||
|
KEYWORDS = [
|
||||||
|
"Bankgirot", "PlusGirot", "OCR-nummer", "Att betala", "Mottagare",
|
||||||
|
"Betalningsmottagare", "Tillhanda senast", "Förfallodag", "Belopp",
|
||||||
|
"BG-nr", "PG-nr", "Meddelande", "OCR-kod", "Inbetalningskort", "Betalningsavi"
|
||||||
|
]
|
||||||
|
KEYWORDS_REGEX = '|'.join(KEYWORDS)
|
||||||
|
|
||||||
|
REGEX_PATTERNS = {
|
||||||
|
'bg_pg': r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', # Bankgiro/PlusGiro
|
||||||
|
'long_num': r'\b\d{10,}\b', # 可能是 OCR
|
||||||
|
'machine_code': r'#[0-9\s>#]+#' # 机读码
|
||||||
|
}
|
||||||
|
FULL_REGEX = f"({KEYWORDS_REGEX})|{REGEX_PATTERNS['bg_pg']}|{REGEX_PATTERNS['long_num']}|{REGEX_PATTERNS['machine_code']}"
|
||||||
|
|
||||||
|
# --- 路径设置 ---
|
||||||
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results")
|
||||||
|
IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images")
|
||||||
|
# Use temp directories for initial label generation (before train/val split)
|
||||||
|
LABEL_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_labels")
|
||||||
|
IMAGE_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_images")
|
||||||
|
|
||||||
|
os.makedirs(LABEL_DIR, exist_ok=True)
|
||||||
|
os.makedirs(IMAGE_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
def find_text_anchors(text_boxes, img_height):
|
||||||
|
"""技术一 (位置) + 技术二 (词典)"""
|
||||||
|
anchors = []
|
||||||
|
if not text_boxes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 技术一: 只在页面下半部分 (40% 处开始) 查找
|
||||||
|
page_midpoint = img_height * 0.4
|
||||||
|
|
||||||
|
for box in text_boxes:
|
||||||
|
# 检查位置
|
||||||
|
if box["bbox"]["y_min"] > page_midpoint:
|
||||||
|
# 检查文本内容
|
||||||
|
if re.search(FULL_REGEX, box["text"], re.IGNORECASE):
|
||||||
|
bbox = box["bbox"]
|
||||||
|
anchors.append(pd.Series({
|
||||||
|
'left': bbox["x_min"],
|
||||||
|
'top': bbox["y_min"],
|
||||||
|
'width': bbox["x_max"] - bbox["x_min"],
|
||||||
|
'height': bbox["y_max"] - bbox["y_min"]
|
||||||
|
}))
|
||||||
|
return anchors
|
||||||
|
|
||||||
|
def find_visual_anchors(image, img_height, img_width):
|
||||||
|
"""技术三 (视觉锚点 - 找线)"""
|
||||||
|
anchors = []
|
||||||
|
try:
|
||||||
|
# 1. 只看下半页
|
||||||
|
crop_y_start = img_height // 2
|
||||||
|
crop = image[crop_y_start:, :]
|
||||||
|
|
||||||
|
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
|
||||||
|
thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)[1]
|
||||||
|
|
||||||
|
# 2. 查找长的水平线
|
||||||
|
min_line_length = img_width // 4
|
||||||
|
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_line_length, 1))
|
||||||
|
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
|
||||||
|
|
||||||
|
contours, _ = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
for c in contours:
|
||||||
|
x, y, w, h = cv2.boundingRect(c)
|
||||||
|
if w > min_line_length: # 确保线足够长
|
||||||
|
original_y = y + crop_y_start # 转换回原图坐标
|
||||||
|
anchors.append(pd.Series({
|
||||||
|
'left': x, 'top': original_y, 'width': w, 'height': h
|
||||||
|
}))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"查找视觉锚点时出错: {e}")
|
||||||
|
return anchors
|
||||||
|
|
||||||
|
def cluster_anchors(all_anchors, img_height):
|
||||||
|
"""技术四 (聚类)"""
|
||||||
|
if len(all_anchors) < 2:
|
||||||
|
return all_anchors # 锚点太少,无法聚类
|
||||||
|
|
||||||
|
# 1. 获取中心点
|
||||||
|
points = []
|
||||||
|
for anchor in all_anchors:
|
||||||
|
x_center = anchor.left + anchor.width / 2
|
||||||
|
y_center = anchor.top + anchor.height / 2
|
||||||
|
points.append((x_center, y_center))
|
||||||
|
|
||||||
|
points = np.array(points)
|
||||||
|
|
||||||
|
# 2. 运行 DBSCAN
|
||||||
|
eps_dist = img_height * 0.2
|
||||||
|
clustering = DBSCAN(eps=eps_dist, min_samples=2).fit(points)
|
||||||
|
labels = clustering.labels_
|
||||||
|
|
||||||
|
if len(labels) == 0 or np.all(labels == -1):
|
||||||
|
return all_anchors # 聚类失败
|
||||||
|
|
||||||
|
# 3. 找到最大的那个簇
|
||||||
|
unique_labels, counts = np.unique(labels[labels != -1], return_counts=True)
|
||||||
|
if len(counts) == 0:
|
||||||
|
return all_anchors # 只有噪声
|
||||||
|
|
||||||
|
largest_cluster_label = unique_labels[np.argmax(counts)]
|
||||||
|
|
||||||
|
# 4. 只返回属于最大簇的锚点
|
||||||
|
main_cluster_anchors = [
|
||||||
|
anchor for i, anchor in enumerate(all_anchors)
|
||||||
|
if labels[i] == largest_cluster_label
|
||||||
|
]
|
||||||
|
return main_cluster_anchors
|
||||||
|
|
||||||
|
def create_labels():
|
||||||
|
print("开始生成弱标签 (用于区域检测器)...")
|
||||||
|
processed_count = 0
|
||||||
|
skipped_count = 0
|
||||||
|
|
||||||
|
for ocr_filename in os.listdir(OCR_DIR):
|
||||||
|
if not ocr_filename.endswith(".json"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
base_name = os.path.splitext(ocr_filename)[0]
|
||||||
|
json_path = os.path.join(OCR_DIR, ocr_filename)
|
||||||
|
img_path = os.path.join(IMG_DIR, f"{base_name}.png")
|
||||||
|
|
||||||
|
if not os.path.exists(img_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 加载数据
|
||||||
|
image = cv2.imread(img_path)
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as f:
|
||||||
|
ocr_data = json.load(f)
|
||||||
|
text_boxes = ocr_data.get("text_boxes", [])
|
||||||
|
|
||||||
|
# 2. 查找所有锚点
|
||||||
|
text_anchors = find_text_anchors(text_boxes, img_h)
|
||||||
|
visual_anchors = find_visual_anchors(image, img_h, img_w)
|
||||||
|
all_anchors = text_anchors + visual_anchors
|
||||||
|
|
||||||
|
if not all_anchors:
|
||||||
|
print(f"SKIPPING: {base_name} (未找到任何锚点)")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3. 聚类锚点
|
||||||
|
final_anchors = cluster_anchors(all_anchors, img_h)
|
||||||
|
|
||||||
|
if not final_anchors:
|
||||||
|
print(f"SKIPPING: {base_name} (未找到有效聚类)")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 4. 聚合坐标
|
||||||
|
min_x = min(a.left for a in final_anchors)
|
||||||
|
min_y = min(a.top for a in final_anchors)
|
||||||
|
max_x = max(a.left + a.width for a in final_anchors)
|
||||||
|
max_y = max(a.top + a.height for a in final_anchors)
|
||||||
|
|
||||||
|
# 5. 添加边距 (Padding)
|
||||||
|
padding = 10
|
||||||
|
min_x = max(0, min_x - padding)
|
||||||
|
min_y = max(0, min_y - padding)
|
||||||
|
max_x = min(img_w, max_x + padding)
|
||||||
|
max_y = min(img_h, max_y + padding)
|
||||||
|
|
||||||
|
# 6. 转换为 YOLO 格式
|
||||||
|
box_w = max_x - min_x
|
||||||
|
box_h = max_y - min_y
|
||||||
|
x_center = (min_x + box_w / 2) / img_w
|
||||||
|
y_center = (min_y + box_h / 2) / img_h
|
||||||
|
norm_w = box_w / img_w
|
||||||
|
norm_h = box_h / img_h
|
||||||
|
|
||||||
|
# ** 关键: 类别 ID 永远是 0 **
|
||||||
|
class_id = 0 # 唯一的类别: 'payment_slip'
|
||||||
|
|
||||||
|
yolo_label = f"{class_id} {x_center} {y_center} {norm_w} {norm_h}\n"
|
||||||
|
|
||||||
|
# 7. 保存标签和图片
|
||||||
|
label_path = os.path.join(LABEL_DIR, f"{base_name}.txt")
|
||||||
|
with open(label_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(yolo_label)
|
||||||
|
|
||||||
|
shutil.copy(
|
||||||
|
img_path,
|
||||||
|
os.path.join(IMAGE_DIR, f"{base_name}.png")
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_count += 1
|
||||||
|
if processed_count % 20 == 0:
|
||||||
|
print(f"已生成 {processed_count} 个标签...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理 {base_name} 时出错: {e}")
|
||||||
|
skipped_count += 1
|
||||||
|
|
||||||
|
print("--- 弱标签生成完成 ---")
|
||||||
|
print(f"成功生成: {processed_count} 个")
|
||||||
|
print(f"跳过: {skipped_count} 个")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
create_labels()
|
||||||
123
scripts/03_split_dataset.py
Normal file
123
scripts/03_split_dataset.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
Dataset Split Script - Step 3
|
||||||
|
Splits images and labels into training and validation sets
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
YOLO_DATASET_DIR = Path(BASE_DIR + "/data/yolo_dataset")
|
||||||
|
TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images"
|
||||||
|
TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels"
|
||||||
|
|
||||||
|
TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train"
|
||||||
|
VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val"
|
||||||
|
TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train"
|
||||||
|
VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val"
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
VALIDATION_SPLIT = 0.2 # 20% for validation
|
||||||
|
RANDOM_SEED = 42
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(val_split=VALIDATION_SPLIT, seed=RANDOM_SEED):
|
||||||
|
"""
|
||||||
|
Split dataset into training and validation sets
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val_split: Fraction of data to use for validation (0.0 to 1.0)
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
"""
|
||||||
|
print("="*60)
|
||||||
|
print("Splitting Dataset into Train/Val Sets")
|
||||||
|
print("="*60)
|
||||||
|
print(f"Validation split: {val_split*100:.1f}%")
|
||||||
|
print(f"Random seed: {seed}\n")
|
||||||
|
|
||||||
|
# Check if temp directories exist
|
||||||
|
if not TEMP_IMAGES_DIR.exists() or not TEMP_LABELS_DIR.exists():
|
||||||
|
print(BASE_DIR)
|
||||||
|
print(YOLO_DATASET_DIR)
|
||||||
|
print(TEMP_IMAGES_DIR)
|
||||||
|
print(f"Error: Temporary directories not found")
|
||||||
|
print(f"Please run 02_create_labels.py first")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all image files
|
||||||
|
image_files = list(TEMP_IMAGES_DIR.glob("*.jpg")) + list(TEMP_IMAGES_DIR.glob("*.png"))
|
||||||
|
|
||||||
|
if not image_files:
|
||||||
|
print(f"No image files found in {TEMP_IMAGES_DIR}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Filter images that have corresponding labels
|
||||||
|
valid_pairs = []
|
||||||
|
for image_file in image_files:
|
||||||
|
label_file = TEMP_LABELS_DIR / (image_file.stem + ".txt")
|
||||||
|
if label_file.exists():
|
||||||
|
valid_pairs.append({
|
||||||
|
"image": image_file,
|
||||||
|
"label": label_file
|
||||||
|
})
|
||||||
|
|
||||||
|
if not valid_pairs:
|
||||||
|
print("No valid image-label pairs found")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Found {len(valid_pairs)} image-label pair(s)")
|
||||||
|
|
||||||
|
# Shuffle and split
|
||||||
|
random.seed(seed)
|
||||||
|
random.shuffle(valid_pairs)
|
||||||
|
|
||||||
|
split_index = int(len(valid_pairs) * (1 - val_split))
|
||||||
|
train_pairs = valid_pairs[:split_index]
|
||||||
|
val_pairs = valid_pairs[split_index:]
|
||||||
|
|
||||||
|
print(f"\nSplit results:")
|
||||||
|
print(f" Training set: {len(train_pairs)} samples")
|
||||||
|
print(f" Validation set: {len(val_pairs)} samples")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Clear existing train/val directories
|
||||||
|
for directory in [TRAIN_IMAGES_DIR, VAL_IMAGES_DIR, TRAIN_LABELS_DIR, VAL_LABELS_DIR]:
|
||||||
|
if directory.exists():
|
||||||
|
shutil.rmtree(directory)
|
||||||
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Copy training files
|
||||||
|
print("Copying training files...")
|
||||||
|
for pair in train_pairs:
|
||||||
|
shutil.copy(pair["image"], TRAIN_IMAGES_DIR / pair["image"].name)
|
||||||
|
shutil.copy(pair["label"], TRAIN_LABELS_DIR / pair["label"].name)
|
||||||
|
print(f" Copied {len(train_pairs)} image-label pairs to train/")
|
||||||
|
|
||||||
|
# Copy validation files
|
||||||
|
print("Copying validation files...")
|
||||||
|
for pair in val_pairs:
|
||||||
|
shutil.copy(pair["image"], VAL_IMAGES_DIR / pair["image"].name)
|
||||||
|
shutil.copy(pair["label"], VAL_LABELS_DIR / pair["label"].name)
|
||||||
|
print(f" Copied {len(val_pairs)} image-label pairs to val/")
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Dataset split complete!")
|
||||||
|
print(f"\nDataset structure:")
|
||||||
|
print(f" {TRAIN_IMAGES_DIR}")
|
||||||
|
print(f" {TRAIN_LABELS_DIR}")
|
||||||
|
print(f" {VAL_IMAGES_DIR}")
|
||||||
|
print(f" {VAL_LABELS_DIR}")
|
||||||
|
print(f"\nNext step: Run 04_train_yolo.py to train the model")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function"""
|
||||||
|
split_dataset()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
230
scripts/04_train_yolo.py
Normal file
230
scripts/04_train_yolo.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""
|
||||||
|
YOLO Training Script - Step 4
|
||||||
|
Trains YOLOv8 model on the prepared invoice dataset
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
DATASET_YAML = Path(BASE_DIR + "/data/yolo_dataset/dataset.yaml")
|
||||||
|
MODELS_DIR = Path(BASE_DIR + "/models")
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
MODEL_SIZE = "n" # Options: n (nano), s (small), m (medium), l (large), x (xlarge)
|
||||||
|
EPOCHS = 100
|
||||||
|
BATCH_SIZE = 16
|
||||||
|
IMAGE_SIZE = 640
|
||||||
|
DEVICE = 0 if torch.cuda.is_available() else "cpu" # Use GPU with PyTorch 2.7 + CUDA 12.8
|
||||||
|
|
||||||
|
# Create models directory
|
||||||
|
MODELS_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
model_size=MODEL_SIZE,
|
||||||
|
epochs=EPOCHS,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
img_size=IMAGE_SIZE,
|
||||||
|
device=DEVICE
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Train YOLOv8 model on invoice dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_size: Size of YOLO model (n, s, m, l, x)
|
||||||
|
epochs: Number of training epochs
|
||||||
|
batch_size: Batch size for training
|
||||||
|
img_size: Input image size
|
||||||
|
device: Device to use for training (cuda or cpu)
|
||||||
|
"""
|
||||||
|
print("="*60)
|
||||||
|
print("YOLOv8 Invoice Detection Training")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Check if dataset.yaml exists
|
||||||
|
if not DATASET_YAML.exists():
|
||||||
|
print(f"Error: {DATASET_YAML} not found")
|
||||||
|
print("Please ensure the dataset.yaml file exists")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Print configuration
|
||||||
|
print(f"\nConfiguration:")
|
||||||
|
print(f" Model: YOLOv8{model_size}")
|
||||||
|
print(f" Epochs: {epochs}")
|
||||||
|
print(f" Batch size: {batch_size}")
|
||||||
|
print(f" Image size: {img_size}")
|
||||||
|
print(f" Device: {device}")
|
||||||
|
print(f" Dataset config: {DATASET_YAML}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
print(f"Loading YOLOv8{model_size} model...")
|
||||||
|
model = YOLO(f"yolov8{model_size}.pt") # Load pretrained model
|
||||||
|
|
||||||
|
# Print device info
|
||||||
|
if device == 0:
|
||||||
|
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
||||||
|
else:
|
||||||
|
print("Using CPU (training will be slower)")
|
||||||
|
|
||||||
|
print("\nStarting training...")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
results = model.train(
|
||||||
|
data=str(DATASET_YAML),
|
||||||
|
epochs=epochs,
|
||||||
|
imgsz=img_size,
|
||||||
|
batch=batch_size,
|
||||||
|
device=device,
|
||||||
|
project=str(MODELS_DIR),
|
||||||
|
name="payment_slip_detector_v1",
|
||||||
|
exist_ok=True,
|
||||||
|
patience=20, # Early stopping patience
|
||||||
|
save=True,
|
||||||
|
save_period=10, # Save checkpoint every 10 epochs
|
||||||
|
verbose=True,
|
||||||
|
plots=True # Generate training plots
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Training complete!")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
best_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||||
|
last_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "last.pt"
|
||||||
|
|
||||||
|
print(f"\nTrained models saved to:")
|
||||||
|
print(f" Best model: {best_model_path}")
|
||||||
|
print(f" Last model: {last_model_path}")
|
||||||
|
|
||||||
|
print(f"\nTraining plots saved to:")
|
||||||
|
print(f" {MODELS_DIR / 'payment_slip_detector_v1'}")
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model(model_path=None):
|
||||||
|
"""
|
||||||
|
Validate trained model on validation set
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model weights (default: best.pt from last training)
|
||||||
|
"""
|
||||||
|
if model_path is None:
|
||||||
|
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||||
|
|
||||||
|
if not Path(model_path).exists():
|
||||||
|
print(f"Error: Model not found at {model_path}")
|
||||||
|
print("Please train a model first")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("="*60)
|
||||||
|
print("Validating Model")
|
||||||
|
print("="*60)
|
||||||
|
print(f"Model: {model_path}\n")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = YOLO(str(model_path))
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
results = model.val(data=str(DATASET_YAML))
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Validation complete!")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
|
||||||
|
def predict_sample(model_path=None, image_path=None, conf_threshold=0.25):
|
||||||
|
"""
|
||||||
|
Run prediction on a sample image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model weights
|
||||||
|
image_path: Path to image to predict on
|
||||||
|
conf_threshold: Confidence threshold for detections
|
||||||
|
"""
|
||||||
|
if model_path is None:
|
||||||
|
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||||
|
|
||||||
|
if not Path(model_path).exists():
|
||||||
|
print(f"Error: Model not found at {model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if image_path is None:
|
||||||
|
# Try to get a sample from validation set
|
||||||
|
val_images_dir = Path("data/yolo_dataset/images/val")
|
||||||
|
sample_images = list(val_images_dir.glob("*.jpg")) + list(val_images_dir.glob("*.png"))
|
||||||
|
if sample_images:
|
||||||
|
image_path = sample_images[0]
|
||||||
|
else:
|
||||||
|
print("No sample images found")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("="*60)
|
||||||
|
print("Running Prediction")
|
||||||
|
print("="*60)
|
||||||
|
print(f"Model: {model_path}")
|
||||||
|
print(f"Image: {image_path}")
|
||||||
|
print(f"Confidence threshold: {conf_threshold}\n")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = YOLO(str(model_path))
|
||||||
|
|
||||||
|
# Predict
|
||||||
|
results = model.predict(
|
||||||
|
source=str(image_path),
|
||||||
|
conf=conf_threshold,
|
||||||
|
save=True,
|
||||||
|
project=str(MODELS_DIR / "predictions"),
|
||||||
|
name="sample"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nPrediction saved to: {MODELS_DIR / 'predictions' / 'sample'}")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main training function"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Train YOLOv8 on invoice dataset")
|
||||||
|
parser.add_argument("--mode", type=str, default="train", choices=["train", "validate", "predict"],
|
||||||
|
help="Mode: train, validate, or predict")
|
||||||
|
parser.add_argument("--model-size", type=str, default=MODEL_SIZE,
|
||||||
|
choices=["n", "s", "m", "l", "x"],
|
||||||
|
help="YOLO model size")
|
||||||
|
parser.add_argument("--epochs", type=int, default=EPOCHS,
|
||||||
|
help="Number of training epochs")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
|
||||||
|
help="Batch size")
|
||||||
|
parser.add_argument("--img-size", type=int, default=IMAGE_SIZE,
|
||||||
|
help="Image size")
|
||||||
|
parser.add_argument("--model-path", type=str, default=None,
|
||||||
|
help="Path to model weights (for validate/predict)")
|
||||||
|
parser.add_argument("--image-path", type=str, default=None,
|
||||||
|
help="Path to image (for predict)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == "train":
|
||||||
|
train_model(
|
||||||
|
model_size=args.model_size,
|
||||||
|
epochs=args.epochs,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
img_size=args.img_size
|
||||||
|
)
|
||||||
|
elif args.mode == "validate":
|
||||||
|
validate_model(model_path=args.model_path)
|
||||||
|
elif args.mode == "predict":
|
||||||
|
predict_sample(model_path=args.model_path, image_path=args.image_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
237
scripts/main.py
Normal file
237
scripts/main.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
# --- main.py (已升级) ---
|
||||||
|
|
||||||
|
from fastapi import FastAPI, UploadFile, File, HTTPException
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pytesseract
|
||||||
|
import re
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
# --- 配置 ---
|
||||||
|
# TODO: 确保此路径指向您训练好的最佳模型
|
||||||
|
MODEL_PATH = os.path.join(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)),
|
||||||
|
"models", "invoice_detector_v1", "weights", "best.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 定义一个字典来在 FastAPI 启动时加载模型
|
||||||
|
ml_models = {}
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# 启动时加载模型
|
||||||
|
print(MODEL_PATH)
|
||||||
|
if not os.path.exists(MODEL_PATH):
|
||||||
|
print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。")
|
||||||
|
ml_models["yolo"] = None
|
||||||
|
else:
|
||||||
|
# 加载您的 "payment_slip" 区域检测器
|
||||||
|
ml_models["yolo"] = YOLO(MODEL_PATH)
|
||||||
|
print("YOLOv8 区域检测模型加载成功。")
|
||||||
|
yield
|
||||||
|
# 清理模型
|
||||||
|
ml_models.clear()
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
# --- Luhn (Modulus 10) 校验函数 ---
|
||||||
|
def luhn_validate(number_str: str, expected_check_digit: str) -> bool:
|
||||||
|
"""
|
||||||
|
使用 Modulus 10 (Luhn) 算法验证一个数字字符串。
|
||||||
|
(从右到左, 权重 1, 2, 1, 2...)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
digits = [int(d) for d in number_str]
|
||||||
|
weights = [1, 2] * (len(digits) // 2 + 1)
|
||||||
|
weights = weights[:len(digits)] # 确保权重列表长度一致
|
||||||
|
|
||||||
|
sum_val = 0
|
||||||
|
# 从右到左计算
|
||||||
|
for d, w in zip(reversed(digits), weights):
|
||||||
|
product = d * w
|
||||||
|
if product >= 10:
|
||||||
|
sum_val += (product // 10) + (product % 10)
|
||||||
|
else:
|
||||||
|
sum_val += product
|
||||||
|
|
||||||
|
calculated_check_digit = (10 - (sum_val % 10)) % 10
|
||||||
|
return str(calculated_check_digit) == expected_check_digit
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# --- 提取逻辑 ---
|
||||||
|
|
||||||
|
def parse_ocr_rad(ocr_text: str) -> dict:
|
||||||
|
"""
|
||||||
|
Plan A: 尝试解析机器可读码 (OCR-rad)
|
||||||
|
示例: # 400299582421 # 4603 00 7 > 48180020 #14#
|
||||||
|
"""
|
||||||
|
# 移除所有空格以简化匹配
|
||||||
|
text_no_space = re.sub(r'\s+', '', ocr_text)
|
||||||
|
|
||||||
|
# 定义一个更健壮的正则表达式
|
||||||
|
# 组1: OCR号 (在 #...# 之间)
|
||||||
|
# 组2: 金额 (Kronor + Öre) (在 #... 之后)
|
||||||
|
# 组3: 校验码 (1位数字)
|
||||||
|
# 组4: 账户号 (在 >...# 之间)
|
||||||
|
rad_regex = r'#([\d>]+)#(\d{2,})(\d)(\d{1})>([\d]+)#'
|
||||||
|
|
||||||
|
match = re.search(rad_regex, text_no_space)
|
||||||
|
|
||||||
|
if not match:
|
||||||
|
return None # 未找到机读码行
|
||||||
|
|
||||||
|
try:
|
||||||
|
ocr_num = match.group(1).replace(">", "") # 移除 > (如果有)
|
||||||
|
amount_base = match.group(2) + match.group(3) # "4603" + "00" = "460300"
|
||||||
|
check_digit = match.group(4) # "7"
|
||||||
|
account = match.group(5) # "48180020"
|
||||||
|
|
||||||
|
# 运行Luhn校验
|
||||||
|
if luhn_validate(amount_base, check_digit):
|
||||||
|
# 校验成功! 这是高置信度数据
|
||||||
|
amount_kronor = amount_base[:-2]
|
||||||
|
amount_ore = amount_base[-2:]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"source": "OCR-rad (High Confidence)",
|
||||||
|
"ocr_number": ocr_num,
|
||||||
|
"amount_due": f"{amount_kronor}.{amount_ore}",
|
||||||
|
"bankgiro_plusgiro": account,
|
||||||
|
"due_date": None # 机读码行通常不包含日期
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 校验失败!
|
||||||
|
print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析 OCR-rad 时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse_human_readable(ocr_text: str) -> dict:
|
||||||
|
"""
|
||||||
|
Plan B: 回退到人工可读区域
|
||||||
|
(这里我们使用之前版本中的简单 Regex,
|
||||||
|
您也可以替换为您那个更复杂的 classify_text 逻辑)
|
||||||
|
"""
|
||||||
|
data = {"source": "Human-Readable (Fallback)"}
|
||||||
|
|
||||||
|
# 查找 BG/PG (Bankgiro/Plusgiro)
|
||||||
|
bg_match = re.search(r'Bankgiro\D*(\d{2,4}[- ]\d{4})', ocr_text, re.IGNORECASE)
|
||||||
|
pg_match = re.search(r'PlusGiro\D*(\d{2,7}[- ]\d)', ocr_text, re.IGNORECASE)
|
||||||
|
if bg_match:
|
||||||
|
data["bankgiro_plusgiro"] = bg_match.group(1).replace(" ", "")
|
||||||
|
elif pg_match:
|
||||||
|
data["bankgiro_plusgiro"] = pg_match.group(1).replace(" ", "")
|
||||||
|
else:
|
||||||
|
# 备用查找
|
||||||
|
bg_pg_alt = re.search(r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', ocr_text)
|
||||||
|
if bg_pg_alt:
|
||||||
|
data["bankgiro_plusgiro"] = bg_pg_alt.group(0).replace(" ", "")
|
||||||
|
|
||||||
|
# 查找 OCR
|
||||||
|
ocr_match = re.search(r'(OCR|Fakturanummer|Referens)\D*(\d[\d\s]{5,}\d)', ocr_text, re.IGNORECASE)
|
||||||
|
if ocr_match:
|
||||||
|
data["ocr_number"] = re.sub(r'\s', '', ocr_match.group(2))
|
||||||
|
|
||||||
|
# 查找金额
|
||||||
|
amount_match = re.search(r'(Att betala|Belopp)\D*([\d\s,.]+)\s*(kr|SEK)?', ocr_text, re.IGNORECASE)
|
||||||
|
if amount_match:
|
||||||
|
amount_str = amount_match.group(2).strip().replace(" ", "").replace(",", ".")
|
||||||
|
if amount_str.count('.') > 1:
|
||||||
|
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
|
||||||
|
data["amount_due"] = amount_str
|
||||||
|
|
||||||
|
# 查找截止日期
|
||||||
|
date_match = re.search(r'(senast|Förfallodag)\D*(\d{4}[- ]\d{2}[- ]\d{2})', ocr_text, re.IGNORECASE)
|
||||||
|
if date_match:
|
||||||
|
data["due_date"] = date_match.group(2).replace(" ", "-")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def extract_info_from_crop(crop_image: np.ndarray) -> dict:
|
||||||
|
"""
|
||||||
|
主提取函数: 执行 Plan A 和 Plan B
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. 预处理并运行 OCR (获取所有文本)
|
||||||
|
gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY)
|
||||||
|
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
||||||
|
|
||||||
|
# 使用 psm 6 假设是一个统一的文本块, 这对 OCR-rad 很友好
|
||||||
|
ocr_text = pytesseract.image_to_string(thresh, lang='swe', config='--psm 6')
|
||||||
|
|
||||||
|
# 2. --- Plan A: 尝试解析机读码 ---
|
||||||
|
rad_data = parse_ocr_rad(ocr_text)
|
||||||
|
|
||||||
|
if rad_data:
|
||||||
|
# Plan A 成功!
|
||||||
|
return rad_data
|
||||||
|
|
||||||
|
# 3. --- Plan B: Plan A 失败, 回退到人工读取 ---
|
||||||
|
# 我们重新运行 OCR, 使用 psm 3 (自动布局), 这对人工区域更友好
|
||||||
|
ocr_text_human = pytesseract.image_to_string(thresh, lang='swe', config='--psm 3')
|
||||||
|
|
||||||
|
human_data = parse_human_readable(ocr_text_human)
|
||||||
|
|
||||||
|
# 即使回退失败, 也返回空字典 (或部分数据)
|
||||||
|
return human_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"提取时出错: {e}"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/extract_invoice/")
|
||||||
|
async def extract_invoice_data(file: UploadFile = File(...)):
|
||||||
|
|
||||||
|
if ml_models.get("yolo") is None:
|
||||||
|
raise HTTPException(status_code=503, detail="模型未加载。请检查模型路径。")
|
||||||
|
|
||||||
|
try:
|
||||||
|
contents = await file.read()
|
||||||
|
nparr = np.frombuffer(contents, np.uint8)
|
||||||
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
raise HTTPException(status_code=400, detail="无法解码图像文件。")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}")
|
||||||
|
|
||||||
|
# 1. 运行 YOLO 检测 (阶段一)
|
||||||
|
results = ml_models["yolo"](img, verbose=False)
|
||||||
|
|
||||||
|
all_extractions = []
|
||||||
|
|
||||||
|
if not results or not results[0].boxes:
|
||||||
|
return {"message": "未在图像中检测到支付凭证区域。"}
|
||||||
|
|
||||||
|
for res in results:
|
||||||
|
# 遍历所有检测到的凭证 (通常只有一个)
|
||||||
|
for box_coords in res.boxes.xyxy.cpu().numpy().astype(int):
|
||||||
|
xmin, ymin, xmax, ymax = box_coords
|
||||||
|
|
||||||
|
# 2. 裁剪图像
|
||||||
|
crop = img[ymin:ymax, xmin:xmax]
|
||||||
|
|
||||||
|
# 3. 在裁剪图上运行 "Plan A/B" 提取 (阶段二)
|
||||||
|
extracted_data = extract_info_from_crop(crop)
|
||||||
|
extracted_data["bounding_box"] = [xmin, ymin, xmax, ymax]
|
||||||
|
all_extractions.append(extracted_data)
|
||||||
|
|
||||||
|
if not all_extractions:
|
||||||
|
return {"message": "检测到支付凭证, 但未能提取任何信息。"}
|
||||||
|
|
||||||
|
return {"invoice_extractions": all_extractions}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
print(f"--- 启动 FastAPI 服务 ---")
|
||||||
|
print(f"加载模型: {MODEL_PATH}")
|
||||||
|
print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档")
|
||||||
|
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||||
14
test_api.py
Normal file
14
test_api.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Test the API with the problematic invoice
|
||||||
|
url = "http://127.0.0.1:8000/extract_invoice/"
|
||||||
|
file_path = r"data\processed_images\4BC5E5B3-E561-4A73-BC9C-46D4F08F89C3.png"
|
||||||
|
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
files = {'file': f}
|
||||||
|
response = requests.post(url, files=files)
|
||||||
|
|
||||||
|
print("Status Code:", response.status_code)
|
||||||
|
print("\nResponse JSON:")
|
||||||
|
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||||||
Reference in New Issue
Block a user