Init
This commit is contained in:
@@ -10,64 +10,54 @@ from io import BytesIO
|
||||
|
||||
from .. import agents
|
||||
from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
|
||||
from ..core.ocr import extract_text_from_image
|
||||
|
||||
# 创建一个APIRouter实例
|
||||
# Create an APIRouter instance
|
||||
router = APIRouter(
|
||||
prefix="/documents", # 为这个路由下的所有路径添加前缀
|
||||
tags=["Document Processing"], # 在API文档中为这组端点添加标签
|
||||
prefix="/documents",
|
||||
tags=["Document Processing"],
|
||||
)
|
||||
|
||||
# 模拟一个SQL数据库来存储最终结果
|
||||
# Simulate an SQL database to store the final results
|
||||
db_results: Dict[str, Any] = {}
|
||||
|
||||
|
||||
async def hybrid_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
|
||||
"""混合处理流水线"""
|
||||
ocr_text = await run_in_threadpool(extract_text_from_image, image)
|
||||
|
||||
classification_result = await agents.agent_classify_document_from_text(ocr_text)
|
||||
async def multimodal_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
|
||||
image_base64 = await run_in_threadpool(image_to_base64_str, image)
|
||||
classification_result = await agents.agent_classify_document_from_image(image_base64)
|
||||
category = classification_result.category
|
||||
language = classification_result.language
|
||||
print(f"Document page {page_num} classified as: {category}, Language: {language}")
|
||||
|
||||
extraction_result = None
|
||||
if category in ["RECEIPT", "INVOICE"]:
|
||||
image_base64 = await run_in_threadpool(image_to_base64_str, image)
|
||||
if category == "RECEIPT":
|
||||
# 将语言传递给提取Agent
|
||||
extraction_result = await agents.agent_extract_receipt_info(image_base64, language)
|
||||
elif category == "INVOICE":
|
||||
# 将语言传递给提取Agent
|
||||
extraction_result = await agents.agent_extract_invoice_info(image_base64, language)
|
||||
else:
|
||||
print(f"Document classified as '{category}', skipping high-precision extraction.")
|
||||
print(f"Document classified as '{category}', skipping extraction.")
|
||||
|
||||
final_result = {
|
||||
"doc_id": f"{doc_id}_page_{page_num}",
|
||||
"category": category,
|
||||
"language": language,
|
||||
"ocr_text_for_classification": ocr_text,
|
||||
"extraction_data": extraction_result.dict() if extraction_result else None,
|
||||
"status": "Processed"
|
||||
}
|
||||
db_results[final_result["doc_id"]] = final_result
|
||||
return final_result
|
||||
|
||||
|
||||
@router.post("/process", summary="上传并处理单个文档(混合模式)")
|
||||
@router.post("/process", summary="upload and process a document")
|
||||
async def upload_and_process_document(file: UploadFile = File(...)):
|
||||
"""处理上传的文档文件 (PDF, PNG, JPG)"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided.")
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
print(f"\n接收到新文件: {file.filename} (分配ID: {doc_id})")
|
||||
print(f"\nReceiving document: {file.filename} (allocated ID: {doc_id})")
|
||||
contents = await file.read()
|
||||
|
||||
try:
|
||||
file_type = mimetypes.guess_type(file.filename)[0]
|
||||
print(f"检测到文件类型: {file_type}")
|
||||
print(f"File type: {file_type}")
|
||||
|
||||
images: List[Image.Image] = []
|
||||
if file_type == 'application/pdf':
|
||||
@@ -80,20 +70,39 @@ async def upload_and_process_document(file: UploadFile = File(...)):
|
||||
if not images:
|
||||
raise HTTPException(status_code=400, detail="Could not extract images from document.")
|
||||
|
||||
all_page_results = []
|
||||
for i, img in enumerate(images):
|
||||
page_result = await hybrid_process_pipeline(doc_id, img, i + 1)
|
||||
all_page_results.append(page_result)
|
||||
images_base64 = [await run_in_threadpool(image_to_base64_str, img) for img in images]
|
||||
|
||||
return all_page_results
|
||||
classification_result = await agents.agent_classify_document_from_image(images_base64)
|
||||
category = classification_result.category
|
||||
language = classification_result.language
|
||||
print(f"The document is classified as: {category}, Language: {language}")
|
||||
|
||||
extraction_result = None
|
||||
if category in ["RECEIPT", "INVOICE"]:
|
||||
if category == "RECEIPT":
|
||||
extraction_result = await agents.agent_extract_receipt_info(images_base64, language)
|
||||
elif category == "INVOICE":
|
||||
extraction_result = await agents.agent_extract_invoice_info(images_base64, language)
|
||||
else:
|
||||
print(f"The document is classified as '{category}',skipping extraction。")
|
||||
|
||||
# 3. Return a unified result
|
||||
final_result = {
|
||||
"doc_id": doc_id,
|
||||
"page_count": len(images),
|
||||
"category": category,
|
||||
"language": language,
|
||||
"extraction_data": extraction_result.dict() if extraction_result else None,
|
||||
"status": "Processed"
|
||||
}
|
||||
db_results[doc_id] = final_result
|
||||
return final_result
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/results/{doc_id}", summary="根据ID获取处理结果")
|
||||
@router.get("/results/{doc_id}", summary="Get result by doc_id")
|
||||
async def get_result(doc_id: str):
|
||||
"""根据文档处理后返回的 doc_id 获取其详细处理结果。"""
|
||||
if doc_id in db_results:
|
||||
return db_results[doc_id]
|
||||
raise HTTPException(status_code=404, detail="Document not found.")
|
||||
|
||||
Reference in New Issue
Block a user