# --- main.py (已升级, v4.6 - 分离 InvoiceNr 和 OCRNr) --- from fastapi import FastAPI, UploadFile, File, HTTPException from ultralytics import YOLO import cv2 import numpy as np import pytesseract import pandas as pd import re import io import os import sys from contextlib import asynccontextmanager from pathlib import Path from pdf2image import convert_from_bytes # --- 配置 (使用您的自定义配置) --- # 1. 设置 BASE_DIR BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # 2. 将项目根目录添加到 sys.path sys.path.insert(0, BASE_DIR) # 3. 从 config.py 导入路径 from config import POPPLER_PATH, apply_tesseract_path # --------------------------------- MODEL_PATH_STR = os.path.join(BASE_DIR, "models", "payment_slip_detector_v1", "weights", "best.pt") MODEL_PATH = Path(MODEL_PATH_STR) ml_models = {} @asynccontextmanager async def lifespan(app: FastAPI): # 启动时加载模型 print(f"正在加载模型: {MODEL_PATH}") if not os.path.exists(MODEL_PATH): print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。") ml_models["yolo"] = None else: ml_models["yolo"] = YOLO(MODEL_PATH) print("YOLOv8 区域检测模型加载成功。") # 应用 Tesseract 路径 (来自 config) try: apply_tesseract_path() print("已应用 Tesseract 路径配置。") except Exception as e: print(f"警告: 调用 apply_tesseract_path 时出错: {e}") # 检查 Poppler 路径 (来自 config) if not POPPLER_PATH or not os.path.exists(POPPLER_PATH): print(f"警告: POPPLER_PATH 未设置或无效: {POPPLER_PATH}。PDF 处理将失败。") else: print(f"Poppler 路径已加载: {POPPLER_PATH}") yield ml_models.clear() app = FastAPI(lifespan=lifespan) # --- 校验函数 (Luhn, Mod11, etc.) --- def luhn_validate(number_str: str, expected_check_digit: str) -> bool: """ 使用 Modulus 10 (Luhn) 算法验证一个数字字符串。 (从右到左, 权重 2, 1, 2, 1...) """ try: digits = [int(d) for d in number_str] weights = [2, 1] * (len(digits) // 2 + 1) weights = weights[:len(digits)] sum_val = 0 for d, w in zip(reversed(digits), weights): product = d * w if product >= 10: sum_val += (product // 10) + (product % 10) else: sum_val += product calculated_check_digit = (10 - (sum_val % 10)) % 10 return str(calculated_check_digit) == expected_check_digit except Exception: return False def modulus_11_validate(number_str: str) -> bool: try: if not number_str.isdigit() or len(number_str) < 2: return False base_num = number_str[:-1] expected_check_digit_str = number_str[-1] weights = [i for i in range(2, 2 + len(base_num))] sum_val = 0 for digit_char, weight in zip(reversed(base_num), weights): sum_val += int(digit_char) * weight remainder = sum_val % 11 calculated_check_val = 11 - remainder if calculated_check_val == 10: calculated_check_digit = "0" elif calculated_check_val == 11: calculated_check_digit = "0" else: calculated_check_digit = str(calculated_check_val) return calculated_check_digit == expected_check_digit_str except Exception: return False def luhn_validate_ocr(number_str: str) -> bool: try: if not number_str.isdigit() or len(number_str) < 2: return False base_num = number_str[:-1] check_digit = number_str[-1] return luhn_validate(base_num, check_digit) except Exception: return False def check_ocr_validity(ocr_number: str) -> bool: if not ocr_number or not ocr_number.isdigit(): return False if luhn_validate_ocr(ocr_number): return True if modulus_11_validate(ocr_number): return True return False # --- Plan A: 机读码解析 (已修复) --- def parse_ocr_rad(ocr_text: str) -> dict: text_no_space = re.sub(r'\s+', '', ocr_text) # 修正: 捕获组索引 (group(1) 是 OCR, group(2) 是金额, group(3) 是校验码, group(4) 是账号) rad_regex = r'#([\d>]+)#(\d{2,})(\d{1})>([\d]+)#' match = re.search(rad_regex, text_no_space) if not match: return None try: ocr_num = match.group(1).replace(">", "") amount_base = match.group(2) check_digit = match.group(3) account = match.group(4) amount_valid = luhn_validate(amount_base, check_digit) if amount_valid: amount_kronor = amount_base[:-2] amount_ore = amount_base[-2:] ocr_valid = check_ocr_validity(ocr_num) print("Plan A (OCR-rad) 成功!") return { "source": "OCR-rad (High Confidence)", "ocr_number": ocr_num, # 这是支付参考号 "ocr_number_valid": ocr_valid, "amount_due": f"{amount_kronor}.{amount_ore}", "amount_due_valid": True, "bankgiro_plusgiro": account, "account_type": None, "due_date": None, "invoice_number": None # Plan A 不知道发票号 } else: print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}") return None except Exception as e: print(f"解析 OCR-rad 时出错: {e}") return None # --- 升级: Plan B (智能 KIE 逻辑) --- def find_value_for_key(key_box, all_boxes, value_regex, max_dist_x_h=400, max_dist_y_h=20, max_dist_y_v=70, max_dist_x_v=150): """ 一个辅助函数, 用于查找一个 "键" (key_box) 右侧或下方的 "值" """ key_y_center = key_box['top'] + key_box['height'] / 2 key_x_center = key_box['left'] + key_box['width'] / 2 key_x_end = key_box['left'] + key_box['width'] key_y_end = key_box['top'] + key_box['height'] potential_values = [] for _, value_box in all_boxes.iterrows(): if value_box.name == key_box.name: continue value_y_center = value_box['top'] + value_box['height'] / 2 value_x_center = value_box['left'] + value_box['width'] / 2 value_x_start = value_box['left'] value_y_start = value_box['top'] value_text = str(value_box['text']).strip() if not re.match(value_regex, value_text): continue # --- 检查 A: 水平 (Horizontal) --- if value_x_start > key_x_end: if abs(value_y_center - key_y_center) < max_dist_y_h: dist_x = value_x_start - key_x_end if dist_x < max_dist_x_h: potential_values.append((dist_x, value_text)) # --- 检查 B: 垂直 (Vertical) --- if value_y_start > key_y_end: if abs(value_x_center - key_x_center) < max_dist_x_v: dist_y = value_y_start - key_y_end if dist_y < max_dist_y_v: potential_values.append((dist_y + 1000, value_text)) # H 优先 if potential_values: potential_values.sort(key=lambda x: x[0]) return potential_values[0][1] return None def parse_human_readable_KIE(ocr_df: pd.DataFrame) -> dict: """ Plan B (v4.6): 使用 KIE 逻辑回退 (分离 InvoiceNr 和 OCRNr) """ data = { "source": "Human-Readable KIE (Fallback)", "bankgiro_plusgiro": None, "account_type": None, "invoice_number": None, # <-- 新增 "ocr_number": None, "ocr_number_valid": False, "amount_due": None, "amount_due_valid": False, "due_date": None } ocr_df = ocr_df[ocr_df['conf'] > 30].copy() ocr_df['text'] = ocr_df['text'].astype(str).str.strip() ocr_df = ocr_df.dropna(subset=['text']) ocr_df = ocr_df[ocr_df['text'] != ""] if ocr_df.empty: return data # 定义所有 Key-Value 查找规则 # Regex 格式: ( [关键词列表], 值Regex, 目标字段名 ) search_rules = [ (r'PlusGiro|PG', r'(\d{2,7}[- ]\d{1,2}[- ]\d{1,2})', "bankgiro_plusgiro", "PG"), (r'Bankgiro|BG|bankgironr', r'(\d{2,4}[- ]\d{4})', "bankgiro_plusgiro", "BG"), (r'Att betala:?|Belopp:?', r'([\d\s,.]+\d{2})', "amount_due", None), (r'senast|Förfallodag', r'\d{4}[- ]\d{2}[- ]\d{2}', "due_date", None), # --- 逻辑分离 --- (r'Fakturanr|Fakturanummer', r'([\w\d][\w\d-]{3,}[\w\d])', "invoice_number", None), (r'OCR|Referens|Betalningsreferens', r'([\w\d][\w\d-]{3,}[\w\d])', "ocr_number", None), ] for _, box in ocr_df.iterrows(): text = box['text'] for key_regex, value_regex, field_name, extra_value in search_rules: # 如果这个字段已经被填过了, 就跳过 (例如 BG 找到了, 就不要再找 PG) if data.get(field_name) is not None: continue if re.search(key_regex, text, re.IGNORECASE): value = find_value_for_key(box, ocr_df, value_regex, max_dist_y_v=70) if value: # 特殊清理 if field_name == "bankgiro_plusgiro": data[field_name] = re.sub(r'\s', '', value) data["account_type"] = extra_value # 设置 "BG" 或 "PG" elif field_name == "amount_due": amount_str = value.replace(" ", "").replace(",", ".") if amount_str.count('.') > 1: amount_str = amount_str.replace(".", "", amount_str.count('.') - 1) data[field_name] = amount_str elif field_name == "due_date": data[field_name] = value.replace(" ", "-") elif field_name in ["invoice_number", "ocr_number"]: if value.lower() == "kronor" or value.lower() == "öre": continue # 跳过错误的匹配 data[field_name] = re.sub(r'\s', '', value) if field_name == "ocr_number": data["ocr_number_valid"] = check_ocr_validity(data[field_name]) return data # --- Plan C: 全页回退 (用于发票号码) --- def find_invoice_number_full_page(ocr_df: pd.DataFrame) -> dict: """ Plan C: 如果 invoice_number 仍为 null, 则在整个页面上搜索。 """ ocr_df = ocr_df[ocr_df['conf'] > 30].copy() ocr_df['text'] = ocr_df['text'].astype(str).str.strip() ocr_df = ocr_df.dropna(subset=['text']) ocr_df = ocr_df[ocr_df['text'] != ""] if ocr_df.empty: return {} inv_regex = r'([\w\d][\w\d-]{3,}[\w\d])' for _, box in ocr_df.iterrows(): # 只查找 Fakturanr if re.search(r'Fakturanr|Fakturanummer', box['text'], re.IGNORECASE): value = find_value_for_key(box, ocr_df, inv_regex, max_dist_x_h=200, max_dist_y_h=10, max_dist_y_v=70, max_dist_x_v=150) if value: inv_num_str = re.sub(r'\s', '', value) return {"invoice_number": inv_num_str} return {} def extract_info_from_crop(crop_image: np.ndarray, full_image_df: pd.DataFrame) -> dict: """ 主提取函数: 执行 Plan A, B, C """ try: gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY) thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] # 1. --- Plan B (KIE) --- # 总是先运行 Plan B, 获取基础信息 ocr_df_crop = pytesseract.image_to_data(thresh, lang='swe', config='--psm 11', output_type=pytesseract.Output.DATAFRAME) final_data = parse_human_readable_KIE(ocr_df_crop) # 2. --- 尝试 Plan A (OCR-rad) --- ocr_text_rad = pytesseract.image_to_string(thresh, lang='swe', config='--psm 7') # psm 7: 假设为单行文本 rad_data = parse_ocr_rad(ocr_text_rad) if rad_data: # Plan A 成功! (高置信度) # **覆盖** Plan B 的支付字段, 但 *保留* Plan B 的 invoice_number final_data["source"] = rad_data["source"] final_data["ocr_number"] = rad_data["ocr_number"] final_data["ocr_number_valid"] = rad_data["ocr_number_valid"] final_data["amount_due"] = rad_data["amount_due"] final_data["amount_due_valid"] = rad_data["amount_due_valid"] final_data["bankgiro_plusgiro"] = rad_data["bankgiro_plusgiro"] # (如果 Plan A 能确定 BG/PG, 也可以覆盖 account_type) # 3. --- Plan C (全页回退发票号码) --- if final_data["invoice_number"] is None: print("Plan B 未能从裁剪区找到 InvoiceNr, 启动 Plan C (全页搜索)...") full_page_inv_result = find_invoice_number_full_page(full_image_df) if full_page_inv_result: final_data["invoice_number"] = full_page_inv_result.get("invoice_number") final_data["source"] += " + Full-Page-InvNr" return final_data except Exception as e: print(f"提取时出错: {e}") return {"error": f"提取时出错: {e}"} @app.post("/extract_invoice/") async def extract_invoice_data(file: UploadFile = File(...)): if ml_models.get("yolo") is None: raise HTTPException(status_code=503, detail="模型未加载。请检查模型路径。") try: contents = await file.read() content_type = file.content_type img = None if content_type == "application/pdf": if not POPPLER_PATH or not os.path.exists(POPPLER_PATH): raise HTTPException(status_code=500, detail="PDF 处理未在服务器上配置 (POPPLER_PATH)。") try: images = convert_from_bytes(contents, poppler_path=POPPLER_PATH) if images: img = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR) except Exception as e: raise HTTPException(status_code=500, detail=f"PDF 处理失败: {e}") elif content_type in ["image/jpeg", "image/png", "image/bmp"]: nparr = np.frombuffer(contents, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) else: raise HTTPException(status_code=415, detail=f"不支持的文件类型: {content_type}。") if img is None: raise HTTPException(status_code=400, detail="无法解码文件。") except Exception as e: raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}") try: full_image_df = pytesseract.image_to_data(img, lang='swe', config='--psm 3', output_type=pytesseract.Output.DATAFRAME) except Exception as e: print(f"全页 OCR 失败: {e}") full_image_df = pd.DataFrame() results = ml_models["yolo"](img, verbose=False) all_extractions = [] if not results or not results[0].boxes: return {"message": "未在图像中检测到支付凭证区域。"} for res in results: for box_coords in res.boxes.xyxy.cpu().numpy().astype(int): xmin, ymin, xmax, ymax = box_coords crop = img[ymin:ymax, xmin:xmax] extracted_data = extract_info_from_crop(crop, full_image_df) extracted_data["bounding_box"] = [int(xmin), int(ymin), int(xmax), int(ymax)] all_extractions.append(extracted_data) if not all_extractions: return {"message": "检测到支付凭证, 但未能提取任何信息。"} return {"invoice_extractions": all_extractions} if __name__ == "__main__": import uvicorn print(f"--- 启动 FastAPI 服务 ---") print(f"加载模型: {MODEL_PATH}") if not POPPLER_PATH or not os.path.exists(POPPLER_PATH): print(f"警告: POPPLER_PATH 未设置或无效。PDF 上传将失败。") print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档") uvicorn.run(app, host="127.0.0.1", port=8000)