Files
invoice-master/scripts/02_create_labels.py
Yaojia Wang dafa86c588 Init
2025-10-26 20:41:11 +01:00

223 lines
8.0 KiB
Python

# --- scripts/02_create_labels.py ---
#
# **重要**: 此脚本用于训练 "阶段一" 的区域检测器.
# 它只生成 1 个类别 (class_id = 0), 即 "payment_slip" 的 *整个* 区域.
# 它不使用您的高级 classify_text 逻辑, 而是使用启发式规则 (关键词, 线条) 来找到大区域.
#
import os
import pandas as pd
import numpy as np
import cv2
import re
import shutil
import json
from sklearn.cluster import DBSCAN
# --- 锚点词典 (技术二) ---
# 这些是用来 *定位* 凭证区域的词, 不是用来提取的
KEYWORDS = [
"Bankgirot", "PlusGirot", "OCR-nummer", "Att betala", "Mottagare",
"Betalningsmottagare", "Tillhanda senast", "Förfallodag", "Belopp",
"BG-nr", "PG-nr", "Meddelande", "OCR-kod", "Inbetalningskort", "Betalningsavi"
]
KEYWORDS_REGEX = '|'.join(KEYWORDS)
REGEX_PATTERNS = {
'bg_pg': r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', # Bankgiro/PlusGiro
'long_num': r'\b\d{10,}\b', # 可能是 OCR
'machine_code': r'#[0-9\s>#]+#' # 机读码
}
FULL_REGEX = f"({KEYWORDS_REGEX})|{REGEX_PATTERNS['bg_pg']}|{REGEX_PATTERNS['long_num']}|{REGEX_PATTERNS['machine_code']}"
# --- 路径设置 ---
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results")
IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images")
# Use temp directories for initial label generation (before train/val split)
LABEL_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_labels")
IMAGE_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_images")
os.makedirs(LABEL_DIR, exist_ok=True)
os.makedirs(IMAGE_DIR, exist_ok=True)
def find_text_anchors(text_boxes, img_height):
"""技术一 (位置) + 技术二 (词典)"""
anchors = []
if not text_boxes:
return []
# 技术一: 只在页面下半部分 (40% 处开始) 查找
page_midpoint = img_height * 0.4
for box in text_boxes:
# 检查位置
if box["bbox"]["y_min"] > page_midpoint:
# 检查文本内容
if re.search(FULL_REGEX, box["text"], re.IGNORECASE):
bbox = box["bbox"]
anchors.append(pd.Series({
'left': bbox["x_min"],
'top': bbox["y_min"],
'width': bbox["x_max"] - bbox["x_min"],
'height': bbox["y_max"] - bbox["y_min"]
}))
return anchors
def find_visual_anchors(image, img_height, img_width):
"""技术三 (视觉锚点 - 找线)"""
anchors = []
try:
# 1. 只看下半页
crop_y_start = img_height // 2
crop = image[crop_y_start:, :]
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)[1]
# 2. 查找长的水平线
min_line_length = img_width // 4
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_line_length, 1))
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
contours, _ = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for c in contours:
x, y, w, h = cv2.boundingRect(c)
if w > min_line_length: # 确保线足够长
original_y = y + crop_y_start # 转换回原图坐标
anchors.append(pd.Series({
'left': x, 'top': original_y, 'width': w, 'height': h
}))
except Exception as e:
print(f"查找视觉锚点时出错: {e}")
return anchors
def cluster_anchors(all_anchors, img_height):
"""技术四 (聚类)"""
if len(all_anchors) < 2:
return all_anchors # 锚点太少,无法聚类
# 1. 获取中心点
points = []
for anchor in all_anchors:
x_center = anchor.left + anchor.width / 2
y_center = anchor.top + anchor.height / 2
points.append((x_center, y_center))
points = np.array(points)
# 2. 运行 DBSCAN
eps_dist = img_height * 0.2
clustering = DBSCAN(eps=eps_dist, min_samples=2).fit(points)
labels = clustering.labels_
if len(labels) == 0 or np.all(labels == -1):
return all_anchors # 聚类失败
# 3. 找到最大的那个簇
unique_labels, counts = np.unique(labels[labels != -1], return_counts=True)
if len(counts) == 0:
return all_anchors # 只有噪声
largest_cluster_label = unique_labels[np.argmax(counts)]
# 4. 只返回属于最大簇的锚点
main_cluster_anchors = [
anchor for i, anchor in enumerate(all_anchors)
if labels[i] == largest_cluster_label
]
return main_cluster_anchors
def create_labels():
print("开始生成弱标签 (用于区域检测器)...")
processed_count = 0
skipped_count = 0
for ocr_filename in os.listdir(OCR_DIR):
if not ocr_filename.endswith(".json"):
continue
base_name = os.path.splitext(ocr_filename)[0]
json_path = os.path.join(OCR_DIR, ocr_filename)
img_path = os.path.join(IMG_DIR, f"{base_name}.png")
if not os.path.exists(img_path):
continue
try:
# 1. 加载数据
image = cv2.imread(img_path)
img_h, img_w = image.shape[:2]
with open(json_path, 'r', encoding='utf-8') as f:
ocr_data = json.load(f)
text_boxes = ocr_data.get("text_boxes", [])
# 2. 查找所有锚点
text_anchors = find_text_anchors(text_boxes, img_h)
visual_anchors = find_visual_anchors(image, img_h, img_w)
all_anchors = text_anchors + visual_anchors
if not all_anchors:
print(f"SKIPPING: {base_name} (未找到任何锚点)")
skipped_count += 1
continue
# 3. 聚类锚点
final_anchors = cluster_anchors(all_anchors, img_h)
if not final_anchors:
print(f"SKIPPING: {base_name} (未找到有效聚类)")
skipped_count += 1
continue
# 4. 聚合坐标
min_x = min(a.left for a in final_anchors)
min_y = min(a.top for a in final_anchors)
max_x = max(a.left + a.width for a in final_anchors)
max_y = max(a.top + a.height for a in final_anchors)
# 5. 添加边距 (Padding)
padding = 10
min_x = max(0, min_x - padding)
min_y = max(0, min_y - padding)
max_x = min(img_w, max_x + padding)
max_y = min(img_h, max_y + padding)
# 6. 转换为 YOLO 格式
box_w = max_x - min_x
box_h = max_y - min_y
x_center = (min_x + box_w / 2) / img_w
y_center = (min_y + box_h / 2) / img_h
norm_w = box_w / img_w
norm_h = box_h / img_h
# ** 关键: 类别 ID 永远是 0 **
class_id = 0 # 唯一的类别: 'payment_slip'
yolo_label = f"{class_id} {x_center} {y_center} {norm_w} {norm_h}\n"
# 7. 保存标签和图片
label_path = os.path.join(LABEL_DIR, f"{base_name}.txt")
with open(label_path, 'w', encoding='utf-8') as f:
f.write(yolo_label)
shutil.copy(
img_path,
os.path.join(IMAGE_DIR, f"{base_name}.png")
)
processed_count += 1
if processed_count % 20 == 0:
print(f"已生成 {processed_count} 个标签...")
except Exception as e:
print(f"处理 {base_name} 时出错: {e}")
skipped_count += 1
print("--- 弱标签生成完成 ---")
print(f"成功生成: {processed_count}")
print(f"跳过: {skipped_count}")
if __name__ == "__main__":
create_labels()