Files
AmazingDoc/app/routers/documents.py
Yaojia Wang 0a80400720 Init project
2025-08-11 00:07:41 +02:00

100 lines
3.8 KiB
Python

# app/routers/documents.py
import uuid
import mimetypes
import base64
from fastapi import APIRouter, UploadFile, File, HTTPException
from typing import Dict, Any, List
from fastapi.concurrency import run_in_threadpool
from PIL import Image
from io import BytesIO
from .. import agents
from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
from ..core.ocr import extract_text_from_image
# 创建一个APIRouter实例
router = APIRouter(
prefix="/documents", # 为这个路由下的所有路径添加前缀
tags=["Document Processing"], # 在API文档中为这组端点添加标签
)
# 模拟一个SQL数据库来存储最终结果
db_results: Dict[str, Any] = {}
async def hybrid_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
"""混合处理流水线"""
ocr_text = await run_in_threadpool(extract_text_from_image, image)
classification_result = await agents.agent_classify_document_from_text(ocr_text)
category = classification_result.category
language = classification_result.language
print(f"Document page {page_num} classified as: {category}, Language: {language}")
extraction_result = None
if category in ["RECEIPT", "INVOICE"]:
image_base64 = await run_in_threadpool(image_to_base64_str, image)
if category == "RECEIPT":
# 将语言传递给提取Agent
extraction_result = await agents.agent_extract_receipt_info(image_base64, language)
elif category == "INVOICE":
# 将语言传递给提取Agent
extraction_result = await agents.agent_extract_invoice_info(image_base64, language)
else:
print(f"Document classified as '{category}', skipping high-precision extraction.")
final_result = {
"doc_id": f"{doc_id}_page_{page_num}",
"category": category,
"language": language,
"ocr_text_for_classification": ocr_text,
"extraction_data": extraction_result.dict() if extraction_result else None,
"status": "Processed"
}
db_results[final_result["doc_id"]] = final_result
return final_result
@router.post("/process", summary="上传并处理单个文档(混合模式)")
async def upload_and_process_document(file: UploadFile = File(...)):
"""处理上传的文档文件 (PDF, PNG, JPG)"""
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided.")
doc_id = str(uuid.uuid4())
print(f"\n接收到新文件: {file.filename} (分配ID: {doc_id})")
contents = await file.read()
try:
file_type = mimetypes.guess_type(file.filename)[0]
print(f"检测到文件类型: {file_type}")
images: List[Image.Image] = []
if file_type == 'application/pdf':
images = await run_in_threadpool(convert_pdf_to_images, contents)
elif file_type in ['image/png', 'image/jpeg']:
images.append(Image.open(BytesIO(contents)))
else:
raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_type}.")
if not images:
raise HTTPException(status_code=400, detail="Could not extract images from document.")
all_page_results = []
for i, img in enumerate(images):
page_result = await hybrid_process_pipeline(doc_id, img, i + 1)
all_page_results.append(page_result)
return all_page_results
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
@router.get("/results/{doc_id}", summary="根据ID获取处理结果")
async def get_result(doc_id: str):
"""根据文档处理后返回的 doc_id 获取其详细处理结果。"""
if doc_id in db_results:
return db_results[doc_id]
raise HTTPException(status_code=404, detail="Document not found.")