# --- main.py (已升级) --- from fastapi import FastAPI, UploadFile, File, HTTPException from ultralytics import YOLO import cv2 import numpy as np import pytesseract import re import io import os from contextlib import asynccontextmanager # --- 配置 --- # TODO: 确保此路径指向您训练好的最佳模型 MODEL_PATH = os.path.join( os.path.dirname(os.path.abspath(__file__)), "models", "invoice_detector_v1", "weights", "best.pt" ) # 定义一个字典来在 FastAPI 启动时加载模型 ml_models = {} @asynccontextmanager async def lifespan(app: FastAPI): # 启动时加载模型 print(MODEL_PATH) if not os.path.exists(MODEL_PATH): print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。") ml_models["yolo"] = None else: # 加载您的 "payment_slip" 区域检测器 ml_models["yolo"] = YOLO(MODEL_PATH) print("YOLOv8 区域检测模型加载成功。") yield # 清理模型 ml_models.clear() app = FastAPI(lifespan=lifespan) # --- Luhn (Modulus 10) 校验函数 --- def luhn_validate(number_str: str, expected_check_digit: str) -> bool: """ 使用 Modulus 10 (Luhn) 算法验证一个数字字符串。 (从右到左, 权重 1, 2, 1, 2...) """ try: digits = [int(d) for d in number_str] weights = [1, 2] * (len(digits) // 2 + 1) weights = weights[:len(digits)] # 确保权重列表长度一致 sum_val = 0 # 从右到左计算 for d, w in zip(reversed(digits), weights): product = d * w if product >= 10: sum_val += (product // 10) + (product % 10) else: sum_val += product calculated_check_digit = (10 - (sum_val % 10)) % 10 return str(calculated_check_digit) == expected_check_digit except Exception: return False # --- 提取逻辑 --- def parse_ocr_rad(ocr_text: str) -> dict: """ Plan A: 尝试解析机器可读码 (OCR-rad) 示例: # 400299582421 # 4603 00 7 > 48180020 #14# """ # 移除所有空格以简化匹配 text_no_space = re.sub(r'\s+', '', ocr_text) # 定义一个更健壮的正则表达式 # 组1: OCR号 (在 #...# 之间) # 组2: 金额 (Kronor + Öre) (在 #... 之后) # 组3: 校验码 (1位数字) # 组4: 账户号 (在 >...# 之间) rad_regex = r'#([\d>]+)#(\d{2,})(\d)(\d{1})>([\d]+)#' match = re.search(rad_regex, text_no_space) if not match: return None # 未找到机读码行 try: ocr_num = match.group(1).replace(">", "") # 移除 > (如果有) amount_base = match.group(2) + match.group(3) # "4603" + "00" = "460300" check_digit = match.group(4) # "7" account = match.group(5) # "48180020" # 运行Luhn校验 if luhn_validate(amount_base, check_digit): # 校验成功! 这是高置信度数据 amount_kronor = amount_base[:-2] amount_ore = amount_base[-2:] return { "source": "OCR-rad (High Confidence)", "ocr_number": ocr_num, "amount_due": f"{amount_kronor}.{amount_ore}", "bankgiro_plusgiro": account, "due_date": None # 机读码行通常不包含日期 } else: # 校验失败! print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}") return None except Exception as e: print(f"解析 OCR-rad 时出错: {e}") return None def parse_human_readable(ocr_text: str) -> dict: """ Plan B: 回退到人工可读区域 (这里我们使用之前版本中的简单 Regex, 您也可以替换为您那个更复杂的 classify_text 逻辑) """ data = {"source": "Human-Readable (Fallback)"} # 查找 BG/PG (Bankgiro/Plusgiro) bg_match = re.search(r'Bankgiro\D*(\d{2,4}[- ]\d{4})', ocr_text, re.IGNORECASE) pg_match = re.search(r'PlusGiro\D*(\d{2,7}[- ]\d)', ocr_text, re.IGNORECASE) if bg_match: data["bankgiro_plusgiro"] = bg_match.group(1).replace(" ", "") elif pg_match: data["bankgiro_plusgiro"] = pg_match.group(1).replace(" ", "") else: # 备用查找 bg_pg_alt = re.search(r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', ocr_text) if bg_pg_alt: data["bankgiro_plusgiro"] = bg_pg_alt.group(0).replace(" ", "") # 查找 OCR ocr_match = re.search(r'(OCR|Fakturanummer|Referens)\D*(\d[\d\s]{5,}\d)', ocr_text, re.IGNORECASE) if ocr_match: data["ocr_number"] = re.sub(r'\s', '', ocr_match.group(2)) # 查找金额 amount_match = re.search(r'(Att betala|Belopp)\D*([\d\s,.]+)\s*(kr|SEK)?', ocr_text, re.IGNORECASE) if amount_match: amount_str = amount_match.group(2).strip().replace(" ", "").replace(",", ".") if amount_str.count('.') > 1: amount_str = amount_str.replace(".", "", amount_str.count('.') - 1) data["amount_due"] = amount_str # 查找截止日期 date_match = re.search(r'(senast|Förfallodag)\D*(\d{4}[- ]\d{2}[- ]\d{2})', ocr_text, re.IGNORECASE) if date_match: data["due_date"] = date_match.group(2).replace(" ", "-") return data def extract_info_from_crop(crop_image: np.ndarray) -> dict: """ 主提取函数: 执行 Plan A 和 Plan B """ try: # 1. 预处理并运行 OCR (获取所有文本) gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY) thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] # 使用 psm 6 假设是一个统一的文本块, 这对 OCR-rad 很友好 ocr_text = pytesseract.image_to_string(thresh, lang='swe', config='--psm 6') # 2. --- Plan A: 尝试解析机读码 --- rad_data = parse_ocr_rad(ocr_text) if rad_data: # Plan A 成功! return rad_data # 3. --- Plan B: Plan A 失败, 回退到人工读取 --- # 我们重新运行 OCR, 使用 psm 3 (自动布局), 这对人工区域更友好 ocr_text_human = pytesseract.image_to_string(thresh, lang='swe', config='--psm 3') human_data = parse_human_readable(ocr_text_human) # 即使回退失败, 也返回空字典 (或部分数据) return human_data except Exception as e: return {"error": f"提取时出错: {e}"} @app.post("/extract_invoice/") async def extract_invoice_data(file: UploadFile = File(...)): if ml_models.get("yolo") is None: raise HTTPException(status_code=503, detail="模型未加载。请检查模型路径。") try: contents = await file.read() nparr = np.frombuffer(contents, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: raise HTTPException(status_code=400, detail="无法解码图像文件。") except Exception as e: raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}") # 1. 运行 YOLO 检测 (阶段一) results = ml_models["yolo"](img, verbose=False) all_extractions = [] if not results or not results[0].boxes: return {"message": "未在图像中检测到支付凭证区域。"} for res in results: # 遍历所有检测到的凭证 (通常只有一个) for box_coords in res.boxes.xyxy.cpu().numpy().astype(int): xmin, ymin, xmax, ymax = box_coords # 2. 裁剪图像 crop = img[ymin:ymax, xmin:xmax] # 3. 在裁剪图上运行 "Plan A/B" 提取 (阶段二) extracted_data = extract_info_from_crop(crop) extracted_data["bounding_box"] = [xmin, ymin, xmax, ymax] all_extractions.append(extracted_data) if not all_extractions: return {"message": "检测到支付凭证, 但未能提取任何信息。"} return {"invoice_extractions": all_extractions} if __name__ == "__main__": import uvicorn print(f"--- 启动 FastAPI 服务 ---") print(f"加载模型: {MODEL_PATH}") print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档") uvicorn.run(app, host="127.0.0.1", port=8000)