408 lines
16 KiB
Python
408 lines
16 KiB
Python
# --- main.py (已升级, v4.6 - 分离 InvoiceNr 和 OCRNr) ---
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException
|
|
from ultralytics import YOLO
|
|
import cv2
|
|
import numpy as np
|
|
import pytesseract
|
|
import pandas as pd
|
|
import re
|
|
import io
|
|
import os
|
|
import sys
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from pdf2image import convert_from_bytes
|
|
|
|
# --- 配置 (使用您的自定义配置) ---
|
|
# 1. 设置 BASE_DIR
|
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
# 2. 将项目根目录添加到 sys.path
|
|
sys.path.insert(0, BASE_DIR)
|
|
# 3. 从 config.py 导入路径
|
|
from config import POPPLER_PATH, apply_tesseract_path
|
|
# ---------------------------------
|
|
|
|
MODEL_PATH_STR = os.path.join(BASE_DIR, "models", "payment_slip_detector_v1", "weights", "best.pt")
|
|
MODEL_PATH = Path(MODEL_PATH_STR)
|
|
|
|
ml_models = {}
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# 启动时加载模型
|
|
print(f"正在加载模型: {MODEL_PATH}")
|
|
if not os.path.exists(MODEL_PATH):
|
|
print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。")
|
|
ml_models["yolo"] = None
|
|
else:
|
|
ml_models["yolo"] = YOLO(MODEL_PATH)
|
|
print("YOLOv8 区域检测模型加载成功。")
|
|
|
|
# 应用 Tesseract 路径 (来自 config)
|
|
try:
|
|
apply_tesseract_path()
|
|
print("已应用 Tesseract 路径配置。")
|
|
except Exception as e:
|
|
print(f"警告: 调用 apply_tesseract_path 时出错: {e}")
|
|
|
|
# 检查 Poppler 路径 (来自 config)
|
|
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
|
|
print(f"警告: POPPLER_PATH 未设置或无效: {POPPLER_PATH}。PDF 处理将失败。")
|
|
else:
|
|
print(f"Poppler 路径已加载: {POPPLER_PATH}")
|
|
|
|
yield
|
|
ml_models.clear()
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# --- 校验函数 (Luhn, Mod11, etc.) ---
|
|
def luhn_validate(number_str: str, expected_check_digit: str) -> bool:
|
|
"""
|
|
使用 Modulus 10 (Luhn) 算法验证一个数字字符串。
|
|
(从右到左, 权重 2, 1, 2, 1...)
|
|
"""
|
|
try:
|
|
digits = [int(d) for d in number_str]
|
|
weights = [2, 1] * (len(digits) // 2 + 1)
|
|
weights = weights[:len(digits)]
|
|
sum_val = 0
|
|
for d, w in zip(reversed(digits), weights):
|
|
product = d * w
|
|
if product >= 10:
|
|
sum_val += (product // 10) + (product % 10)
|
|
else:
|
|
sum_val += product
|
|
calculated_check_digit = (10 - (sum_val % 10)) % 10
|
|
return str(calculated_check_digit) == expected_check_digit
|
|
except Exception:
|
|
return False
|
|
|
|
def modulus_11_validate(number_str: str) -> bool:
|
|
try:
|
|
if not number_str.isdigit() or len(number_str) < 2: return False
|
|
base_num = number_str[:-1]
|
|
expected_check_digit_str = number_str[-1]
|
|
weights = [i for i in range(2, 2 + len(base_num))]
|
|
sum_val = 0
|
|
for digit_char, weight in zip(reversed(base_num), weights):
|
|
sum_val += int(digit_char) * weight
|
|
remainder = sum_val % 11
|
|
calculated_check_val = 11 - remainder
|
|
if calculated_check_val == 10: calculated_check_digit = "0"
|
|
elif calculated_check_val == 11: calculated_check_digit = "0"
|
|
else: calculated_check_digit = str(calculated_check_val)
|
|
return calculated_check_digit == expected_check_digit_str
|
|
except Exception:
|
|
return False
|
|
|
|
def luhn_validate_ocr(number_str: str) -> bool:
|
|
try:
|
|
if not number_str.isdigit() or len(number_str) < 2: return False
|
|
base_num = number_str[:-1]
|
|
check_digit = number_str[-1]
|
|
return luhn_validate(base_num, check_digit)
|
|
except Exception:
|
|
return False
|
|
|
|
def check_ocr_validity(ocr_number: str) -> bool:
|
|
if not ocr_number or not ocr_number.isdigit():
|
|
return False
|
|
if luhn_validate_ocr(ocr_number):
|
|
return True
|
|
if modulus_11_validate(ocr_number):
|
|
return True
|
|
return False
|
|
|
|
# --- Plan A: 机读码解析 (已修复) ---
|
|
def parse_ocr_rad(ocr_text: str) -> dict:
|
|
text_no_space = re.sub(r'\s+', '', ocr_text)
|
|
# 修正: 捕获组索引 (group(1) 是 OCR, group(2) 是金额, group(3) 是校验码, group(4) 是账号)
|
|
rad_regex = r'#([\d>]+)#(\d{2,})(\d{1})>([\d]+)#'
|
|
match = re.search(rad_regex, text_no_space)
|
|
|
|
if not match:
|
|
return None
|
|
|
|
try:
|
|
ocr_num = match.group(1).replace(">", "")
|
|
amount_base = match.group(2)
|
|
check_digit = match.group(3)
|
|
account = match.group(4)
|
|
|
|
amount_valid = luhn_validate(amount_base, check_digit)
|
|
|
|
if amount_valid:
|
|
amount_kronor = amount_base[:-2]
|
|
amount_ore = amount_base[-2:]
|
|
ocr_valid = check_ocr_validity(ocr_num)
|
|
|
|
print("Plan A (OCR-rad) 成功!")
|
|
|
|
return {
|
|
"source": "OCR-rad (High Confidence)",
|
|
"ocr_number": ocr_num, # 这是支付参考号
|
|
"ocr_number_valid": ocr_valid,
|
|
"amount_due": f"{amount_kronor}.{amount_ore}",
|
|
"amount_due_valid": True,
|
|
"bankgiro_plusgiro": account,
|
|
"account_type": None,
|
|
"due_date": None,
|
|
"invoice_number": None # Plan A 不知道发票号
|
|
}
|
|
else:
|
|
print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}")
|
|
return None
|
|
except Exception as e:
|
|
print(f"解析 OCR-rad 时出错: {e}")
|
|
return None
|
|
|
|
# --- 升级: Plan B (智能 KIE 逻辑) ---
|
|
def find_value_for_key(key_box, all_boxes, value_regex, max_dist_x_h=400, max_dist_y_h=20, max_dist_y_v=70, max_dist_x_v=150):
|
|
"""
|
|
一个辅助函数, 用于查找一个 "键" (key_box) 右侧或下方的 "值"
|
|
"""
|
|
key_y_center = key_box['top'] + key_box['height'] / 2
|
|
key_x_center = key_box['left'] + key_box['width'] / 2
|
|
key_x_end = key_box['left'] + key_box['width']
|
|
key_y_end = key_box['top'] + key_box['height']
|
|
|
|
potential_values = []
|
|
|
|
for _, value_box in all_boxes.iterrows():
|
|
if value_box.name == key_box.name:
|
|
continue
|
|
|
|
value_y_center = value_box['top'] + value_box['height'] / 2
|
|
value_x_center = value_box['left'] + value_box['width'] / 2
|
|
value_x_start = value_box['left']
|
|
value_y_start = value_box['top']
|
|
value_text = str(value_box['text']).strip()
|
|
|
|
if not re.match(value_regex, value_text):
|
|
continue
|
|
|
|
# --- 检查 A: 水平 (Horizontal) ---
|
|
if value_x_start > key_x_end:
|
|
if abs(value_y_center - key_y_center) < max_dist_y_h:
|
|
dist_x = value_x_start - key_x_end
|
|
if dist_x < max_dist_x_h:
|
|
potential_values.append((dist_x, value_text))
|
|
|
|
# --- 检查 B: 垂直 (Vertical) ---
|
|
if value_y_start > key_y_end:
|
|
if abs(value_x_center - key_x_center) < max_dist_x_v:
|
|
dist_y = value_y_start - key_y_end
|
|
if dist_y < max_dist_y_v:
|
|
potential_values.append((dist_y + 1000, value_text)) # H 优先
|
|
|
|
if potential_values:
|
|
potential_values.sort(key=lambda x: x[0])
|
|
return potential_values[0][1]
|
|
|
|
return None
|
|
|
|
def parse_human_readable_KIE(ocr_df: pd.DataFrame) -> dict:
|
|
"""
|
|
Plan B (v4.6): 使用 KIE 逻辑回退 (分离 InvoiceNr 和 OCRNr)
|
|
"""
|
|
data = {
|
|
"source": "Human-Readable KIE (Fallback)",
|
|
"bankgiro_plusgiro": None,
|
|
"account_type": None,
|
|
"invoice_number": None, # <-- 新增
|
|
"ocr_number": None,
|
|
"ocr_number_valid": False,
|
|
"amount_due": None,
|
|
"amount_due_valid": False,
|
|
"due_date": None
|
|
}
|
|
|
|
ocr_df = ocr_df[ocr_df['conf'] > 30].copy()
|
|
ocr_df['text'] = ocr_df['text'].astype(str).str.strip()
|
|
ocr_df = ocr_df.dropna(subset=['text'])
|
|
ocr_df = ocr_df[ocr_df['text'] != ""]
|
|
|
|
if ocr_df.empty:
|
|
return data
|
|
|
|
# 定义所有 Key-Value 查找规则
|
|
# Regex 格式: ( [关键词列表], 值Regex, 目标字段名 )
|
|
search_rules = [
|
|
(r'PlusGiro|PG', r'(\d{2,7}[- ]\d{1,2}[- ]\d{1,2})', "bankgiro_plusgiro", "PG"),
|
|
(r'Bankgiro|BG|bankgironr', r'(\d{2,4}[- ]\d{4})', "bankgiro_plusgiro", "BG"),
|
|
(r'Att betala:?|Belopp:?', r'([\d\s,.]+\d{2})', "amount_due", None),
|
|
(r'senast|Förfallodag', r'\d{4}[- ]\d{2}[- ]\d{2}', "due_date", None),
|
|
# --- 逻辑分离 ---
|
|
(r'Fakturanr|Fakturanummer', r'([\w\d][\w\d-]{3,}[\w\d])', "invoice_number", None),
|
|
(r'OCR|Referens|Betalningsreferens', r'([\w\d][\w\d-]{3,}[\w\d])', "ocr_number", None),
|
|
]
|
|
|
|
for _, box in ocr_df.iterrows():
|
|
text = box['text']
|
|
|
|
for key_regex, value_regex, field_name, extra_value in search_rules:
|
|
# 如果这个字段已经被填过了, 就跳过 (例如 BG 找到了, 就不要再找 PG)
|
|
if data.get(field_name) is not None:
|
|
continue
|
|
|
|
if re.search(key_regex, text, re.IGNORECASE):
|
|
value = find_value_for_key(box, ocr_df, value_regex, max_dist_y_v=70)
|
|
|
|
if value:
|
|
# 特殊清理
|
|
if field_name == "bankgiro_plusgiro":
|
|
data[field_name] = re.sub(r'\s', '', value)
|
|
data["account_type"] = extra_value # 设置 "BG" 或 "PG"
|
|
elif field_name == "amount_due":
|
|
amount_str = value.replace(" ", "").replace(",", ".")
|
|
if amount_str.count('.') > 1:
|
|
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
|
|
data[field_name] = amount_str
|
|
elif field_name == "due_date":
|
|
data[field_name] = value.replace(" ", "-")
|
|
elif field_name in ["invoice_number", "ocr_number"]:
|
|
if value.lower() == "kronor" or value.lower() == "öre":
|
|
continue # 跳过错误的匹配
|
|
data[field_name] = re.sub(r'\s', '', value)
|
|
if field_name == "ocr_number":
|
|
data["ocr_number_valid"] = check_ocr_validity(data[field_name])
|
|
|
|
return data
|
|
|
|
# --- Plan C: 全页回退 (用于发票号码) ---
|
|
def find_invoice_number_full_page(ocr_df: pd.DataFrame) -> dict:
|
|
"""
|
|
Plan C: 如果 invoice_number 仍为 null, 则在整个页面上搜索。
|
|
"""
|
|
ocr_df = ocr_df[ocr_df['conf'] > 30].copy()
|
|
ocr_df['text'] = ocr_df['text'].astype(str).str.strip()
|
|
ocr_df = ocr_df.dropna(subset=['text'])
|
|
ocr_df = ocr_df[ocr_df['text'] != ""]
|
|
|
|
if ocr_df.empty:
|
|
return {}
|
|
|
|
inv_regex = r'([\w\d][\w\d-]{3,}[\w\d])'
|
|
for _, box in ocr_df.iterrows():
|
|
# 只查找 Fakturanr
|
|
if re.search(r'Fakturanr|Fakturanummer', box['text'], re.IGNORECASE):
|
|
value = find_value_for_key(box, ocr_df, inv_regex, max_dist_x_h=200, max_dist_y_h=10, max_dist_y_v=70, max_dist_x_v=150)
|
|
if value:
|
|
inv_num_str = re.sub(r'\s', '', value)
|
|
return {"invoice_number": inv_num_str}
|
|
return {}
|
|
|
|
def extract_info_from_crop(crop_image: np.ndarray, full_image_df: pd.DataFrame) -> dict:
|
|
"""
|
|
主提取函数: 执行 Plan A, B, C
|
|
"""
|
|
try:
|
|
gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY)
|
|
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
|
|
|
# 1. --- Plan B (KIE) ---
|
|
# 总是先运行 Plan B, 获取基础信息
|
|
ocr_df_crop = pytesseract.image_to_data(thresh, lang='swe', config='--psm 11', output_type=pytesseract.Output.DATAFRAME)
|
|
final_data = parse_human_readable_KIE(ocr_df_crop)
|
|
|
|
# 2. --- 尝试 Plan A (OCR-rad) ---
|
|
ocr_text_rad = pytesseract.image_to_string(thresh, lang='swe', config='--psm 7') # psm 7: 假设为单行文本
|
|
rad_data = parse_ocr_rad(ocr_text_rad)
|
|
|
|
if rad_data:
|
|
# Plan A 成功! (高置信度)
|
|
# **覆盖** Plan B 的支付字段, 但 *保留* Plan B 的 invoice_number
|
|
final_data["source"] = rad_data["source"]
|
|
final_data["ocr_number"] = rad_data["ocr_number"]
|
|
final_data["ocr_number_valid"] = rad_data["ocr_number_valid"]
|
|
final_data["amount_due"] = rad_data["amount_due"]
|
|
final_data["amount_due_valid"] = rad_data["amount_due_valid"]
|
|
final_data["bankgiro_plusgiro"] = rad_data["bankgiro_plusgiro"]
|
|
# (如果 Plan A 能确定 BG/PG, 也可以覆盖 account_type)
|
|
|
|
# 3. --- Plan C (全页回退发票号码) ---
|
|
if final_data["invoice_number"] is None:
|
|
print("Plan B 未能从裁剪区找到 InvoiceNr, 启动 Plan C (全页搜索)...")
|
|
full_page_inv_result = find_invoice_number_full_page(full_image_df)
|
|
|
|
if full_page_inv_result:
|
|
final_data["invoice_number"] = full_page_inv_result.get("invoice_number")
|
|
final_data["source"] += " + Full-Page-InvNr"
|
|
|
|
return final_data
|
|
|
|
except Exception as e:
|
|
print(f"提取时出错: {e}")
|
|
return {"error": f"提取时出错: {e}"}
|
|
|
|
|
|
@app.post("/extract_invoice/")
|
|
async def extract_invoice_data(file: UploadFile = File(...)):
|
|
|
|
if ml_models.get("yolo") is None:
|
|
raise HTTPException(status_code=503, detail="模型未加载。请检查模型路径。")
|
|
|
|
try:
|
|
contents = await file.read()
|
|
content_type = file.content_type
|
|
img = None
|
|
|
|
if content_type == "application/pdf":
|
|
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
|
|
raise HTTPException(status_code=500, detail="PDF 处理未在服务器上配置 (POPPLER_PATH)。")
|
|
try:
|
|
images = convert_from_bytes(contents, poppler_path=POPPLER_PATH)
|
|
if images:
|
|
img = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"PDF 处理失败: {e}")
|
|
elif content_type in ["image/jpeg", "image/png", "image/bmp"]:
|
|
nparr = np.frombuffer(contents, np.uint8)
|
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
else:
|
|
raise HTTPException(status_code=415, detail=f"不支持的文件类型: {content_type}。")
|
|
|
|
if img is None:
|
|
raise HTTPException(status_code=400, detail="无法解码文件。")
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}")
|
|
|
|
try:
|
|
full_image_df = pytesseract.image_to_data(img, lang='swe', config='--psm 3', output_type=pytesseract.Output.DATAFRAME)
|
|
except Exception as e:
|
|
print(f"全页 OCR 失败: {e}")
|
|
full_image_df = pd.DataFrame()
|
|
|
|
results = ml_models["yolo"](img, verbose=False)
|
|
|
|
all_extractions = []
|
|
|
|
if not results or not results[0].boxes:
|
|
return {"message": "未在图像中检测到支付凭证区域。"}
|
|
|
|
for res in results:
|
|
for box_coords in res.boxes.xyxy.cpu().numpy().astype(int):
|
|
xmin, ymin, xmax, ymax = box_coords
|
|
crop = img[ymin:ymax, xmin:xmax]
|
|
extracted_data = extract_info_from_crop(crop, full_image_df)
|
|
extracted_data["bounding_box"] = [int(xmin), int(ymin), int(xmax), int(ymax)]
|
|
all_extractions.append(extracted_data)
|
|
|
|
if not all_extractions:
|
|
return {"message": "检测到支付凭证, 但未能提取任何信息。"}
|
|
|
|
return {"invoice_extractions": all_extractions}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
print(f"--- 启动 FastAPI 服务 ---")
|
|
print(f"加载模型: {MODEL_PATH}")
|
|
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
|
|
print(f"警告: POPPLER_PATH 未设置或无效。PDF 上传将失败。")
|
|
print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档")
|
|
uvicorn.run(app, host="127.0.0.1", port=8000)
|
|
|