Init
This commit is contained in:
28
.claude/settings.local.json
Normal file
28
.claude/settings.local.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(mkdir:*)",
|
||||
"Bash(python:*)",
|
||||
"Bash(pip install:*)",
|
||||
"Bash(dir:*)",
|
||||
"Bash(tesseract:*)",
|
||||
"Bash(nvidia-smi:*)",
|
||||
"Bash(pip uninstall:*)",
|
||||
"Bash(for img in ../../images/train/*.jpg)",
|
||||
"Bash(do basename=\"$img##*/\")",
|
||||
"Bash(labelname=\"$basename%.jpg.txt\")",
|
||||
"Bash(if [ -f \"../../temp_visual_labels/$labelname\" ])",
|
||||
"Bash(then cp \"../../temp_visual_labels/$labelname\" .)",
|
||||
"Bash(fi)",
|
||||
"Bash(done)",
|
||||
"Bash(awk:*)",
|
||||
"Bash(chcp 65001)",
|
||||
"Bash(PYTHONIOENCODING=utf-8 python:*)",
|
||||
"Bash(timeout 600 tail:*)",
|
||||
"Bash(cat:*)",
|
||||
"Bash(powershell -Command \"$response = Invoke-WebRequest -Uri ''http://127.0.0.1:8000/extract_invoice/'' -Method POST -Form @{file=Get-Item ''data\\processed_images\\4BC5E5B3-E561-4A73-BC9C-46D4F08F89C3.png''} -UseBasicParsing; $response.Content\")"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
}
|
||||
}
|
||||
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
/data
|
||||
/.idea
|
||||
/__pycache__/
|
||||
137
config.py
Normal file
137
config.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Configuration file for system-dependent paths and settings.
|
||||
|
||||
This file contains paths that may vary between different systems.
|
||||
Copy this file and modify the paths according to your local installation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# ============================================================================
|
||||
# System Paths - Modify these according to your installation
|
||||
# ============================================================================
|
||||
|
||||
# Poppler path (required for PDF to image conversion)
|
||||
# Download from: https://github.com/oschwartz10612/poppler-windows/releases
|
||||
# Example: r"C:\poppler-23.11.0\bin"
|
||||
POPPLER_PATH = os.getenv("POPPLER_PATH", r"C:\Program Files\poppler-25.07.0\Library\bin")
|
||||
|
||||
# Tesseract path (optional - only needed if not in system PATH)
|
||||
# Download from: https://github.com/UB-Mannheim/tesseract/wiki
|
||||
# Example: r"C:\Program Files\Tesseract-OCR\tesseract.exe"
|
||||
TESSERACT_CMD = os.getenv("TESSERACT_CMD", None)
|
||||
|
||||
# ============================================================================
|
||||
# Project Paths - Generally don't need to modify these
|
||||
# ============================================================================
|
||||
|
||||
# Project root directory
|
||||
PROJECT_ROOT = Path(__file__).parent.absolute()
|
||||
|
||||
# Data directories
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
RAW_INVOICES_DIR = DATA_DIR / "raw_invoices"
|
||||
PROCESSED_IMAGES_DIR = DATA_DIR / "processed_images"
|
||||
OCR_RESULTS_DIR = DATA_DIR / "ocr_results"
|
||||
|
||||
# YOLO dataset directories
|
||||
YOLO_DATASET_DIR = DATA_DIR / "yolo_dataset"
|
||||
YOLO_TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images"
|
||||
YOLO_TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels"
|
||||
YOLO_TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train"
|
||||
YOLO_TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train"
|
||||
YOLO_VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val"
|
||||
YOLO_VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val"
|
||||
|
||||
# Model directories
|
||||
MODELS_DIR = PROJECT_ROOT / "models"
|
||||
DEFAULT_MODEL_PATH = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||
|
||||
# ============================================================================
|
||||
# OCR Settings
|
||||
# ============================================================================
|
||||
|
||||
# Tesseract language (Swedish + English)
|
||||
TESSERACT_LANG = "swe" # Ensure Swedish language pack is installed
|
||||
|
||||
# OCR confidence threshold (0-100)
|
||||
OCR_CONFIDENCE_THRESHOLD = 0
|
||||
|
||||
# ============================================================================
|
||||
# Training Settings
|
||||
# ============================================================================
|
||||
|
||||
# YOLO model size: n (nano), s (small), m (medium), l (large), x (xlarge)
|
||||
YOLO_MODEL_SIZE = "n"
|
||||
|
||||
# Training epochs
|
||||
TRAINING_EPOCHS = 100
|
||||
|
||||
# Batch size
|
||||
BATCH_SIZE = 16
|
||||
|
||||
# Image size for training
|
||||
IMAGE_SIZE = 640
|
||||
|
||||
# Validation split ratio (0.0 to 1.0)
|
||||
VALIDATION_SPLIT = 0.2
|
||||
|
||||
# Random seed for reproducibility
|
||||
RANDOM_SEED = 42
|
||||
|
||||
# ============================================================================
|
||||
# API Settings (for main.py FastAPI server)
|
||||
# ============================================================================
|
||||
|
||||
# API host
|
||||
API_HOST = "127.0.0.1"
|
||||
|
||||
# API port
|
||||
API_PORT = 8000
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
def apply_tesseract_path():
|
||||
"""Apply Tesseract path if configured."""
|
||||
if TESSERACT_CMD:
|
||||
import pytesseract
|
||||
pytesseract.pytesseract.tesseract_cmd = TESSERACT_CMD
|
||||
|
||||
def validate_paths():
|
||||
"""Validate that required system paths exist."""
|
||||
issues = []
|
||||
|
||||
# Check Poppler
|
||||
if not os.path.exists(POPPLER_PATH):
|
||||
issues.append(f"Poppler not found at: {POPPLER_PATH}")
|
||||
issues.append(" Download from: https://github.com/oschwartz10612/poppler-windows/releases")
|
||||
|
||||
# Check Tesseract (if specified)
|
||||
if TESSERACT_CMD and not os.path.exists(TESSERACT_CMD):
|
||||
issues.append(f"Tesseract not found at: {TESSERACT_CMD}")
|
||||
issues.append(" Download from: https://github.com/UB-Mannheim/tesseract/wiki")
|
||||
|
||||
if issues:
|
||||
print("Configuration Issues Found:")
|
||||
for issue in issues:
|
||||
print(f" {issue}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# ============================================================================
|
||||
# Example: Environment Variable Override
|
||||
# ============================================================================
|
||||
# You can set these in your environment instead of modifying this file:
|
||||
#
|
||||
# Windows:
|
||||
# set POPPLER_PATH=C:\poppler\bin
|
||||
# set TESSERACT_CMD=C:\Program Files\Tesseract-OCR\tesseract.exe
|
||||
#
|
||||
# Linux/Mac:
|
||||
# export POPPLER_PATH=/usr/bin
|
||||
# export TESSERACT_CMD=/usr/bin/tesseract
|
||||
# ============================================================================
|
||||
17
extraction_results.json
Normal file
17
extraction_results.json
Normal file
@@ -0,0 +1,17 @@
|
||||
[
|
||||
{
|
||||
"image": "data\\processed_images\\20250917.03.1.011299_c328f5a8-06f9-4093-85b5-e3a40f24bd30_page_1.jpg",
|
||||
"fields": {},
|
||||
"all_detections": []
|
||||
},
|
||||
{
|
||||
"image": "data\\processed_images\\64A80892-8A9E-454C-9AEB-B740E8C3ACB3_page_1.jpg",
|
||||
"fields": {},
|
||||
"all_detections": []
|
||||
},
|
||||
{
|
||||
"image": "data\\processed_images\\9fb6129f-671d-4aa1-9bad-096e84e6ded3_page_1.jpg",
|
||||
"fields": {},
|
||||
"all_detections": []
|
||||
}
|
||||
]
|
||||
45
requirements.txt
Normal file
45
requirements.txt
Normal file
@@ -0,0 +1,45 @@
|
||||
# Core dependencies
|
||||
ultralytics>=8.0.0 # YOLOv8
|
||||
pytesseract>=0.3.10 # Tesseract OCR Python wrapper
|
||||
|
||||
# Image processing
|
||||
pdf2image>=1.16.0 # PDF to image conversion
|
||||
Pillow>=10.0.0 # Image manipulation
|
||||
opencv-python>=4.8.0 # Computer vision
|
||||
|
||||
# Data handling
|
||||
numpy>=1.24.0
|
||||
pandas>=2.0.0
|
||||
scikit-learn>=1.3.0 # For DBSCAN clustering in 02_create_labels.py
|
||||
|
||||
# API dependencies (for main.py)
|
||||
fastapi>=0.104.0 # FastAPI web framework
|
||||
uvicorn>=0.24.0 # ASGI server
|
||||
python-multipart>=0.0.6 # For file upload support
|
||||
|
||||
# System utilities
|
||||
# IMPORTANT: Requires system-level installation of:
|
||||
#
|
||||
# 1. Tesseract OCR:
|
||||
# - Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki
|
||||
# After installation, add to PATH or set: pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
|
||||
# - Linux: sudo apt-get install tesseract-ocr tesseract-ocr-swe tesseract-ocr-eng
|
||||
# - macOS: brew install tesseract tesseract-lang
|
||||
#
|
||||
# 2. Poppler (for pdf2image):
|
||||
# - Windows: Download from https://github.com/oschwartz10612/poppler-windows/releases
|
||||
# - Linux: sudo apt-get install poppler-utils
|
||||
# - macOS: brew install poppler
|
||||
#
|
||||
# 3. Swedish language data for Tesseract:
|
||||
# After installing Tesseract, you may need to download Swedish language files (swe.traineddata)
|
||||
# from https://github.com/tesseract-ocr/tessdata
|
||||
|
||||
# Optional: GPU support
|
||||
# torch>=2.0.0 # PyTorch with CUDA support
|
||||
# torchvision>=0.15.0
|
||||
|
||||
# Development tools (optional)
|
||||
# jupyter>=1.0.0
|
||||
# matplotlib>=3.7.0
|
||||
# seaborn>=0.12.0
|
||||
106
scripts/01_process_invoices.py
Normal file
106
scripts/01_process_invoices.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# --- scripts/01_process_invoices.py ---
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pytesseract
|
||||
import pandas as pd
|
||||
from pdf2image import convert_from_path
|
||||
import cv2
|
||||
import numpy as np
|
||||
import shutil
|
||||
|
||||
# Add parent directory to path to import config
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from config import POPPLER_PATH, apply_tesseract_path
|
||||
|
||||
# Apply Tesseract path from config
|
||||
apply_tesseract_path()
|
||||
|
||||
# 项目路径设置
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
RAW_DIR = os.path.join(BASE_DIR, "data", "raw_invoices")
|
||||
IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images")
|
||||
OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results")
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(IMG_DIR, exist_ok=True)
|
||||
os.makedirs(OCR_DIR, exist_ok=True)
|
||||
|
||||
def process_invoices():
|
||||
print(f"开始处理 {RAW_DIR} 中的发票...")
|
||||
|
||||
for filename in os.listdir(RAW_DIR):
|
||||
filepath = os.path.join(RAW_DIR, filename)
|
||||
base_name = os.path.splitext(filename)[0]
|
||||
img_path = os.path.join(IMG_DIR, f"{base_name}.png")
|
||||
json_path = os.path.join(OCR_DIR, f"{base_name}.json")
|
||||
|
||||
# 防止重复处理
|
||||
if os.path.exists(img_path) and os.path.exists(json_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
# 1. 加载图像 (PDF 或 图片)
|
||||
if filename.lower().endswith(".pdf"):
|
||||
images = convert_from_path(filepath, poppler_path=POPPLER_PATH)
|
||||
if not images:
|
||||
print(f"警告: 无法从 {filename} 提取图像。")
|
||||
continue
|
||||
img_pil = images[0] # 取第一页
|
||||
img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
img = cv2.imread(filepath)
|
||||
if img is None:
|
||||
print(f"警告: 无法读取图像文件 {filename}。")
|
||||
continue
|
||||
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
# 2. 保存统一格式的图像
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
# 3. 运行 Tesseract OCR
|
||||
ocr_data_df = pytesseract.image_to_data(
|
||||
img,
|
||||
lang='swe', # 确保已安装瑞典语包
|
||||
output_type=pytesseract.Output.DATAFRAME
|
||||
)
|
||||
|
||||
# 4. 清理 OCR 结果
|
||||
ocr_data_df = ocr_data_df[ocr_data_df.conf > 0]
|
||||
ocr_data_df.dropna(subset=['text'], inplace=True)
|
||||
ocr_data_df['text'] = ocr_data_df['text'].astype(str).str.strip()
|
||||
ocr_data_df = ocr_data_df[ocr_data_df['text'] != ""]
|
||||
|
||||
# 5. 转换为您在另一脚本中使用的 JSON 格式 (包含 text_boxes)
|
||||
text_boxes = []
|
||||
for i, row in ocr_data_df.iterrows():
|
||||
text_boxes.append({
|
||||
"text": row["text"],
|
||||
"bbox": {
|
||||
"x_min": row["left"],
|
||||
"y_min": row["top"],
|
||||
"x_max": row["left"] + row["width"],
|
||||
"y_max": row["top"] + row["height"]
|
||||
},
|
||||
"confidence": row["conf"] / 100.0
|
||||
})
|
||||
|
||||
output_json = {
|
||||
"image_name": f"{base_name}.png",
|
||||
"width": img_w,
|
||||
"height": img_h,
|
||||
"text_boxes": text_boxes
|
||||
}
|
||||
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_json, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"已处理: {filename}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 {filename} 时出错: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
process_invoices()
|
||||
223
scripts/02_create_labels.py
Normal file
223
scripts/02_create_labels.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# --- scripts/02_create_labels.py ---
|
||||
#
|
||||
# **重要**: 此脚本用于训练 "阶段一" 的区域检测器.
|
||||
# 它只生成 1 个类别 (class_id = 0), 即 "payment_slip" 的 *整个* 区域.
|
||||
# 它不使用您的高级 classify_text 逻辑, 而是使用启发式规则 (关键词, 线条) 来找到大区域.
|
||||
#
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import cv2
|
||||
import re
|
||||
import shutil
|
||||
import json
|
||||
from sklearn.cluster import DBSCAN
|
||||
|
||||
# --- 锚点词典 (技术二) ---
|
||||
# 这些是用来 *定位* 凭证区域的词, 不是用来提取的
|
||||
KEYWORDS = [
|
||||
"Bankgirot", "PlusGirot", "OCR-nummer", "Att betala", "Mottagare",
|
||||
"Betalningsmottagare", "Tillhanda senast", "Förfallodag", "Belopp",
|
||||
"BG-nr", "PG-nr", "Meddelande", "OCR-kod", "Inbetalningskort", "Betalningsavi"
|
||||
]
|
||||
KEYWORDS_REGEX = '|'.join(KEYWORDS)
|
||||
|
||||
REGEX_PATTERNS = {
|
||||
'bg_pg': r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', # Bankgiro/PlusGiro
|
||||
'long_num': r'\b\d{10,}\b', # 可能是 OCR
|
||||
'machine_code': r'#[0-9\s>#]+#' # 机读码
|
||||
}
|
||||
FULL_REGEX = f"({KEYWORDS_REGEX})|{REGEX_PATTERNS['bg_pg']}|{REGEX_PATTERNS['long_num']}|{REGEX_PATTERNS['machine_code']}"
|
||||
|
||||
# --- 路径设置 ---
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results")
|
||||
IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images")
|
||||
# Use temp directories for initial label generation (before train/val split)
|
||||
LABEL_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_labels")
|
||||
IMAGE_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_images")
|
||||
|
||||
os.makedirs(LABEL_DIR, exist_ok=True)
|
||||
os.makedirs(IMAGE_DIR, exist_ok=True)
|
||||
|
||||
def find_text_anchors(text_boxes, img_height):
|
||||
"""技术一 (位置) + 技术二 (词典)"""
|
||||
anchors = []
|
||||
if not text_boxes:
|
||||
return []
|
||||
|
||||
# 技术一: 只在页面下半部分 (40% 处开始) 查找
|
||||
page_midpoint = img_height * 0.4
|
||||
|
||||
for box in text_boxes:
|
||||
# 检查位置
|
||||
if box["bbox"]["y_min"] > page_midpoint:
|
||||
# 检查文本内容
|
||||
if re.search(FULL_REGEX, box["text"], re.IGNORECASE):
|
||||
bbox = box["bbox"]
|
||||
anchors.append(pd.Series({
|
||||
'left': bbox["x_min"],
|
||||
'top': bbox["y_min"],
|
||||
'width': bbox["x_max"] - bbox["x_min"],
|
||||
'height': bbox["y_max"] - bbox["y_min"]
|
||||
}))
|
||||
return anchors
|
||||
|
||||
def find_visual_anchors(image, img_height, img_width):
|
||||
"""技术三 (视觉锚点 - 找线)"""
|
||||
anchors = []
|
||||
try:
|
||||
# 1. 只看下半页
|
||||
crop_y_start = img_height // 2
|
||||
crop = image[crop_y_start:, :]
|
||||
|
||||
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
|
||||
thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)[1]
|
||||
|
||||
# 2. 查找长的水平线
|
||||
min_line_length = img_width // 4
|
||||
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_line_length, 1))
|
||||
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
|
||||
|
||||
contours, _ = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for c in contours:
|
||||
x, y, w, h = cv2.boundingRect(c)
|
||||
if w > min_line_length: # 确保线足够长
|
||||
original_y = y + crop_y_start # 转换回原图坐标
|
||||
anchors.append(pd.Series({
|
||||
'left': x, 'top': original_y, 'width': w, 'height': h
|
||||
}))
|
||||
except Exception as e:
|
||||
print(f"查找视觉锚点时出错: {e}")
|
||||
return anchors
|
||||
|
||||
def cluster_anchors(all_anchors, img_height):
|
||||
"""技术四 (聚类)"""
|
||||
if len(all_anchors) < 2:
|
||||
return all_anchors # 锚点太少,无法聚类
|
||||
|
||||
# 1. 获取中心点
|
||||
points = []
|
||||
for anchor in all_anchors:
|
||||
x_center = anchor.left + anchor.width / 2
|
||||
y_center = anchor.top + anchor.height / 2
|
||||
points.append((x_center, y_center))
|
||||
|
||||
points = np.array(points)
|
||||
|
||||
# 2. 运行 DBSCAN
|
||||
eps_dist = img_height * 0.2
|
||||
clustering = DBSCAN(eps=eps_dist, min_samples=2).fit(points)
|
||||
labels = clustering.labels_
|
||||
|
||||
if len(labels) == 0 or np.all(labels == -1):
|
||||
return all_anchors # 聚类失败
|
||||
|
||||
# 3. 找到最大的那个簇
|
||||
unique_labels, counts = np.unique(labels[labels != -1], return_counts=True)
|
||||
if len(counts) == 0:
|
||||
return all_anchors # 只有噪声
|
||||
|
||||
largest_cluster_label = unique_labels[np.argmax(counts)]
|
||||
|
||||
# 4. 只返回属于最大簇的锚点
|
||||
main_cluster_anchors = [
|
||||
anchor for i, anchor in enumerate(all_anchors)
|
||||
if labels[i] == largest_cluster_label
|
||||
]
|
||||
return main_cluster_anchors
|
||||
|
||||
def create_labels():
|
||||
print("开始生成弱标签 (用于区域检测器)...")
|
||||
processed_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for ocr_filename in os.listdir(OCR_DIR):
|
||||
if not ocr_filename.endswith(".json"):
|
||||
continue
|
||||
|
||||
base_name = os.path.splitext(ocr_filename)[0]
|
||||
json_path = os.path.join(OCR_DIR, ocr_filename)
|
||||
img_path = os.path.join(IMG_DIR, f"{base_name}.png")
|
||||
|
||||
if not os.path.exists(img_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
# 1. 加载数据
|
||||
image = cv2.imread(img_path)
|
||||
img_h, img_w = image.shape[:2]
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
ocr_data = json.load(f)
|
||||
text_boxes = ocr_data.get("text_boxes", [])
|
||||
|
||||
# 2. 查找所有锚点
|
||||
text_anchors = find_text_anchors(text_boxes, img_h)
|
||||
visual_anchors = find_visual_anchors(image, img_h, img_w)
|
||||
all_anchors = text_anchors + visual_anchors
|
||||
|
||||
if not all_anchors:
|
||||
print(f"SKIPPING: {base_name} (未找到任何锚点)")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# 3. 聚类锚点
|
||||
final_anchors = cluster_anchors(all_anchors, img_h)
|
||||
|
||||
if not final_anchors:
|
||||
print(f"SKIPPING: {base_name} (未找到有效聚类)")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# 4. 聚合坐标
|
||||
min_x = min(a.left for a in final_anchors)
|
||||
min_y = min(a.top for a in final_anchors)
|
||||
max_x = max(a.left + a.width for a in final_anchors)
|
||||
max_y = max(a.top + a.height for a in final_anchors)
|
||||
|
||||
# 5. 添加边距 (Padding)
|
||||
padding = 10
|
||||
min_x = max(0, min_x - padding)
|
||||
min_y = max(0, min_y - padding)
|
||||
max_x = min(img_w, max_x + padding)
|
||||
max_y = min(img_h, max_y + padding)
|
||||
|
||||
# 6. 转换为 YOLO 格式
|
||||
box_w = max_x - min_x
|
||||
box_h = max_y - min_y
|
||||
x_center = (min_x + box_w / 2) / img_w
|
||||
y_center = (min_y + box_h / 2) / img_h
|
||||
norm_w = box_w / img_w
|
||||
norm_h = box_h / img_h
|
||||
|
||||
# ** 关键: 类别 ID 永远是 0 **
|
||||
class_id = 0 # 唯一的类别: 'payment_slip'
|
||||
|
||||
yolo_label = f"{class_id} {x_center} {y_center} {norm_w} {norm_h}\n"
|
||||
|
||||
# 7. 保存标签和图片
|
||||
label_path = os.path.join(LABEL_DIR, f"{base_name}.txt")
|
||||
with open(label_path, 'w', encoding='utf-8') as f:
|
||||
f.write(yolo_label)
|
||||
|
||||
shutil.copy(
|
||||
img_path,
|
||||
os.path.join(IMAGE_DIR, f"{base_name}.png")
|
||||
)
|
||||
|
||||
processed_count += 1
|
||||
if processed_count % 20 == 0:
|
||||
print(f"已生成 {processed_count} 个标签...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 {base_name} 时出错: {e}")
|
||||
skipped_count += 1
|
||||
|
||||
print("--- 弱标签生成完成 ---")
|
||||
print(f"成功生成: {processed_count} 个")
|
||||
print(f"跳过: {skipped_count} 个")
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_labels()
|
||||
123
scripts/03_split_dataset.py
Normal file
123
scripts/03_split_dataset.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Dataset Split Script - Step 3
|
||||
Splits images and labels into training and validation sets
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import random
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# Paths
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
YOLO_DATASET_DIR = Path(BASE_DIR + "/data/yolo_dataset")
|
||||
TEMP_IMAGES_DIR = YOLO_DATASET_DIR / "temp_all_images"
|
||||
TEMP_LABELS_DIR = YOLO_DATASET_DIR / "temp_all_labels"
|
||||
|
||||
TRAIN_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "train"
|
||||
VAL_IMAGES_DIR = YOLO_DATASET_DIR / "images" / "val"
|
||||
TRAIN_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "train"
|
||||
VAL_LABELS_DIR = YOLO_DATASET_DIR / "labels" / "val"
|
||||
|
||||
# Configuration
|
||||
VALIDATION_SPLIT = 0.2 # 20% for validation
|
||||
RANDOM_SEED = 42
|
||||
|
||||
|
||||
def split_dataset(val_split=VALIDATION_SPLIT, seed=RANDOM_SEED):
|
||||
"""
|
||||
Split dataset into training and validation sets
|
||||
|
||||
Args:
|
||||
val_split: Fraction of data to use for validation (0.0 to 1.0)
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
print("="*60)
|
||||
print("Splitting Dataset into Train/Val Sets")
|
||||
print("="*60)
|
||||
print(f"Validation split: {val_split*100:.1f}%")
|
||||
print(f"Random seed: {seed}\n")
|
||||
|
||||
# Check if temp directories exist
|
||||
if not TEMP_IMAGES_DIR.exists() or not TEMP_LABELS_DIR.exists():
|
||||
print(BASE_DIR)
|
||||
print(YOLO_DATASET_DIR)
|
||||
print(TEMP_IMAGES_DIR)
|
||||
print(f"Error: Temporary directories not found")
|
||||
print(f"Please run 02_create_labels.py first")
|
||||
return
|
||||
|
||||
# Get all image files
|
||||
image_files = list(TEMP_IMAGES_DIR.glob("*.jpg")) + list(TEMP_IMAGES_DIR.glob("*.png"))
|
||||
|
||||
if not image_files:
|
||||
print(f"No image files found in {TEMP_IMAGES_DIR}")
|
||||
return
|
||||
|
||||
# Filter images that have corresponding labels
|
||||
valid_pairs = []
|
||||
for image_file in image_files:
|
||||
label_file = TEMP_LABELS_DIR / (image_file.stem + ".txt")
|
||||
if label_file.exists():
|
||||
valid_pairs.append({
|
||||
"image": image_file,
|
||||
"label": label_file
|
||||
})
|
||||
|
||||
if not valid_pairs:
|
||||
print("No valid image-label pairs found")
|
||||
return
|
||||
|
||||
print(f"Found {len(valid_pairs)} image-label pair(s)")
|
||||
|
||||
# Shuffle and split
|
||||
random.seed(seed)
|
||||
random.shuffle(valid_pairs)
|
||||
|
||||
split_index = int(len(valid_pairs) * (1 - val_split))
|
||||
train_pairs = valid_pairs[:split_index]
|
||||
val_pairs = valid_pairs[split_index:]
|
||||
|
||||
print(f"\nSplit results:")
|
||||
print(f" Training set: {len(train_pairs)} samples")
|
||||
print(f" Validation set: {len(val_pairs)} samples")
|
||||
print()
|
||||
|
||||
# Clear existing train/val directories
|
||||
for directory in [TRAIN_IMAGES_DIR, VAL_IMAGES_DIR, TRAIN_LABELS_DIR, VAL_LABELS_DIR]:
|
||||
if directory.exists():
|
||||
shutil.rmtree(directory)
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy training files
|
||||
print("Copying training files...")
|
||||
for pair in train_pairs:
|
||||
shutil.copy(pair["image"], TRAIN_IMAGES_DIR / pair["image"].name)
|
||||
shutil.copy(pair["label"], TRAIN_LABELS_DIR / pair["label"].name)
|
||||
print(f" Copied {len(train_pairs)} image-label pairs to train/")
|
||||
|
||||
# Copy validation files
|
||||
print("Copying validation files...")
|
||||
for pair in val_pairs:
|
||||
shutil.copy(pair["image"], VAL_IMAGES_DIR / pair["image"].name)
|
||||
shutil.copy(pair["label"], VAL_LABELS_DIR / pair["label"].name)
|
||||
print(f" Copied {len(val_pairs)} image-label pairs to val/")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Dataset split complete!")
|
||||
print(f"\nDataset structure:")
|
||||
print(f" {TRAIN_IMAGES_DIR}")
|
||||
print(f" {TRAIN_LABELS_DIR}")
|
||||
print(f" {VAL_IMAGES_DIR}")
|
||||
print(f" {VAL_LABELS_DIR}")
|
||||
print(f"\nNext step: Run 04_train_yolo.py to train the model")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
split_dataset()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
230
scripts/04_train_yolo.py
Normal file
230
scripts/04_train_yolo.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
YOLO Training Script - Step 4
|
||||
Trains YOLOv8 model on the prepared invoice dataset
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
import os
|
||||
|
||||
# Paths
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
DATASET_YAML = Path(BASE_DIR + "/data/yolo_dataset/dataset.yaml")
|
||||
MODELS_DIR = Path(BASE_DIR + "/models")
|
||||
|
||||
# Training configuration
|
||||
MODEL_SIZE = "n" # Options: n (nano), s (small), m (medium), l (large), x (xlarge)
|
||||
EPOCHS = 100
|
||||
BATCH_SIZE = 16
|
||||
IMAGE_SIZE = 640
|
||||
DEVICE = 0 if torch.cuda.is_available() else "cpu" # Use GPU with PyTorch 2.7 + CUDA 12.8
|
||||
|
||||
# Create models directory
|
||||
MODELS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def train_model(
|
||||
model_size=MODEL_SIZE,
|
||||
epochs=EPOCHS,
|
||||
batch_size=BATCH_SIZE,
|
||||
img_size=IMAGE_SIZE,
|
||||
device=DEVICE
|
||||
):
|
||||
"""
|
||||
Train YOLOv8 model on invoice dataset
|
||||
|
||||
Args:
|
||||
model_size: Size of YOLO model (n, s, m, l, x)
|
||||
epochs: Number of training epochs
|
||||
batch_size: Batch size for training
|
||||
img_size: Input image size
|
||||
device: Device to use for training (cuda or cpu)
|
||||
"""
|
||||
print("="*60)
|
||||
print("YOLOv8 Invoice Detection Training")
|
||||
print("="*60)
|
||||
|
||||
# Check if dataset.yaml exists
|
||||
if not DATASET_YAML.exists():
|
||||
print(f"Error: {DATASET_YAML} not found")
|
||||
print("Please ensure the dataset.yaml file exists")
|
||||
return
|
||||
|
||||
# Print configuration
|
||||
print(f"\nConfiguration:")
|
||||
print(f" Model: YOLOv8{model_size}")
|
||||
print(f" Epochs: {epochs}")
|
||||
print(f" Batch size: {batch_size}")
|
||||
print(f" Image size: {img_size}")
|
||||
print(f" Device: {device}")
|
||||
print(f" Dataset config: {DATASET_YAML}")
|
||||
print()
|
||||
|
||||
# Initialize model
|
||||
print(f"Loading YOLOv8{model_size} model...")
|
||||
model = YOLO(f"yolov8{model_size}.pt") # Load pretrained model
|
||||
|
||||
# Print device info
|
||||
if device == 0:
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
||||
else:
|
||||
print("Using CPU (training will be slower)")
|
||||
|
||||
print("\nStarting training...")
|
||||
print("-" * 60)
|
||||
|
||||
# Train the model
|
||||
results = model.train(
|
||||
data=str(DATASET_YAML),
|
||||
epochs=epochs,
|
||||
imgsz=img_size,
|
||||
batch=batch_size,
|
||||
device=device,
|
||||
project=str(MODELS_DIR),
|
||||
name="payment_slip_detector_v1",
|
||||
exist_ok=True,
|
||||
patience=20, # Early stopping patience
|
||||
save=True,
|
||||
save_period=10, # Save checkpoint every 10 epochs
|
||||
verbose=True,
|
||||
plots=True # Generate training plots
|
||||
)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Training complete!")
|
||||
print("="*60)
|
||||
|
||||
# Print results
|
||||
best_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||
last_model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "last.pt"
|
||||
|
||||
print(f"\nTrained models saved to:")
|
||||
print(f" Best model: {best_model_path}")
|
||||
print(f" Last model: {last_model_path}")
|
||||
|
||||
print(f"\nTraining plots saved to:")
|
||||
print(f" {MODELS_DIR / 'payment_slip_detector_v1'}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
|
||||
|
||||
def validate_model(model_path=None):
|
||||
"""
|
||||
Validate trained model on validation set
|
||||
|
||||
Args:
|
||||
model_path: Path to model weights (default: best.pt from last training)
|
||||
"""
|
||||
if model_path is None:
|
||||
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||
|
||||
if not Path(model_path).exists():
|
||||
print(f"Error: Model not found at {model_path}")
|
||||
print("Please train a model first")
|
||||
return
|
||||
|
||||
print("="*60)
|
||||
print("Validating Model")
|
||||
print("="*60)
|
||||
print(f"Model: {model_path}\n")
|
||||
|
||||
# Load model
|
||||
model = YOLO(str(model_path))
|
||||
|
||||
# Validate
|
||||
results = model.val(data=str(DATASET_YAML))
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Validation complete!")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def predict_sample(model_path=None, image_path=None, conf_threshold=0.25):
|
||||
"""
|
||||
Run prediction on a sample image
|
||||
|
||||
Args:
|
||||
model_path: Path to model weights
|
||||
image_path: Path to image to predict on
|
||||
conf_threshold: Confidence threshold for detections
|
||||
"""
|
||||
if model_path is None:
|
||||
model_path = MODELS_DIR / "payment_slip_detector_v1" / "weights" / "best.pt"
|
||||
|
||||
if not Path(model_path).exists():
|
||||
print(f"Error: Model not found at {model_path}")
|
||||
return
|
||||
|
||||
if image_path is None:
|
||||
# Try to get a sample from validation set
|
||||
val_images_dir = Path("data/yolo_dataset/images/val")
|
||||
sample_images = list(val_images_dir.glob("*.jpg")) + list(val_images_dir.glob("*.png"))
|
||||
if sample_images:
|
||||
image_path = sample_images[0]
|
||||
else:
|
||||
print("No sample images found")
|
||||
return
|
||||
|
||||
print("="*60)
|
||||
print("Running Prediction")
|
||||
print("="*60)
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Image: {image_path}")
|
||||
print(f"Confidence threshold: {conf_threshold}\n")
|
||||
|
||||
# Load model
|
||||
model = YOLO(str(model_path))
|
||||
|
||||
# Predict
|
||||
results = model.predict(
|
||||
source=str(image_path),
|
||||
conf=conf_threshold,
|
||||
save=True,
|
||||
project=str(MODELS_DIR / "predictions"),
|
||||
name="sample"
|
||||
)
|
||||
|
||||
print(f"\nPrediction saved to: {MODELS_DIR / 'predictions' / 'sample'}")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train YOLOv8 on invoice dataset")
|
||||
parser.add_argument("--mode", type=str, default="train", choices=["train", "validate", "predict"],
|
||||
help="Mode: train, validate, or predict")
|
||||
parser.add_argument("--model-size", type=str, default=MODEL_SIZE,
|
||||
choices=["n", "s", "m", "l", "x"],
|
||||
help="YOLO model size")
|
||||
parser.add_argument("--epochs", type=int, default=EPOCHS,
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
|
||||
help="Batch size")
|
||||
parser.add_argument("--img-size", type=int, default=IMAGE_SIZE,
|
||||
help="Image size")
|
||||
parser.add_argument("--model-path", type=str, default=None,
|
||||
help="Path to model weights (for validate/predict)")
|
||||
parser.add_argument("--image-path", type=str, default=None,
|
||||
help="Path to image (for predict)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "train":
|
||||
train_model(
|
||||
model_size=args.model_size,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
img_size=args.img_size
|
||||
)
|
||||
elif args.mode == "validate":
|
||||
validate_model(model_path=args.model_path)
|
||||
elif args.mode == "predict":
|
||||
predict_sample(model_path=args.model_path, image_path=args.image_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
237
scripts/main.py
Normal file
237
scripts/main.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# --- main.py (已升级) ---
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytesseract
|
||||
import re
|
||||
import io
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# --- 配置 ---
|
||||
# TODO: 确保此路径指向您训练好的最佳模型
|
||||
MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"models", "invoice_detector_v1", "weights", "best.pt"
|
||||
)
|
||||
|
||||
# 定义一个字典来在 FastAPI 启动时加载模型
|
||||
ml_models = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 启动时加载模型
|
||||
print(MODEL_PATH)
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。")
|
||||
ml_models["yolo"] = None
|
||||
else:
|
||||
# 加载您的 "payment_slip" 区域检测器
|
||||
ml_models["yolo"] = YOLO(MODEL_PATH)
|
||||
print("YOLOv8 区域检测模型加载成功。")
|
||||
yield
|
||||
# 清理模型
|
||||
ml_models.clear()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# --- Luhn (Modulus 10) 校验函数 ---
|
||||
def luhn_validate(number_str: str, expected_check_digit: str) -> bool:
|
||||
"""
|
||||
使用 Modulus 10 (Luhn) 算法验证一个数字字符串。
|
||||
(从右到左, 权重 1, 2, 1, 2...)
|
||||
"""
|
||||
try:
|
||||
digits = [int(d) for d in number_str]
|
||||
weights = [1, 2] * (len(digits) // 2 + 1)
|
||||
weights = weights[:len(digits)] # 确保权重列表长度一致
|
||||
|
||||
sum_val = 0
|
||||
# 从右到左计算
|
||||
for d, w in zip(reversed(digits), weights):
|
||||
product = d * w
|
||||
if product >= 10:
|
||||
sum_val += (product // 10) + (product % 10)
|
||||
else:
|
||||
sum_val += product
|
||||
|
||||
calculated_check_digit = (10 - (sum_val % 10)) % 10
|
||||
return str(calculated_check_digit) == expected_check_digit
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# --- 提取逻辑 ---
|
||||
|
||||
def parse_ocr_rad(ocr_text: str) -> dict:
|
||||
"""
|
||||
Plan A: 尝试解析机器可读码 (OCR-rad)
|
||||
示例: # 400299582421 # 4603 00 7 > 48180020 #14#
|
||||
"""
|
||||
# 移除所有空格以简化匹配
|
||||
text_no_space = re.sub(r'\s+', '', ocr_text)
|
||||
|
||||
# 定义一个更健壮的正则表达式
|
||||
# 组1: OCR号 (在 #...# 之间)
|
||||
# 组2: 金额 (Kronor + Öre) (在 #... 之后)
|
||||
# 组3: 校验码 (1位数字)
|
||||
# 组4: 账户号 (在 >...# 之间)
|
||||
rad_regex = r'#([\d>]+)#(\d{2,})(\d)(\d{1})>([\d]+)#'
|
||||
|
||||
match = re.search(rad_regex, text_no_space)
|
||||
|
||||
if not match:
|
||||
return None # 未找到机读码行
|
||||
|
||||
try:
|
||||
ocr_num = match.group(1).replace(">", "") # 移除 > (如果有)
|
||||
amount_base = match.group(2) + match.group(3) # "4603" + "00" = "460300"
|
||||
check_digit = match.group(4) # "7"
|
||||
account = match.group(5) # "48180020"
|
||||
|
||||
# 运行Luhn校验
|
||||
if luhn_validate(amount_base, check_digit):
|
||||
# 校验成功! 这是高置信度数据
|
||||
amount_kronor = amount_base[:-2]
|
||||
amount_ore = amount_base[-2:]
|
||||
|
||||
return {
|
||||
"source": "OCR-rad (High Confidence)",
|
||||
"ocr_number": ocr_num,
|
||||
"amount_due": f"{amount_kronor}.{amount_ore}",
|
||||
"bankgiro_plusgiro": account,
|
||||
"due_date": None # 机读码行通常不包含日期
|
||||
}
|
||||
else:
|
||||
# 校验失败!
|
||||
print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析 OCR-rad 时出错: {e}")
|
||||
return None
|
||||
|
||||
def parse_human_readable(ocr_text: str) -> dict:
|
||||
"""
|
||||
Plan B: 回退到人工可读区域
|
||||
(这里我们使用之前版本中的简单 Regex,
|
||||
您也可以替换为您那个更复杂的 classify_text 逻辑)
|
||||
"""
|
||||
data = {"source": "Human-Readable (Fallback)"}
|
||||
|
||||
# 查找 BG/PG (Bankgiro/Plusgiro)
|
||||
bg_match = re.search(r'Bankgiro\D*(\d{2,4}[- ]\d{4})', ocr_text, re.IGNORECASE)
|
||||
pg_match = re.search(r'PlusGiro\D*(\d{2,7}[- ]\d)', ocr_text, re.IGNORECASE)
|
||||
if bg_match:
|
||||
data["bankgiro_plusgiro"] = bg_match.group(1).replace(" ", "")
|
||||
elif pg_match:
|
||||
data["bankgiro_plusgiro"] = pg_match.group(1).replace(" ", "")
|
||||
else:
|
||||
# 备用查找
|
||||
bg_pg_alt = re.search(r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', ocr_text)
|
||||
if bg_pg_alt:
|
||||
data["bankgiro_plusgiro"] = bg_pg_alt.group(0).replace(" ", "")
|
||||
|
||||
# 查找 OCR
|
||||
ocr_match = re.search(r'(OCR|Fakturanummer|Referens)\D*(\d[\d\s]{5,}\d)', ocr_text, re.IGNORECASE)
|
||||
if ocr_match:
|
||||
data["ocr_number"] = re.sub(r'\s', '', ocr_match.group(2))
|
||||
|
||||
# 查找金额
|
||||
amount_match = re.search(r'(Att betala|Belopp)\D*([\d\s,.]+)\s*(kr|SEK)?', ocr_text, re.IGNORECASE)
|
||||
if amount_match:
|
||||
amount_str = amount_match.group(2).strip().replace(" ", "").replace(",", ".")
|
||||
if amount_str.count('.') > 1:
|
||||
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
|
||||
data["amount_due"] = amount_str
|
||||
|
||||
# 查找截止日期
|
||||
date_match = re.search(r'(senast|Förfallodag)\D*(\d{4}[- ]\d{2}[- ]\d{2})', ocr_text, re.IGNORECASE)
|
||||
if date_match:
|
||||
data["due_date"] = date_match.group(2).replace(" ", "-")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def extract_info_from_crop(crop_image: np.ndarray) -> dict:
|
||||
"""
|
||||
主提取函数: 执行 Plan A 和 Plan B
|
||||
"""
|
||||
try:
|
||||
# 1. 预处理并运行 OCR (获取所有文本)
|
||||
gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY)
|
||||
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
||||
|
||||
# 使用 psm 6 假设是一个统一的文本块, 这对 OCR-rad 很友好
|
||||
ocr_text = pytesseract.image_to_string(thresh, lang='swe', config='--psm 6')
|
||||
|
||||
# 2. --- Plan A: 尝试解析机读码 ---
|
||||
rad_data = parse_ocr_rad(ocr_text)
|
||||
|
||||
if rad_data:
|
||||
# Plan A 成功!
|
||||
return rad_data
|
||||
|
||||
# 3. --- Plan B: Plan A 失败, 回退到人工读取 ---
|
||||
# 我们重新运行 OCR, 使用 psm 3 (自动布局), 这对人工区域更友好
|
||||
ocr_text_human = pytesseract.image_to_string(thresh, lang='swe', config='--psm 3')
|
||||
|
||||
human_data = parse_human_readable(ocr_text_human)
|
||||
|
||||
# 即使回退失败, 也返回空字典 (或部分数据)
|
||||
return human_data
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"提取时出错: {e}"}
|
||||
|
||||
|
||||
@app.post("/extract_invoice/")
|
||||
async def extract_invoice_data(file: UploadFile = File(...)):
|
||||
|
||||
if ml_models.get("yolo") is None:
|
||||
raise HTTPException(status_code=503, detail="模型未加载。请检查模型路径。")
|
||||
|
||||
try:
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=400, detail="无法解码图像文件。")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}")
|
||||
|
||||
# 1. 运行 YOLO 检测 (阶段一)
|
||||
results = ml_models["yolo"](img, verbose=False)
|
||||
|
||||
all_extractions = []
|
||||
|
||||
if not results or not results[0].boxes:
|
||||
return {"message": "未在图像中检测到支付凭证区域。"}
|
||||
|
||||
for res in results:
|
||||
# 遍历所有检测到的凭证 (通常只有一个)
|
||||
for box_coords in res.boxes.xyxy.cpu().numpy().astype(int):
|
||||
xmin, ymin, xmax, ymax = box_coords
|
||||
|
||||
# 2. 裁剪图像
|
||||
crop = img[ymin:ymax, xmin:xmax]
|
||||
|
||||
# 3. 在裁剪图上运行 "Plan A/B" 提取 (阶段二)
|
||||
extracted_data = extract_info_from_crop(crop)
|
||||
extracted_data["bounding_box"] = [xmin, ymin, xmax, ymax]
|
||||
all_extractions.append(extracted_data)
|
||||
|
||||
if not all_extractions:
|
||||
return {"message": "检测到支付凭证, 但未能提取任何信息。"}
|
||||
|
||||
return {"invoice_extractions": all_extractions}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
print(f"--- 启动 FastAPI 服务 ---")
|
||||
print(f"加载模型: {MODEL_PATH}")
|
||||
print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档")
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
14
test_api.py
Normal file
14
test_api.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
# Test the API with the problematic invoice
|
||||
url = "http://127.0.0.1:8000/extract_invoice/"
|
||||
file_path = r"data\processed_images\4BC5E5B3-E561-4A73-BC9C-46D4F08F89C3.png"
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
files = {'file': f}
|
||||
response = requests.post(url, files=files)
|
||||
|
||||
print("Status Code:", response.status_code)
|
||||
print("\nResponse JSON:")
|
||||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||||
Reference in New Issue
Block a user