# --- scripts/02_create_labels.py --- # # **重要**: 此脚本用于训练 "阶段一" 的区域检测器. # 它只生成 1 个类别 (class_id = 0), 即 "payment_slip" 的 *整个* 区域. # 它不使用您的高级 classify_text 逻辑, 而是使用启发式规则 (关键词, 线条) 来找到大区域. # import os import pandas as pd import numpy as np import cv2 import re import shutil import json from sklearn.cluster import DBSCAN # --- 锚点词典 (技术二) --- # 这些是用来 *定位* 凭证区域的词, 不是用来提取的 KEYWORDS = [ "Bankgirot", "PlusGirot", "OCR-nummer", "Att betala", "Mottagare", "Betalningsmottagare", "Tillhanda senast", "Förfallodag", "Belopp", "BG-nr", "PG-nr", "Meddelande", "OCR-kod", "Inbetalningskort", "Betalningsavi" ] KEYWORDS_REGEX = '|'.join(KEYWORDS) REGEX_PATTERNS = { 'bg_pg': r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', # Bankgiro/PlusGiro 'long_num': r'\b\d{10,}\b', # 可能是 OCR 'machine_code': r'#[0-9\s>#]+#' # 机读码 } FULL_REGEX = f"({KEYWORDS_REGEX})|{REGEX_PATTERNS['bg_pg']}|{REGEX_PATTERNS['long_num']}|{REGEX_PATTERNS['machine_code']}" # --- 路径设置 --- BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results") IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images") # Use temp directories for initial label generation (before train/val split) LABEL_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_labels") IMAGE_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_images") os.makedirs(LABEL_DIR, exist_ok=True) os.makedirs(IMAGE_DIR, exist_ok=True) def find_text_anchors(text_boxes, img_height): """技术一 (位置) + 技术二 (词典)""" anchors = [] if not text_boxes: return [] # 技术一: 只在页面下半部分 (40% 处开始) 查找 page_midpoint = img_height * 0.4 for box in text_boxes: # 检查位置 if box["bbox"]["y_min"] > page_midpoint: # 检查文本内容 if re.search(FULL_REGEX, box["text"], re.IGNORECASE): bbox = box["bbox"] anchors.append(pd.Series({ 'left': bbox["x_min"], 'top': bbox["y_min"], 'width': bbox["x_max"] - bbox["x_min"], 'height': bbox["y_max"] - bbox["y_min"] })) return anchors def find_visual_anchors(image, img_height, img_width): """技术三 (视觉锚点 - 找线)""" anchors = [] try: # 1. 只看下半页 crop_y_start = img_height // 2 crop = image[crop_y_start:, :] gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY) thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)[1] # 2. 查找长的水平线 min_line_length = img_width // 4 horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_line_length, 1)) detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2) contours, _ = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for c in contours: x, y, w, h = cv2.boundingRect(c) if w > min_line_length: # 确保线足够长 original_y = y + crop_y_start # 转换回原图坐标 anchors.append(pd.Series({ 'left': x, 'top': original_y, 'width': w, 'height': h })) except Exception as e: print(f"查找视觉锚点时出错: {e}") return anchors def cluster_anchors(all_anchors, img_height): """技术四 (聚类)""" if len(all_anchors) < 2: return all_anchors # 锚点太少,无法聚类 # 1. 获取中心点 points = [] for anchor in all_anchors: x_center = anchor.left + anchor.width / 2 y_center = anchor.top + anchor.height / 2 points.append((x_center, y_center)) points = np.array(points) # 2. 运行 DBSCAN eps_dist = img_height * 0.2 clustering = DBSCAN(eps=eps_dist, min_samples=2).fit(points) labels = clustering.labels_ if len(labels) == 0 or np.all(labels == -1): return all_anchors # 聚类失败 # 3. 找到最大的那个簇 unique_labels, counts = np.unique(labels[labels != -1], return_counts=True) if len(counts) == 0: return all_anchors # 只有噪声 largest_cluster_label = unique_labels[np.argmax(counts)] # 4. 只返回属于最大簇的锚点 main_cluster_anchors = [ anchor for i, anchor in enumerate(all_anchors) if labels[i] == largest_cluster_label ] return main_cluster_anchors def create_labels(): print("开始生成弱标签 (用于区域检测器)...") processed_count = 0 skipped_count = 0 for ocr_filename in os.listdir(OCR_DIR): if not ocr_filename.endswith(".json"): continue base_name = os.path.splitext(ocr_filename)[0] json_path = os.path.join(OCR_DIR, ocr_filename) img_path = os.path.join(IMG_DIR, f"{base_name}.png") if not os.path.exists(img_path): continue try: # 1. 加载数据 image = cv2.imread(img_path) img_h, img_w = image.shape[:2] with open(json_path, 'r', encoding='utf-8') as f: ocr_data = json.load(f) text_boxes = ocr_data.get("text_boxes", []) # 2. 查找所有锚点 text_anchors = find_text_anchors(text_boxes, img_h) visual_anchors = find_visual_anchors(image, img_h, img_w) all_anchors = text_anchors + visual_anchors if not all_anchors: print(f"SKIPPING: {base_name} (未找到任何锚点)") skipped_count += 1 continue # 3. 聚类锚点 final_anchors = cluster_anchors(all_anchors, img_h) if not final_anchors: print(f"SKIPPING: {base_name} (未找到有效聚类)") skipped_count += 1 continue # 4. 聚合坐标 min_x = min(a.left for a in final_anchors) min_y = min(a.top for a in final_anchors) max_x = max(a.left + a.width for a in final_anchors) max_y = max(a.top + a.height for a in final_anchors) # 5. 添加边距 (Padding) padding = 10 min_x = max(0, min_x - padding) min_y = max(0, min_y - padding) max_x = min(img_w, max_x + padding) max_y = min(img_h, max_y + padding) # 6. 转换为 YOLO 格式 box_w = max_x - min_x box_h = max_y - min_y x_center = (min_x + box_w / 2) / img_w y_center = (min_y + box_h / 2) / img_h norm_w = box_w / img_w norm_h = box_h / img_h # ** 关键: 类别 ID 永远是 0 ** class_id = 0 # 唯一的类别: 'payment_slip' yolo_label = f"{class_id} {x_center} {y_center} {norm_w} {norm_h}\n" # 7. 保存标签和图片 label_path = os.path.join(LABEL_DIR, f"{base_name}.txt") with open(label_path, 'w', encoding='utf-8') as f: f.write(yolo_label) shutil.copy( img_path, os.path.join(IMAGE_DIR, f"{base_name}.png") ) processed_count += 1 if processed_count % 20 == 0: print(f"已生成 {processed_count} 个标签...") except Exception as e: print(f"处理 {base_name} 时出错: {e}") skipped_count += 1 print("--- 弱标签生成完成 ---") print(f"成功生成: {processed_count} 个") print(f"跳过: {skipped_count} 个") if __name__ == "__main__": create_labels()