# app/routers/documents.py import uuid import mimetypes import base64 from fastapi import APIRouter, UploadFile, File, HTTPException from typing import Dict, Any, List from fastapi.concurrency import run_in_threadpool from PIL import Image from io import BytesIO from .. import agents from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str from ..core.ocr import extract_text_from_image # 创建一个APIRouter实例 router = APIRouter( prefix="/documents", # 为这个路由下的所有路径添加前缀 tags=["Document Processing"], # 在API文档中为这组端点添加标签 ) # 模拟一个SQL数据库来存储最终结果 db_results: Dict[str, Any] = {} async def hybrid_process_pipeline(doc_id: str, image: Image.Image, page_num: int): """混合处理流水线""" ocr_text = await run_in_threadpool(extract_text_from_image, image) classification_result = await agents.agent_classify_document_from_text(ocr_text) category = classification_result.category language = classification_result.language print(f"Document page {page_num} classified as: {category}, Language: {language}") extraction_result = None if category in ["RECEIPT", "INVOICE"]: image_base64 = await run_in_threadpool(image_to_base64_str, image) if category == "RECEIPT": # 将语言传递给提取Agent extraction_result = await agents.agent_extract_receipt_info(image_base64, language) elif category == "INVOICE": # 将语言传递给提取Agent extraction_result = await agents.agent_extract_invoice_info(image_base64, language) else: print(f"Document classified as '{category}', skipping high-precision extraction.") final_result = { "doc_id": f"{doc_id}_page_{page_num}", "category": category, "language": language, "ocr_text_for_classification": ocr_text, "extraction_data": extraction_result.dict() if extraction_result else None, "status": "Processed" } db_results[final_result["doc_id"]] = final_result return final_result @router.post("/process", summary="上传并处理单个文档(混合模式)") async def upload_and_process_document(file: UploadFile = File(...)): """处理上传的文档文件 (PDF, PNG, JPG)""" if not file.filename: raise HTTPException(status_code=400, detail="No file provided.") doc_id = str(uuid.uuid4()) print(f"\n接收到新文件: {file.filename} (分配ID: {doc_id})") contents = await file.read() try: file_type = mimetypes.guess_type(file.filename)[0] print(f"检测到文件类型: {file_type}") images: List[Image.Image] = [] if file_type == 'application/pdf': images = await run_in_threadpool(convert_pdf_to_images, contents) elif file_type in ['image/png', 'image/jpeg']: images.append(Image.open(BytesIO(contents))) else: raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_type}.") if not images: raise HTTPException(status_code=400, detail="Could not extract images from document.") all_page_results = [] for i, img in enumerate(images): page_result = await hybrid_process_pipeline(doc_id, img, i + 1) all_page_results.append(page_result) return all_page_results except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") @router.get("/results/{doc_id}", summary="根据ID获取处理结果") async def get_result(doc_id: str): """根据文档处理后返回的 doc_id 获取其详细处理结果。""" if doc_id in db_results: return db_results[doc_id] raise HTTPException(status_code=404, detail="Document not found.")