100 lines
3.8 KiB
Python
100 lines
3.8 KiB
Python
# app/routers/documents.py
|
|
import uuid
|
|
import mimetypes
|
|
import base64
|
|
from fastapi import APIRouter, UploadFile, File, HTTPException
|
|
from typing import Dict, Any, List
|
|
from fastapi.concurrency import run_in_threadpool
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
|
|
from .. import agents
|
|
from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
|
|
from ..core.ocr import extract_text_from_image
|
|
|
|
# 创建一个APIRouter实例
|
|
router = APIRouter(
|
|
prefix="/documents", # 为这个路由下的所有路径添加前缀
|
|
tags=["Document Processing"], # 在API文档中为这组端点添加标签
|
|
)
|
|
|
|
# 模拟一个SQL数据库来存储最终结果
|
|
db_results: Dict[str, Any] = {}
|
|
|
|
|
|
async def hybrid_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
|
|
"""混合处理流水线"""
|
|
ocr_text = await run_in_threadpool(extract_text_from_image, image)
|
|
|
|
classification_result = await agents.agent_classify_document_from_text(ocr_text)
|
|
category = classification_result.category
|
|
language = classification_result.language
|
|
print(f"Document page {page_num} classified as: {category}, Language: {language}")
|
|
|
|
extraction_result = None
|
|
if category in ["RECEIPT", "INVOICE"]:
|
|
image_base64 = await run_in_threadpool(image_to_base64_str, image)
|
|
if category == "RECEIPT":
|
|
# 将语言传递给提取Agent
|
|
extraction_result = await agents.agent_extract_receipt_info(image_base64, language)
|
|
elif category == "INVOICE":
|
|
# 将语言传递给提取Agent
|
|
extraction_result = await agents.agent_extract_invoice_info(image_base64, language)
|
|
else:
|
|
print(f"Document classified as '{category}', skipping high-precision extraction.")
|
|
|
|
final_result = {
|
|
"doc_id": f"{doc_id}_page_{page_num}",
|
|
"category": category,
|
|
"language": language,
|
|
"ocr_text_for_classification": ocr_text,
|
|
"extraction_data": extraction_result.dict() if extraction_result else None,
|
|
"status": "Processed"
|
|
}
|
|
db_results[final_result["doc_id"]] = final_result
|
|
return final_result
|
|
|
|
|
|
@router.post("/process", summary="上传并处理单个文档(混合模式)")
|
|
async def upload_and_process_document(file: UploadFile = File(...)):
|
|
"""处理上传的文档文件 (PDF, PNG, JPG)"""
|
|
if not file.filename:
|
|
raise HTTPException(status_code=400, detail="No file provided.")
|
|
|
|
doc_id = str(uuid.uuid4())
|
|
print(f"\n接收到新文件: {file.filename} (分配ID: {doc_id})")
|
|
contents = await file.read()
|
|
|
|
try:
|
|
file_type = mimetypes.guess_type(file.filename)[0]
|
|
print(f"检测到文件类型: {file_type}")
|
|
|
|
images: List[Image.Image] = []
|
|
if file_type == 'application/pdf':
|
|
images = await run_in_threadpool(convert_pdf_to_images, contents)
|
|
elif file_type in ['image/png', 'image/jpeg']:
|
|
images.append(Image.open(BytesIO(contents)))
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_type}.")
|
|
|
|
if not images:
|
|
raise HTTPException(status_code=400, detail="Could not extract images from document.")
|
|
|
|
all_page_results = []
|
|
for i, img in enumerate(images):
|
|
page_result = await hybrid_process_pipeline(doc_id, img, i + 1)
|
|
all_page_results.append(page_result)
|
|
|
|
return all_page_results
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
|
|
|
|
|
@router.get("/results/{doc_id}", summary="根据ID获取处理结果")
|
|
async def get_result(doc_id: str):
|
|
"""根据文档处理后返回的 doc_id 获取其详细处理结果。"""
|
|
if doc_id in db_results:
|
|
return db_results[doc_id]
|
|
raise HTTPException(status_code=404, detail="Document not found.")
|