223 lines
8.0 KiB
Python
223 lines
8.0 KiB
Python
# --- scripts/02_create_labels.py ---
|
|
#
|
|
# **重要**: 此脚本用于训练 "阶段一" 的区域检测器.
|
|
# 它只生成 1 个类别 (class_id = 0), 即 "payment_slip" 的 *整个* 区域.
|
|
# 它不使用您的高级 classify_text 逻辑, 而是使用启发式规则 (关键词, 线条) 来找到大区域.
|
|
#
|
|
|
|
import os
|
|
import pandas as pd
|
|
import numpy as np
|
|
import cv2
|
|
import re
|
|
import shutil
|
|
import json
|
|
from sklearn.cluster import DBSCAN
|
|
|
|
# --- 锚点词典 (技术二) ---
|
|
# 这些是用来 *定位* 凭证区域的词, 不是用来提取的
|
|
KEYWORDS = [
|
|
"Bankgirot", "PlusGirot", "OCR-nummer", "Att betala", "Mottagare",
|
|
"Betalningsmottagare", "Tillhanda senast", "Förfallodag", "Belopp",
|
|
"BG-nr", "PG-nr", "Meddelande", "OCR-kod", "Inbetalningskort", "Betalningsavi"
|
|
]
|
|
KEYWORDS_REGEX = '|'.join(KEYWORDS)
|
|
|
|
REGEX_PATTERNS = {
|
|
'bg_pg': r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', # Bankgiro/PlusGiro
|
|
'long_num': r'\b\d{10,}\b', # 可能是 OCR
|
|
'machine_code': r'#[0-9\s>#]+#' # 机读码
|
|
}
|
|
FULL_REGEX = f"({KEYWORDS_REGEX})|{REGEX_PATTERNS['bg_pg']}|{REGEX_PATTERNS['long_num']}|{REGEX_PATTERNS['machine_code']}"
|
|
|
|
# --- 路径设置 ---
|
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results")
|
|
IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images")
|
|
# Use temp directories for initial label generation (before train/val split)
|
|
LABEL_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_labels")
|
|
IMAGE_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_images")
|
|
|
|
os.makedirs(LABEL_DIR, exist_ok=True)
|
|
os.makedirs(IMAGE_DIR, exist_ok=True)
|
|
|
|
def find_text_anchors(text_boxes, img_height):
|
|
"""技术一 (位置) + 技术二 (词典)"""
|
|
anchors = []
|
|
if not text_boxes:
|
|
return []
|
|
|
|
# 技术一: 只在页面下半部分 (40% 处开始) 查找
|
|
page_midpoint = img_height * 0.4
|
|
|
|
for box in text_boxes:
|
|
# 检查位置
|
|
if box["bbox"]["y_min"] > page_midpoint:
|
|
# 检查文本内容
|
|
if re.search(FULL_REGEX, box["text"], re.IGNORECASE):
|
|
bbox = box["bbox"]
|
|
anchors.append(pd.Series({
|
|
'left': bbox["x_min"],
|
|
'top': bbox["y_min"],
|
|
'width': bbox["x_max"] - bbox["x_min"],
|
|
'height': bbox["y_max"] - bbox["y_min"]
|
|
}))
|
|
return anchors
|
|
|
|
def find_visual_anchors(image, img_height, img_width):
|
|
"""技术三 (视觉锚点 - 找线)"""
|
|
anchors = []
|
|
try:
|
|
# 1. 只看下半页
|
|
crop_y_start = img_height // 2
|
|
crop = image[crop_y_start:, :]
|
|
|
|
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
|
|
thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)[1]
|
|
|
|
# 2. 查找长的水平线
|
|
min_line_length = img_width // 4
|
|
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_line_length, 1))
|
|
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
|
|
|
|
contours, _ = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
for c in contours:
|
|
x, y, w, h = cv2.boundingRect(c)
|
|
if w > min_line_length: # 确保线足够长
|
|
original_y = y + crop_y_start # 转换回原图坐标
|
|
anchors.append(pd.Series({
|
|
'left': x, 'top': original_y, 'width': w, 'height': h
|
|
}))
|
|
except Exception as e:
|
|
print(f"查找视觉锚点时出错: {e}")
|
|
return anchors
|
|
|
|
def cluster_anchors(all_anchors, img_height):
|
|
"""技术四 (聚类)"""
|
|
if len(all_anchors) < 2:
|
|
return all_anchors # 锚点太少,无法聚类
|
|
|
|
# 1. 获取中心点
|
|
points = []
|
|
for anchor in all_anchors:
|
|
x_center = anchor.left + anchor.width / 2
|
|
y_center = anchor.top + anchor.height / 2
|
|
points.append((x_center, y_center))
|
|
|
|
points = np.array(points)
|
|
|
|
# 2. 运行 DBSCAN
|
|
eps_dist = img_height * 0.2
|
|
clustering = DBSCAN(eps=eps_dist, min_samples=2).fit(points)
|
|
labels = clustering.labels_
|
|
|
|
if len(labels) == 0 or np.all(labels == -1):
|
|
return all_anchors # 聚类失败
|
|
|
|
# 3. 找到最大的那个簇
|
|
unique_labels, counts = np.unique(labels[labels != -1], return_counts=True)
|
|
if len(counts) == 0:
|
|
return all_anchors # 只有噪声
|
|
|
|
largest_cluster_label = unique_labels[np.argmax(counts)]
|
|
|
|
# 4. 只返回属于最大簇的锚点
|
|
main_cluster_anchors = [
|
|
anchor for i, anchor in enumerate(all_anchors)
|
|
if labels[i] == largest_cluster_label
|
|
]
|
|
return main_cluster_anchors
|
|
|
|
def create_labels():
|
|
print("开始生成弱标签 (用于区域检测器)...")
|
|
processed_count = 0
|
|
skipped_count = 0
|
|
|
|
for ocr_filename in os.listdir(OCR_DIR):
|
|
if not ocr_filename.endswith(".json"):
|
|
continue
|
|
|
|
base_name = os.path.splitext(ocr_filename)[0]
|
|
json_path = os.path.join(OCR_DIR, ocr_filename)
|
|
img_path = os.path.join(IMG_DIR, f"{base_name}.png")
|
|
|
|
if not os.path.exists(img_path):
|
|
continue
|
|
|
|
try:
|
|
# 1. 加载数据
|
|
image = cv2.imread(img_path)
|
|
img_h, img_w = image.shape[:2]
|
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|
ocr_data = json.load(f)
|
|
text_boxes = ocr_data.get("text_boxes", [])
|
|
|
|
# 2. 查找所有锚点
|
|
text_anchors = find_text_anchors(text_boxes, img_h)
|
|
visual_anchors = find_visual_anchors(image, img_h, img_w)
|
|
all_anchors = text_anchors + visual_anchors
|
|
|
|
if not all_anchors:
|
|
print(f"SKIPPING: {base_name} (未找到任何锚点)")
|
|
skipped_count += 1
|
|
continue
|
|
|
|
# 3. 聚类锚点
|
|
final_anchors = cluster_anchors(all_anchors, img_h)
|
|
|
|
if not final_anchors:
|
|
print(f"SKIPPING: {base_name} (未找到有效聚类)")
|
|
skipped_count += 1
|
|
continue
|
|
|
|
# 4. 聚合坐标
|
|
min_x = min(a.left for a in final_anchors)
|
|
min_y = min(a.top for a in final_anchors)
|
|
max_x = max(a.left + a.width for a in final_anchors)
|
|
max_y = max(a.top + a.height for a in final_anchors)
|
|
|
|
# 5. 添加边距 (Padding)
|
|
padding = 10
|
|
min_x = max(0, min_x - padding)
|
|
min_y = max(0, min_y - padding)
|
|
max_x = min(img_w, max_x + padding)
|
|
max_y = min(img_h, max_y + padding)
|
|
|
|
# 6. 转换为 YOLO 格式
|
|
box_w = max_x - min_x
|
|
box_h = max_y - min_y
|
|
x_center = (min_x + box_w / 2) / img_w
|
|
y_center = (min_y + box_h / 2) / img_h
|
|
norm_w = box_w / img_w
|
|
norm_h = box_h / img_h
|
|
|
|
# ** 关键: 类别 ID 永远是 0 **
|
|
class_id = 0 # 唯一的类别: 'payment_slip'
|
|
|
|
yolo_label = f"{class_id} {x_center} {y_center} {norm_w} {norm_h}\n"
|
|
|
|
# 7. 保存标签和图片
|
|
label_path = os.path.join(LABEL_DIR, f"{base_name}.txt")
|
|
with open(label_path, 'w', encoding='utf-8') as f:
|
|
f.write(yolo_label)
|
|
|
|
shutil.copy(
|
|
img_path,
|
|
os.path.join(IMAGE_DIR, f"{base_name}.png")
|
|
)
|
|
|
|
processed_count += 1
|
|
if processed_count % 20 == 0:
|
|
print(f"已生成 {processed_count} 个标签...")
|
|
|
|
except Exception as e:
|
|
print(f"处理 {base_name} 时出错: {e}")
|
|
skipped_count += 1
|
|
|
|
print("--- 弱标签生成完成 ---")
|
|
print(f"成功生成: {processed_count} 个")
|
|
print(f"跳过: {skipped_count} 个")
|
|
|
|
if __name__ == "__main__":
|
|
create_labels() |