Init
This commit is contained in:
223
scripts/02_create_labels.py
Normal file
223
scripts/02_create_labels.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# --- scripts/02_create_labels.py ---
|
||||
#
|
||||
# **重要**: 此脚本用于训练 "阶段一" 的区域检测器.
|
||||
# 它只生成 1 个类别 (class_id = 0), 即 "payment_slip" 的 *整个* 区域.
|
||||
# 它不使用您的高级 classify_text 逻辑, 而是使用启发式规则 (关键词, 线条) 来找到大区域.
|
||||
#
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import cv2
|
||||
import re
|
||||
import shutil
|
||||
import json
|
||||
from sklearn.cluster import DBSCAN
|
||||
|
||||
# --- 锚点词典 (技术二) ---
|
||||
# 这些是用来 *定位* 凭证区域的词, 不是用来提取的
|
||||
KEYWORDS = [
|
||||
"Bankgirot", "PlusGirot", "OCR-nummer", "Att betala", "Mottagare",
|
||||
"Betalningsmottagare", "Tillhanda senast", "Förfallodag", "Belopp",
|
||||
"BG-nr", "PG-nr", "Meddelande", "OCR-kod", "Inbetalningskort", "Betalningsavi"
|
||||
]
|
||||
KEYWORDS_REGEX = '|'.join(KEYWORDS)
|
||||
|
||||
REGEX_PATTERNS = {
|
||||
'bg_pg': r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', # Bankgiro/PlusGiro
|
||||
'long_num': r'\b\d{10,}\b', # 可能是 OCR
|
||||
'machine_code': r'#[0-9\s>#]+#' # 机读码
|
||||
}
|
||||
FULL_REGEX = f"({KEYWORDS_REGEX})|{REGEX_PATTERNS['bg_pg']}|{REGEX_PATTERNS['long_num']}|{REGEX_PATTERNS['machine_code']}"
|
||||
|
||||
# --- 路径设置 ---
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
OCR_DIR = os.path.join(BASE_DIR, "data", "ocr_results")
|
||||
IMG_DIR = os.path.join(BASE_DIR, "data", "processed_images")
|
||||
# Use temp directories for initial label generation (before train/val split)
|
||||
LABEL_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_labels")
|
||||
IMAGE_DIR = os.path.join(BASE_DIR, "data", "yolo_dataset", "temp_all_images")
|
||||
|
||||
os.makedirs(LABEL_DIR, exist_ok=True)
|
||||
os.makedirs(IMAGE_DIR, exist_ok=True)
|
||||
|
||||
def find_text_anchors(text_boxes, img_height):
|
||||
"""技术一 (位置) + 技术二 (词典)"""
|
||||
anchors = []
|
||||
if not text_boxes:
|
||||
return []
|
||||
|
||||
# 技术一: 只在页面下半部分 (40% 处开始) 查找
|
||||
page_midpoint = img_height * 0.4
|
||||
|
||||
for box in text_boxes:
|
||||
# 检查位置
|
||||
if box["bbox"]["y_min"] > page_midpoint:
|
||||
# 检查文本内容
|
||||
if re.search(FULL_REGEX, box["text"], re.IGNORECASE):
|
||||
bbox = box["bbox"]
|
||||
anchors.append(pd.Series({
|
||||
'left': bbox["x_min"],
|
||||
'top': bbox["y_min"],
|
||||
'width': bbox["x_max"] - bbox["x_min"],
|
||||
'height': bbox["y_max"] - bbox["y_min"]
|
||||
}))
|
||||
return anchors
|
||||
|
||||
def find_visual_anchors(image, img_height, img_width):
|
||||
"""技术三 (视觉锚点 - 找线)"""
|
||||
anchors = []
|
||||
try:
|
||||
# 1. 只看下半页
|
||||
crop_y_start = img_height // 2
|
||||
crop = image[crop_y_start:, :]
|
||||
|
||||
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
|
||||
thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)[1]
|
||||
|
||||
# 2. 查找长的水平线
|
||||
min_line_length = img_width // 4
|
||||
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_line_length, 1))
|
||||
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
|
||||
|
||||
contours, _ = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for c in contours:
|
||||
x, y, w, h = cv2.boundingRect(c)
|
||||
if w > min_line_length: # 确保线足够长
|
||||
original_y = y + crop_y_start # 转换回原图坐标
|
||||
anchors.append(pd.Series({
|
||||
'left': x, 'top': original_y, 'width': w, 'height': h
|
||||
}))
|
||||
except Exception as e:
|
||||
print(f"查找视觉锚点时出错: {e}")
|
||||
return anchors
|
||||
|
||||
def cluster_anchors(all_anchors, img_height):
|
||||
"""技术四 (聚类)"""
|
||||
if len(all_anchors) < 2:
|
||||
return all_anchors # 锚点太少,无法聚类
|
||||
|
||||
# 1. 获取中心点
|
||||
points = []
|
||||
for anchor in all_anchors:
|
||||
x_center = anchor.left + anchor.width / 2
|
||||
y_center = anchor.top + anchor.height / 2
|
||||
points.append((x_center, y_center))
|
||||
|
||||
points = np.array(points)
|
||||
|
||||
# 2. 运行 DBSCAN
|
||||
eps_dist = img_height * 0.2
|
||||
clustering = DBSCAN(eps=eps_dist, min_samples=2).fit(points)
|
||||
labels = clustering.labels_
|
||||
|
||||
if len(labels) == 0 or np.all(labels == -1):
|
||||
return all_anchors # 聚类失败
|
||||
|
||||
# 3. 找到最大的那个簇
|
||||
unique_labels, counts = np.unique(labels[labels != -1], return_counts=True)
|
||||
if len(counts) == 0:
|
||||
return all_anchors # 只有噪声
|
||||
|
||||
largest_cluster_label = unique_labels[np.argmax(counts)]
|
||||
|
||||
# 4. 只返回属于最大簇的锚点
|
||||
main_cluster_anchors = [
|
||||
anchor for i, anchor in enumerate(all_anchors)
|
||||
if labels[i] == largest_cluster_label
|
||||
]
|
||||
return main_cluster_anchors
|
||||
|
||||
def create_labels():
|
||||
print("开始生成弱标签 (用于区域检测器)...")
|
||||
processed_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for ocr_filename in os.listdir(OCR_DIR):
|
||||
if not ocr_filename.endswith(".json"):
|
||||
continue
|
||||
|
||||
base_name = os.path.splitext(ocr_filename)[0]
|
||||
json_path = os.path.join(OCR_DIR, ocr_filename)
|
||||
img_path = os.path.join(IMG_DIR, f"{base_name}.png")
|
||||
|
||||
if not os.path.exists(img_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
# 1. 加载数据
|
||||
image = cv2.imread(img_path)
|
||||
img_h, img_w = image.shape[:2]
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
ocr_data = json.load(f)
|
||||
text_boxes = ocr_data.get("text_boxes", [])
|
||||
|
||||
# 2. 查找所有锚点
|
||||
text_anchors = find_text_anchors(text_boxes, img_h)
|
||||
visual_anchors = find_visual_anchors(image, img_h, img_w)
|
||||
all_anchors = text_anchors + visual_anchors
|
||||
|
||||
if not all_anchors:
|
||||
print(f"SKIPPING: {base_name} (未找到任何锚点)")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# 3. 聚类锚点
|
||||
final_anchors = cluster_anchors(all_anchors, img_h)
|
||||
|
||||
if not final_anchors:
|
||||
print(f"SKIPPING: {base_name} (未找到有效聚类)")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# 4. 聚合坐标
|
||||
min_x = min(a.left for a in final_anchors)
|
||||
min_y = min(a.top for a in final_anchors)
|
||||
max_x = max(a.left + a.width for a in final_anchors)
|
||||
max_y = max(a.top + a.height for a in final_anchors)
|
||||
|
||||
# 5. 添加边距 (Padding)
|
||||
padding = 10
|
||||
min_x = max(0, min_x - padding)
|
||||
min_y = max(0, min_y - padding)
|
||||
max_x = min(img_w, max_x + padding)
|
||||
max_y = min(img_h, max_y + padding)
|
||||
|
||||
# 6. 转换为 YOLO 格式
|
||||
box_w = max_x - min_x
|
||||
box_h = max_y - min_y
|
||||
x_center = (min_x + box_w / 2) / img_w
|
||||
y_center = (min_y + box_h / 2) / img_h
|
||||
norm_w = box_w / img_w
|
||||
norm_h = box_h / img_h
|
||||
|
||||
# ** 关键: 类别 ID 永远是 0 **
|
||||
class_id = 0 # 唯一的类别: 'payment_slip'
|
||||
|
||||
yolo_label = f"{class_id} {x_center} {y_center} {norm_w} {norm_h}\n"
|
||||
|
||||
# 7. 保存标签和图片
|
||||
label_path = os.path.join(LABEL_DIR, f"{base_name}.txt")
|
||||
with open(label_path, 'w', encoding='utf-8') as f:
|
||||
f.write(yolo_label)
|
||||
|
||||
shutil.copy(
|
||||
img_path,
|
||||
os.path.join(IMAGE_DIR, f"{base_name}.png")
|
||||
)
|
||||
|
||||
processed_count += 1
|
||||
if processed_count % 20 == 0:
|
||||
print(f"已生成 {processed_count} 个标签...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 {base_name} 时出错: {e}")
|
||||
skipped_count += 1
|
||||
|
||||
print("--- 弱标签生成完成 ---")
|
||||
print(f"成功生成: {processed_count} 个")
|
||||
print(f"跳过: {skipped_count} 个")
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_labels()
|
||||
Reference in New Issue
Block a user