diff --git a/scripts/main.py b/scripts/main.py index 82b81e6..26c5f64 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,189 +1,340 @@ -# --- main.py (已升级) --- +# --- main.py (已升级, v4.6 - 分离 InvoiceNr 和 OCRNr) --- from fastapi import FastAPI, UploadFile, File, HTTPException from ultralytics import YOLO import cv2 import numpy as np import pytesseract +import pandas as pd import re import io import os +import sys from contextlib import asynccontextmanager +from pathlib import Path +from pdf2image import convert_from_bytes -# --- 配置 --- -# TODO: 确保此路径指向您训练好的最佳模型 -MODEL_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "models", "invoice_detector_v1", "weights", "best.pt" -) +# --- 配置 (使用您的自定义配置) --- +# 1. 设置 BASE_DIR +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# 2. 将项目根目录添加到 sys.path +sys.path.insert(0, BASE_DIR) +# 3. 从 config.py 导入路径 +from config import POPPLER_PATH, apply_tesseract_path +# --------------------------------- + +MODEL_PATH_STR = os.path.join(BASE_DIR, "models", "payment_slip_detector_v1", "weights", "best.pt") +MODEL_PATH = Path(MODEL_PATH_STR) -# 定义一个字典来在 FastAPI 启动时加载模型 ml_models = {} @asynccontextmanager async def lifespan(app: FastAPI): # 启动时加载模型 - print(MODEL_PATH) + print(f"正在加载模型: {MODEL_PATH}") if not os.path.exists(MODEL_PATH): print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。") ml_models["yolo"] = None else: - # 加载您的 "payment_slip" 区域检测器 ml_models["yolo"] = YOLO(MODEL_PATH) print("YOLOv8 区域检测模型加载成功。") + + # 应用 Tesseract 路径 (来自 config) + try: + apply_tesseract_path() + print("已应用 Tesseract 路径配置。") + except Exception as e: + print(f"警告: 调用 apply_tesseract_path 时出错: {e}") + + # 检查 Poppler 路径 (来自 config) + if not POPPLER_PATH or not os.path.exists(POPPLER_PATH): + print(f"警告: POPPLER_PATH 未设置或无效: {POPPLER_PATH}。PDF 处理将失败。") + else: + print(f"Poppler 路径已加载: {POPPLER_PATH}") + yield - # 清理模型 ml_models.clear() app = FastAPI(lifespan=lifespan) -# --- Luhn (Modulus 10) 校验函数 --- +# --- 校验函数 (Luhn, Mod11, etc.) --- def luhn_validate(number_str: str, expected_check_digit: str) -> bool: """ 使用 Modulus 10 (Luhn) 算法验证一个数字字符串。 - (从右到左, 权重 1, 2, 1, 2...) + (从右到左, 权重 2, 1, 2, 1...) """ try: digits = [int(d) for d in number_str] - weights = [1, 2] * (len(digits) // 2 + 1) - weights = weights[:len(digits)] # 确保权重列表长度一致 - + weights = [2, 1] * (len(digits) // 2 + 1) + weights = weights[:len(digits)] sum_val = 0 - # 从右到左计算 for d, w in zip(reversed(digits), weights): product = d * w if product >= 10: sum_val += (product // 10) + (product % 10) else: sum_val += product - calculated_check_digit = (10 - (sum_val % 10)) % 10 return str(calculated_check_digit) == expected_check_digit - except Exception: return False -# --- 提取逻辑 --- +def modulus_11_validate(number_str: str) -> bool: + try: + if not number_str.isdigit() or len(number_str) < 2: return False + base_num = number_str[:-1] + expected_check_digit_str = number_str[-1] + weights = [i for i in range(2, 2 + len(base_num))] + sum_val = 0 + for digit_char, weight in zip(reversed(base_num), weights): + sum_val += int(digit_char) * weight + remainder = sum_val % 11 + calculated_check_val = 11 - remainder + if calculated_check_val == 10: calculated_check_digit = "0" + elif calculated_check_val == 11: calculated_check_digit = "0" + else: calculated_check_digit = str(calculated_check_val) + return calculated_check_digit == expected_check_digit_str + except Exception: + return False +def luhn_validate_ocr(number_str: str) -> bool: + try: + if not number_str.isdigit() or len(number_str) < 2: return False + base_num = number_str[:-1] + check_digit = number_str[-1] + return luhn_validate(base_num, check_digit) + except Exception: + return False + +def check_ocr_validity(ocr_number: str) -> bool: + if not ocr_number or not ocr_number.isdigit(): + return False + if luhn_validate_ocr(ocr_number): + return True + if modulus_11_validate(ocr_number): + return True + return False + +# --- Plan A: 机读码解析 (已修复) --- def parse_ocr_rad(ocr_text: str) -> dict: - """ - Plan A: 尝试解析机器可读码 (OCR-rad) - 示例: # 400299582421 # 4603 00 7 > 48180020 #14# - """ - # 移除所有空格以简化匹配 text_no_space = re.sub(r'\s+', '', ocr_text) - - # 定义一个更健壮的正则表达式 - # 组1: OCR号 (在 #...# 之间) - # 组2: 金额 (Kronor + Öre) (在 #... 之后) - # 组3: 校验码 (1位数字) - # 组4: 账户号 (在 >...# 之间) - rad_regex = r'#([\d>]+)#(\d{2,})(\d)(\d{1})>([\d]+)#' - + # 修正: 捕获组索引 (group(1) 是 OCR, group(2) 是金额, group(3) 是校验码, group(4) 是账号) + rad_regex = r'#([\d>]+)#(\d{2,})(\d{1})>([\d]+)#' match = re.search(rad_regex, text_no_space) if not match: - return None # 未找到机读码行 + return None try: - ocr_num = match.group(1).replace(">", "") # 移除 > (如果有) - amount_base = match.group(2) + match.group(3) # "4603" + "00" = "460300" - check_digit = match.group(4) # "7" - account = match.group(5) # "48180020" + ocr_num = match.group(1).replace(">", "") + amount_base = match.group(2) + check_digit = match.group(3) + account = match.group(4) - # 运行Luhn校验 - if luhn_validate(amount_base, check_digit): - # 校验成功! 这是高置信度数据 + amount_valid = luhn_validate(amount_base, check_digit) + + if amount_valid: amount_kronor = amount_base[:-2] amount_ore = amount_base[-2:] + ocr_valid = check_ocr_validity(ocr_num) + + print("Plan A (OCR-rad) 成功!") return { "source": "OCR-rad (High Confidence)", - "ocr_number": ocr_num, + "ocr_number": ocr_num, # 这是支付参考号 + "ocr_number_valid": ocr_valid, "amount_due": f"{amount_kronor}.{amount_ore}", + "amount_due_valid": True, "bankgiro_plusgiro": account, - "due_date": None # 机读码行通常不包含日期 + "account_type": None, + "due_date": None, + "invoice_number": None # Plan A 不知道发票号 } else: - # 校验失败! print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}") return None - except Exception as e: print(f"解析 OCR-rad 时出错: {e}") return None -def parse_human_readable(ocr_text: str) -> dict: +# --- 升级: Plan B (智能 KIE 逻辑) --- +def find_value_for_key(key_box, all_boxes, value_regex, max_dist_x_h=400, max_dist_y_h=20, max_dist_y_v=70, max_dist_x_v=150): """ - Plan B: 回退到人工可读区域 - (这里我们使用之前版本中的简单 Regex, - 您也可以替换为您那个更复杂的 classify_text 逻辑) + 一个辅助函数, 用于查找一个 "键" (key_box) 右侧或下方的 "值" """ - data = {"source": "Human-Readable (Fallback)"} + key_y_center = key_box['top'] + key_box['height'] / 2 + key_x_center = key_box['left'] + key_box['width'] / 2 + key_x_end = key_box['left'] + key_box['width'] + key_y_end = key_box['top'] + key_box['height'] - # 查找 BG/PG (Bankgiro/Plusgiro) - bg_match = re.search(r'Bankgiro\D*(\d{2,4}[- ]\d{4})', ocr_text, re.IGNORECASE) - pg_match = re.search(r'PlusGiro\D*(\d{2,7}[- ]\d)', ocr_text, re.IGNORECASE) - if bg_match: - data["bankgiro_plusgiro"] = bg_match.group(1).replace(" ", "") - elif pg_match: - data["bankgiro_plusgiro"] = pg_match.group(1).replace(" ", "") - else: - # 备用查找 - bg_pg_alt = re.search(r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', ocr_text) - if bg_pg_alt: - data["bankgiro_plusgiro"] = bg_pg_alt.group(0).replace(" ", "") + potential_values = [] + + for _, value_box in all_boxes.iterrows(): + if value_box.name == key_box.name: + continue + + value_y_center = value_box['top'] + value_box['height'] / 2 + value_x_center = value_box['left'] + value_box['width'] / 2 + value_x_start = value_box['left'] + value_y_start = value_box['top'] + value_text = str(value_box['text']).strip() - # 查找 OCR - ocr_match = re.search(r'(OCR|Fakturanummer|Referens)\D*(\d[\d\s]{5,}\d)', ocr_text, re.IGNORECASE) - if ocr_match: - data["ocr_number"] = re.sub(r'\s', '', ocr_match.group(2)) + if not re.match(value_regex, value_text): + continue - # 查找金额 - amount_match = re.search(r'(Att betala|Belopp)\D*([\d\s,.]+)\s*(kr|SEK)?', ocr_text, re.IGNORECASE) - if amount_match: - amount_str = amount_match.group(2).strip().replace(" ", "").replace(",", ".") - if amount_str.count('.') > 1: - amount_str = amount_str.replace(".", "", amount_str.count('.') - 1) - data["amount_due"] = amount_str + # --- 检查 A: 水平 (Horizontal) --- + if value_x_start > key_x_end: + if abs(value_y_center - key_y_center) < max_dist_y_h: + dist_x = value_x_start - key_x_end + if dist_x < max_dist_x_h: + potential_values.append((dist_x, value_text)) + + # --- 检查 B: 垂直 (Vertical) --- + if value_y_start > key_y_end: + if abs(value_x_center - key_x_center) < max_dist_x_v: + dist_y = value_y_start - key_y_end + if dist_y < max_dist_y_v: + potential_values.append((dist_y + 1000, value_text)) # H 优先 + + if potential_values: + potential_values.sort(key=lambda x: x[0]) + return potential_values[0][1] - # 查找截止日期 - date_match = re.search(r'(senast|Förfallodag)\D*(\d{4}[- ]\d{2}[- ]\d{2})', ocr_text, re.IGNORECASE) - if date_match: - data["due_date"] = date_match.group(2).replace(" ", "-") + return None +def parse_human_readable_KIE(ocr_df: pd.DataFrame) -> dict: + """ + Plan B (v4.6): 使用 KIE 逻辑回退 (分离 InvoiceNr 和 OCRNr) + """ + data = { + "source": "Human-Readable KIE (Fallback)", + "bankgiro_plusgiro": None, + "account_type": None, + "invoice_number": None, # <-- 新增 + "ocr_number": None, + "ocr_number_valid": False, + "amount_due": None, + "amount_due_valid": False, + "due_date": None + } + + ocr_df = ocr_df[ocr_df['conf'] > 30].copy() + ocr_df['text'] = ocr_df['text'].astype(str).str.strip() + ocr_df = ocr_df.dropna(subset=['text']) + ocr_df = ocr_df[ocr_df['text'] != ""] + + if ocr_df.empty: + return data + + # 定义所有 Key-Value 查找规则 + # Regex 格式: ( [关键词列表], 值Regex, 目标字段名 ) + search_rules = [ + (r'PlusGiro|PG', r'(\d{2,7}[- ]\d{1,2}[- ]\d{1,2})', "bankgiro_plusgiro", "PG"), + (r'Bankgiro|BG|bankgironr', r'(\d{2,4}[- ]\d{4})', "bankgiro_plusgiro", "BG"), + (r'Att betala:?|Belopp:?', r'([\d\s,.]+\d{2})', "amount_due", None), + (r'senast|Förfallodag', r'\d{4}[- ]\d{2}[- ]\d{2}', "due_date", None), + # --- 逻辑分离 --- + (r'Fakturanr|Fakturanummer', r'([\w\d][\w\d-]{3,}[\w\d])', "invoice_number", None), + (r'OCR|Referens|Betalningsreferens', r'([\w\d][\w\d-]{3,}[\w\d])', "ocr_number", None), + ] + + for _, box in ocr_df.iterrows(): + text = box['text'] + + for key_regex, value_regex, field_name, extra_value in search_rules: + # 如果这个字段已经被填过了, 就跳过 (例如 BG 找到了, 就不要再找 PG) + if data.get(field_name) is not None: + continue + + if re.search(key_regex, text, re.IGNORECASE): + value = find_value_for_key(box, ocr_df, value_regex, max_dist_y_v=70) + + if value: + # 特殊清理 + if field_name == "bankgiro_plusgiro": + data[field_name] = re.sub(r'\s', '', value) + data["account_type"] = extra_value # 设置 "BG" 或 "PG" + elif field_name == "amount_due": + amount_str = value.replace(" ", "").replace(",", ".") + if amount_str.count('.') > 1: + amount_str = amount_str.replace(".", "", amount_str.count('.') - 1) + data[field_name] = amount_str + elif field_name == "due_date": + data[field_name] = value.replace(" ", "-") + elif field_name in ["invoice_number", "ocr_number"]: + if value.lower() == "kronor" or value.lower() == "öre": + continue # 跳过错误的匹配 + data[field_name] = re.sub(r'\s', '', value) + if field_name == "ocr_number": + data["ocr_number_valid"] = check_ocr_validity(data[field_name]) + return data - -def extract_info_from_crop(crop_image: np.ndarray) -> dict: +# --- Plan C: 全页回退 (用于发票号码) --- +def find_invoice_number_full_page(ocr_df: pd.DataFrame) -> dict: """ - 主提取函数: 执行 Plan A 和 Plan B + Plan C: 如果 invoice_number 仍为 null, 则在整个页面上搜索。 + """ + ocr_df = ocr_df[ocr_df['conf'] > 30].copy() + ocr_df['text'] = ocr_df['text'].astype(str).str.strip() + ocr_df = ocr_df.dropna(subset=['text']) + ocr_df = ocr_df[ocr_df['text'] != ""] + + if ocr_df.empty: + return {} + + inv_regex = r'([\w\d][\w\d-]{3,}[\w\d])' + for _, box in ocr_df.iterrows(): + # 只查找 Fakturanr + if re.search(r'Fakturanr|Fakturanummer', box['text'], re.IGNORECASE): + value = find_value_for_key(box, ocr_df, inv_regex, max_dist_x_h=200, max_dist_y_h=10, max_dist_y_v=70, max_dist_x_v=150) + if value: + inv_num_str = re.sub(r'\s', '', value) + return {"invoice_number": inv_num_str} + return {} + +def extract_info_from_crop(crop_image: np.ndarray, full_image_df: pd.DataFrame) -> dict: + """ + 主提取函数: 执行 Plan A, B, C """ try: - # 1. 预处理并运行 OCR (获取所有文本) gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY) thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] - # 使用 psm 6 假设是一个统一的文本块, 这对 OCR-rad 很友好 - ocr_text = pytesseract.image_to_string(thresh, lang='swe', config='--psm 6') + # 1. --- Plan B (KIE) --- + # 总是先运行 Plan B, 获取基础信息 + ocr_df_crop = pytesseract.image_to_data(thresh, lang='swe', config='--psm 11', output_type=pytesseract.Output.DATAFRAME) + final_data = parse_human_readable_KIE(ocr_df_crop) - # 2. --- Plan A: 尝试解析机读码 --- - rad_data = parse_ocr_rad(ocr_text) + # 2. --- 尝试 Plan A (OCR-rad) --- + ocr_text_rad = pytesseract.image_to_string(thresh, lang='swe', config='--psm 7') # psm 7: 假设为单行文本 + rad_data = parse_ocr_rad(ocr_text_rad) if rad_data: - # Plan A 成功! - return rad_data + # Plan A 成功! (高置信度) + # **覆盖** Plan B 的支付字段, 但 *保留* Plan B 的 invoice_number + final_data["source"] = rad_data["source"] + final_data["ocr_number"] = rad_data["ocr_number"] + final_data["ocr_number_valid"] = rad_data["ocr_number_valid"] + final_data["amount_due"] = rad_data["amount_due"] + final_data["amount_due_valid"] = rad_data["amount_due_valid"] + final_data["bankgiro_plusgiro"] = rad_data["bankgiro_plusgiro"] + # (如果 Plan A 能确定 BG/PG, 也可以覆盖 account_type) + + # 3. --- Plan C (全页回退发票号码) --- + if final_data["invoice_number"] is None: + print("Plan B 未能从裁剪区找到 InvoiceNr, 启动 Plan C (全页搜索)...") + full_page_inv_result = find_invoice_number_full_page(full_image_df) + + if full_page_inv_result: + final_data["invoice_number"] = full_page_inv_result.get("invoice_number") + final_data["source"] += " + Full-Page-InvNr" - # 3. --- Plan B: Plan A 失败, 回退到人工读取 --- - # 我们重新运行 OCR, 使用 psm 3 (自动布局), 这对人工区域更友好 - ocr_text_human = pytesseract.image_to_string(thresh, lang='swe', config='--psm 3') - - human_data = parse_human_readable(ocr_text_human) - - # 即使回退失败, 也返回空字典 (或部分数据) - return human_data + return final_data except Exception as e: + print(f"提取时出错: {e}") return {"error": f"提取时出错: {e}"} @@ -195,15 +346,36 @@ async def extract_invoice_data(file: UploadFile = File(...)): try: contents = await file.read() - nparr = np.frombuffer(contents, np.uint8) - img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + content_type = file.content_type + img = None + + if content_type == "application/pdf": + if not POPPLER_PATH or not os.path.exists(POPPLER_PATH): + raise HTTPException(status_code=500, detail="PDF 处理未在服务器上配置 (POPPLER_PATH)。") + try: + images = convert_from_bytes(contents, poppler_path=POPPLER_PATH) + if images: + img = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR) + except Exception as e: + raise HTTPException(status_code=500, detail=f"PDF 处理失败: {e}") + elif content_type in ["image/jpeg", "image/png", "image/bmp"]: + nparr = np.frombuffer(contents, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + else: + raise HTTPException(status_code=415, detail=f"不支持的文件类型: {content_type}。") + if img is None: - raise HTTPException(status_code=400, detail="无法解码图像文件。") + raise HTTPException(status_code=400, detail="无法解码文件。") except Exception as e: raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}") - # 1. 运行 YOLO 检测 (阶段一) + try: + full_image_df = pytesseract.image_to_data(img, lang='swe', config='--psm 3', output_type=pytesseract.Output.DATAFRAME) + except Exception as e: + print(f"全页 OCR 失败: {e}") + full_image_df = pd.DataFrame() + results = ml_models["yolo"](img, verbose=False) all_extractions = [] @@ -212,16 +384,11 @@ async def extract_invoice_data(file: UploadFile = File(...)): return {"message": "未在图像中检测到支付凭证区域。"} for res in results: - # 遍历所有检测到的凭证 (通常只有一个) for box_coords in res.boxes.xyxy.cpu().numpy().astype(int): xmin, ymin, xmax, ymax = box_coords - - # 2. 裁剪图像 crop = img[ymin:ymax, xmin:xmax] - - # 3. 在裁剪图上运行 "Plan A/B" 提取 (阶段二) - extracted_data = extract_info_from_crop(crop) - extracted_data["bounding_box"] = [xmin, ymin, xmax, ymax] + extracted_data = extract_info_from_crop(crop, full_image_df) + extracted_data["bounding_box"] = [int(xmin), int(ymin), int(xmax), int(ymax)] all_extractions.append(extracted_data) if not all_extractions: @@ -233,5 +400,8 @@ if __name__ == "__main__": import uvicorn print(f"--- 启动 FastAPI 服务 ---") print(f"加载模型: {MODEL_PATH}") + if not POPPLER_PATH or not os.path.exists(POPPLER_PATH): + print(f"警告: POPPLER_PATH 未设置或无效。PDF 上传将失败。") print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档") - uvicorn.run(app, host="127.0.0.1", port=8000) \ No newline at end of file + uvicorn.run(app, host="127.0.0.1", port=8000) +