109 lines
4.2 KiB
Python
109 lines
4.2 KiB
Python
# app/routers/documents.py
|
||
import uuid
|
||
import mimetypes
|
||
import base64
|
||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||
from typing import Dict, Any, List
|
||
from fastapi.concurrency import run_in_threadpool
|
||
from PIL import Image
|
||
from io import BytesIO
|
||
|
||
from .. import agents
|
||
from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
|
||
|
||
# Create an APIRouter instance
|
||
router = APIRouter(
|
||
prefix="/documents",
|
||
tags=["Document Processing"],
|
||
)
|
||
|
||
# Simulate an SQL database to store the final results
|
||
db_results: Dict[str, Any] = {}
|
||
|
||
async def multimodal_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
|
||
image_base64 = await run_in_threadpool(image_to_base64_str, image)
|
||
classification_result = await agents.agent_classify_document_from_image(image_base64)
|
||
category = classification_result.category
|
||
language = classification_result.language
|
||
print(f"Document page {page_num} classified as: {category}, Language: {language}")
|
||
|
||
extraction_result = None
|
||
if category in ["RECEIPT", "INVOICE"]:
|
||
if category == "RECEIPT":
|
||
extraction_result = await agents.agent_extract_receipt_info(image_base64, language)
|
||
elif category == "INVOICE":
|
||
extraction_result = await agents.agent_extract_invoice_info(image_base64, language)
|
||
else:
|
||
print(f"Document classified as '{category}', skipping extraction.")
|
||
|
||
final_result = {
|
||
"doc_id": f"{doc_id}_page_{page_num}",
|
||
"category": category,
|
||
"language": language,
|
||
"extraction_data": extraction_result.dict() if extraction_result else None,
|
||
"status": "Processed"
|
||
}
|
||
db_results[final_result["doc_id"]] = final_result
|
||
return final_result
|
||
|
||
@router.post("/process", summary="upload and process a document")
|
||
async def upload_and_process_document(file: UploadFile = File(...)):
|
||
if not file.filename:
|
||
raise HTTPException(status_code=400, detail="No file provided.")
|
||
|
||
doc_id = str(uuid.uuid4())
|
||
print(f"\nReceiving document: {file.filename} (allocated ID: {doc_id})")
|
||
contents = await file.read()
|
||
|
||
try:
|
||
file_type = mimetypes.guess_type(file.filename)[0]
|
||
print(f"File type: {file_type}")
|
||
|
||
images: List[Image.Image] = []
|
||
if file_type == 'application/pdf':
|
||
images = await run_in_threadpool(convert_pdf_to_images, contents)
|
||
elif file_type in ['image/png', 'image/jpeg']:
|
||
images.append(Image.open(BytesIO(contents)))
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_type}.")
|
||
|
||
if not images:
|
||
raise HTTPException(status_code=400, detail="Could not extract images from document.")
|
||
|
||
images_base64 = [await run_in_threadpool(image_to_base64_str, img) for img in images]
|
||
|
||
classification_result = await agents.agent_classify_document_from_image(images_base64)
|
||
category = classification_result.category
|
||
language = classification_result.language
|
||
print(f"The document is classified as: {category}, Language: {language}")
|
||
|
||
extraction_result = None
|
||
if category in ["RECEIPT", "INVOICE"]:
|
||
if category == "RECEIPT":
|
||
extraction_result = await agents.agent_extract_receipt_info(images_base64, language)
|
||
elif category == "INVOICE":
|
||
extraction_result = await agents.agent_extract_invoice_info(images_base64, language)
|
||
else:
|
||
print(f"The document is classified as '{category}',skipping extraction。")
|
||
|
||
# 3. Return a unified result
|
||
final_result = {
|
||
"doc_id": doc_id,
|
||
"page_count": len(images),
|
||
"category": category,
|
||
"language": language,
|
||
"extraction_data": extraction_result.dict() if extraction_result else None,
|
||
"status": "Processed"
|
||
}
|
||
db_results[doc_id] = final_result
|
||
return final_result
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
||
|
||
@router.get("/results/{doc_id}", summary="Get result by doc_id")
|
||
async def get_result(doc_id: str):
|
||
if doc_id in db_results:
|
||
return db_results[doc_id]
|
||
raise HTTPException(status_code=404, detail="Document not found.")
|