This commit is contained in:
Yaojia Wang
2025-10-26 20:41:11 +01:00
commit dafa86c588
11 changed files with 1171 additions and 0 deletions

View File

@@ -0,0 +1,28 @@
{
"permissions": {
"allow": [
"Bash(mkdir:*)",
"Bash(python:*)",
"Bash(pip install:*)",
"Bash(dir:*)",
"Bash(tesseract:*)",
"Bash(nvidia-smi:*)",
"Bash(pip uninstall:*)",
"Bash(for img in ../../images/train/*.jpg)",
"Bash(do basename=\"$img##*/\")",
"Bash(labelname=\"$basename%.jpg.txt\")",
"Bash(if [ -f \"../../temp_visual_labels/$labelname\" ])",
"Bash(then cp \"../../temp_visual_labels/$labelname\" .)",
"Bash(fi)",
"Bash(done)",
"Bash(awk:*)",
"Bash(chcp 65001)",
"Bash(PYTHONIOENCODING=utf-8 python:*)",
"Bash(timeout 600 tail:*)",
"Bash(cat:*)",
"Bash(powershell -Command \"$response = Invoke-WebRequest -Uri ''http://127.0.0.1:8000/extract_invoice/'' -Method POST -Form @{file=Get-Item ''data\\processed_images\\4BC5E5B3-E561-4A73-BC9C-46D4F08F89C3.png''} -UseBasicParsing; $response.Content\")"
],
"deny": [],
"ask": []
}
}

11
.gitignore vendored Normal file
View File

@@ -0,0 +1,11 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
/data
/.idea
/__pycache__/

137
config.py Normal file
View File

@@ -0,0 +1,137 @@
"""
Configuration file for system-dependent paths and settings.
This file contains paths that may vary between different systems.
Copy this file and modify the paths according to your local installation.
"""
import os
from pathlib import Path
# ============================================================================
# System Paths - Modify these according to your installation
# ============================================================================
# Poppler path (required for PDF to image conversion)
# Download from: https://github.com/oschwartz10612/poppler-windows/releases
# Example: r"C:\poppler-23.11.0\bin"
POPPLER_PATH = os.getenv("POPPLER_PATH", r"C:\Program Files\poppler-25.07.0\Library\bin")
# Tesseract path (optional - only needed if not in system PATH)
# Download from: https://github.com/UB-Mannheim/tesseract/wiki
# Example: r"C:\Program Files\Tesseract-OCR\tesseract.exe"
TESSERACT_CMD = os.getenv("TESSERACT_CMD", None)
# ============================================================================
# Project Paths - Generally don't need to modify these
# ============================================================================
# Project root directory
PROJECT_ROOT = Path(__file__).parent.absolute()
# Data directories
DATA_DIR = PROJECT_ROOT / "data"
RAW_INVOICES_DIR = DATA_DIR / "raw_invoices"
PROCESSED_IMAGES_DIR = DATA_DIR / "processed_images"
OCR_RESULTS_DIR = DATA_DIR / "ocr_results"
# YOLO dataset directories
YOLO_DATASET_DIR = DATA_DIR / "yolo_dataset"
YOLO_TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images"
YOLO_TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels"
YOLO_TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train"
YOLO_TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train"
YOLO_VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val"
YOLO_VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val"
# Model directories
MODELS_DIR = PROJECT_ROOT / "models"
DEFAULT_MODEL_PATH = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
# ============================================================================
# OCR Settings
# ============================================================================
# Tesseract language (Swedish + English)
TESSERACT_LANG = "swe" # Ensure Swedish language pack is installed
# OCR confidence threshold (0-100)
OCR_CONFIDENCE_THRESHOLD = 0
# ============================================================================
# Training Settings
# ============================================================================
# YOLO model size: n (nano), s (small), m (medium), l (large), x (xlarge)
YOLO_MODEL_SIZE = "n"
# Training epochs
TRAINING_EPOCHS = 100
# Batch size
BATCH_SIZE = 16
# Image size for training
IMAGE_SIZE = 640
# Validation split ratio (0.0 to 1.0)
VALIDATION_SPLIT = 0.2
# Random seed for reproducibility
RANDOM_SEED = 42
# ============================================================================
# API Settings (for main.py FastAPI server)
# ============================================================================
# API host
API_HOST = "127.0.0.1"
# API port
API_PORT = 8000
# ============================================================================
# Helper Functions
# ============================================================================
def apply_tesseract_path():
"""Apply Tesseract path if configured."""
if TESSERACT_CMD:
import pytesseract
pytesseract.pytesseract.tesseract_cmd = TESSERACT_CMD
def validate_paths():
"""Validate that required system paths exist."""
issues = []
# Check Poppler
if not os.path.exists(POPPLER_PATH):
issues.append(f"Poppler not found at: {POPPLER_PATH}")
issues.append(" Download from: https://github.com/oschwartz10612/poppler-windows/releases")
# Check Tesseract (if specified)
if TESSERACT_CMD and not os.path.exists(TESSERACT_CMD):
issues.append(f"Tesseract not found at: {TESSERACT_CMD}")
issues.append(" Download from: https://github.com/UB-Mannheim/tesseract/wiki")
if issues:
print("Configuration Issues Found:")
for issue in issues:
print(f" {issue}")
return False
return True
# ============================================================================
# Example: Environment Variable Override
# ============================================================================
# You can set these in your environment instead of modifying this file:
#
# Windows:
# set POPPLER_PATH=C:\poppler\bin
# set TESSERACT_CMD=C:\Program Files\Tesseract-OCR\tesseract.exe
#
# Linux/Mac:
# export POPPLER_PATH=/usr/bin
# export TESSERACT_CMD=/usr/bin/tesseract
# ============================================================================

17
extraction_results.json Normal file
View File

@@ -0,0 +1,17 @@
[
{
"image": "data\\processed_images\\20250917.03.1.011299_c328f5a8-06f9-4093-85b5-e3a40f24bd30_page_1.jpg",
"fields": {},
"all_detections": []
},
{
"image": "data\\processed_images\\64A80892-8A9E-454C-9AEB-B740E8C3ACB3_page_1.jpg",
"fields": {},
"all_detections": []
},
{
"image": "data\\processed_images\\9fb6129f-671d-4aa1-9bad-096e84e6ded3_page_1.jpg",
"fields": {},
"all_detections": []
}
]

45
requirements.txt Normal file
View File

@@ -0,0 +1,45 @@
# Core dependencies
ultralytics>=8.0.0 # YOLOv8
pytesseract>=0.3.10 # Tesseract OCR Python wrapper
# Image processing
pdf2image>=1.16.0 # PDF to image conversion
Pillow>=10.0.0 # Image manipulation
opencv-python>=4.8.0 # Computer vision
# Data handling
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0 # For DBSCAN clustering in 02_create_labels.py
# API dependencies (for main.py)
fastapi>=0.104.0 # FastAPI web framework
uvicorn>=0.24.0 # ASGI server
python-multipart>=0.0.6 # For file upload support
# System utilities
# IMPORTANT: Requires system-level installation of:
#
# 1. Tesseract OCR:
# - Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki
# After installation, add to PATH or set: pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
# - Linux: sudo apt-get install tesseract-ocr tesseract-ocr-swe tesseract-ocr-eng
# - macOS: brew install tesseract tesseract-lang
#
# 2. Poppler (for pdf2image):
# - Windows: Download from https://github.com/oschwartz10612/poppler-windows/releases
# - Linux: sudo apt-get install poppler-utils
# - macOS: brew install poppler
#
# 3. Swedish language data for Tesseract:
# After installing Tesseract, you may need to download Swedish language files (swe.traineddata)
# from https://github.com/tesseract-ocr/tessdata
# Optional: GPU support
# torch>=2.0.0 # PyTorch with CUDA support
# torchvision>=0.15.0
# Development tools (optional)
# jupyter>=1.0.0
# matplotlib>=3.7.0
# seaborn>=0.12.0

View File

@@ -0,0 +1,106 @@
# --- scripts/01_process_invoices.py ---
import os
import sys
import json
import pytesseract
import pandas as pd
from pdf2image import convert_from_path
import cv2
import numpy as np
import shutil
# Add parent directory to path to import config
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import POPPLER_PATH, apply_tesseract_path
# Apply Tesseract path from config
apply_tesseract_path()
# 项目路径设置
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
RAW_DIR = os.path.join(BASE_DIR, "data", "raw_invoices")
IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images")
OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results")
# 创建输出目录
os.makedirs(IMG_DIR, exist_ok=True)
os.makedirs(OCR_DIR, exist_ok=True)
def process_invoices():
print(f"开始处理 {RAW_DIR} 中的发票...")
for filename in os.listdir(RAW_DIR):
filepath = os.path.join(RAW_DIR, filename)
base_name = os.path.splitext(filename)[0]
img_path = os.path.join(IMG_DIR, f"{base_name}.png")
json_path = os.path.join(OCR_DIR, f"{base_name}.json")
# 防止重复处理
if os.path.exists(img_path) and os.path.exists(json_path):
continue
try:
# 1. 加载图像 (PDF 或 图片)
if filename.lower().endswith(".pdf"):
images = convert_from_path(filepath, poppler_path=POPPLER_PATH)
if not images:
print(f"警告: 无法从 {filename} 提取图像。")
continue
img_pil = images[0] # 取第一页
img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
else:
img = cv2.imread(filepath)
if img is None:
print(f"警告: 无法读取图像文件 {filename}")
continue
img_h, img_w = img.shape[:2]
# 2. 保存统一格式的图像
cv2.imwrite(img_path, img)
# 3. 运行 Tesseract OCR
ocr_data_df = pytesseract.image_to_data(
img,
lang='swe', # 确保已安装瑞典语包
output_type=pytesseract.Output.DATAFRAME
)
# 4. 清理 OCR 结果
ocr_data_df = ocr_data_df[ocr_data_df.conf > 0]
ocr_data_df.dropna(subset=['text'], inplace=True)
ocr_data_df['text'] = ocr_data_df['text'].astype(str).str.strip()
ocr_data_df = ocr_data_df[ocr_data_df['text'] != ""]
# 5. 转换为您在另一脚本中使用的 JSON 格式 (包含 text_boxes)
text_boxes = []
for i, row in ocr_data_df.iterrows():
text_boxes.append({
"text": row["text"],
"bbox": {
"x_min": row["left"],
"y_min": row["top"],
"x_max": row["left"] + row["width"],
"y_max": row["top"] + row["height"]
},
"confidence": row["conf"] / 100.0
})
output_json = {
"image_name": f"{base_name}.png",
"width": img_w,
"height": img_h,
"text_boxes": text_boxes
}
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(output_json, f, indent=4, ensure_ascii=False)
print(f"已处理: {filename}")
except Exception as e:
print(f"处理 {filename} 时出错: {e}")
if __name__ == "__main__":
process_invoices()

223
scripts/02_create_labels.py Normal file
View File

@@ -0,0 +1,223 @@
# --- scripts/02_create_labels.py ---
#
# **重要**: 此脚本用于训练 "阶段一" 的区域检测器.
# 它只生成 1 个类别 (class_id = 0), 即 "payment_slip" 的 *整个* 区域.
# 它不使用您的高级 classify_text 逻辑, 而是使用启发式规则 (关键词, 线条) 来找到大区域.
#
import os
import pandas as pd
import numpy as np
import cv2
import re
import shutil
import json
from sklearn.cluster import DBSCAN
# --- 锚点词典 (技术二) ---
# 这些是用来 *定位* 凭证区域的词, 不是用来提取的
KEYWORDS = [
"Bankgirot", "PlusGirot", "OCR-nummer", "Att betala", "Mottagare",
"Betalningsmottagare", "Tillhanda senast", "Förfallodag", "Belopp",
"BG-nr", "PG-nr", "Meddelande", "OCR-kod", "Inbetalningskort", "Betalningsavi"
]
KEYWORDS_REGEX = '|'.join(KEYWORDS)
REGEX_PATTERNS = {
'bg_pg': r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', # Bankgiro/PlusGiro
'long_num': r'\b\d{10,}\b', # 可能是 OCR
'machine_code': r'#[0-9\s>#]+#' # 机读码
}
FULL_REGEX = f"({KEYWORDS_REGEX})|{REGEX_PATTERNS['bg_pg']}|{REGEX_PATTERNS['long_num']}|{REGEX_PATTERNS['machine_code']}"
# --- 路径设置 ---
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results")
IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images")
# Use temp directories for initial label generation (before train/val split)
LABEL_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_labels")
IMAGE_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_images")
os.makedirs(LABEL_DIR, exist_ok=True)
os.makedirs(IMAGE_DIR, exist_ok=True)
def find_text_anchors(text_boxes, img_height):
"""技术一 (位置) + 技术二 (词典)"""
anchors = []
if not text_boxes:
return []
# 技术一: 只在页面下半部分 (40% 处开始) 查找
page_midpoint = img_height * 0.4
for box in text_boxes:
# 检查位置
if box["bbox"]["y_min"] > page_midpoint:
# 检查文本内容
if re.search(FULL_REGEX, box["text"], re.IGNORECASE):
bbox = box["bbox"]
anchors.append(pd.Series({
'left': bbox["x_min"],
'top': bbox["y_min"],
'width': bbox["x_max"] - bbox["x_min"],
'height': bbox["y_max"] - bbox["y_min"]
}))
return anchors
def find_visual_anchors(image, img_height, img_width):
"""技术三 (视觉锚点 - 找线)"""
anchors = []
try:
# 1. 只看下半页
crop_y_start = img_height // 2
crop = image[crop_y_start:, :]
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)[1]
# 2. 查找长的水平线
min_line_length = img_width // 4
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_line_length, 1))
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
contours, _ = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for c in contours:
x, y, w, h = cv2.boundingRect(c)
if w > min_line_length: # 确保线足够长
original_y = y + crop_y_start # 转换回原图坐标
anchors.append(pd.Series({
'left': x, 'top': original_y, 'width': w, 'height': h
}))
except Exception as e:
print(f"查找视觉锚点时出错: {e}")
return anchors
def cluster_anchors(all_anchors, img_height):
"""技术四 (聚类)"""
if len(all_anchors) < 2:
return all_anchors # 锚点太少,无法聚类
# 1. 获取中心点
points = []
for anchor in all_anchors:
x_center = anchor.left + anchor.width / 2
y_center = anchor.top + anchor.height / 2
points.append((x_center, y_center))
points = np.array(points)
# 2. 运行 DBSCAN
eps_dist = img_height * 0.2
clustering = DBSCAN(eps=eps_dist, min_samples=2).fit(points)
labels = clustering.labels_
if len(labels) == 0 or np.all(labels == -1):
return all_anchors # 聚类失败
# 3. 找到最大的那个簇
unique_labels, counts = np.unique(labels[labels != -1], return_counts=True)
if len(counts) == 0:
return all_anchors # 只有噪声
largest_cluster_label = unique_labels[np.argmax(counts)]
# 4. 只返回属于最大簇的锚点
main_cluster_anchors = [
anchor for i, anchor in enumerate(all_anchors)
if labels[i] == largest_cluster_label
]
return main_cluster_anchors
def create_labels():
print("开始生成弱标签 (用于区域检测器)...")
processed_count = 0
skipped_count = 0
for ocr_filename in os.listdir(OCR_DIR):
if not ocr_filename.endswith(".json"):
continue
base_name = os.path.splitext(ocr_filename)[0]
json_path = os.path.join(OCR_DIR, ocr_filename)
img_path = os.path.join(IMG_DIR, f"{base_name}.png")
if not os.path.exists(img_path):
continue
try:
# 1. 加载数据
image = cv2.imread(img_path)
img_h, img_w = image.shape[:2]
with open(json_path, 'r', encoding='utf-8') as f:
ocr_data = json.load(f)
text_boxes = ocr_data.get("text_boxes", [])
# 2. 查找所有锚点
text_anchors = find_text_anchors(text_boxes, img_h)
visual_anchors = find_visual_anchors(image, img_h, img_w)
all_anchors = text_anchors + visual_anchors
if not all_anchors:
print(f"SKIPPING: {base_name} (未找到任何锚点)")
skipped_count += 1
continue
# 3. 聚类锚点
final_anchors = cluster_anchors(all_anchors, img_h)
if not final_anchors:
print(f"SKIPPING: {base_name} (未找到有效聚类)")
skipped_count += 1
continue
# 4. 聚合坐标
min_x = min(a.left for a in final_anchors)
min_y = min(a.top for a in final_anchors)
max_x = max(a.left + a.width for a in final_anchors)
max_y = max(a.top + a.height for a in final_anchors)
# 5. 添加边距 (Padding)
padding = 10
min_x = max(0, min_x - padding)
min_y = max(0, min_y - padding)
max_x = min(img_w, max_x + padding)
max_y = min(img_h, max_y + padding)
# 6. 转换为 YOLO 格式
box_w = max_x - min_x
box_h = max_y - min_y
x_center = (min_x + box_w / 2) / img_w
y_center = (min_y + box_h / 2) / img_h
norm_w = box_w / img_w
norm_h = box_h / img_h
# ** 关键: 类别 ID 永远是 0 **
class_id = 0 # 唯一的类别: 'payment_slip'
yolo_label = f"{class_id} {x_center} {y_center} {norm_w} {norm_h}\n"
# 7. 保存标签和图片
label_path = os.path.join(LABEL_DIR, f"{base_name}.txt")
with open(label_path, 'w', encoding='utf-8') as f:
f.write(yolo_label)
shutil.copy(
img_path,
os.path.join(IMAGE_DIR, f"{base_name}.png")
)
processed_count += 1
if processed_count % 20 == 0:
print(f"已生成 {processed_count} 个标签...")
except Exception as e:
print(f"处理 {base_name} 时出错: {e}")
skipped_count += 1
print("--- 弱标签生成完成 ---")
print(f"成功生成: {processed_count}")
print(f"跳过: {skipped_count}")
if __name__ == "__main__":
create_labels()

123
scripts/03_split_dataset.py Normal file
View File

@@ -0,0 +1,123 @@
"""
Dataset Split Script - Step 3
Splits images and labels into training and validation sets
"""
import shutil
import random
from pathlib import Path
import os
# Paths
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
YOLO_DATASET_DIR = Path(BASE_DIR + "/data/yolo_dataset")
TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images"
TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels"
TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train"
VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val"
TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train"
VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val"
# Configuration
VALIDATION_SPLIT = 0.2 # 20% for validation
RANDOM_SEED = 42
def split_dataset(val_split=VALIDATION_SPLIT, seed=RANDOM_SEED):
"""
Split dataset into training and validation sets
Args:
val_split: Fraction of data to use for validation (0.0 to 1.0)
seed: Random seed for reproducibility
"""
print("="*60)
print("Splitting Dataset into Train/Val Sets")
print("="*60)
print(f"Validation split: {val_split*100:.1f}%")
print(f"Random seed: {seed}\n")
# Check if temp directories exist
if not TEMP_IMAGES_DIR.exists() or not TEMP_LABELS_DIR.exists():
print(BASE_DIR)
print(YOLO_DATASET_DIR)
print(TEMP_IMAGES_DIR)
print(f"Error: Temporary directories not found")
print(f"Please run 02_create_labels.py first")
return
# Get all image files
image_files = list(TEMP_IMAGES_DIR.glob("*.jpg")) + list(TEMP_IMAGES_DIR.glob("*.png"))
if not image_files:
print(f"No image files found in {TEMP_IMAGES_DIR}")
return
# Filter images that have corresponding labels
valid_pairs = []
for image_file in image_files:
label_file = TEMP_LABELS_DIR / (image_file.stem + ".txt")
if label_file.exists():
valid_pairs.append({
"image": image_file,
"label": label_file
})
if not valid_pairs:
print("No valid image-label pairs found")
return
print(f"Found {len(valid_pairs)} image-label pair(s)")
# Shuffle and split
random.seed(seed)
random.shuffle(valid_pairs)
split_index = int(len(valid_pairs) * (1 - val_split))
train_pairs = valid_pairs[:split_index]
val_pairs = valid_pairs[split_index:]
print(f"\nSplit results:")
print(f" Training set: {len(train_pairs)} samples")
print(f" Validation set: {len(val_pairs)} samples")
print()
# Clear existing train/val directories
for directory in [TRAIN_IMAGES_DIR, VAL_IMAGES_DIR, TRAIN_LABELS_DIR, VAL_LABELS_DIR]:
if directory.exists():
shutil.rmtree(directory)
directory.mkdir(parents=True, exist_ok=True)
# Copy training files
print("Copying training files...")
for pair in train_pairs:
shutil.copy(pair["image"], TRAIN_IMAGES_DIR / pair["image"].name)
shutil.copy(pair["label"], TRAIN_LABELS_DIR / pair["label"].name)
print(f" Copied {len(train_pairs)} image-label pairs to train/")
# Copy validation files
print("Copying validation files...")
for pair in val_pairs:
shutil.copy(pair["image"], VAL_IMAGES_DIR / pair["image"].name)
shutil.copy(pair["label"], VAL_LABELS_DIR / pair["label"].name)
print(f" Copied {len(val_pairs)} image-label pairs to val/")
print("\n" + "="*60)
print("Dataset split complete!")
print(f"\nDataset structure:")
print(f" {TRAIN_IMAGES_DIR}")
print(f" {TRAIN_LABELS_DIR}")
print(f" {VAL_IMAGES_DIR}")
print(f" {VAL_LABELS_DIR}")
print(f"\nNext step: Run 04_train_yolo.py to train the model")
print("="*60)
def main():
"""Main function"""
split_dataset()
if __name__ == "__main__":
main()

230
scripts/04_train_yolo.py Normal file
View File

@@ -0,0 +1,230 @@
"""
YOLO Training Script - Step 4
Trains YOLOv8 model on the prepared invoice dataset
"""
from pathlib import Path
from ultralytics import YOLO
import torch
import os
# Paths
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATASET_YAML = Path(BASE_DIR + "/data/yolo_dataset/dataset.yaml")
MODELS_DIR = Path(BASE_DIR + "/models")
# Training configuration
MODEL_SIZE = "n" # Options: n (nano), s (small), m (medium), l (large), x (xlarge)
EPOCHS = 100
BATCH_SIZE = 16
IMAGE_SIZE = 640
DEVICE = 0 if torch.cuda.is_available() else "cpu" # Use GPU with PyTorch 2.7 + CUDA 12.8
# Create models directory
MODELS_DIR.mkdir(exist_ok=True)
def train_model(
model_size=MODEL_SIZE,
epochs=EPOCHS,
batch_size=BATCH_SIZE,
img_size=IMAGE_SIZE,
device=DEVICE
):
"""
Train YOLOv8 model on invoice dataset
Args:
model_size: Size of YOLO model (n, s, m, l, x)
epochs: Number of training epochs
batch_size: Batch size for training
img_size: Input image size
device: Device to use for training (cuda or cpu)
"""
print("="*60)
print("YOLOv8 Invoice Detection Training")
print("="*60)
# Check if dataset.yaml exists
if not DATASET_YAML.exists():
print(f"Error: {DATASET_YAML} not found")
print("Please ensure the dataset.yaml file exists")
return
# Print configuration
print(f"\nConfiguration:")
print(f" Model: YOLOv8{model_size}")
print(f" Epochs: {epochs}")
print(f" Batch size: {batch_size}")
print(f" Image size: {img_size}")
print(f" Device: {device}")
print(f" Dataset config: {DATASET_YAML}")
print()
# Initialize model
print(f"Loading YOLOv8{model_size} model...")
model = YOLO(f"yolov8{model_size}.pt") # Load pretrained model
# Print device info
if device == 0:
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
print("Using CPU (training will be slower)")
print("\nStarting training...")
print("-" * 60)
# Train the model
results = model.train(
data=str(DATASET_YAML),
epochs=epochs,
imgsz=img_size,
batch=batch_size,
device=device,
project=str(MODELS_DIR),
name="payment_slip_detector_v1",
exist_ok=True,
patience=20, # Early stopping patience
save=True,
save_period=10, # Save checkpoint every 10 epochs
verbose=True,
plots=True # Generate training plots
)
print("\n" + "="*60)
print("Training complete!")
print("="*60)
# Print results
best_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
last_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "last.pt"
print(f"\nTrained models saved to:")
print(f" Best model: {best_model_path}")
print(f" Last model: {last_model_path}")
print(f"\nTraining plots saved to:")
print(f" {MODELS_DIR / 'payment_slip_detector_v1'}")
print("\n" + "="*60)
def validate_model(model_path=None):
"""
Validate trained model on validation set
Args:
model_path: Path to model weights (default: best.pt from last training)
"""
if model_path is None:
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
if not Path(model_path).exists():
print(f"Error: Model not found at {model_path}")
print("Please train a model first")
return
print("="*60)
print("Validating Model")
print("="*60)
print(f"Model: {model_path}\n")
# Load model
model = YOLO(str(model_path))
# Validate
results = model.val(data=str(DATASET_YAML))
print("\n" + "="*60)
print("Validation complete!")
print("="*60)
def predict_sample(model_path=None, image_path=None, conf_threshold=0.25):
"""
Run prediction on a sample image
Args:
model_path: Path to model weights
image_path: Path to image to predict on
conf_threshold: Confidence threshold for detections
"""
if model_path is None:
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
if not Path(model_path).exists():
print(f"Error: Model not found at {model_path}")
return
if image_path is None:
# Try to get a sample from validation set
val_images_dir = Path("data/yolo_dataset/images/val")
sample_images = list(val_images_dir.glob("*.jpg")) + list(val_images_dir.glob("*.png"))
if sample_images:
image_path = sample_images[0]
else:
print("No sample images found")
return
print("="*60)
print("Running Prediction")
print("="*60)
print(f"Model: {model_path}")
print(f"Image: {image_path}")
print(f"Confidence threshold: {conf_threshold}\n")
# Load model
model = YOLO(str(model_path))
# Predict
results = model.predict(
source=str(image_path),
conf=conf_threshold,
save=True,
project=str(MODELS_DIR / "predictions"),
name="sample"
)
print(f"\nPrediction saved to: {MODELS_DIR / 'predictions' / 'sample'}")
print("="*60)
def main():
"""Main training function"""
import argparse
parser = argparse.ArgumentParser(description="Train YOLOv8 on invoice dataset")
parser.add_argument("--mode", type=str, default="train", choices=["train", "validate", "predict"],
help="Mode: train, validate, or predict")
parser.add_argument("--model-size", type=str, default=MODEL_SIZE,
choices=["n", "s", "m", "l", "x"],
help="YOLO model size")
parser.add_argument("--epochs", type=int, default=EPOCHS,
help="Number of training epochs")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
help="Batch size")
parser.add_argument("--img-size", type=int, default=IMAGE_SIZE,
help="Image size")
parser.add_argument("--model-path", type=str, default=None,
help="Path to model weights (for validate/predict)")
parser.add_argument("--image-path", type=str, default=None,
help="Path to image (for predict)")
args = parser.parse_args()
if args.mode == "train":
train_model(
model_size=args.model_size,
epochs=args.epochs,
batch_size=args.batch_size,
img_size=args.img_size
)
elif args.mode == "validate":
validate_model(model_path=args.model_path)
elif args.mode == "predict":
predict_sample(model_path=args.model_path, image_path=args.image_path)
if __name__ == "__main__":
main()

237
scripts/main.py Normal file
View File

@@ -0,0 +1,237 @@
# --- main.py (已升级) ---
from fastapi import FastAPI, UploadFile, File, HTTPException
from ultralytics import YOLO
import cv2
import numpy as np
import pytesseract
import re
import io
import os
from contextlib import asynccontextmanager
# --- 配置 ---
# TODO: 确保此路径指向您训练好的最佳模型
MODEL_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"models", "invoice_detector_v1", "weights", "best.pt"
)
# 定义一个字典来在 FastAPI 启动时加载模型
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时加载模型
print(MODEL_PATH)
if not os.path.exists(MODEL_PATH):
print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。")
ml_models["yolo"] = None
else:
# 加载您的 "payment_slip" 区域检测器
ml_models["yolo"] = YOLO(MODEL_PATH)
print("YOLOv8 区域检测模型加载成功。")
yield
# 清理模型
ml_models.clear()
app = FastAPI(lifespan=lifespan)
# --- Luhn (Modulus 10) 校验函数 ---
def luhn_validate(number_str: str, expected_check_digit: str) -> bool:
"""
使用 Modulus 10 (Luhn) 算法验证一个数字字符串。
(从右到左, 权重 1, 2, 1, 2...)
"""
try:
digits = [int(d) for d in number_str]
weights = [1, 2] * (len(digits) // 2 + 1)
weights = weights[:len(digits)] # 确保权重列表长度一致
sum_val = 0
# 从右到左计算
for d, w in zip(reversed(digits), weights):
product = d * w
if product >= 10:
sum_val += (product // 10) + (product % 10)
else:
sum_val += product
calculated_check_digit = (10 - (sum_val % 10)) % 10
return str(calculated_check_digit) == expected_check_digit
except Exception:
return False
# --- 提取逻辑 ---
def parse_ocr_rad(ocr_text: str) -> dict:
"""
Plan A: 尝试解析机器可读码 (OCR-rad)
示例: # 400299582421 # 4603 00 7 > 48180020 #14#
"""
# 移除所有空格以简化匹配
text_no_space = re.sub(r'\s+', '', ocr_text)
# 定义一个更健壮的正则表达式
# 组1: OCR号 (在 #...# 之间)
# 组2: 金额 (Kronor + Öre) (在 #... 之后)
# 组3: 校验码 (1位数字)
# 组4: 账户号 (在 >...# 之间)
rad_regex = r'#([\d>]+)#(\d{2,})(\d)(\d{1})>([\d]+)#'
match = re.search(rad_regex, text_no_space)
if not match:
return None # 未找到机读码行
try:
ocr_num = match.group(1).replace(">", "") # 移除 > (如果有)
amount_base = match.group(2) + match.group(3) # "4603" + "00" = "460300"
check_digit = match.group(4) # "7"
account = match.group(5) # "48180020"
# 运行Luhn校验
if luhn_validate(amount_base, check_digit):
# 校验成功! 这是高置信度数据
amount_kronor = amount_base[:-2]
amount_ore = amount_base[-2:]
return {
"source": "OCR-rad (High Confidence)",
"ocr_number": ocr_num,
"amount_due": f"{amount_kronor}.{amount_ore}",
"bankgiro_plusgiro": account,
"due_date": None # 机读码行通常不包含日期
}
else:
# 校验失败!
print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}")
return None
except Exception as e:
print(f"解析 OCR-rad 时出错: {e}")
return None
def parse_human_readable(ocr_text: str) -> dict:
"""
Plan B: 回退到人工可读区域
(这里我们使用之前版本中的简单 Regex,
您也可以替换为您那个更复杂的 classify_text 逻辑)
"""
data = {"source": "Human-Readable (Fallback)"}
# 查找 BG/PG (Bankgiro/Plusgiro)
bg_match = re.search(r'Bankgiro\D*(\d{2,4}[- ]\d{4})', ocr_text, re.IGNORECASE)
pg_match = re.search(r'PlusGiro\D*(\d{2,7}[- ]\d)', ocr_text, re.IGNORECASE)
if bg_match:
data["bankgiro_plusgiro"] = bg_match.group(1).replace(" ", "")
elif pg_match:
data["bankgiro_plusgiro"] = pg_match.group(1).replace(" ", "")
else:
# 备用查找
bg_pg_alt = re.search(r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', ocr_text)
if bg_pg_alt:
data["bankgiro_plusgiro"] = bg_pg_alt.group(0).replace(" ", "")
# 查找 OCR
ocr_match = re.search(r'(OCR|Fakturanummer|Referens)\D*(\d[\d\s]{5,}\d)', ocr_text, re.IGNORECASE)
if ocr_match:
data["ocr_number"] = re.sub(r'\s', '', ocr_match.group(2))
# 查找金额
amount_match = re.search(r'(Att betala|Belopp)\D*([\d\s,.]+)\s*(kr|SEK)?', ocr_text, re.IGNORECASE)
if amount_match:
amount_str = amount_match.group(2).strip().replace(" ", "").replace(",", ".")
if amount_str.count('.') > 1:
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
data["amount_due"] = amount_str
# 查找截止日期
date_match = re.search(r'(senast|Förfallodag)\D*(\d{4}[- ]\d{2}[- ]\d{2})', ocr_text, re.IGNORECASE)
if date_match:
data["due_date"] = date_match.group(2).replace(" ", "-")
return data
def extract_info_from_crop(crop_image: np.ndarray) -> dict:
"""
主提取函数: 执行 Plan A 和 Plan B
"""
try:
# 1. 预处理并运行 OCR (获取所有文本)
gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
# 使用 psm 6 假设是一个统一的文本块, 这对 OCR-rad 很友好
ocr_text = pytesseract.image_to_string(thresh, lang='swe', config='--psm 6')
# 2. --- Plan A: 尝试解析机读码 ---
rad_data = parse_ocr_rad(ocr_text)
if rad_data:
# Plan A 成功!
return rad_data
# 3. --- Plan B: Plan A 失败, 回退到人工读取 ---
# 我们重新运行 OCR, 使用 psm 3 (自动布局), 这对人工区域更友好
ocr_text_human = pytesseract.image_to_string(thresh, lang='swe', config='--psm 3')
human_data = parse_human_readable(ocr_text_human)
# 即使回退失败, 也返回空字典 (或部分数据)
return human_data
except Exception as e:
return {"error": f"提取时出错: {e}"}
@app.post("/extract_invoice/")
async def extract_invoice_data(file: UploadFile = File(...)):
if ml_models.get("yolo") is None:
raise HTTPException(status_code=503, detail="模型未加载。请检查模型路径。")
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=400, detail="无法解码图像文件。")
except Exception as e:
raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}")
# 1. 运行 YOLO 检测 (阶段一)
results = ml_models["yolo"](img, verbose=False)
all_extractions = []
if not results or not results[0].boxes:
return {"message": "未在图像中检测到支付凭证区域。"}
for res in results:
# 遍历所有检测到的凭证 (通常只有一个)
for box_coords in res.boxes.xyxy.cpu().numpy().astype(int):
xmin, ymin, xmax, ymax = box_coords
# 2. 裁剪图像
crop = img[ymin:ymax, xmin:xmax]
# 3. 在裁剪图上运行 "Plan A/B" 提取 (阶段二)
extracted_data = extract_info_from_crop(crop)
extracted_data["bounding_box"] = [xmin, ymin, xmax, ymax]
all_extractions.append(extracted_data)
if not all_extractions:
return {"message": "检测到支付凭证, 但未能提取任何信息。"}
return {"invoice_extractions": all_extractions}
if __name__ == "__main__":
import uvicorn
print(f"--- 启动 FastAPI 服务 ---")
print(f"加载模型: {MODEL_PATH}")
print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档")
uvicorn.run(app, host="127.0.0.1", port=8000)

14
test_api.py Normal file
View File

@@ -0,0 +1,14 @@
import requests
import json
# Test the API with the problematic invoice
url = "http://127.0.0.1:8000/extract_invoice/"
file_path = r"data\processed_images\4BC5E5B3-E561-4A73-BC9C-46D4F08F89C3.png"
with open(file_path, 'rb') as f:
files = {'file': f}
response = requests.post(url, files=files)
print("Status Code:", response.status_code)
print("\nResponse JSON:")
print(json.dumps(response.json(), indent=2, ensure_ascii=False))