Files
invoice-master/scripts/main.py
Yaojia Wang dafa86c588 Init
2025-10-26 20:41:11 +01:00

237 lines
8.2 KiB
Python

# --- main.py (已升级) ---
from fastapi import FastAPI, UploadFile, File, HTTPException
from ultralytics import YOLO
import cv2
import numpy as np
import pytesseract
import re
import io
import os
from contextlib import asynccontextmanager
# --- 配置 ---
# TODO: 确保此路径指向您训练好的最佳模型
MODEL_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"models", "invoice_detector_v1", "weights", "best.pt"
)
# 定义一个字典来在 FastAPI 启动时加载模型
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时加载模型
print(MODEL_PATH)
if not os.path.exists(MODEL_PATH):
print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。")
ml_models["yolo"] = None
else:
# 加载您的 "payment_slip" 区域检测器
ml_models["yolo"] = YOLO(MODEL_PATH)
print("YOLOv8 区域检测模型加载成功。")
yield
# 清理模型
ml_models.clear()
app = FastAPI(lifespan=lifespan)
# --- Luhn (Modulus 10) 校验函数 ---
def luhn_validate(number_str: str, expected_check_digit: str) -> bool:
"""
使用 Modulus 10 (Luhn) 算法验证一个数字字符串。
(从右到左, 权重 1, 2, 1, 2...)
"""
try:
digits = [int(d) for d in number_str]
weights = [1, 2] * (len(digits) // 2 + 1)
weights = weights[:len(digits)] # 确保权重列表长度一致
sum_val = 0
# 从右到左计算
for d, w in zip(reversed(digits), weights):
product = d * w
if product >= 10:
sum_val += (product // 10) + (product % 10)
else:
sum_val += product
calculated_check_digit = (10 - (sum_val % 10)) % 10
return str(calculated_check_digit) == expected_check_digit
except Exception:
return False
# --- 提取逻辑 ---
def parse_ocr_rad(ocr_text: str) -> dict:
"""
Plan A: 尝试解析机器可读码 (OCR-rad)
示例: # 400299582421 # 4603 00 7 > 48180020 #14#
"""
# 移除所有空格以简化匹配
text_no_space = re.sub(r'\s+', '', ocr_text)
# 定义一个更健壮的正则表达式
# 组1: OCR号 (在 #...# 之间)
# 组2: 金额 (Kronor + Öre) (在 #... 之后)
# 组3: 校验码 (1位数字)
# 组4: 账户号 (在 >...# 之间)
rad_regex = r'#([\d>]+)#(\d{2,})(\d)(\d{1})>([\d]+)#'
match = re.search(rad_regex, text_no_space)
if not match:
return None # 未找到机读码行
try:
ocr_num = match.group(1).replace(">", "") # 移除 > (如果有)
amount_base = match.group(2) + match.group(3) # "4603" + "00" = "460300"
check_digit = match.group(4) # "7"
account = match.group(5) # "48180020"
# 运行Luhn校验
if luhn_validate(amount_base, check_digit):
# 校验成功! 这是高置信度数据
amount_kronor = amount_base[:-2]
amount_ore = amount_base[-2:]
return {
"source": "OCR-rad (High Confidence)",
"ocr_number": ocr_num,
"amount_due": f"{amount_kronor}.{amount_ore}",
"bankgiro_plusgiro": account,
"due_date": None # 机读码行通常不包含日期
}
else:
# 校验失败!
print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}")
return None
except Exception as e:
print(f"解析 OCR-rad 时出错: {e}")
return None
def parse_human_readable(ocr_text: str) -> dict:
"""
Plan B: 回退到人工可读区域
(这里我们使用之前版本中的简单 Regex,
您也可以替换为您那个更复杂的 classify_text 逻辑)
"""
data = {"source": "Human-Readable (Fallback)"}
# 查找 BG/PG (Bankgiro/Plusgiro)
bg_match = re.search(r'Bankgiro\D*(\d{2,4}[- ]\d{4})', ocr_text, re.IGNORECASE)
pg_match = re.search(r'PlusGiro\D*(\d{2,7}[- ]\d)', ocr_text, re.IGNORECASE)
if bg_match:
data["bankgiro_plusgiro"] = bg_match.group(1).replace(" ", "")
elif pg_match:
data["bankgiro_plusgiro"] = pg_match.group(1).replace(" ", "")
else:
# 备用查找
bg_pg_alt = re.search(r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', ocr_text)
if bg_pg_alt:
data["bankgiro_plusgiro"] = bg_pg_alt.group(0).replace(" ", "")
# 查找 OCR
ocr_match = re.search(r'(OCR|Fakturanummer|Referens)\D*(\d[\d\s]{5,}\d)', ocr_text, re.IGNORECASE)
if ocr_match:
data["ocr_number"] = re.sub(r'\s', '', ocr_match.group(2))
# 查找金额
amount_match = re.search(r'(Att betala|Belopp)\D*([\d\s,.]+)\s*(kr|SEK)?', ocr_text, re.IGNORECASE)
if amount_match:
amount_str = amount_match.group(2).strip().replace(" ", "").replace(",", ".")
if amount_str.count('.') > 1:
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
data["amount_due"] = amount_str
# 查找截止日期
date_match = re.search(r'(senast|Förfallodag)\D*(\d{4}[- ]\d{2}[- ]\d{2})', ocr_text, re.IGNORECASE)
if date_match:
data["due_date"] = date_match.group(2).replace(" ", "-")
return data
def extract_info_from_crop(crop_image: np.ndarray) -> dict:
"""
主提取函数: 执行 Plan A 和 Plan B
"""
try:
# 1. 预处理并运行 OCR (获取所有文本)
gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
# 使用 psm 6 假设是一个统一的文本块, 这对 OCR-rad 很友好
ocr_text = pytesseract.image_to_string(thresh, lang='swe', config='--psm 6')
# 2. --- Plan A: 尝试解析机读码 ---
rad_data = parse_ocr_rad(ocr_text)
if rad_data:
# Plan A 成功!
return rad_data
# 3. --- Plan B: Plan A 失败, 回退到人工读取 ---
# 我们重新运行 OCR, 使用 psm 3 (自动布局), 这对人工区域更友好
ocr_text_human = pytesseract.image_to_string(thresh, lang='swe', config='--psm 3')
human_data = parse_human_readable(ocr_text_human)
# 即使回退失败, 也返回空字典 (或部分数据)
return human_data
except Exception as e:
return {"error": f"提取时出错: {e}"}
@app.post("/extract_invoice/")
async def extract_invoice_data(file: UploadFile = File(...)):
if ml_models.get("yolo") is None:
raise HTTPException(status_code=503, detail="模型未加载。请检查模型路径。")
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=400, detail="无法解码图像文件。")
except Exception as e:
raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}")
# 1. 运行 YOLO 检测 (阶段一)
results = ml_models["yolo"](img, verbose=False)
all_extractions = []
if not results or not results[0].boxes:
return {"message": "未在图像中检测到支付凭证区域。"}
for res in results:
# 遍历所有检测到的凭证 (通常只有一个)
for box_coords in res.boxes.xyxy.cpu().numpy().astype(int):
xmin, ymin, xmax, ymax = box_coords
# 2. 裁剪图像
crop = img[ymin:ymax, xmin:xmax]
# 3. 在裁剪图上运行 "Plan A/B" 提取 (阶段二)
extracted_data = extract_info_from_crop(crop)
extracted_data["bounding_box"] = [xmin, ymin, xmax, ymax]
all_extractions.append(extracted_data)
if not all_extractions:
return {"message": "检测到支付凭证, 但未能提取任何信息。"}
return {"invoice_extractions": all_extractions}
if __name__ == "__main__":
import uvicorn
print(f"--- 启动 FastAPI 服务 ---")
print(f"加载模型: {MODEL_PATH}")
print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档")
uvicorn.run(app, host="127.0.0.1", port=8000)