Files
invoice-master/scripts/main.py
Yaojia Wang b3f361847a Init
2025-10-26 22:14:21 +01:00

408 lines
16 KiB
Python

# --- main.py (已升级, v4.6 - 分离 InvoiceNr 和 OCRNr) ---
from fastapi import FastAPI, UploadFile, File, HTTPException
from ultralytics import YOLO
import cv2
import numpy as np
import pytesseract
import pandas as pd
import re
import io
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from pdf2image import convert_from_bytes
# --- 配置 (使用您的自定义配置) ---
# 1. 设置 BASE_DIR
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 2. 将项目根目录添加到 sys.path
sys.path.insert(0, BASE_DIR)
# 3. 从 config.py 导入路径
from config import POPPLER_PATH, apply_tesseract_path
# ---------------------------------
MODEL_PATH_STR = os.path.join(BASE_DIR, "models", "payment_slip_detector_v1", "weights", "best.pt")
MODEL_PATH = Path(MODEL_PATH_STR)
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时加载模型
print(f"正在加载模型: {MODEL_PATH}")
if not os.path.exists(MODEL_PATH):
print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。")
ml_models["yolo"] = None
else:
ml_models["yolo"] = YOLO(MODEL_PATH)
print("YOLOv8 区域检测模型加载成功。")
# 应用 Tesseract 路径 (来自 config)
try:
apply_tesseract_path()
print("已应用 Tesseract 路径配置。")
except Exception as e:
print(f"警告: 调用 apply_tesseract_path 时出错: {e}")
# 检查 Poppler 路径 (来自 config)
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
print(f"警告: POPPLER_PATH 未设置或无效: {POPPLER_PATH}。PDF 处理将失败。")
else:
print(f"Poppler 路径已加载: {POPPLER_PATH}")
yield
ml_models.clear()
app = FastAPI(lifespan=lifespan)
# --- 校验函数 (Luhn, Mod11, etc.) ---
def luhn_validate(number_str: str, expected_check_digit: str) -> bool:
"""
使用 Modulus 10 (Luhn) 算法验证一个数字字符串。
(从右到左, 权重 2, 1, 2, 1...)
"""
try:
digits = [int(d) for d in number_str]
weights = [2, 1] * (len(digits) // 2 + 1)
weights = weights[:len(digits)]
sum_val = 0
for d, w in zip(reversed(digits), weights):
product = d * w
if product >= 10:
sum_val += (product // 10) + (product % 10)
else:
sum_val += product
calculated_check_digit = (10 - (sum_val % 10)) % 10
return str(calculated_check_digit) == expected_check_digit
except Exception:
return False
def modulus_11_validate(number_str: str) -> bool:
try:
if not number_str.isdigit() or len(number_str) < 2: return False
base_num = number_str[:-1]
expected_check_digit_str = number_str[-1]
weights = [i for i in range(2, 2 + len(base_num))]
sum_val = 0
for digit_char, weight in zip(reversed(base_num), weights):
sum_val += int(digit_char) * weight
remainder = sum_val % 11
calculated_check_val = 11 - remainder
if calculated_check_val == 10: calculated_check_digit = "0"
elif calculated_check_val == 11: calculated_check_digit = "0"
else: calculated_check_digit = str(calculated_check_val)
return calculated_check_digit == expected_check_digit_str
except Exception:
return False
def luhn_validate_ocr(number_str: str) -> bool:
try:
if not number_str.isdigit() or len(number_str) < 2: return False
base_num = number_str[:-1]
check_digit = number_str[-1]
return luhn_validate(base_num, check_digit)
except Exception:
return False
def check_ocr_validity(ocr_number: str) -> bool:
if not ocr_number or not ocr_number.isdigit():
return False
if luhn_validate_ocr(ocr_number):
return True
if modulus_11_validate(ocr_number):
return True
return False
# --- Plan A: 机读码解析 (已修复) ---
def parse_ocr_rad(ocr_text: str) -> dict:
text_no_space = re.sub(r'\s+', '', ocr_text)
# 修正: 捕获组索引 (group(1) 是 OCR, group(2) 是金额, group(3) 是校验码, group(4) 是账号)
rad_regex = r'#([\d>]+)#(\d{2,})(\d{1})>([\d]+)#'
match = re.search(rad_regex, text_no_space)
if not match:
return None
try:
ocr_num = match.group(1).replace(">", "")
amount_base = match.group(2)
check_digit = match.group(3)
account = match.group(4)
amount_valid = luhn_validate(amount_base, check_digit)
if amount_valid:
amount_kronor = amount_base[:-2]
amount_ore = amount_base[-2:]
ocr_valid = check_ocr_validity(ocr_num)
print("Plan A (OCR-rad) 成功!")
return {
"source": "OCR-rad (High Confidence)",
"ocr_number": ocr_num, # 这是支付参考号
"ocr_number_valid": ocr_valid,
"amount_due": f"{amount_kronor}.{amount_ore}",
"amount_due_valid": True,
"bankgiro_plusgiro": account,
"account_type": None,
"due_date": None,
"invoice_number": None # Plan A 不知道发票号
}
else:
print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}")
return None
except Exception as e:
print(f"解析 OCR-rad 时出错: {e}")
return None
# --- 升级: Plan B (智能 KIE 逻辑) ---
def find_value_for_key(key_box, all_boxes, value_regex, max_dist_x_h=400, max_dist_y_h=20, max_dist_y_v=70, max_dist_x_v=150):
"""
一个辅助函数, 用于查找一个 "" (key_box) 右侧或下方的 ""
"""
key_y_center = key_box['top'] + key_box['height'] / 2
key_x_center = key_box['left'] + key_box['width'] / 2
key_x_end = key_box['left'] + key_box['width']
key_y_end = key_box['top'] + key_box['height']
potential_values = []
for _, value_box in all_boxes.iterrows():
if value_box.name == key_box.name:
continue
value_y_center = value_box['top'] + value_box['height'] / 2
value_x_center = value_box['left'] + value_box['width'] / 2
value_x_start = value_box['left']
value_y_start = value_box['top']
value_text = str(value_box['text']).strip()
if not re.match(value_regex, value_text):
continue
# --- 检查 A: 水平 (Horizontal) ---
if value_x_start > key_x_end:
if abs(value_y_center - key_y_center) < max_dist_y_h:
dist_x = value_x_start - key_x_end
if dist_x < max_dist_x_h:
potential_values.append((dist_x, value_text))
# --- 检查 B: 垂直 (Vertical) ---
if value_y_start > key_y_end:
if abs(value_x_center - key_x_center) < max_dist_x_v:
dist_y = value_y_start - key_y_end
if dist_y < max_dist_y_v:
potential_values.append((dist_y + 1000, value_text)) # H 优先
if potential_values:
potential_values.sort(key=lambda x: x[0])
return potential_values[0][1]
return None
def parse_human_readable_KIE(ocr_df: pd.DataFrame) -> dict:
"""
Plan B (v4.6): 使用 KIE 逻辑回退 (分离 InvoiceNr 和 OCRNr)
"""
data = {
"source": "Human-Readable KIE (Fallback)",
"bankgiro_plusgiro": None,
"account_type": None,
"invoice_number": None, # <-- 新增
"ocr_number": None,
"ocr_number_valid": False,
"amount_due": None,
"amount_due_valid": False,
"due_date": None
}
ocr_df = ocr_df[ocr_df['conf'] > 30].copy()
ocr_df['text'] = ocr_df['text'].astype(str).str.strip()
ocr_df = ocr_df.dropna(subset=['text'])
ocr_df = ocr_df[ocr_df['text'] != ""]
if ocr_df.empty:
return data
# 定义所有 Key-Value 查找规则
# Regex 格式: ( [关键词列表], 值Regex, 目标字段名 )
search_rules = [
(r'PlusGiro|PG', r'(\d{2,7}[- ]\d{1,2}[- ]\d{1,2})', "bankgiro_plusgiro", "PG"),
(r'Bankgiro|BG|bankgironr', r'(\d{2,4}[- ]\d{4})', "bankgiro_plusgiro", "BG"),
(r'Att betala:?|Belopp:?', r'([\d\s,.]+\d{2})', "amount_due", None),
(r'senast|Förfallodag', r'\d{4}[- ]\d{2}[- ]\d{2}', "due_date", None),
# --- 逻辑分离 ---
(r'Fakturanr|Fakturanummer', r'([\w\d][\w\d-]{3,}[\w\d])', "invoice_number", None),
(r'OCR|Referens|Betalningsreferens', r'([\w\d][\w\d-]{3,}[\w\d])', "ocr_number", None),
]
for _, box in ocr_df.iterrows():
text = box['text']
for key_regex, value_regex, field_name, extra_value in search_rules:
# 如果这个字段已经被填过了, 就跳过 (例如 BG 找到了, 就不要再找 PG)
if data.get(field_name) is not None:
continue
if re.search(key_regex, text, re.IGNORECASE):
value = find_value_for_key(box, ocr_df, value_regex, max_dist_y_v=70)
if value:
# 特殊清理
if field_name == "bankgiro_plusgiro":
data[field_name] = re.sub(r'\s', '', value)
data["account_type"] = extra_value # 设置 "BG" 或 "PG"
elif field_name == "amount_due":
amount_str = value.replace(" ", "").replace(",", ".")
if amount_str.count('.') > 1:
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
data[field_name] = amount_str
elif field_name == "due_date":
data[field_name] = value.replace(" ", "-")
elif field_name in ["invoice_number", "ocr_number"]:
if value.lower() == "kronor" or value.lower() == "öre":
continue # 跳过错误的匹配
data[field_name] = re.sub(r'\s', '', value)
if field_name == "ocr_number":
data["ocr_number_valid"] = check_ocr_validity(data[field_name])
return data
# --- Plan C: 全页回退 (用于发票号码) ---
def find_invoice_number_full_page(ocr_df: pd.DataFrame) -> dict:
"""
Plan C: 如果 invoice_number 仍为 null, 则在整个页面上搜索。
"""
ocr_df = ocr_df[ocr_df['conf'] > 30].copy()
ocr_df['text'] = ocr_df['text'].astype(str).str.strip()
ocr_df = ocr_df.dropna(subset=['text'])
ocr_df = ocr_df[ocr_df['text'] != ""]
if ocr_df.empty:
return {}
inv_regex = r'([\w\d][\w\d-]{3,}[\w\d])'
for _, box in ocr_df.iterrows():
# 只查找 Fakturanr
if re.search(r'Fakturanr|Fakturanummer', box['text'], re.IGNORECASE):
value = find_value_for_key(box, ocr_df, inv_regex, max_dist_x_h=200, max_dist_y_h=10, max_dist_y_v=70, max_dist_x_v=150)
if value:
inv_num_str = re.sub(r'\s', '', value)
return {"invoice_number": inv_num_str}
return {}
def extract_info_from_crop(crop_image: np.ndarray, full_image_df: pd.DataFrame) -> dict:
"""
主提取函数: 执行 Plan A, B, C
"""
try:
gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
# 1. --- Plan B (KIE) ---
# 总是先运行 Plan B, 获取基础信息
ocr_df_crop = pytesseract.image_to_data(thresh, lang='swe', config='--psm 11', output_type=pytesseract.Output.DATAFRAME)
final_data = parse_human_readable_KIE(ocr_df_crop)
# 2. --- 尝试 Plan A (OCR-rad) ---
ocr_text_rad = pytesseract.image_to_string(thresh, lang='swe', config='--psm 7') # psm 7: 假设为单行文本
rad_data = parse_ocr_rad(ocr_text_rad)
if rad_data:
# Plan A 成功! (高置信度)
# **覆盖** Plan B 的支付字段, 但 *保留* Plan B 的 invoice_number
final_data["source"] = rad_data["source"]
final_data["ocr_number"] = rad_data["ocr_number"]
final_data["ocr_number_valid"] = rad_data["ocr_number_valid"]
final_data["amount_due"] = rad_data["amount_due"]
final_data["amount_due_valid"] = rad_data["amount_due_valid"]
final_data["bankgiro_plusgiro"] = rad_data["bankgiro_plusgiro"]
# (如果 Plan A 能确定 BG/PG, 也可以覆盖 account_type)
# 3. --- Plan C (全页回退发票号码) ---
if final_data["invoice_number"] is None:
print("Plan B 未能从裁剪区找到 InvoiceNr, 启动 Plan C (全页搜索)...")
full_page_inv_result = find_invoice_number_full_page(full_image_df)
if full_page_inv_result:
final_data["invoice_number"] = full_page_inv_result.get("invoice_number")
final_data["source"] += " + Full-Page-InvNr"
return final_data
except Exception as e:
print(f"提取时出错: {e}")
return {"error": f"提取时出错: {e}"}
@app.post("/extract_invoice/")
async def extract_invoice_data(file: UploadFile = File(...)):
if ml_models.get("yolo") is None:
raise HTTPException(status_code=503, detail="模型未加载。请检查模型路径。")
try:
contents = await file.read()
content_type = file.content_type
img = None
if content_type == "application/pdf":
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
raise HTTPException(status_code=500, detail="PDF 处理未在服务器上配置 (POPPLER_PATH)。")
try:
images = convert_from_bytes(contents, poppler_path=POPPLER_PATH)
if images:
img = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR)
except Exception as e:
raise HTTPException(status_code=500, detail=f"PDF 处理失败: {e}")
elif content_type in ["image/jpeg", "image/png", "image/bmp"]:
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
else:
raise HTTPException(status_code=415, detail=f"不支持的文件类型: {content_type}")
if img is None:
raise HTTPException(status_code=400, detail="无法解码文件。")
except Exception as e:
raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}")
try:
full_image_df = pytesseract.image_to_data(img, lang='swe', config='--psm 3', output_type=pytesseract.Output.DATAFRAME)
except Exception as e:
print(f"全页 OCR 失败: {e}")
full_image_df = pd.DataFrame()
results = ml_models["yolo"](img, verbose=False)
all_extractions = []
if not results or not results[0].boxes:
return {"message": "未在图像中检测到支付凭证区域。"}
for res in results:
for box_coords in res.boxes.xyxy.cpu().numpy().astype(int):
xmin, ymin, xmax, ymax = box_coords
crop = img[ymin:ymax, xmin:xmax]
extracted_data = extract_info_from_crop(crop, full_image_df)
extracted_data["bounding_box"] = [int(xmin), int(ymin), int(xmax), int(ymax)]
all_extractions.append(extracted_data)
if not all_extractions:
return {"message": "检测到支付凭证, 但未能提取任何信息。"}
return {"invoice_extractions": all_extractions}
if __name__ == "__main__":
import uvicorn
print(f"--- 启动 FastAPI 服务 ---")
print(f"加载模型: {MODEL_PATH}")
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
print(f"警告: POPPLER_PATH 未设置或无效。PDF 上传将失败。")
print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档")
uvicorn.run(app, host="127.0.0.1", port=8000)