This commit is contained in:
2025-08-11 14:20:56 +02:00
parent 0a80400720
commit f077c6351d
17 changed files with 165 additions and 248 deletions

2
.idea/AmazingDoc.iml generated
View File

@@ -5,7 +5,7 @@
</component>
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="jdk" jdkName="AmazingDoc" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">

2
.idea/misc.xml generated
View File

@@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.13" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="AmazingDoc" project-jdk-type="Python SDK" />
</project>

View File

@@ -1,51 +0,0 @@
# app/agents.py
import asyncio
import random
from .schemas import ReceiptInfo, InvoiceInfo, ReceiptItem
# --- Agent核心功能 (占位符/模拟实现) ---
# 在实际应用中这些函数将被替换为调用LangChain和LLM的真实逻辑。
async def agent_classify_document(text: str) -> str:
"""Agent 1: 文件分类 (模拟)"""
print("--- [Agent 1] 正在进行文档分类...")
await asyncio.sleep(0.5) # 模拟网络延迟
doc_types = ["信件", "收据", "发票", "合约"]
if "发票" in text: return "发票"
if "收据" in text or "小票" in text: return "收据"
if "合同" in text or "协议" in text: return "合约"
return random.choice(doc_types)
async def agent_extract_receipt_info(text: str) -> ReceiptInfo:
"""Agent 2: 收据信息提取 (模拟)"""
print("--- [Agent 2] 正在提取收据信息...")
await asyncio.sleep(1) # 模拟LLM处理时间
return ReceiptInfo(
merchant_name="模拟超市",
transaction_date="2025-08-10",
total_amount=198.50,
items=[ReceiptItem(name="牛奶", quantity=2, price=11.5)]
)
async def agent_extract_invoice_info(text: str) -> InvoiceInfo:
"""Agent 3: 发票信息提取 (模拟)"""
print("--- [Agent 3] 正在提取发票信息...")
await asyncio.sleep(1) # 模拟LLM处理时间
return InvoiceInfo(
invoice_number="INV123456789",
issue_date="2025-08-09",
seller_name="模拟科技有限公司",
total_amount_in_figures=12000.00
)
async def agent_vectorize_and_store(doc_id: str, text: str, category: str, vector_db: dict):
"""Agent 4: 向量化并存储 (模拟)"""
print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...")
await asyncio.sleep(0.5)
chunks = [text[i:i+200] for i in range(0, len(text), 200)]
vector_db[doc_id] = {
"metadata": {"category": category, "chunk_count": len(chunks)},
"content_chunks": chunks,
"vectors": [random.random() for _ in range(len(chunks) * 128)]
}
print(f"--- [Agent 4] 文档 {doc_id} 已存入向量数据库。")

View File

@@ -1,7 +1,4 @@
# app/agents/__init__.py
# This file makes it easy to import all agents from the 'agents' package.
# We are importing the function with its correct name now.
from .classification_agent import agent_classify_document_from_text
from .classification_agent import agent_classify_document_from_image
from .receipt_agent import agent_extract_receipt_info
from .invoice_agent import agent_extract_invoice_info

View File

@@ -1,42 +1,48 @@
# app/agents/classification_agent.py
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import ClassificationResult # 导入新的Schema
from ..schemas import ClassificationResult
from typing import List
# 1. 设置PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=ClassificationResult)
# 2. 更新Prompt模板以要求语言并包含格式指令
classification_template = """
You are a professional document analysis assistant. Please perform two tasks on the following text:
1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
You are a professional document analysis assistant. The following images represent pages from a single document. Please perform two tasks based on all pages provided:
1. Determine the overall category of the document. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
2. Detect the primary language of the document. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese).
Please provide a single response for the entire document in the requested JSON format.
{format_instructions}
[Document Text Start]
{document_text}
[Document Text End]
"""
classification_prompt = PromptTemplate(
template=classification_template,
input_variables=["document_text"],
input_variables=[],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
# 3. 创建新的LangChain链
classification_chain = classification_prompt | llm | parser
async def agent_classify_document_from_image(images_base64: List[str]) -> ClassificationResult:
"""Agent 1: Classifies an entire document (multiple pages) and detects its language from a list of images."""
print(f"--- [Agent 1] Calling multimodal LLM for classification of a {len(images_base64)}-page document...")
async def agent_classify_document_from_text(text: str) -> ClassificationResult:
"""Agent 1: Classify document and detect language from OCR-extracted text."""
print("--- [Agent 1] Calling LLM for classification and language detection...")
if not text.strip():
print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.")
return ClassificationResult(category="OTHER", language="unknown")
prompt_text = await classification_prompt.aformat()
# 调用链并返回Pydantic对象
result = await classification_chain.ainvoke({"document_text": text})
# Create a list of content parts, starting with the text prompt
content_parts = [{"type": "text", "text": prompt_text}]
# Add each image to the content list
for image_base64 in images_base64:
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
})
msg = HumanMessage(content=content_parts)
chain = llm | parser
result = await chain.ainvoke([msg])
return result

View File

@@ -4,10 +4,10 @@ from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import InvoiceInfo
from typing import List
parser = PydanticOutputParser(pydantic_object=InvoiceInfo)
# The prompt now includes the detailed rules for each field using snake_case.
invoice_template = """
You are an expert data entry clerk AI. Your primary goal is to extract information from an invoice image with the highest possible accuracy.
The document's primary language is '{language}'.
@@ -30,6 +30,9 @@ Carefully analyze the invoice image and extract the following fields according t
- `customer_address_region`: This is the receiver's region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.
- `customer_address_care_of`: This is the receiver's 'care of' (c/o) line. If not found or unclear, leave as an empty string.
- `billo_id`: To find this, think step-by-step: 1. Find the customer_address. 2. Scan the address for a pattern of three letters, an optional space, three digits, an optional dash, and one alphanumeric character (e.g., 'ABC 123-X' or 'DEF 456Z'). 3. If found, extract it. If not found or unclear, leave as an empty string.
- `bank_giro`: If found, extract the bank giro number. It often follows patterns like 'ddd-dddd', 'dddd-dddd', or 'dddddddd #41#'. If not found or unclear, leave as an empty string.
- `plus_giro`: If found, extract the plus giro number. It often follows patterns like 'ddddddd-d #16#', 'ddddddd-d', or 'ddd dd dd-d'. If not found or unclear, leave as an empty string.
- `customer_ssn`: If found, extract the customer social security number (personnummer). It follows the pattern 'YYYYMMDD-XXXX' or 'YYMMDD-XXXX'. If not found or unclear, leave as an empty string.
- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null.
## Example:
@@ -41,33 +44,31 @@ If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $
"unit_price": 100.00,
"total_price": 200.00
}}
```
Your Task:
Now, analyze the provided image and output the full JSON object according to the format below.
{format_instructions}
"""
invoice_prompt = PromptTemplate(
template=invoice_template,
input_variables=["language"],
partial_variables={"format_instructions": parser.get_format_instructions()},
partial_variables={"format_instructions": parser.get_format_instructions()}
)
async def agent_extract_invoice_info(image_base64: str, language: str) -> InvoiceInfo:
"""Agent 3: Extracts invoice information from an image, aware of the document's language."""
async def agent_extract_invoice_info(images_base64: List[str], language: str) -> InvoiceInfo:
"""Agent 3: Extracts invoice information from a list of images, aware of the document's language."""
print(f"--- [Agent 3] Calling multimodal LLM to extract invoice info (Language: {language})...")
prompt_text = await invoice_prompt.aformat(language=language)
msg = HumanMessage(
content=[
{"type": "text", "text": prompt_text},
{
"type": "image_url",
"image_url": f"data:image/png;base64,{image_base64}",
},
]
)
content_parts = [{"type": "text", "text": prompt_text}]
for image_base64 in images_base64:
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
})
msg = HumanMessage(content=content_parts)
chain = llm | parser
invoice_info = await chain.ainvoke([msg])

View File

@@ -4,15 +4,15 @@ from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import ReceiptInfo
from typing import List
parser = PydanticOutputParser(pydantic_object=ReceiptInfo)
# 更新Prompt模板以包含语言信息
receipt_template = """
You are a highly accurate receipt information extraction robot.
The document's primary language is '{language}'.
Please extract all key information from the following receipt image.
If some information is not present in the image, leave it as null.
Please extract all key information from the following receipt images, which belong to a single document.
If some information is not present in the images, leave it as null.
Please strictly follow the JSON format below, without adding any extra explanations or comments.
{format_instructions}
@@ -25,22 +25,21 @@ receipt_prompt = PromptTemplate(
)
async def agent_extract_receipt_info(image_base64: str, language: str) -> ReceiptInfo:
"""Agent 2: Extracts receipt information from an image, aware of the document's language."""
async def agent_extract_receipt_info(images_base64: List[str], language: str) -> ReceiptInfo:
"""Agent 2: Extracts receipt information from a list of images, aware of the document's language."""
print(f"--- [Agent 2] Calling multimodal LLM to extract receipt info (Language: {language})...")
prompt_text = await receipt_prompt.aformat(language=language)
msg = HumanMessage(
content=[
{"type": "text", "text": prompt_text},
{
"type": "image_url",
"image_url": f"data:image/png;base64,{image_base64}",
},
]
)
content_parts = [{"type": "text", "text": prompt_text}]
for image_base64 in images_base64:
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
})
msg = HumanMessage(content=content_parts)
chain = llm | parser
receipt_info = await chain.ainvoke([msg])
return receipt_info
return receipt_info

View File

@@ -2,32 +2,32 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from ..core.vector_store import vector_store, embedding_model
# 初始化文本分割器,用于将长文档切成小块
# Initialize the text splitter to divide long documents into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, # 每个块的大小(字符数)
chunk_overlap=50, # 块之间的重叠部分
chunk_size=500,
chunk_overlap=50,
)
def agent_vectorize_and_store(doc_id: str, text: str, category: str):
"""Agent 4: 向量化并存储 (真实实现)"""
print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...")
"""Agent 4: Vectorization and Storage (Real Implementation)"""
print(f"--- [Agent 4] Vectorizing document (ID: {doc_id})...")
# 1. 将文档文本分割成块
# 1. Split the document text into chunks
chunks = text_splitter.split_text(text)
print(f"--- [Agent 4] 文档被切分为 {len(chunks)} 个块。")
print(f"--- [Agent 4] Document split into {len(chunks)} chunks.")
if not chunks:
print(f"--- [Agent 4] 文档内容为空,跳过向量化。")
print(f"--- [Agent 4] Document is empty, skipping vectorization.")
return
# 2. 为每个块创建唯一的ID和元数据
# 2. Create a unique ID and metadata for each chunk
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
metadatas = [{"doc_id": doc_id, "category": category, "chunk_number": i} for i in range(len(chunks))]
# 3. 使用嵌入模型为所有块生成向量
# 3. Use an embedding model to generate vectors for all chunks
embeddings = embedding_model.embed_documents(chunks)
# 4. 将ID、向量、元数据和文本块本身添加到ChromaDB
# 4. Add the IDs, vectors, metadata, and text chunks to ChromaDB
vector_store.add(
ids=chunk_ids,
embeddings=embeddings,
@@ -35,4 +35,4 @@ def agent_vectorize_and_store(doc_id: str, text: str, category: str):
metadatas=metadatas
)
print(f"--- [Agent 4] 文档 {doc_id} 的向量已存入ChromaDB。")
print(f"--- [Agent 4] document {doc_id} stored in ChromaDB。")

18
app/core/callbacks.py Normal file
View File

@@ -0,0 +1,18 @@
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from typing import Any, Dict
class TokenUsageCallbackHandler(BaseCallbackHandler):
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
token_usage = response.llm_output.get('token_usage', {})
if token_usage:
prompt_tokens = token_usage.get('prompt_tokens', 0)
completion_tokens = token_usage.get('completion_tokens', 0)
total_tokens = token_usage.get('total_tokens', 0)
print("--- [Token Usage] ---")
print(f" Prompt Tokens: {prompt_tokens}")
print(f" Completion Tokens: {completion_tokens}")
print(f" Total Tokens: {total_tokens}")
print("---------------------")

View File

@@ -1,45 +1,33 @@
# app/core/llm.py
import os
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from .callbacks import TokenUsageCallbackHandler
# 加载.env文件中的环境变量
load_dotenv()
# 获取配置的LLM供应商
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
llm = None
print(f"--- [Core] Initializing LLM with provider: {LLM_PROVIDER} ---")
if LLM_PROVIDER == "azure":
# --- Azure OpenAI 配置 ---
required_vars = [
"AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY",
"OPENAI_API_VERSION", "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"
]
if not all(os.getenv(var) for var in required_vars):
raise ValueError("One or more Azure OpenAI environment variables for chat are not set.")
token_callback = TokenUsageCallbackHandler()
if LLM_PROVIDER == "azure":
llm = AzureChatOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("OPENAI_API_VERSION"),
azure_deployment=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"),
temperature=0,
callbacks=[token_callback]
)
elif LLM_PROVIDER == "openai":
# --- 标准 OpenAI 配置 ---
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.")
llm = ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
model_name=os.getenv("OPENAI_MODEL_NAME", "gpt-4o"),
temperature=0,
callbacks=[token_callback]
)
else:
raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.")
raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.")

View File

@@ -1,27 +0,0 @@
# app/core/ocr.py
import pytesseract
from PIL import Image
# 注意: 您需要先在您的系统中安装Google的Tesseract OCR引擎。
# 详情请参考之前的安装说明。
def extract_text_from_image(image: Image.Image) -> str:
"""
使用Tesseract OCR从Pillow Image对象中提取文本。
参数:
image: Pillow Image对象。
返回:
从图片中提取出的字符串文本。
"""
try:
print("--- [Core OCR] 正在从图片中提取文本用于分类...")
# lang='chi_sim+eng' 表示同时识别简体中文和英文
text = pytesseract.image_to_string(image, lang='chi_sim+eng')
print("--- [Core OCR] 文本提取成功。")
return text
except Exception as e:
print(f"--- [Core OCR] OCR处理失败: {e}")
raise IOError(f"OCR processing failed: {e}")

View File

@@ -5,39 +5,20 @@ from io import BytesIO
from typing import List
import base64
# 注意: 您需要安装Poppler。
# - macOS: brew install poppler
# - Ubuntu/Debian: sudo apt-get install poppler-utils
# - Windows: 下载Poppler并将其bin目录添加到系统PATH。
def convert_pdf_to_images(pdf_bytes: bytes) -> List[Image.Image]:
"""将PDF文件的字节流转换为Pillow Image对象列表。"""
try:
print("--- [Core PDF] 正在将PDF转换为图片...")
print("--- [Core PDF] Converting PDF to images...")
# --- 新增代码开始 ---
# 在这里直接指定您电脑上Poppler的bin目录路径
# 请确保将下面的示例路径替换为您的真实路径
poppler_path = r"C:\ProgramData\chocolatey\lib\poppler\tools\Library\bin"
# --- 新增代码结束 ---
# --- 修改的代码开始 ---
# 在调用时传入poppler_path参数
images = convert_from_bytes(pdf_bytes)
# --- 修改的代码结束 ---
print(f"--- [Core PDF] 转换成功,共 {len(images)} ")
print(f"--- [Core PDF] converted PDF to imagestotal {len(images)} pages")
return images
except Exception as e:
print(f"--- [Core PDF] PDF转换失败: {e}")
# 增加一个更友好的错误提示
print("--- [Core PDF] 请确认您已在系统中正确安装Poppler并在上面的代码中指定了正确的poppler_path。")
print(f"--- [Core PDF] PDF conversion failed: {e}")
raise IOError(f"PDF to image conversion failed: {e}")
def image_to_base64_str(image: Image.Image) -> str:
"""将Pillow Image对象转换为Base64编码的字符串。"""
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')

View File

@@ -4,10 +4,8 @@ import chromadb
from dotenv import load_dotenv
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
# 加载.env文件中的环境变量
load_dotenv()
# 获取配置的LLM供应商
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
embedding_model = None
@@ -15,7 +13,6 @@ embedding_model = None
print(f"--- [Core] Initializing Embeddings with provider: {LLM_PROVIDER} ---")
if LLM_PROVIDER == "azure":
# --- Azure OpenAI 配置 ---
required_vars = [
"AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY",
"OPENAI_API_VERSION", "AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"
@@ -31,7 +28,6 @@ if LLM_PROVIDER == "azure":
)
elif LLM_PROVIDER == "openai":
# --- 标准 OpenAI 配置 ---
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.")
@@ -44,7 +40,6 @@ else:
raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.")
# 初始化ChromaDB客户端 (无变化)
client = chromadb.PersistentClient(path="./chroma_db")
vector_store = client.get_or_create_collection(
name="documents",

View File

@@ -1,20 +1,18 @@
# app/main.py
import uvicorn
from fastapi import FastAPI
from .routers import documents # 导入我们新的文档路由
from .routers import documents
app = FastAPI(
title="混合模式文档处理AI Agent",
description="一个用于自动分类、提取和处理文档的AI应用框架。",
version="0.9.0", # 版本升级: 模块化API路由
title="Hybrid Mode Document Processing AI Agent",
description="An AI application framework for automatic document classification, extraction, and processing.",
version="0.9.0",
)
# 将文档路由包含到主应用中
app.include_router(documents.router)
@app.get("/", tags=["Root"])
async def read_root():
"""一个简单的根端点,用于检查服务是否正在运行。"""
return {"message": "Welcome to the Document Processing AI Agent API!"}
if __name__ == "__main__":

View File

@@ -10,64 +10,54 @@ from io import BytesIO
from .. import agents
from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
from ..core.ocr import extract_text_from_image
# 创建一个APIRouter实例
# Create an APIRouter instance
router = APIRouter(
prefix="/documents", # 为这个路由下的所有路径添加前缀
tags=["Document Processing"], # 在API文档中为这组端点添加标签
prefix="/documents",
tags=["Document Processing"],
)
# 模拟一个SQL数据库来存储最终结果
# Simulate an SQL database to store the final results
db_results: Dict[str, Any] = {}
async def hybrid_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
"""混合处理流水线"""
ocr_text = await run_in_threadpool(extract_text_from_image, image)
classification_result = await agents.agent_classify_document_from_text(ocr_text)
async def multimodal_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
image_base64 = await run_in_threadpool(image_to_base64_str, image)
classification_result = await agents.agent_classify_document_from_image(image_base64)
category = classification_result.category
language = classification_result.language
print(f"Document page {page_num} classified as: {category}, Language: {language}")
extraction_result = None
if category in ["RECEIPT", "INVOICE"]:
image_base64 = await run_in_threadpool(image_to_base64_str, image)
if category == "RECEIPT":
# 将语言传递给提取Agent
extraction_result = await agents.agent_extract_receipt_info(image_base64, language)
elif category == "INVOICE":
# 将语言传递给提取Agent
extraction_result = await agents.agent_extract_invoice_info(image_base64, language)
else:
print(f"Document classified as '{category}', skipping high-precision extraction.")
print(f"Document classified as '{category}', skipping extraction.")
final_result = {
"doc_id": f"{doc_id}_page_{page_num}",
"category": category,
"language": language,
"ocr_text_for_classification": ocr_text,
"extraction_data": extraction_result.dict() if extraction_result else None,
"status": "Processed"
}
db_results[final_result["doc_id"]] = final_result
return final_result
@router.post("/process", summary="上传并处理单个文档(混合模式)")
@router.post("/process", summary="upload and process a document")
async def upload_and_process_document(file: UploadFile = File(...)):
"""处理上传的文档文件 (PDF, PNG, JPG)"""
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided.")
doc_id = str(uuid.uuid4())
print(f"\n接收到新文件: {file.filename} (分配ID: {doc_id})")
print(f"\nReceiving document: {file.filename} (allocated ID: {doc_id})")
contents = await file.read()
try:
file_type = mimetypes.guess_type(file.filename)[0]
print(f"检测到文件类型: {file_type}")
print(f"File type: {file_type}")
images: List[Image.Image] = []
if file_type == 'application/pdf':
@@ -80,20 +70,39 @@ async def upload_and_process_document(file: UploadFile = File(...)):
if not images:
raise HTTPException(status_code=400, detail="Could not extract images from document.")
all_page_results = []
for i, img in enumerate(images):
page_result = await hybrid_process_pipeline(doc_id, img, i + 1)
all_page_results.append(page_result)
images_base64 = [await run_in_threadpool(image_to_base64_str, img) for img in images]
return all_page_results
classification_result = await agents.agent_classify_document_from_image(images_base64)
category = classification_result.category
language = classification_result.language
print(f"The document is classified as: {category}, Language: {language}")
extraction_result = None
if category in ["RECEIPT", "INVOICE"]:
if category == "RECEIPT":
extraction_result = await agents.agent_extract_receipt_info(images_base64, language)
elif category == "INVOICE":
extraction_result = await agents.agent_extract_invoice_info(images_base64, language)
else:
print(f"The document is classified as '{category}'skipping extraction。")
# 3. Return a unified result
final_result = {
"doc_id": doc_id,
"page_count": len(images),
"category": category,
"language": language,
"extraction_data": extraction_result.dict() if extraction_result else None,
"status": "Processed"
}
db_results[doc_id] = final_result
return final_result
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
@router.get("/results/{doc_id}", summary="根据ID获取处理结果")
@router.get("/results/{doc_id}", summary="Get result by doc_id")
async def get_result(doc_id: str):
"""根据文档处理后返回的 doc_id 获取其详细处理结果。"""
if doc_id in db_results:
return db_results[doc_id]
raise HTTPException(status_code=404, detail="Document not found.")

View File

@@ -2,26 +2,26 @@
from pydantic import BaseModel, Field
from typing import List, Optional
# --- 分类结果模型 ---
# --- Classification Result Model ---
class ClassificationResult(BaseModel):
"""Defines the structured output for the classification agent."""
category: str = Field(description="The category of the document, must be one of ['LETTER', 'INVOICE', 'RECEIPT', 'CONTRACT', 'OTHER']")
language: str = Field(description="The detected primary language of the document as a two-letter code (e.g., 'en', 'zh', 'es').")
# --- 现有模型 (无变化) ---
# --- Existing Model (Unchanged) ---
class ReceiptItem(BaseModel):
name: str = Field(description="购买的项目或服务名称")
quantity: float = Field(description="项目数量")
price: float = Field(description="项目单价")
name: str = Field(description="The name of the purchased item or service")
quantity: float = Field(description="The quantity of the item")
price: float = Field(description="The unit price of the item")
class ReceiptInfo(BaseModel):
merchant_name: Optional[str] = Field(None, description="商户或店铺的名称")
transaction_date: Optional[str] = Field(None, description="交易日期,格式为 YYYY-MM-DD")
total_amount: Optional[float] = Field(None, description="收据上的总金额")
items: Optional[List[ReceiptItem]] = Field(None, description="购买的所有项目列表")
merchant_name: Optional[str] = Field(None, description="The name of the merchant or store")
transaction_date: Optional[str] = Field(None, description="The transaction date in the YYYY-MM-DD format")
total_amount: Optional[float] = Field(None, description="The total amount on the receipt")
items: Optional[List[ReceiptItem]] = Field(None, description="The list of all purchased items")
# --- 新增: 发票行项目模型 ---
# --- Added: Invoice Line Item Model ---
class LineItem(BaseModel):
"""Defines a single line item from an invoice."""
description: Optional[str] = Field("", description="The description of the product or service.")
@@ -30,7 +30,6 @@ class LineItem(BaseModel):
total_price: Optional[float] = Field(None, description="The total price for this line item (quantity * unit_price).")
# --- 发票模型 (已更新) ---
class InvoiceInfo(BaseModel):
"""Defines the detailed, structured information to be extracted from an invoice."""
date: Optional[str] = Field("", description="Extract in YYYY-MM-DD format. If unclear, leave as an empty string.")
@@ -48,4 +47,7 @@ class InvoiceInfo(BaseModel):
customer_address_region: Optional[str] = Field("", description="It's the receiver's address region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.")
customer_address_care_of: Optional[str] = Field("", description="It's the receiver's address care of. If not found or unclear, leave as an empty string.")
billo_id: Optional[str] = Field("", description="Extract from customer_address if it exists, following the format 'LLL NNN-A'. If not found or unclear, leave as an empty string.")
bank_giro: Optional[str] = Field("", description="BankGiro number, e.g., '123-4567'. If not found, leave as an empty string.")
plus_giro: Optional[str] = Field("", description="PlusGiro number, e.g., '123456-7'. If not found, leave as an empty string.")
customer_ssn: Optional[str] = Field("", description="Customer's social security number, e.g., 'YYYYMMDD-XXXX'. If not found, leave as an empty string.")
line_items: List[LineItem] = Field([], description="A list of all line items from the invoice.")

View File

@@ -1,11 +1,12 @@
fastapi
uvicorn[standard]
python-dotenv
langchain
langchain-openai
chromadb
fastapi~=0.115.9
uvicorn[standard]~=0.35.0
python-dotenv~=1.1.0
langchain~=0.3.25
langchain-openai~=0.3.16
chromadb~=1.0.16
tiktoken
pdf2image
pdf2image~=1.17.0
python-multipart
pytesseract
Pillow
Pillow~=11.3.0
langchain-core~=0.3.58
pydantic~=2.11.7