Init
This commit is contained in:
384
scripts/main.py
384
scripts/main.py
@@ -1,189 +1,340 @@
|
||||
# --- main.py (已升级) ---
|
||||
# --- main.py (已升级, v4.6 - 分离 InvoiceNr 和 OCRNr) ---
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytesseract
|
||||
import pandas as pd
|
||||
import re
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from pdf2image import convert_from_bytes
|
||||
|
||||
# --- 配置 ---
|
||||
# TODO: 确保此路径指向您训练好的最佳模型
|
||||
MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"models", "invoice_detector_v1", "weights", "best.pt"
|
||||
)
|
||||
# --- 配置 (使用您的自定义配置) ---
|
||||
# 1. 设置 BASE_DIR
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
# 2. 将项目根目录添加到 sys.path
|
||||
sys.path.insert(0, BASE_DIR)
|
||||
# 3. 从 config.py 导入路径
|
||||
from config import POPPLER_PATH, apply_tesseract_path
|
||||
# ---------------------------------
|
||||
|
||||
MODEL_PATH_STR = os.path.join(BASE_DIR, "models", "payment_slip_detector_v1", "weights", "best.pt")
|
||||
MODEL_PATH = Path(MODEL_PATH_STR)
|
||||
|
||||
# 定义一个字典来在 FastAPI 启动时加载模型
|
||||
ml_models = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 启动时加载模型
|
||||
print(MODEL_PATH)
|
||||
print(f"正在加载模型: {MODEL_PATH}")
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
print(f"警告: 找不到模型 {MODEL_PATH}。API 将无法工作。")
|
||||
ml_models["yolo"] = None
|
||||
else:
|
||||
# 加载您的 "payment_slip" 区域检测器
|
||||
ml_models["yolo"] = YOLO(MODEL_PATH)
|
||||
print("YOLOv8 区域检测模型加载成功。")
|
||||
|
||||
# 应用 Tesseract 路径 (来自 config)
|
||||
try:
|
||||
apply_tesseract_path()
|
||||
print("已应用 Tesseract 路径配置。")
|
||||
except Exception as e:
|
||||
print(f"警告: 调用 apply_tesseract_path 时出错: {e}")
|
||||
|
||||
# 检查 Poppler 路径 (来自 config)
|
||||
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
|
||||
print(f"警告: POPPLER_PATH 未设置或无效: {POPPLER_PATH}。PDF 处理将失败。")
|
||||
else:
|
||||
print(f"Poppler 路径已加载: {POPPLER_PATH}")
|
||||
|
||||
yield
|
||||
# 清理模型
|
||||
ml_models.clear()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# --- Luhn (Modulus 10) 校验函数 ---
|
||||
# --- 校验函数 (Luhn, Mod11, etc.) ---
|
||||
def luhn_validate(number_str: str, expected_check_digit: str) -> bool:
|
||||
"""
|
||||
使用 Modulus 10 (Luhn) 算法验证一个数字字符串。
|
||||
(从右到左, 权重 1, 2, 1, 2...)
|
||||
(从右到左, 权重 2, 1, 2, 1...)
|
||||
"""
|
||||
try:
|
||||
digits = [int(d) for d in number_str]
|
||||
weights = [1, 2] * (len(digits) // 2 + 1)
|
||||
weights = weights[:len(digits)] # 确保权重列表长度一致
|
||||
|
||||
weights = [2, 1] * (len(digits) // 2 + 1)
|
||||
weights = weights[:len(digits)]
|
||||
sum_val = 0
|
||||
# 从右到左计算
|
||||
for d, w in zip(reversed(digits), weights):
|
||||
product = d * w
|
||||
if product >= 10:
|
||||
sum_val += (product // 10) + (product % 10)
|
||||
else:
|
||||
sum_val += product
|
||||
|
||||
calculated_check_digit = (10 - (sum_val % 10)) % 10
|
||||
return str(calculated_check_digit) == expected_check_digit
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# --- 提取逻辑 ---
|
||||
def modulus_11_validate(number_str: str) -> bool:
|
||||
try:
|
||||
if not number_str.isdigit() or len(number_str) < 2: return False
|
||||
base_num = number_str[:-1]
|
||||
expected_check_digit_str = number_str[-1]
|
||||
weights = [i for i in range(2, 2 + len(base_num))]
|
||||
sum_val = 0
|
||||
for digit_char, weight in zip(reversed(base_num), weights):
|
||||
sum_val += int(digit_char) * weight
|
||||
remainder = sum_val % 11
|
||||
calculated_check_val = 11 - remainder
|
||||
if calculated_check_val == 10: calculated_check_digit = "0"
|
||||
elif calculated_check_val == 11: calculated_check_digit = "0"
|
||||
else: calculated_check_digit = str(calculated_check_val)
|
||||
return calculated_check_digit == expected_check_digit_str
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def luhn_validate_ocr(number_str: str) -> bool:
|
||||
try:
|
||||
if not number_str.isdigit() or len(number_str) < 2: return False
|
||||
base_num = number_str[:-1]
|
||||
check_digit = number_str[-1]
|
||||
return luhn_validate(base_num, check_digit)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def check_ocr_validity(ocr_number: str) -> bool:
|
||||
if not ocr_number or not ocr_number.isdigit():
|
||||
return False
|
||||
if luhn_validate_ocr(ocr_number):
|
||||
return True
|
||||
if modulus_11_validate(ocr_number):
|
||||
return True
|
||||
return False
|
||||
|
||||
# --- Plan A: 机读码解析 (已修复) ---
|
||||
def parse_ocr_rad(ocr_text: str) -> dict:
|
||||
"""
|
||||
Plan A: 尝试解析机器可读码 (OCR-rad)
|
||||
示例: # 400299582421 # 4603 00 7 > 48180020 #14#
|
||||
"""
|
||||
# 移除所有空格以简化匹配
|
||||
text_no_space = re.sub(r'\s+', '', ocr_text)
|
||||
|
||||
# 定义一个更健壮的正则表达式
|
||||
# 组1: OCR号 (在 #...# 之间)
|
||||
# 组2: 金额 (Kronor + Öre) (在 #... 之后)
|
||||
# 组3: 校验码 (1位数字)
|
||||
# 组4: 账户号 (在 >...# 之间)
|
||||
rad_regex = r'#([\d>]+)#(\d{2,})(\d)(\d{1})>([\d]+)#'
|
||||
|
||||
# 修正: 捕获组索引 (group(1) 是 OCR, group(2) 是金额, group(3) 是校验码, group(4) 是账号)
|
||||
rad_regex = r'#([\d>]+)#(\d{2,})(\d{1})>([\d]+)#'
|
||||
match = re.search(rad_regex, text_no_space)
|
||||
|
||||
if not match:
|
||||
return None # 未找到机读码行
|
||||
return None
|
||||
|
||||
try:
|
||||
ocr_num = match.group(1).replace(">", "") # 移除 > (如果有)
|
||||
amount_base = match.group(2) + match.group(3) # "4603" + "00" = "460300"
|
||||
check_digit = match.group(4) # "7"
|
||||
account = match.group(5) # "48180020"
|
||||
ocr_num = match.group(1).replace(">", "")
|
||||
amount_base = match.group(2)
|
||||
check_digit = match.group(3)
|
||||
account = match.group(4)
|
||||
|
||||
# 运行Luhn校验
|
||||
if luhn_validate(amount_base, check_digit):
|
||||
# 校验成功! 这是高置信度数据
|
||||
amount_valid = luhn_validate(amount_base, check_digit)
|
||||
|
||||
if amount_valid:
|
||||
amount_kronor = amount_base[:-2]
|
||||
amount_ore = amount_base[-2:]
|
||||
ocr_valid = check_ocr_validity(ocr_num)
|
||||
|
||||
print("Plan A (OCR-rad) 成功!")
|
||||
|
||||
return {
|
||||
"source": "OCR-rad (High Confidence)",
|
||||
"ocr_number": ocr_num,
|
||||
"ocr_number": ocr_num, # 这是支付参考号
|
||||
"ocr_number_valid": ocr_valid,
|
||||
"amount_due": f"{amount_kronor}.{amount_ore}",
|
||||
"amount_due_valid": True,
|
||||
"bankgiro_plusgiro": account,
|
||||
"due_date": None # 机读码行通常不包含日期
|
||||
"account_type": None,
|
||||
"due_date": None,
|
||||
"invoice_number": None # Plan A 不知道发票号
|
||||
}
|
||||
else:
|
||||
# 校验失败!
|
||||
print(f"Luhn 校验失败: 基础={amount_base}, 期望={check_digit}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析 OCR-rad 时出错: {e}")
|
||||
return None
|
||||
|
||||
def parse_human_readable(ocr_text: str) -> dict:
|
||||
# --- 升级: Plan B (智能 KIE 逻辑) ---
|
||||
def find_value_for_key(key_box, all_boxes, value_regex, max_dist_x_h=400, max_dist_y_h=20, max_dist_y_v=70, max_dist_x_v=150):
|
||||
"""
|
||||
Plan B: 回退到人工可读区域
|
||||
(这里我们使用之前版本中的简单 Regex,
|
||||
您也可以替换为您那个更复杂的 classify_text 逻辑)
|
||||
一个辅助函数, 用于查找一个 "键" (key_box) 右侧或下方的 "值"
|
||||
"""
|
||||
data = {"source": "Human-Readable (Fallback)"}
|
||||
key_y_center = key_box['top'] + key_box['height'] / 2
|
||||
key_x_center = key_box['left'] + key_box['width'] / 2
|
||||
key_x_end = key_box['left'] + key_box['width']
|
||||
key_y_end = key_box['top'] + key_box['height']
|
||||
|
||||
# 查找 BG/PG (Bankgiro/Plusgiro)
|
||||
bg_match = re.search(r'Bankgiro\D*(\d{2,4}[- ]\d{4})', ocr_text, re.IGNORECASE)
|
||||
pg_match = re.search(r'PlusGiro\D*(\d{2,7}[- ]\d)', ocr_text, re.IGNORECASE)
|
||||
if bg_match:
|
||||
data["bankgiro_plusgiro"] = bg_match.group(1).replace(" ", "")
|
||||
elif pg_match:
|
||||
data["bankgiro_plusgiro"] = pg_match.group(1).replace(" ", "")
|
||||
else:
|
||||
# 备用查找
|
||||
bg_pg_alt = re.search(r'(\b\d{2,4}[- ]\d{4}\b)|(\b\d{2,7}[- ]\d\b)', ocr_text)
|
||||
if bg_pg_alt:
|
||||
data["bankgiro_plusgiro"] = bg_pg_alt.group(0).replace(" ", "")
|
||||
potential_values = []
|
||||
|
||||
for _, value_box in all_boxes.iterrows():
|
||||
if value_box.name == key_box.name:
|
||||
continue
|
||||
|
||||
value_y_center = value_box['top'] + value_box['height'] / 2
|
||||
value_x_center = value_box['left'] + value_box['width'] / 2
|
||||
value_x_start = value_box['left']
|
||||
value_y_start = value_box['top']
|
||||
value_text = str(value_box['text']).strip()
|
||||
|
||||
# 查找 OCR
|
||||
ocr_match = re.search(r'(OCR|Fakturanummer|Referens)\D*(\d[\d\s]{5,}\d)', ocr_text, re.IGNORECASE)
|
||||
if ocr_match:
|
||||
data["ocr_number"] = re.sub(r'\s', '', ocr_match.group(2))
|
||||
if not re.match(value_regex, value_text):
|
||||
continue
|
||||
|
||||
# 查找金额
|
||||
amount_match = re.search(r'(Att betala|Belopp)\D*([\d\s,.]+)\s*(kr|SEK)?', ocr_text, re.IGNORECASE)
|
||||
if amount_match:
|
||||
amount_str = amount_match.group(2).strip().replace(" ", "").replace(",", ".")
|
||||
if amount_str.count('.') > 1:
|
||||
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
|
||||
data["amount_due"] = amount_str
|
||||
# --- 检查 A: 水平 (Horizontal) ---
|
||||
if value_x_start > key_x_end:
|
||||
if abs(value_y_center - key_y_center) < max_dist_y_h:
|
||||
dist_x = value_x_start - key_x_end
|
||||
if dist_x < max_dist_x_h:
|
||||
potential_values.append((dist_x, value_text))
|
||||
|
||||
# --- 检查 B: 垂直 (Vertical) ---
|
||||
if value_y_start > key_y_end:
|
||||
if abs(value_x_center - key_x_center) < max_dist_x_v:
|
||||
dist_y = value_y_start - key_y_end
|
||||
if dist_y < max_dist_y_v:
|
||||
potential_values.append((dist_y + 1000, value_text)) # H 优先
|
||||
|
||||
if potential_values:
|
||||
potential_values.sort(key=lambda x: x[0])
|
||||
return potential_values[0][1]
|
||||
|
||||
# 查找截止日期
|
||||
date_match = re.search(r'(senast|Förfallodag)\D*(\d{4}[- ]\d{2}[- ]\d{2})', ocr_text, re.IGNORECASE)
|
||||
if date_match:
|
||||
data["due_date"] = date_match.group(2).replace(" ", "-")
|
||||
return None
|
||||
|
||||
def parse_human_readable_KIE(ocr_df: pd.DataFrame) -> dict:
|
||||
"""
|
||||
Plan B (v4.6): 使用 KIE 逻辑回退 (分离 InvoiceNr 和 OCRNr)
|
||||
"""
|
||||
data = {
|
||||
"source": "Human-Readable KIE (Fallback)",
|
||||
"bankgiro_plusgiro": None,
|
||||
"account_type": None,
|
||||
"invoice_number": None, # <-- 新增
|
||||
"ocr_number": None,
|
||||
"ocr_number_valid": False,
|
||||
"amount_due": None,
|
||||
"amount_due_valid": False,
|
||||
"due_date": None
|
||||
}
|
||||
|
||||
ocr_df = ocr_df[ocr_df['conf'] > 30].copy()
|
||||
ocr_df['text'] = ocr_df['text'].astype(str).str.strip()
|
||||
ocr_df = ocr_df.dropna(subset=['text'])
|
||||
ocr_df = ocr_df[ocr_df['text'] != ""]
|
||||
|
||||
if ocr_df.empty:
|
||||
return data
|
||||
|
||||
# 定义所有 Key-Value 查找规则
|
||||
# Regex 格式: ( [关键词列表], 值Regex, 目标字段名 )
|
||||
search_rules = [
|
||||
(r'PlusGiro|PG', r'(\d{2,7}[- ]\d{1,2}[- ]\d{1,2})', "bankgiro_plusgiro", "PG"),
|
||||
(r'Bankgiro|BG|bankgironr', r'(\d{2,4}[- ]\d{4})', "bankgiro_plusgiro", "BG"),
|
||||
(r'Att betala:?|Belopp:?', r'([\d\s,.]+\d{2})', "amount_due", None),
|
||||
(r'senast|Förfallodag', r'\d{4}[- ]\d{2}[- ]\d{2}', "due_date", None),
|
||||
# --- 逻辑分离 ---
|
||||
(r'Fakturanr|Fakturanummer', r'([\w\d][\w\d-]{3,}[\w\d])', "invoice_number", None),
|
||||
(r'OCR|Referens|Betalningsreferens', r'([\w\d][\w\d-]{3,}[\w\d])', "ocr_number", None),
|
||||
]
|
||||
|
||||
for _, box in ocr_df.iterrows():
|
||||
text = box['text']
|
||||
|
||||
for key_regex, value_regex, field_name, extra_value in search_rules:
|
||||
# 如果这个字段已经被填过了, 就跳过 (例如 BG 找到了, 就不要再找 PG)
|
||||
if data.get(field_name) is not None:
|
||||
continue
|
||||
|
||||
if re.search(key_regex, text, re.IGNORECASE):
|
||||
value = find_value_for_key(box, ocr_df, value_regex, max_dist_y_v=70)
|
||||
|
||||
if value:
|
||||
# 特殊清理
|
||||
if field_name == "bankgiro_plusgiro":
|
||||
data[field_name] = re.sub(r'\s', '', value)
|
||||
data["account_type"] = extra_value # 设置 "BG" 或 "PG"
|
||||
elif field_name == "amount_due":
|
||||
amount_str = value.replace(" ", "").replace(",", ".")
|
||||
if amount_str.count('.') > 1:
|
||||
amount_str = amount_str.replace(".", "", amount_str.count('.') - 1)
|
||||
data[field_name] = amount_str
|
||||
elif field_name == "due_date":
|
||||
data[field_name] = value.replace(" ", "-")
|
||||
elif field_name in ["invoice_number", "ocr_number"]:
|
||||
if value.lower() == "kronor" or value.lower() == "öre":
|
||||
continue # 跳过错误的匹配
|
||||
data[field_name] = re.sub(r'\s', '', value)
|
||||
if field_name == "ocr_number":
|
||||
data["ocr_number_valid"] = check_ocr_validity(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def extract_info_from_crop(crop_image: np.ndarray) -> dict:
|
||||
# --- Plan C: 全页回退 (用于发票号码) ---
|
||||
def find_invoice_number_full_page(ocr_df: pd.DataFrame) -> dict:
|
||||
"""
|
||||
主提取函数: 执行 Plan A 和 Plan B
|
||||
Plan C: 如果 invoice_number 仍为 null, 则在整个页面上搜索。
|
||||
"""
|
||||
ocr_df = ocr_df[ocr_df['conf'] > 30].copy()
|
||||
ocr_df['text'] = ocr_df['text'].astype(str).str.strip()
|
||||
ocr_df = ocr_df.dropna(subset=['text'])
|
||||
ocr_df = ocr_df[ocr_df['text'] != ""]
|
||||
|
||||
if ocr_df.empty:
|
||||
return {}
|
||||
|
||||
inv_regex = r'([\w\d][\w\d-]{3,}[\w\d])'
|
||||
for _, box in ocr_df.iterrows():
|
||||
# 只查找 Fakturanr
|
||||
if re.search(r'Fakturanr|Fakturanummer', box['text'], re.IGNORECASE):
|
||||
value = find_value_for_key(box, ocr_df, inv_regex, max_dist_x_h=200, max_dist_y_h=10, max_dist_y_v=70, max_dist_x_v=150)
|
||||
if value:
|
||||
inv_num_str = re.sub(r'\s', '', value)
|
||||
return {"invoice_number": inv_num_str}
|
||||
return {}
|
||||
|
||||
def extract_info_from_crop(crop_image: np.ndarray, full_image_df: pd.DataFrame) -> dict:
|
||||
"""
|
||||
主提取函数: 执行 Plan A, B, C
|
||||
"""
|
||||
try:
|
||||
# 1. 预处理并运行 OCR (获取所有文本)
|
||||
gray = cv2.cvtColor(crop_image, cv2.COLOR_BGR2GRAY)
|
||||
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
||||
|
||||
# 使用 psm 6 假设是一个统一的文本块, 这对 OCR-rad 很友好
|
||||
ocr_text = pytesseract.image_to_string(thresh, lang='swe', config='--psm 6')
|
||||
# 1. --- Plan B (KIE) ---
|
||||
# 总是先运行 Plan B, 获取基础信息
|
||||
ocr_df_crop = pytesseract.image_to_data(thresh, lang='swe', config='--psm 11', output_type=pytesseract.Output.DATAFRAME)
|
||||
final_data = parse_human_readable_KIE(ocr_df_crop)
|
||||
|
||||
# 2. --- Plan A: 尝试解析机读码 ---
|
||||
rad_data = parse_ocr_rad(ocr_text)
|
||||
# 2. --- 尝试 Plan A (OCR-rad) ---
|
||||
ocr_text_rad = pytesseract.image_to_string(thresh, lang='swe', config='--psm 7') # psm 7: 假设为单行文本
|
||||
rad_data = parse_ocr_rad(ocr_text_rad)
|
||||
|
||||
if rad_data:
|
||||
# Plan A 成功!
|
||||
return rad_data
|
||||
# Plan A 成功! (高置信度)
|
||||
# **覆盖** Plan B 的支付字段, 但 *保留* Plan B 的 invoice_number
|
||||
final_data["source"] = rad_data["source"]
|
||||
final_data["ocr_number"] = rad_data["ocr_number"]
|
||||
final_data["ocr_number_valid"] = rad_data["ocr_number_valid"]
|
||||
final_data["amount_due"] = rad_data["amount_due"]
|
||||
final_data["amount_due_valid"] = rad_data["amount_due_valid"]
|
||||
final_data["bankgiro_plusgiro"] = rad_data["bankgiro_plusgiro"]
|
||||
# (如果 Plan A 能确定 BG/PG, 也可以覆盖 account_type)
|
||||
|
||||
# 3. --- Plan C (全页回退发票号码) ---
|
||||
if final_data["invoice_number"] is None:
|
||||
print("Plan B 未能从裁剪区找到 InvoiceNr, 启动 Plan C (全页搜索)...")
|
||||
full_page_inv_result = find_invoice_number_full_page(full_image_df)
|
||||
|
||||
if full_page_inv_result:
|
||||
final_data["invoice_number"] = full_page_inv_result.get("invoice_number")
|
||||
final_data["source"] += " + Full-Page-InvNr"
|
||||
|
||||
# 3. --- Plan B: Plan A 失败, 回退到人工读取 ---
|
||||
# 我们重新运行 OCR, 使用 psm 3 (自动布局), 这对人工区域更友好
|
||||
ocr_text_human = pytesseract.image_to_string(thresh, lang='swe', config='--psm 3')
|
||||
|
||||
human_data = parse_human_readable(ocr_text_human)
|
||||
|
||||
# 即使回退失败, 也返回空字典 (或部分数据)
|
||||
return human_data
|
||||
return final_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"提取时出错: {e}")
|
||||
return {"error": f"提取时出错: {e}"}
|
||||
|
||||
|
||||
@@ -195,15 +346,36 @@ async def extract_invoice_data(file: UploadFile = File(...)):
|
||||
|
||||
try:
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
content_type = file.content_type
|
||||
img = None
|
||||
|
||||
if content_type == "application/pdf":
|
||||
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
|
||||
raise HTTPException(status_code=500, detail="PDF 处理未在服务器上配置 (POPPLER_PATH)。")
|
||||
try:
|
||||
images = convert_from_bytes(contents, poppler_path=POPPLER_PATH)
|
||||
if images:
|
||||
img = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"PDF 处理失败: {e}")
|
||||
elif content_type in ["image/jpeg", "image/png", "image/bmp"]:
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
else:
|
||||
raise HTTPException(status_code=415, detail=f"不支持的文件类型: {content_type}。")
|
||||
|
||||
if img is None:
|
||||
raise HTTPException(status_code=400, detail="无法解码图像文件。")
|
||||
raise HTTPException(status_code=400, detail="无法解码文件。")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"读取文件时出错: {e}")
|
||||
|
||||
# 1. 运行 YOLO 检测 (阶段一)
|
||||
try:
|
||||
full_image_df = pytesseract.image_to_data(img, lang='swe', config='--psm 3', output_type=pytesseract.Output.DATAFRAME)
|
||||
except Exception as e:
|
||||
print(f"全页 OCR 失败: {e}")
|
||||
full_image_df = pd.DataFrame()
|
||||
|
||||
results = ml_models["yolo"](img, verbose=False)
|
||||
|
||||
all_extractions = []
|
||||
@@ -212,16 +384,11 @@ async def extract_invoice_data(file: UploadFile = File(...)):
|
||||
return {"message": "未在图像中检测到支付凭证区域。"}
|
||||
|
||||
for res in results:
|
||||
# 遍历所有检测到的凭证 (通常只有一个)
|
||||
for box_coords in res.boxes.xyxy.cpu().numpy().astype(int):
|
||||
xmin, ymin, xmax, ymax = box_coords
|
||||
|
||||
# 2. 裁剪图像
|
||||
crop = img[ymin:ymax, xmin:xmax]
|
||||
|
||||
# 3. 在裁剪图上运行 "Plan A/B" 提取 (阶段二)
|
||||
extracted_data = extract_info_from_crop(crop)
|
||||
extracted_data["bounding_box"] = [xmin, ymin, xmax, ymax]
|
||||
extracted_data = extract_info_from_crop(crop, full_image_df)
|
||||
extracted_data["bounding_box"] = [int(xmin), int(ymin), int(xmax), int(ymax)]
|
||||
all_extractions.append(extracted_data)
|
||||
|
||||
if not all_extractions:
|
||||
@@ -233,5 +400,8 @@ if __name__ == "__main__":
|
||||
import uvicorn
|
||||
print(f"--- 启动 FastAPI 服务 ---")
|
||||
print(f"加载模型: {MODEL_PATH}")
|
||||
if not POPPLER_PATH or not os.path.exists(POPPLER_PATH):
|
||||
print(f"警告: POPPLER_PATH 未设置或无效。PDF 上传将失败。")
|
||||
print(f"访问 http://127.0.0.1:8000/docs 查看 API 文档")
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user