commit dafa86c58840e3c3f61c9eff05b4c433f634561b Author: Yaojia Wang Date: Sun Oct 26 20:41:11 2025 +0100 Init diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..3776064 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,28 @@ +{ + "permissions": { + "allow": [ + "Bash(mkdir:*)", + "Bash(python:*)", + "Bash(pip install:*)", + "Bash(dir:*)", + "Bash(tesseract:*)", + "Bash(nvidia-smi:*)", + "Bash(pip uninstall:*)", + "Bash(for img in ../../images/train/*.jpg)", + "Bash(do basename=\"$img##*/\")", + "Bash(labelname=\"$basename%.jpg.txt\")", + "Bash(if [ -f \"../../temp_visual_labels/$labelname\" ])", + "Bash(then cp \"../../temp_visual_labels/$labelname\" .)", + "Bash(fi)", + "Bash(done)", + "Bash(awk:*)", + "Bash(chcp 65001)", + "Bash(PYTHONIOENCODING=utf-8 python:*)", + "Bash(timeout 600 tail:*)", + "Bash(cat:*)", + "Bash(powershell -Command \"$response = Invoke-WebRequest -Uri ''http://127.0.0.1:8000/extract_invoice/'' -Method POST -Form @{file=Get-Item ''data\\processed_images\\4BC5E5B3-E561-4A73-BC9C-46D4F08F89C3.png''} -UseBasicParsing; $response.Content\")" + ], + "deny": [], + "ask": [] + } +} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d394568 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml +/data +/.idea +/__pycache__/ diff --git a/config.py b/config.py new file mode 100644 index 0000000..ad9fe83 --- /dev/null +++ b/config.py @@ -0,0 +1,137 @@ +""" +Configuration file for system-dependent paths and settings. + +This file contains paths that may vary between different systems. +Copy this file and modify the paths according to your local installation. +""" + +import os +from pathlib import Path + +# ============================================================================ +# System Paths - Modify these according to your installation +# ============================================================================ + +# Poppler path (required for PDF to image conversion) +# Download from: https://github.com/oschwartz10612/poppler-windows/releases +# Example: r"C:\poppler-23.11.0\bin" +POPPLER_PATH = os.getenv("POPPLER_PATH", r"C:\Program Files\poppler-25.07.0\Library\bin") + +# Tesseract path (optional - only needed if not in system PATH) +# Download from: https://github.com/UB-Mannheim/tesseract/wiki +# Example: r"C:\Program Files\Tesseract-OCR\tesseract.exe" +TESSERACT_CMD = os.getenv("TESSERACT_CMD", None) + +# ============================================================================ +# Project Paths - Generally don't need to modify these +# ============================================================================ + +# Project root directory +PROJECT_ROOT = Path(__file__).parent.absolute() + +# Data directories +DATA_DIR = PROJECT_ROOT / "data" +RAW_INVOICES_DIR = DATA_DIR / "raw_invoices" +PROCESSED_IMAGES_DIR = DATA_DIR / "processed_images" +OCR_RESULTS_DIR = DATA_DIR / "ocr_results" + +# YOLO dataset directories +YOLO_DATASET_DIR = DATA_DIR / "yolo_dataset" +YOLO_TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images" +YOLO_TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels" +YOLO_TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train" +YOLO_TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train" +YOLO_VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val" +YOLO_VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val" + +# Model directories +MODELS_DIR = PROJECT_ROOT / "models" +DEFAULT_MODEL_PATH = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt" + +# ============================================================================ +# OCR Settings +# ============================================================================ + +# Tesseract language (Swedish + English) +TESSERACT_LANG = "swe" # Ensure Swedish language pack is installed + +# OCR confidence threshold (0-100) +OCR_CONFIDENCE_THRESHOLD = 0 + +# ============================================================================ +# Training Settings +# ============================================================================ + +# YOLO model size: n (nano), s (small), m (medium), l (large), x (xlarge) +YOLO_MODEL_SIZE = "n" + +# Training epochs +TRAINING_EPOCHS = 100 + +# Batch size +BATCH_SIZE = 16 + +# Image size for training +IMAGE_SIZE = 640 + +# Validation split ratio (0.0 to 1.0) +VALIDATION_SPLIT = 0.2 + +# Random seed for reproducibility +RANDOM_SEED = 42 + +# ============================================================================ +# API Settings (for main.py FastAPI server) +# ============================================================================ + +# API host +API_HOST = "127.0.0.1" + +# API port +API_PORT = 8000 + +# ============================================================================ +# Helper Functions +# ============================================================================ + +def apply_tesseract_path(): + """Apply Tesseract path if configured.""" + if TESSERACT_CMD: + import pytesseract + pytesseract.pytesseract.tesseract_cmd = TESSERACT_CMD + +def validate_paths(): + """Validate that required system paths exist.""" + issues = [] + + # Check Poppler + if not os.path.exists(POPPLER_PATH): + issues.append(f"Poppler not found at: {POPPLER_PATH}") + issues.append(" Download from: https://github.com/oschwartz10612/poppler-windows/releases") + + # Check Tesseract (if specified) + if TESSERACT_CMD and not os.path.exists(TESSERACT_CMD): + issues.append(f"Tesseract not found at: {TESSERACT_CMD}") + issues.append(" Download from: https://github.com/UB-Mannheim/tesseract/wiki") + + if issues: + print("Configuration Issues Found:") + for issue in issues: + print(f" {issue}") + return False + + return True + +# ============================================================================ +# Example: Environment Variable Override +# ============================================================================ +# You can set these in your environment instead of modifying this file: +# +# Windows: +# set POPPLER_PATH=C:\poppler\bin +# set TESSERACT_CMD=C:\Program Files\Tesseract-OCR\tesseract.exe +# +# Linux/Mac: +# export POPPLER_PATH=/usr/bin +# export TESSERACT_CMD=/usr/bin/tesseract +# ============================================================================ diff --git a/extraction_results.json b/extraction_results.json new file mode 100644 index 0000000..bfcfb00 --- /dev/null +++ b/extraction_results.json @@ -0,0 +1,17 @@ +[ + { + "image": "data\\processed_images\\20250917.03.1.011299_c328f5a8-06f9-4093-85b5-e3a40f24bd30_page_1.jpg", + "fields": {}, + "all_detections": [] + }, + { + "image": "data\\processed_images\\64A80892-8A9E-454C-9AEB-B740E8C3ACB3_page_1.jpg", + "fields": {}, + "all_detections": [] + }, + { + "image": "data\\processed_images\\9fb6129f-671d-4aa1-9bad-096e84e6ded3_page_1.jpg", + "fields": {}, + "all_detections": [] + } +] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..91c4876 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,45 @@ +# Core dependencies +ultralytics>=8.0.0 # YOLOv8 +pytesseract>=0.3.10 # Tesseract OCR Python wrapper + +# Image processing +pdf2image>=1.16.0 # PDF to image conversion +Pillow>=10.0.0 # Image manipulation +opencv-python>=4.8.0 # Computer vision + +# Data handling +numpy>=1.24.0 +pandas>=2.0.0 +scikit-learn>=1.3.0 # For DBSCAN clustering in 02_create_labels.py + +# API dependencies (for main.py) +fastapi>=0.104.0 # FastAPI web framework +uvicorn>=0.24.0 # ASGI server +python-multipart>=0.0.6 # For file upload support + +# System utilities +# IMPORTANT: Requires system-level installation of: +# +# 1. Tesseract OCR: +# - Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki +# After installation, add to PATH or set: pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' +# - Linux: sudo apt-get install tesseract-ocr tesseract-ocr-swe tesseract-ocr-eng +# - macOS: brew install tesseract tesseract-lang +# +# 2. Poppler (for pdf2image): +# - Windows: Download from https://github.com/oschwartz10612/poppler-windows/releases +# - Linux: sudo apt-get install poppler-utils +# - macOS: brew install poppler +# +# 3. Swedish language data for Tesseract: +# After installing Tesseract, you may need to download Swedish language files (swe.traineddata) +# from https://github.com/tesseract-ocr/tessdata + +# Optional: GPU support +# torch>=2.0.0 # PyTorch with CUDA support +# torchvision>=0.15.0 + +# Development tools (optional) +# jupyter>=1.0.0 +# matplotlib>=3.7.0 +# seaborn>=0.12.0 diff --git a/scripts/01_process_invoices.py b/scripts/01_process_invoices.py new file mode 100644 index 0000000..6242024 --- /dev/null +++ b/scripts/01_process_invoices.py @@ -0,0 +1,106 @@ +# --- scripts/01_process_invoices.py --- + +import os +import sys +import json +import pytesseract +import pandas as pd +from pdf2image import convert_from_path +import cv2 +import numpy as np +import shutil + +# Add parent directory to path to import config +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from config import POPPLER_PATH, apply_tesseract_path + +# Apply Tesseract path from config +apply_tesseract_path() + +# 项目路径设置 +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +RAW_DIR = os.path.join(BASE_DIR, "data", "raw_invoices") +IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images") +OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results") + +# 创建输出目录 +os.makedirs(IMG_DIR, exist_ok=True) +os.makedirs(OCR_DIR, exist_ok=True) + +def process_invoices(): + print(f"开始处理 {RAW_DIR} 中的发票...") + + for filename in os.listdir(RAW_DIR): + filepath = os.path.join(RAW_DIR, filename) + base_name = os.path.splitext(filename)[0] + img_path = os.path.join(IMG_DIR, f"{base_name}.png") + json_path = os.path.join(OCR_DIR, f"{base_name}.json") + + # 防止重复处理 + if os.path.exists(img_path) and os.path.exists(json_path): + continue + + try: + # 1. 加载图像 (PDF 或 图片) + if filename.lower().endswith(".pdf"): + images = convert_from_path(filepath, poppler_path=POPPLER_PATH) + if not images: + print(f"警告: 无法从 {filename} 提取图像。") + continue + img_pil = images[0] # 取第一页 + img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) + else: + img = cv2.imread(filepath) + if img is None: + print(f"警告: 无法读取图像文件 {filename}。") + continue + + img_h, img_w = img.shape[:2] + + # 2. 保存统一格式的图像 + cv2.imwrite(img_path, img) + + # 3. 运行 Tesseract OCR + ocr_data_df = pytesseract.image_to_data( + img, + lang='swe', # 确保已安装瑞典语包 + output_type=pytesseract.Output.DATAFRAME + ) + + # 4. 清理 OCR 结果 + ocr_data_df = ocr_data_df[ocr_data_df.conf > 0] + ocr_data_df.dropna(subset=['text'], inplace=True) + ocr_data_df['text'] = ocr_data_df['text'].astype(str).str.strip() + ocr_data_df = ocr_data_df[ocr_data_df['text'] != ""] + + # 5. 转换为您在另一脚本中使用的 JSON 格式 (包含 text_boxes) + text_boxes = [] + for i, row in ocr_data_df.iterrows(): + text_boxes.append({ + "text": row["text"], + "bbox": { + "x_min": row["left"], + "y_min": row["top"], + "x_max": row["left"] + row["width"], + "y_max": row["top"] + row["height"] + }, + "confidence": row["conf"] / 100.0 + }) + + output_json = { + "image_name": f"{base_name}.png", + "width": img_w, + "height": img_h, + "text_boxes": text_boxes + } + + with open(json_path, 'w', encoding='utf-8') as f: + json.dump(output_json, f, indent=4, ensure_ascii=False) + + print(f"已处理: {filename}") + + except Exception as e: + print(f"处理 {filename} 时出错: {e}") + +if __name__ == "__main__": + process_invoices() \ No newline at end of file diff --git a/scripts/02_create_labels.py b/scripts/02_create_labels.py new file mode 100644 index 0000000..fffa1af --- /dev/null +++ b/scripts/02_create_labels.py @@ -0,0 +1,223 @@ +# --- scripts/02_create_labels.py --- +# +# **重要**: 此脚本用于训练 "阶段一" 的区域检测器. +# 它只生成 1 个类别 (class_id = 0), 即 "payment_slip" 的 *整个* 区域. +# 它不使用您的高级 classify_text 逻辑, 而是使用启发式规则 (关键词, 线条) 来找到大区域. +# + +import os +import pandas as pd +import numpy as np +import cv2 +import re +import shutil +import json +from sklearn.cluster import DBSCAN + +# --- 锚点词典 (技术二) --- +# 这些是用来 *定位* 凭证区域的词, 不是用来提取的 +KEYWORDS = [ + "Bankgirot", "PlusGirot", "OCR-nummer", "Att betala", "Mottagare", + "Betalningsmottagare", "Tillhanda senast", "Förfallodag", "Belopp", + "BG-nr", "PG-nr", "Meddelande", "OCR-kod", "Inbetalningskort", "Betalningsavi" +] +KEYWORDS_REGEX = '|'.join(KEYWORDS) + +REGEX_PATTERNS = { + 'bg_pg': r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', # Bankgiro/PlusGiro + 'long_num': r'\b\d{10,}\b', # 可能是 OCR + 'machine_code': r'#[0-9\s>#]+#' # 机读码 +} +FULL_REGEX = f"({KEYWORDS_REGEX})|{REGEX_PATTERNS['bg_pg']}|{REGEX_PATTERNS['long_num']}|{REGEX_PATTERNS['machine_code']}" + +# --- 路径设置 --- +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results") +IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images") +# Use temp directories for initial label generation (before train/val split) +LABEL_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_labels") +IMAGE_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_images") + +os.makedirs(LABEL_DIR, exist_ok=True) +os.makedirs(IMAGE_DIR, exist_ok=True) + +def find_text_anchors(text_boxes, img_height): + """技术一 (位置) + 技术二 (词典)""" + anchors = [] + if not text_boxes: + return [] + + # 技术一: 只在页面下半部分 (40% 处开始) 查找 + page_midpoint = img_height * 0.4 + + for box in text_boxes: + # 检查位置 + if box["bbox"]["y_min"] > page_midpoint: + # 检查文本内容 + if re.search(FULL_REGEX, box["text"], re.IGNORECASE): + bbox = box["bbox"] + anchors.append(pd.Series({ + 'left': bbox["x_min"], + 'top': bbox["y_min"], + 'width': bbox["x_max"] - bbox["x_min"], + 'height': bbox["y_max"] - bbox["y_min"] + })) + return anchors + +def find_visual_anchors(image, img_height, img_width): + """技术三 (视觉锚点 - 找线)""" + anchors = [] + try: + # 1. 只看下半页 + crop_y_start = img_height // 2 + crop = image[crop_y_start:, :] + + gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY) + thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)[1] + + # 2. 查找长的水平线 + min_line_length = img_width // 4 + horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_line_length, 1)) + detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2) + + contours, _ = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + for c in contours: + x, y, w, h = cv2.boundingRect(c) + if w > min_line_length: # 确保线足够长 + original_y = y + crop_y_start # 转换回原图坐标 + anchors.append(pd.Series({ + 'left': x, 'top': original_y, 'width': w, 'height': h + })) + except Exception as e: + print(f"查找视觉锚点时出错: {e}") + return anchors + +def cluster_anchors(all_anchors, img_height): + """技术四 (聚类)""" + if len(all_anchors) < 2: + return all_anchors # 锚点太少,无法聚类 + + # 1. 获取中心点 + points = [] + for anchor in all_anchors: + x_center = anchor.left + anchor.width / 2 + y_center = anchor.top + anchor.height / 2 + points.append((x_center, y_center)) + + points = np.array(points) + + # 2. 运行 DBSCAN + eps_dist = img_height * 0.2 + clustering = DBSCAN(eps=eps_dist, min_samples=2).fit(points) + labels = clustering.labels_ + + if len(labels) == 0 or np.all(labels == -1): + return all_anchors # 聚类失败 + + # 3. 找到最大的那个簇 + unique_labels, counts = np.unique(labels[labels != -1], return_counts=True) + if len(counts) == 0: + return all_anchors # 只有噪声 + + largest_cluster_label = unique_labels[np.argmax(counts)] + + # 4. 只返回属于最大簇的锚点 + main_cluster_anchors = [ + anchor for i, anchor in enumerate(all_anchors) + if labels[i] == largest_cluster_label + ] + return main_cluster_anchors + +def create_labels(): + print("开始生成弱标签 (用于区域检测器)...") + processed_count = 0 + skipped_count = 0 + + for ocr_filename in os.listdir(OCR_DIR): + if not ocr_filename.endswith(".json"): + continue + + base_name = os.path.splitext(ocr_filename)[0] + json_path = os.path.join(OCR_DIR, ocr_filename) + img_path = os.path.join(IMG_DIR, f"{base_name}.png") + + if not os.path.exists(img_path): + continue + + try: + # 1. 加载数据 + image = cv2.imread(img_path) + img_h, img_w = image.shape[:2] + with open(json_path, 'r', encoding='utf-8') as f: + ocr_data = json.load(f) + text_boxes = ocr_data.get("text_boxes", []) + + # 2. 查找所有锚点 + text_anchors = find_text_anchors(text_boxes, img_h) + visual_anchors = find_visual_anchors(image, img_h, img_w) + all_anchors = text_anchors + visual_anchors + + if not all_anchors: + print(f"SKIPPING: {base_name} (未找到任何锚点)") + skipped_count += 1 + continue + + # 3. 聚类锚点 + final_anchors = cluster_anchors(all_anchors, img_h) + + if not final_anchors: + print(f"SKIPPING: {base_name} (未找到有效聚类)") + skipped_count += 1 + continue + + # 4. 聚合坐标 + min_x = min(a.left for a in final_anchors) + min_y = min(a.top for a in final_anchors) + max_x = max(a.left + a.width for a in final_anchors) + max_y = max(a.top + a.height for a in final_anchors) + + # 5. 添加边距 (Padding) + padding = 10 + min_x = max(0, min_x - padding) + min_y = max(0, min_y - padding) + max_x = min(img_w, max_x + padding) + max_y = min(img_h, max_y + padding) + + # 6. 转换为 YOLO 格式 + box_w = max_x - min_x + box_h = max_y - min_y + x_center = (min_x + box_w / 2) / img_w + y_center = (min_y + box_h / 2) / img_h + norm_w = box_w / img_w + norm_h = box_h / img_h + + # ** 关键: 类别 ID 永远是 0 ** + class_id = 0 # 唯一的类别: 'payment_slip' + + yolo_label = f"{class_id} {x_center} {y_center} {norm_w} {norm_h}\n" + + # 7. 保存标签和图片 + label_path = os.path.join(LABEL_DIR, f"{base_name}.txt") + with open(label_path, 'w', encoding='utf-8') as f: + f.write(yolo_label) + + shutil.copy( + img_path, + os.path.join(IMAGE_DIR, f"{base_name}.png") + ) + + processed_count += 1 + if processed_count % 20 == 0: + print(f"已生成 {processed_count} 个标签...") + + except Exception as e: + print(f"处理 {base_name} 时出错: {e}") + skipped_count += 1 + + print("--- 弱标签生成完成 ---") + print(f"成功生成: {processed_count} 个") + print(f"跳过: {skipped_count} 个") + +if __name__ == "__main__": + create_labels() \ No newline at end of file diff --git a/scripts/03_split_dataset.py b/scripts/03_split_dataset.py new file mode 100644 index 0000000..92d64c8 --- /dev/null +++ b/scripts/03_split_dataset.py @@ -0,0 +1,123 @@ +""" +Dataset Split Script - Step 3 +Splits images and labels into training and validation sets +""" + +import shutil +import random +from pathlib import Path +import os + +# Paths +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +YOLO_DATASET_DIR = Path(BASE_DIR + "/data/yolo_dataset") +TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images" +TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels" + +TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train" +VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val" +TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train" +VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val" + +# Configuration +VALIDATION_SPLIT = 0.2 # 20% for validation +RANDOM_SEED = 42 + + +def split_dataset(val_split=VALIDATION_SPLIT, seed=RANDOM_SEED): + """ + Split dataset into training and validation sets + + Args: + val_split: Fraction of data to use for validation (0.0 to 1.0) + seed: Random seed for reproducibility + """ + print("="*60) + print("Splitting Dataset into Train/Val Sets") + print("="*60) + print(f"Validation split: {val_split*100:.1f}%") + print(f"Random seed: {seed}\n") + + # Check if temp directories exist + if not TEMP_IMAGES_DIR.exists() or not TEMP_LABELS_DIR.exists(): + print(BASE_DIR) + print(YOLO_DATASET_DIR) + print(TEMP_IMAGES_DIR) + print(f"Error: Temporary directories not found") + print(f"Please run 02_create_labels.py first") + return + + # Get all image files + image_files = list(TEMP_IMAGES_DIR.glob("*.jpg")) + list(TEMP_IMAGES_DIR.glob("*.png")) + + if not image_files: + print(f"No image files found in {TEMP_IMAGES_DIR}") + return + + # Filter images that have corresponding labels + valid_pairs = [] + for image_file in image_files: + label_file = TEMP_LABELS_DIR / (image_file.stem + ".txt") + if label_file.exists(): + valid_pairs.append({ + "image": image_file, + "label": label_file + }) + + if not valid_pairs: + print("No valid image-label pairs found") + return + + print(f"Found {len(valid_pairs)} image-label pair(s)") + + # Shuffle and split + random.seed(seed) + random.shuffle(valid_pairs) + + split_index = int(len(valid_pairs) * (1 - val_split)) + train_pairs = valid_pairs[:split_index] + val_pairs = valid_pairs[split_index:] + + print(f"\nSplit results:") + print(f" Training set: {len(train_pairs)} samples") + print(f" Validation set: {len(val_pairs)} samples") + print() + + # Clear existing train/val directories + for directory in [TRAIN_IMAGES_DIR, VAL_IMAGES_DIR, TRAIN_LABELS_DIR, VAL_LABELS_DIR]: + if directory.exists(): + shutil.rmtree(directory) + directory.mkdir(parents=True, exist_ok=True) + + # Copy training files + print("Copying training files...") + for pair in train_pairs: + shutil.copy(pair["image"], TRAIN_IMAGES_DIR / pair["image"].name) + shutil.copy(pair["label"], TRAIN_LABELS_DIR / pair["label"].name) + print(f" Copied {len(train_pairs)} image-label pairs to train/") + + # Copy validation files + print("Copying validation files...") + for pair in val_pairs: + shutil.copy(pair["image"], VAL_IMAGES_DIR / pair["image"].name) + shutil.copy(pair["label"], VAL_LABELS_DIR / pair["label"].name) + print(f" Copied {len(val_pairs)} image-label pairs to val/") + + print("\n" + "="*60) + print("Dataset split complete!") + print(f"\nDataset structure:") + print(f" {TRAIN_IMAGES_DIR}") + print(f" {TRAIN_LABELS_DIR}") + print(f" {VAL_IMAGES_DIR}") + print(f" {VAL_LABELS_DIR}") + print(f"\nNext step: Run 04_train_yolo.py to train the model") + print("="*60) + + +def main(): + """Main function""" + split_dataset() + + +if __name__ == "__main__": + main() diff --git a/scripts/04_train_yolo.py b/scripts/04_train_yolo.py new file mode 100644 index 0000000..c25b35c --- /dev/null +++ b/scripts/04_train_yolo.py @@ -0,0 +1,230 @@ +""" +YOLO Training Script - Step 4 +Trains YOLOv8 model on the prepared invoice dataset +""" + +from pathlib import Path +from ultralytics import YOLO +import torch +import os + +# Paths +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +DATASET_YAML = Path(BASE_DIR + "/data/yolo_dataset/dataset.yaml") +MODELS_DIR = Path(BASE_DIR + "/models") + +# Training configuration +MODEL_SIZE = "n" # Options: n (nano), s (small), m (medium), l (large), x (xlarge) +EPOCHS = 100 +BATCH_SIZE = 16 +IMAGE_SIZE = 640 +DEVICE = 0 if torch.cuda.is_available() else "cpu" # Use GPU with PyTorch 2.7 + CUDA 12.8 + +# Create models directory +MODELS_DIR.mkdir(exist_ok=True) + + +def train_model( + model_size=MODEL_SIZE, + epochs=EPOCHS, + batch_size=BATCH_SIZE, + img_size=IMAGE_SIZE, + device=DEVICE +): + """ + Train YOLOv8 model on invoice dataset + + Args: + model_size: Size of YOLO model (n, s, m, l, x) + epochs: Number of training epochs + batch_size: Batch size for training + img_size: Input image size + device: Device to use for training (cuda or cpu) + """ + print("="*60) + print("YOLOv8 Invoice Detection Training") + print("="*60) + + # Check if dataset.yaml exists + if not DATASET_YAML.exists(): + print(f"Error: {DATASET_YAML} not found") + print("Please ensure the dataset.yaml file exists") + return + + # Print configuration + print(f"\nConfiguration:") + print(f" Model: YOLOv8{model_size}") + print(f" Epochs: {epochs}") + print(f" Batch size: {batch_size}") + print(f" Image size: {img_size}") + print(f" Device: {device}") + print(f" Dataset config: {DATASET_YAML}") + print() + + # Initialize model + print(f"Loading YOLOv8{model_size} model...") + model = YOLO(f"yolov8{model_size}.pt") # Load pretrained model + + # Print device info + if device == 0: + print(f"Using GPU: {torch.cuda.get_device_name(0)}") + print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") + else: + print("Using CPU (training will be slower)") + + print("\nStarting training...") + print("-" * 60) + + # Train the model + results = model.train( + data=str(DATASET_YAML), + epochs=epochs, + imgsz=img_size, + batch=batch_size, + device=device, + project=str(MODELS_DIR), + name="payment_slip_detector_v1", + exist_ok=True, + patience=20, # Early stopping patience + save=True, + save_period=10, # Save checkpoint every 10 epochs + verbose=True, + plots=True # Generate training plots + ) + + print("\n" + "="*60) + print("Training complete!") + print("="*60) + + # Print results + best_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt" + last_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "last.pt" + + print(f"\nTrained models saved to:") + print(f" Best model: {best_model_path}") + print(f" Last model: {last_model_path}") + + print(f"\nTraining plots saved to:") + print(f" {MODELS_DIR / 'payment_slip_detector_v1'}") + + print("\n" + "="*60) + + +def validate_model(model_path=None): + """ + Validate trained model on validation set + + Args: + model_path: Path to model weights (default: best.pt from last training) + """ + if model_path is None: + model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt" + + if not Path(model_path).exists(): + print(f"Error: Model not found at {model_path}") + print("Please train a model first") + return + + print("="*60) + print("Validating Model") + print("="*60) + print(f"Model: {model_path}\n") + + # Load model + model = YOLO(str(model_path)) + + # Validate + results = model.val(data=str(DATASET_YAML)) + + print("\n" + "="*60) + print("Validation complete!") + print("="*60) + + +def predict_sample(model_path=None, image_path=None, conf_threshold=0.25): + """ + Run prediction on a sample image + + Args: + model_path: Path to model weights + image_path: Path to image to predict on + conf_threshold: Confidence threshold for detections + """ + if model_path is None: + model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt" + + if not Path(model_path).exists(): + print(f"Error: Model not found at {model_path}") + return + + if image_path is None: + # Try to get a sample from validation set + val_images_dir = Path("data/yolo_dataset/images/val") + sample_images = list(val_images_dir.glob("*.jpg")) + list(val_images_dir.glob("*.png")) + if sample_images: + image_path = sample_images[0] + else: + print("No sample images found") + return + + print("="*60) + print("Running Prediction") + print("="*60) + print(f"Model: {model_path}") + print(f"Image: {image_path}") + print(f"Confidence threshold: {conf_threshold}\n") + + # Load model + model = YOLO(str(model_path)) + + # Predict + results = model.predict( + source=str(image_path), + conf=conf_threshold, + save=True, + project=str(MODELS_DIR / "predictions"), + name="sample" + ) + + print(f"\nPrediction saved to: {MODELS_DIR / 'predictions' / 'sample'}") + print("="*60) + + +def main(): + """Main training function""" + import argparse + + parser = argparse.ArgumentParser(description="Train YOLOv8 on invoice dataset") + parser.add_argument("--mode", type=str, default="train", choices=["train", "validate", "predict"], + help="Mode: train, validate, or predict") + parser.add_argument("--model-size", type=str, default=MODEL_SIZE, + choices=["n", "s", "m", "l", "x"], + help="YOLO model size") + parser.add_argument("--epochs", type=int, default=EPOCHS, + help="Number of training epochs") + parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, + help="Batch size") + parser.add_argument("--img-size", type=int, default=IMAGE_SIZE, + help="Image size") + parser.add_argument("--model-path", type=str, default=None, + help="Path to model weights (for validate/predict)") + parser.add_argument("--image-path", type=str, default=None, + help="Path to image (for predict)") + + args = parser.parse_args() + + if args.mode == "train": + train_model( + model_size=args.model_size, + epochs=args.epochs, + batch_size=args.batch_size, + img_size=args.img_size + ) + elif args.mode == "validate": + validate_model(model_path=args.model_path) + elif args.mode == "predict": + predict_sample(model_path=args.model_path, image_path=args.image_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/main.py b/scripts/main.py new file mode 100644 index 0000000..82b81e6 --- /dev/null +++ b/scripts/main.py @@ -0,0 +1,237 @@ +# --- main.py (已升级) --- + +from fastapi import FastAPI, UploadFile, File, HTTPException +from ultralytics import YOLO +import cv2 +import numpy as np +import pytesseract +import re +import io +import os +from contextlib import asynccontextmanager + +# --- 配置 --- +# TODO: 确保此路径指向您训练好的最佳模型 +MODEL_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "models", "invoice_detector_v1", "weights", "best.pt" +) + +# 定义一个字典来在 FastAPI 启动时加载模型 +ml_models = {} + +@asynccontextmanager +async def lifespan(app: FastAPI): + # 启动时加载模型 + print(MODEL_PATH) + if not os.path.exists(MODEL_PATH): + print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。") + ml_models["yolo"] = None + else: + # 加载您的 "payment_slip" 区域检测器 + ml_models["yolo"] = YOLO(MODEL_PATH) + print("YOLOv8 区域检测模型加载成功。") + yield + # 清理模型 + ml_models.clear() + +app = FastAPI(lifespan=lifespan) + +# --- Luhn (Modulus 10) 校验函数 --- +def luhn_validate(number_str: str, expected_check_digit: str) -> bool: + """ + 使用 Modulus 10 (Luhn) 算法验证一个数字字符串。 + (从右到左, 权重 1, 2, 1, 2...) + """ + try: + digits = [int(d) for d in number_str] + weights = [1, 2] * (len(digits) // 2 + 1) + weights = weights[:len(digits)] # 确保权重列表长度一致 + + sum_val = 0 + # 从右到左计算 + for d, w in zip(reversed(digits), weights): + product = d * w + if product >= 10: + sum_val += (product // 10) + (product % 10) + else: + sum_val += product + + calculated_check_digit = (10 - (sum_val % 10)) % 10 + return str(calculated_check_digit) == expected_check_digit + + except Exception: + return False + +# --- 提取逻辑 --- + +def parse_ocr_rad(ocr_text: str) -> dict: + """ + Plan A: 尝试解析机器可读码 (OCR-rad) + 示例: # 400299582421 # 4603 00 7 > 48180020 #14# + """ + # 移除所有空格以简化匹配 + text_no_space = re.sub(r'\s+', '', ocr_text) + + # 定义一个更健壮的正则表达式 + # 组1: OCR号 (在 #...# 之间) + # 组2: 金额 (Kronor + Öre) (在 #... 之后) + # 组3: 校验码 (1位数字) + # 组4: 账户号 (在 >...# 之间) + rad_regex = r'#([\d>]+)#(\d{2,})(\d)(\d{1})>([\d]+)#' + + match = re.search(rad_regex, text_no_space) + + if not match: + return None # 未找到机读码行 + + try: + ocr_num = match.group(1).replace(">", "") # 移除 > (如果有) + amount_base = match.group(2) + match.group(3) # "4603" + "00" = "460300" + check_digit = match.group(4) # "7" + account = match.group(5) # "48180020" + + # 运行Luhn校验 + if luhn_validate(amount_base, check_digit): + # 校验成功! 这是高置信度数据 + amount_kronor = amount_base[:-2] + amount_ore = amount_base[-2:] + + return { + "source": "OCR-rad (High Confidence)", + "ocr_number": ocr_num, + "amount_due": f"{amount_kronor}.{amount_ore}", + "bankgiro_plusgiro": account, + "due_date": None # 机读码行通常不包含日期 + } + else: + # 校验失败! + print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}") + return None + + except Exception as e: + print(f"解析 OCR-rad 时出错: {e}") + return None + +def parse_human_readable(ocr_text: str) -> dict: + """ + Plan B: 回退到人工可读区域 + (这里我们使用之前版本中的简单 Regex, + 您也可以替换为您那个更复杂的 classify_text 逻辑) + """ + data = {"source": "Human-Readable (Fallback)"} + + # 查找 BG/PG (Bankgiro/Plusgiro) + bg_match = re.search(r'Bankgiro\D*(\d{2,4}[- ]\d{4})', ocr_text, re.IGNORECASE) + pg_match = re.search(r'PlusGiro\D*(\d{2,7}[- ]\d)', ocr_text, re.IGNORECASE) + if bg_match: + data["bankgiro_plusgiro"] = bg_match.group(1).replace(" ", "") + elif pg_match: + data["bankgiro_plusgiro"] = pg_match.group(1).replace(" ", "") + else: + # 备用查找 + bg_pg_alt = re.search(r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', ocr_text) + if bg_pg_alt: + data["bankgiro_plusgiro"] = bg_pg_alt.group(0).replace(" ", "") + + # 查找 OCR + ocr_match = re.search(r'(OCR|Fakturanummer|Referens)\D*(\d[\d\s]{5,}\d)', ocr_text, re.IGNORECASE) + if ocr_match: + data["ocr_number"] = re.sub(r'\s', '', ocr_match.group(2)) + + # 查找金额 + amount_match = re.search(r'(Att betala|Belopp)\D*([\d\s,.]+)\s*(kr|SEK)?', ocr_text, re.IGNORECASE) + if amount_match: + amount_str = amount_match.group(2).strip().replace(" ", "").replace(",", ".") + if amount_str.count('.') > 1: + amount_str = amount_str.replace(".", "", amount_str.count('.') - 1) + data["amount_due"] = amount_str + + # 查找截止日期 + date_match = re.search(r'(senast|Förfallodag)\D*(\d{4}[- ]\d{2}[- ]\d{2})', ocr_text, re.IGNORECASE) + if date_match: + data["due_date"] = date_match.group(2).replace(" ", "-") + + return data + + +def extract_info_from_crop(crop_image: np.ndarray) -> dict: + """ + 主提取函数: 执行 Plan A 和 Plan B + """ + try: + # 1. 预处理并运行 OCR (获取所有文本) + gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY) + thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] + + # 使用 psm 6 假设是一个统一的文本块, 这对 OCR-rad 很友好 + ocr_text = pytesseract.image_to_string(thresh, lang='swe', config='--psm 6') + + # 2. --- Plan A: 尝试解析机读码 --- + rad_data = parse_ocr_rad(ocr_text) + + if rad_data: + # Plan A 成功! + return rad_data + + # 3. --- Plan B: Plan A 失败, 回退到人工读取 --- + # 我们重新运行 OCR, 使用 psm 3 (自动布局), 这对人工区域更友好 + ocr_text_human = pytesseract.image_to_string(thresh, lang='swe', config='--psm 3') + + human_data = parse_human_readable(ocr_text_human) + + # 即使回退失败, 也返回空字典 (或部分数据) + return human_data + + except Exception as e: + return {"error": f"提取时出错: {e}"} + + +@app.post("/extract_invoice/") +async def extract_invoice_data(file: UploadFile = File(...)): + + if ml_models.get("yolo") is None: + raise HTTPException(status_code=503, detail="模型未加载。请检查模型路径。") + + try: + contents = await file.read() + nparr = np.frombuffer(contents, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=400, detail="无法解码图像文件。") + + except Exception as e: + raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}") + + # 1. 运行 YOLO 检测 (阶段一) + results = ml_models["yolo"](img, verbose=False) + + all_extractions = [] + + if not results or not results[0].boxes: + return {"message": "未在图像中检测到支付凭证区域。"} + + for res in results: + # 遍历所有检测到的凭证 (通常只有一个) + for box_coords in res.boxes.xyxy.cpu().numpy().astype(int): + xmin, ymin, xmax, ymax = box_coords + + # 2. 裁剪图像 + crop = img[ymin:ymax, xmin:xmax] + + # 3. 在裁剪图上运行 "Plan A/B" 提取 (阶段二) + extracted_data = extract_info_from_crop(crop) + extracted_data["bounding_box"] = [xmin, ymin, xmax, ymax] + all_extractions.append(extracted_data) + + if not all_extractions: + return {"message": "检测到支付凭证, 但未能提取任何信息。"} + + return {"invoice_extractions": all_extractions} + +if __name__ == "__main__": + import uvicorn + print(f"--- 启动 FastAPI 服务 ---") + print(f"加载模型: {MODEL_PATH}") + print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档") + uvicorn.run(app, host="127.0.0.1", port=8000) \ No newline at end of file diff --git a/test_api.py b/test_api.py new file mode 100644 index 0000000..06c04f1 --- /dev/null +++ b/test_api.py @@ -0,0 +1,14 @@ +import requests +import json + +# Test the API with the problematic invoice +url = "http://127.0.0.1:8000/extract_invoice/" +file_path = r"data\processed_images\4BC5E5B3-E561-4A73-BC9C-46D4F08F89C3.png" + +with open(file_path, 'rb') as f: + files = {'file': f} + response = requests.post(url, files=files) + +print("Status Code:", response.status_code) +print("\nResponse JSON:") +print(json.dumps(response.json(), indent=2, ensure_ascii=False))