This commit is contained in:
Yaojia Wang
2025-10-26 22:14:21 +01:00
parent dafa86c588
commit b3f361847a

View File

@@ -1,189 +1,340 @@
# --- main.py (已升级) ---
# --- main.py (已升级, v4.6 - 分离 InvoiceNr 和 OCRNr) ---
from fastapi import FastAPI, UploadFile, File, HTTPException
from ultralytics import YOLO
import cv2
import numpy as np
import pytesseract
import pandas as pd
import re
import io
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from pdf2image import convert_from_bytes
# --- 配置 ---
# TODO: 确保此路径指向您训练好的最佳模型
MODEL_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"models", "invoice_detector_v1", "weights", "best.pt"
)
# --- 配置 (使用您的自定义配置) ---
# 1. 设置 BASE_DIR
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 2. 将项目根目录添加到 sys.path
sys.path.insert(0, BASE_DIR)
# 3. 从 config.py 导入路径
from config import POPPLER_PATH, apply_tesseract_path
# ---------------------------------
MODEL_PATH_STR = os.path.join(BASE_DIR, "models", "payment_slip_detector_v1", "weights", "best.pt")
MODEL_PATH = Path(MODEL_PATH_STR)
# 定义一个字典来在 FastAPI 启动时加载模型
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时加载模型
print(MODEL_PATH)
print(f"正在加载模型: {MODEL_PATH}")
if not os.path.exists(MODEL_PATH):
print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。")
ml_models["yolo"] = None
else:
# 加载您的 "payment_slip" 区域检测器
ml_models["yolo"] = YOLO(MODEL_PATH)
print("YOLOv8 区域检测模型加载成功。")
# 应用 Tesseract 路径 (来自 config)
try:
apply_tesseract_path()
print("已应用 Tesseract 路径配置。")
except Exception as e:
print(f"警告: 调用 apply_tesseract_path 时出错: {e}")
# 检查 Poppler 路径 (来自 config)
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
print(f"警告: POPPLER_PATH 未设置或无效: {POPPLER_PATH}。PDF 处理将失败。")
else:
print(f"Poppler 路径已加载: {POPPLER_PATH}")
yield
# 清理模型
ml_models.clear()
app = FastAPI(lifespan=lifespan)
# --- Luhn (Modulus 10) 校验函数 ---
# --- 校验函数 (Luhn, Mod11, etc.) ---
def luhn_validate(number_str: str, expected_check_digit: str) -> bool:
"""
使用 Modulus 10 (Luhn) 算法验证一个数字字符串。
(从右到左, 权重 1, 2, 1, 2...)
(从右到左, 权重 2, 1, 2, 1...)
"""
try:
digits = [int(d) for d in number_str]
weights = [1, 2] * (len(digits) // 2 + 1)
weights = weights[:len(digits)] # 确保权重列表长度一致
weights = [2, 1] * (len(digits) // 2 + 1)
weights = weights[:len(digits)]
sum_val = 0
# 从右到左计算
for d, w in zip(reversed(digits), weights):
product = d * w
if product >= 10:
sum_val += (product // 10) + (product % 10)
else:
sum_val += product
calculated_check_digit = (10 - (sum_val % 10)) % 10
return str(calculated_check_digit) == expected_check_digit
except Exception:
return False
# --- 提取逻辑 ---
def modulus_11_validate(number_str: str) -> bool:
try:
if not number_str.isdigit() or len(number_str) < 2: return False
base_num = number_str[:-1]
expected_check_digit_str = number_str[-1]
weights = [i for i in range(2, 2 + len(base_num))]
sum_val = 0
for digit_char, weight in zip(reversed(base_num), weights):
sum_val += int(digit_char) * weight
remainder = sum_val % 11
calculated_check_val = 11 - remainder
if calculated_check_val == 10: calculated_check_digit = "0"
elif calculated_check_val == 11: calculated_check_digit = "0"
else: calculated_check_digit = str(calculated_check_val)
return calculated_check_digit == expected_check_digit_str
except Exception:
return False
def luhn_validate_ocr(number_str: str) -> bool:
try:
if not number_str.isdigit() or len(number_str) < 2: return False
base_num = number_str[:-1]
check_digit = number_str[-1]
return luhn_validate(base_num, check_digit)
except Exception:
return False
def check_ocr_validity(ocr_number: str) -> bool:
if not ocr_number or not ocr_number.isdigit():
return False
if luhn_validate_ocr(ocr_number):
return True
if modulus_11_validate(ocr_number):
return True
return False
# --- Plan A: 机读码解析 (已修复) ---
def parse_ocr_rad(ocr_text: str) -> dict:
"""
Plan A: 尝试解析机器可读码 (OCR-rad)
示例: # 400299582421 # 4603 00 7 > 48180020 #14#
"""
# 移除所有空格以简化匹配
text_no_space = re.sub(r'\s+', '', ocr_text)
# 定义一个更健壮的正则表达式
# 组1: OCR号 (在 #...# 之间)
# 组2: 金额 (Kronor + Öre) (在 #... 之后)
# 组3: 校验码 (1位数字)
# 组4: 账户号 (在 >...# 之间)
rad_regex = r'#([\d>]+)#(\d{2,})(\d)(\d{1})>([\d]+)#'
# 修正: 捕获组索引 (group(1) 是 OCR, group(2) 是金额, group(3) 是校验码, group(4) 是账号)
rad_regex = r'#([\d>]+)#(\d{2,})(\d{1})>([\d]+)#'
match = re.search(rad_regex, text_no_space)
if not match:
return None # 未找到机读码行
return None
try:
ocr_num = match.group(1).replace(">", "") # 移除 > (如果有)
amount_base = match.group(2) + match.group(3) # "4603" + "00" = "460300"
check_digit = match.group(4) # "7"
account = match.group(5) # "48180020"
ocr_num = match.group(1).replace(">", "")
amount_base = match.group(2)
check_digit = match.group(3)
account = match.group(4)
# 运行Luhn校验
if luhn_validate(amount_base, check_digit):
# 校验成功! 这是高置信度数据
amount_valid = luhn_validate(amount_base, check_digit)
if amount_valid:
amount_kronor = amount_base[:-2]
amount_ore = amount_base[-2:]
ocr_valid = check_ocr_validity(ocr_num)
print("Plan A (OCR-rad) 成功!")
return {
"source": "OCR-rad (High Confidence)",
"ocr_number": ocr_num,
"ocr_number": ocr_num, # 这是支付参考号
"ocr_number_valid": ocr_valid,
"amount_due": f"{amount_kronor}.{amount_ore}",
"amount_due_valid": True,
"bankgiro_plusgiro": account,
"due_date": None # 机读码行通常不包含日期
"account_type": None,
"due_date": None,
"invoice_number": None # Plan A 不知道发票号
}
else:
# 校验失败!
print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}")
return None
except Exception as e:
print(f"解析 OCR-rad 时出错: {e}")
return None
def parse_human_readable(ocr_text: str) -> dict:
# --- 升级: Plan B (智能 KIE 逻辑) ---
def find_value_for_key(key_box, all_boxes, value_regex, max_dist_x_h=400, max_dist_y_h=20, max_dist_y_v=70, max_dist_x_v=150):
"""
Plan B: 回退到人工可读区域
(这里我们使用之前版本中的简单 Regex,
您也可以替换为您那个更复杂的 classify_text 逻辑)
一个辅助函数, 用于查找一个 "" (key_box) 右侧或下方的 ""
"""
data = {"source": "Human-Readable (Fallback)"}
key_y_center = key_box['top'] + key_box['height'] / 2
key_x_center = key_box['left'] + key_box['width'] / 2
key_x_end = key_box['left'] + key_box['width']
key_y_end = key_box['top'] + key_box['height']
# 查找 BG/PG (Bankgiro/Plusgiro)
bg_match = re.search(r'Bankgiro\D*(\d{2,4}[- ]\d{4})', ocr_text, re.IGNORECASE)
pg_match = re.search(r'PlusGiro\D*(\d{2,7}[- ]\d)', ocr_text, re.IGNORECASE)
if bg_match:
data["bankgiro_plusgiro"] = bg_match.group(1).replace(" ", "")
elif pg_match:
data["bankgiro_plusgiro"] = pg_match.group(1).replace(" ", "")
else:
# 备用查找
bg_pg_alt = re.search(r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', ocr_text)
if bg_pg_alt:
data["bankgiro_plusgiro"] = bg_pg_alt.group(0).replace(" ", "")
potential_values = []
for _, value_box in all_boxes.iterrows():
if value_box.name == key_box.name:
continue
value_y_center = value_box['top'] + value_box['height'] / 2
value_x_center = value_box['left'] + value_box['width'] / 2
value_x_start = value_box['left']
value_y_start = value_box['top']
value_text = str(value_box['text']).strip()
# 查找 OCR
ocr_match = re.search(r'(OCR|Fakturanummer|Referens)\D*(\d[\d\s]{5,}\d)', ocr_text, re.IGNORECASE)
if ocr_match:
data["ocr_number"] = re.sub(r'\s', '', ocr_match.group(2))
if not re.match(value_regex, value_text):
continue
# 查找金额
amount_match = re.search(r'(Att betala|Belopp)\D*([\d\s,.]+)\s*(kr|SEK)?', ocr_text, re.IGNORECASE)
if amount_match:
amount_str = amount_match.group(2).strip().replace(" ", "").replace(",", ".")
if amount_str.count('.') > 1:
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
data["amount_due"] = amount_str
# --- 检查 A: 水平 (Horizontal) ---
if value_x_start > key_x_end:
if abs(value_y_center - key_y_center) < max_dist_y_h:
dist_x = value_x_start - key_x_end
if dist_x < max_dist_x_h:
potential_values.append((dist_x, value_text))
# --- 检查 B: 垂直 (Vertical) ---
if value_y_start > key_y_end:
if abs(value_x_center - key_x_center) < max_dist_x_v:
dist_y = value_y_start - key_y_end
if dist_y < max_dist_y_v:
potential_values.append((dist_y + 1000, value_text)) # H 优先
if potential_values:
potential_values.sort(key=lambda x: x[0])
return potential_values[0][1]
# 查找截止日期
date_match = re.search(r'(senast|Förfallodag)\D*(\d{4}[- ]\d{2}[- ]\d{2})', ocr_text, re.IGNORECASE)
if date_match:
data["due_date"] = date_match.group(2).replace(" ", "-")
return None
def parse_human_readable_KIE(ocr_df: pd.DataFrame) -> dict:
"""
Plan B (v4.6): 使用 KIE 逻辑回退 (分离 InvoiceNr 和 OCRNr)
"""
data = {
"source": "Human-Readable KIE (Fallback)",
"bankgiro_plusgiro": None,
"account_type": None,
"invoice_number": None, # <-- 新增
"ocr_number": None,
"ocr_number_valid": False,
"amount_due": None,
"amount_due_valid": False,
"due_date": None
}
ocr_df = ocr_df[ocr_df['conf'] > 30].copy()
ocr_df['text'] = ocr_df['text'].astype(str).str.strip()
ocr_df = ocr_df.dropna(subset=['text'])
ocr_df = ocr_df[ocr_df['text'] != ""]
if ocr_df.empty:
return data
# 定义所有 Key-Value 查找规则
# Regex 格式: ( [关键词列表], 值Regex, 目标字段名 )
search_rules = [
(r'PlusGiro|PG', r'(\d{2,7}[- ]\d{1,2}[- ]\d{1,2})', "bankgiro_plusgiro", "PG"),
(r'Bankgiro|BG|bankgironr', r'(\d{2,4}[- ]\d{4})', "bankgiro_plusgiro", "BG"),
(r'Att betala:?|Belopp:?', r'([\d\s,.]+\d{2})', "amount_due", None),
(r'senast|Förfallodag', r'\d{4}[- ]\d{2}[- ]\d{2}', "due_date", None),
# --- 逻辑分离 ---
(r'Fakturanr|Fakturanummer', r'([\w\d][\w\d-]{3,}[\w\d])', "invoice_number", None),
(r'OCR|Referens|Betalningsreferens', r'([\w\d][\w\d-]{3,}[\w\d])', "ocr_number", None),
]
for _, box in ocr_df.iterrows():
text = box['text']
for key_regex, value_regex, field_name, extra_value in search_rules:
# 如果这个字段已经被填过了, 就跳过 (例如 BG 找到了, 就不要再找 PG)
if data.get(field_name) is not None:
continue
if re.search(key_regex, text, re.IGNORECASE):
value = find_value_for_key(box, ocr_df, value_regex, max_dist_y_v=70)
if value:
# 特殊清理
if field_name == "bankgiro_plusgiro":
data[field_name] = re.sub(r'\s', '', value)
data["account_type"] = extra_value # 设置 "BG" 或 "PG"
elif field_name == "amount_due":
amount_str = value.replace(" ", "").replace(",", ".")
if amount_str.count('.') > 1:
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
data[field_name] = amount_str
elif field_name == "due_date":
data[field_name] = value.replace(" ", "-")
elif field_name in ["invoice_number", "ocr_number"]:
if value.lower() == "kronor" or value.lower() == "öre":
continue # 跳过错误的匹配
data[field_name] = re.sub(r'\s', '', value)
if field_name == "ocr_number":
data["ocr_number_valid"] = check_ocr_validity(data[field_name])
return data
def extract_info_from_crop(crop_image: np.ndarray) -> dict:
# --- Plan C: 全页回退 (用于发票号码) ---
def find_invoice_number_full_page(ocr_df: pd.DataFrame) -> dict:
"""
主提取函数: 执行 Plan A 和 Plan B
Plan C: 如果 invoice_number 仍为 null, 则在整个页面上搜索。
"""
ocr_df = ocr_df[ocr_df['conf'] > 30].copy()
ocr_df['text'] = ocr_df['text'].astype(str).str.strip()
ocr_df = ocr_df.dropna(subset=['text'])
ocr_df = ocr_df[ocr_df['text'] != ""]
if ocr_df.empty:
return {}
inv_regex = r'([\w\d][\w\d-]{3,}[\w\d])'
for _, box in ocr_df.iterrows():
# 只查找 Fakturanr
if re.search(r'Fakturanr|Fakturanummer', box['text'], re.IGNORECASE):
value = find_value_for_key(box, ocr_df, inv_regex, max_dist_x_h=200, max_dist_y_h=10, max_dist_y_v=70, max_dist_x_v=150)
if value:
inv_num_str = re.sub(r'\s', '', value)
return {"invoice_number": inv_num_str}
return {}
def extract_info_from_crop(crop_image: np.ndarray, full_image_df: pd.DataFrame) -> dict:
"""
主提取函数: 执行 Plan A, B, C
"""
try:
# 1. 预处理并运行 OCR (获取所有文本)
gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
# 使用 psm 6 假设是一个统一的文本块, 这对 OCR-rad 很友好
ocr_text = pytesseract.image_to_string(thresh, lang='swe', config='--psm 6')
# 1. --- Plan B (KIE) ---
# 总是先运行 Plan B, 获取基础信息
ocr_df_crop = pytesseract.image_to_data(thresh, lang='swe', config='--psm 11', output_type=pytesseract.Output.DATAFRAME)
final_data = parse_human_readable_KIE(ocr_df_crop)
# 2. --- Plan A: 尝试解析机读码 ---
rad_data = parse_ocr_rad(ocr_text)
# 2. --- 尝试 Plan A (OCR-rad) ---
ocr_text_rad = pytesseract.image_to_string(thresh, lang='swe', config='--psm 7') # psm 7: 假设为单行文本
rad_data = parse_ocr_rad(ocr_text_rad)
if rad_data:
# Plan A 成功!
return rad_data
# Plan A 成功! (高置信度)
# **覆盖** Plan B 的支付字段, 但 *保留* Plan B 的 invoice_number
final_data["source"] = rad_data["source"]
final_data["ocr_number"] = rad_data["ocr_number"]
final_data["ocr_number_valid"] = rad_data["ocr_number_valid"]
final_data["amount_due"] = rad_data["amount_due"]
final_data["amount_due_valid"] = rad_data["amount_due_valid"]
final_data["bankgiro_plusgiro"] = rad_data["bankgiro_plusgiro"]
# (如果 Plan A 能确定 BG/PG, 也可以覆盖 account_type)
# 3. --- Plan C (全页回退发票号码) ---
if final_data["invoice_number"] is None:
print("Plan B 未能从裁剪区找到 InvoiceNr, 启动 Plan C (全页搜索)...")
full_page_inv_result = find_invoice_number_full_page(full_image_df)
if full_page_inv_result:
final_data["invoice_number"] = full_page_inv_result.get("invoice_number")
final_data["source"] += " + Full-Page-InvNr"
# 3. --- Plan B: Plan A 失败, 回退到人工读取 ---
# 我们重新运行 OCR, 使用 psm 3 (自动布局), 这对人工区域更友好
ocr_text_human = pytesseract.image_to_string(thresh, lang='swe', config='--psm 3')
human_data = parse_human_readable(ocr_text_human)
# 即使回退失败, 也返回空字典 (或部分数据)
return human_data
return final_data
except Exception as e:
print(f"提取时出错: {e}")
return {"error": f"提取时出错: {e}"}
@@ -195,15 +346,36 @@ async def extract_invoice_data(file: UploadFile = File(...)):
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
content_type = file.content_type
img = None
if content_type == "application/pdf":
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
raise HTTPException(status_code=500, detail="PDF 处理未在服务器上配置 (POPPLER_PATH)。")
try:
images = convert_from_bytes(contents, poppler_path=POPPLER_PATH)
if images:
img = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR)
except Exception as e:
raise HTTPException(status_code=500, detail=f"PDF 处理失败: {e}")
elif content_type in ["image/jpeg", "image/png", "image/bmp"]:
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
else:
raise HTTPException(status_code=415, detail=f"不支持的文件类型: {content_type}")
if img is None:
raise HTTPException(status_code=400, detail="无法解码图像文件。")
raise HTTPException(status_code=400, detail="无法解码文件。")
except Exception as e:
raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}")
# 1. 运行 YOLO 检测 (阶段一)
try:
full_image_df = pytesseract.image_to_data(img, lang='swe', config='--psm 3', output_type=pytesseract.Output.DATAFRAME)
except Exception as e:
print(f"全页 OCR 失败: {e}")
full_image_df = pd.DataFrame()
results = ml_models["yolo"](img, verbose=False)
all_extractions = []
@@ -212,16 +384,11 @@ async def extract_invoice_data(file: UploadFile = File(...)):
return {"message": "未在图像中检测到支付凭证区域。"}
for res in results:
# 遍历所有检测到的凭证 (通常只有一个)
for box_coords in res.boxes.xyxy.cpu().numpy().astype(int):
xmin, ymin, xmax, ymax = box_coords
# 2. 裁剪图像
crop = img[ymin:ymax, xmin:xmax]
# 3. 在裁剪图上运行 "Plan A/B" 提取 (阶段二)
extracted_data = extract_info_from_crop(crop)
extracted_data["bounding_box"] = [xmin, ymin, xmax, ymax]
extracted_data = extract_info_from_crop(crop, full_image_df)
extracted_data["bounding_box"] = [int(xmin), int(ymin), int(xmax), int(ymax)]
all_extractions.append(extracted_data)
if not all_extractions:
@@ -233,5 +400,8 @@ if __name__ == "__main__":
import uvicorn
print(f"--- 启动 FastAPI 服务 ---")
print(f"加载模型: {MODEL_PATH}")
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
print(f"警告: POPPLER_PATH 未设置或无效。PDF 上传将失败。")
print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档")
uvicorn.run(app, host="127.0.0.1", port=8000)
uvicorn.run(app, host="127.0.0.1", port=8000)