diff --git a/.idea/AmazingDoc.iml b/.idea/AmazingDoc.iml
index 5904ce9..67d6524 100644
--- a/.idea/AmazingDoc.iml
+++ b/.idea/AmazingDoc.iml
@@ -5,7 +5,7 @@
-
+
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 1d3ce46..59ad763 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -3,5 +3,5 @@
-
+
\ No newline at end of file
diff --git a/app/agents.py b/app/agents.py
deleted file mode 100644
index 0e0e95b..0000000
--- a/app/agents.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# app/agents.py
-import asyncio
-import random
-from .schemas import ReceiptInfo, InvoiceInfo, ReceiptItem
-
-# --- Agent核心功能 (占位符/模拟实现) ---
-# 在实际应用中,这些函数将被替换为调用LangChain和LLM的真实逻辑。
-
-async def agent_classify_document(text: str) -> str:
- """Agent 1: 文件分类 (模拟)"""
- print("--- [Agent 1] 正在进行文档分类...")
- await asyncio.sleep(0.5) # 模拟网络延迟
- doc_types = ["信件", "收据", "发票", "合约"]
- if "发票" in text: return "发票"
- if "收据" in text or "小票" in text: return "收据"
- if "合同" in text or "协议" in text: return "合约"
- return random.choice(doc_types)
-
-async def agent_extract_receipt_info(text: str) -> ReceiptInfo:
- """Agent 2: 收据信息提取 (模拟)"""
- print("--- [Agent 2] 正在提取收据信息...")
- await asyncio.sleep(1) # 模拟LLM处理时间
- return ReceiptInfo(
- merchant_name="模拟超市",
- transaction_date="2025-08-10",
- total_amount=198.50,
- items=[ReceiptItem(name="牛奶", quantity=2, price=11.5)]
- )
-
-async def agent_extract_invoice_info(text: str) -> InvoiceInfo:
- """Agent 3: 发票信息提取 (模拟)"""
- print("--- [Agent 3] 正在提取发票信息...")
- await asyncio.sleep(1) # 模拟LLM处理时间
- return InvoiceInfo(
- invoice_number="INV123456789",
- issue_date="2025-08-09",
- seller_name="模拟科技有限公司",
- total_amount_in_figures=12000.00
- )
-
-async def agent_vectorize_and_store(doc_id: str, text: str, category: str, vector_db: dict):
- """Agent 4: 向量化并存储 (模拟)"""
- print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...")
- await asyncio.sleep(0.5)
- chunks = [text[i:i+200] for i in range(0, len(text), 200)]
- vector_db[doc_id] = {
- "metadata": {"category": category, "chunk_count": len(chunks)},
- "content_chunks": chunks,
- "vectors": [random.random() for _ in range(len(chunks) * 128)]
- }
- print(f"--- [Agent 4] 文档 {doc_id} 已存入向量数据库。")
\ No newline at end of file
diff --git a/app/agents/__init__.py b/app/agents/__init__.py
index ee85a62..49b8c66 100644
--- a/app/agents/__init__.py
+++ b/app/agents/__init__.py
@@ -1,7 +1,4 @@
# app/agents/__init__.py
-
-# This file makes it easy to import all agents from the 'agents' package.
-# We are importing the function with its correct name now.
-from .classification_agent import agent_classify_document_from_text
+from .classification_agent import agent_classify_document_from_image
from .receipt_agent import agent_extract_receipt_info
from .invoice_agent import agent_extract_invoice_info
\ No newline at end of file
diff --git a/app/agents/classification_agent.py b/app/agents/classification_agent.py
index d04b862..42d3ea2 100644
--- a/app/agents/classification_agent.py
+++ b/app/agents/classification_agent.py
@@ -1,42 +1,48 @@
# app/agents/classification_agent.py
-from langchain.prompts import PromptTemplate
+from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import PydanticOutputParser
+from langchain.prompts import PromptTemplate
from ..core.llm import llm
-from ..schemas import ClassificationResult # 导入新的Schema
+from ..schemas import ClassificationResult
+from typing import List
-# 1. 设置PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=ClassificationResult)
-# 2. 更新Prompt模板以要求语言,并包含格式指令
classification_template = """
-You are a professional document analysis assistant. Please perform two tasks on the following text:
-1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
-2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
+You are a professional document analysis assistant. The following images represent pages from a single document. Please perform two tasks based on all pages provided:
+1. Determine the overall category of the document. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
+2. Detect the primary language of the document. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese).
+
+Please provide a single response for the entire document in the requested JSON format.
{format_instructions}
-
-[Document Text Start]
-{document_text}
-[Document Text End]
"""
classification_prompt = PromptTemplate(
template=classification_template,
- input_variables=["document_text"],
+ input_variables=[],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
-# 3. 创建新的LangChain链
-classification_chain = classification_prompt | llm | parser
+async def agent_classify_document_from_image(images_base64: List[str]) -> ClassificationResult:
+ """Agent 1: Classifies an entire document (multiple pages) and detects its language from a list of images."""
+ print(f"--- [Agent 1] Calling multimodal LLM for classification of a {len(images_base64)}-page document...")
-async def agent_classify_document_from_text(text: str) -> ClassificationResult:
- """Agent 1: Classify document and detect language from OCR-extracted text."""
- print("--- [Agent 1] Calling LLM for classification and language detection...")
- if not text.strip():
- print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.")
- return ClassificationResult(category="OTHER", language="unknown")
+ prompt_text = await classification_prompt.aformat()
- # 调用链并返回Pydantic对象
- result = await classification_chain.ainvoke({"document_text": text})
+ # Create a list of content parts, starting with the text prompt
+ content_parts = [{"type": "text", "text": prompt_text}]
+
+ # Add each image to the content list
+ for image_base64 in images_base64:
+ content_parts.append({
+ "type": "image_url",
+ "image_url": {"url": f"data:image/png;base64,{image_base64}"},
+ })
+
+ msg = HumanMessage(content=content_parts)
+
+ chain = llm | parser
+ result = await chain.ainvoke([msg])
return result
diff --git a/app/agents/invoice_agent.py b/app/agents/invoice_agent.py
index b6c2570..ecf3569 100644
--- a/app/agents/invoice_agent.py
+++ b/app/agents/invoice_agent.py
@@ -4,10 +4,10 @@ from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import InvoiceInfo
+from typing import List
parser = PydanticOutputParser(pydantic_object=InvoiceInfo)
-# The prompt now includes the detailed rules for each field using snake_case.
invoice_template = """
You are an expert data entry clerk AI. Your primary goal is to extract information from an invoice image with the highest possible accuracy.
The document's primary language is '{language}'.
@@ -30,6 +30,9 @@ Carefully analyze the invoice image and extract the following fields according t
- `customer_address_region`: This is the receiver's region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.
- `customer_address_care_of`: This is the receiver's 'care of' (c/o) line. If not found or unclear, leave as an empty string.
- `billo_id`: To find this, think step-by-step: 1. Find the customer_address. 2. Scan the address for a pattern of three letters, an optional space, three digits, an optional dash, and one alphanumeric character (e.g., 'ABC 123-X' or 'DEF 456Z'). 3. If found, extract it. If not found or unclear, leave as an empty string.
+- `bank_giro`: If found, extract the bank giro number. It often follows patterns like 'ddd-dddd', 'dddd-dddd', or 'dddddddd #41#'. If not found or unclear, leave as an empty string.
+- `plus_giro`: If found, extract the plus giro number. It often follows patterns like 'ddddddd-d #16#', 'ddddddd-d', or 'ddd dd dd-d'. If not found or unclear, leave as an empty string.
+- `customer_ssn`: If found, extract the customer social security number (personnummer). It follows the pattern 'YYYYMMDD-XXXX' or 'YYMMDD-XXXX'. If not found or unclear, leave as an empty string.
- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null.
## Example:
@@ -41,33 +44,31 @@ If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $
"unit_price": 100.00,
"total_price": 200.00
}}
-
+```
Your Task:
Now, analyze the provided image and output the full JSON object according to the format below.
{format_instructions}
"""
-
invoice_prompt = PromptTemplate(
template=invoice_template,
input_variables=["language"],
- partial_variables={"format_instructions": parser.get_format_instructions()},
+ partial_variables={"format_instructions": parser.get_format_instructions()}
)
-async def agent_extract_invoice_info(image_base64: str, language: str) -> InvoiceInfo:
- """Agent 3: Extracts invoice information from an image, aware of the document's language."""
+async def agent_extract_invoice_info(images_base64: List[str], language: str) -> InvoiceInfo:
+ """Agent 3: Extracts invoice information from a list of images, aware of the document's language."""
print(f"--- [Agent 3] Calling multimodal LLM to extract invoice info (Language: {language})...")
prompt_text = await invoice_prompt.aformat(language=language)
- msg = HumanMessage(
- content=[
- {"type": "text", "text": prompt_text},
- {
- "type": "image_url",
- "image_url": f"data:image/png;base64,{image_base64}",
- },
- ]
- )
+ content_parts = [{"type": "text", "text": prompt_text}]
+ for image_base64 in images_base64:
+ content_parts.append({
+ "type": "image_url",
+ "image_url": {"url": f"data:image/png;base64,{image_base64}"},
+ })
+
+ msg = HumanMessage(content=content_parts)
chain = llm | parser
invoice_info = await chain.ainvoke([msg])
diff --git a/app/agents/receipt_agent.py b/app/agents/receipt_agent.py
index e1719d0..4608641 100644
--- a/app/agents/receipt_agent.py
+++ b/app/agents/receipt_agent.py
@@ -4,15 +4,15 @@ from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import ReceiptInfo
+from typing import List
parser = PydanticOutputParser(pydantic_object=ReceiptInfo)
-# 更新Prompt模板以包含语言信息
receipt_template = """
You are a highly accurate receipt information extraction robot.
The document's primary language is '{language}'.
-Please extract all key information from the following receipt image.
-If some information is not present in the image, leave it as null.
+Please extract all key information from the following receipt images, which belong to a single document.
+If some information is not present in the images, leave it as null.
Please strictly follow the JSON format below, without adding any extra explanations or comments.
{format_instructions}
@@ -25,22 +25,21 @@ receipt_prompt = PromptTemplate(
)
-async def agent_extract_receipt_info(image_base64: str, language: str) -> ReceiptInfo:
- """Agent 2: Extracts receipt information from an image, aware of the document's language."""
+async def agent_extract_receipt_info(images_base64: List[str], language: str) -> ReceiptInfo:
+ """Agent 2: Extracts receipt information from a list of images, aware of the document's language."""
print(f"--- [Agent 2] Calling multimodal LLM to extract receipt info (Language: {language})...")
prompt_text = await receipt_prompt.aformat(language=language)
- msg = HumanMessage(
- content=[
- {"type": "text", "text": prompt_text},
- {
- "type": "image_url",
- "image_url": f"data:image/png;base64,{image_base64}",
- },
- ]
- )
+ content_parts = [{"type": "text", "text": prompt_text}]
+ for image_base64 in images_base64:
+ content_parts.append({
+ "type": "image_url",
+ "image_url": {"url": f"data:image/png;base64,{image_base64}"},
+ })
+
+ msg = HumanMessage(content=content_parts)
chain = llm | parser
receipt_info = await chain.ainvoke([msg])
- return receipt_info
+ return receipt_info
\ No newline at end of file
diff --git a/app/agents/vectorization_agent.py b/app/agents/vectorization_agent.py
index 0c127be..983e568 100644
--- a/app/agents/vectorization_agent.py
+++ b/app/agents/vectorization_agent.py
@@ -2,32 +2,32 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from ..core.vector_store import vector_store, embedding_model
-# 初始化文本分割器,用于将长文档切成小块
+# Initialize the text splitter to divide long documents into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(
- chunk_size=500, # 每个块的大小(字符数)
- chunk_overlap=50, # 块之间的重叠部分
+ chunk_size=500,
+ chunk_overlap=50,
)
def agent_vectorize_and_store(doc_id: str, text: str, category: str):
- """Agent 4: 向量化并存储 (真实实现)"""
- print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...")
+ """Agent 4: Vectorization and Storage (Real Implementation)"""
+ print(f"--- [Agent 4] Vectorizing document (ID: {doc_id})...")
- # 1. 将文档文本分割成块
+ # 1. Split the document text into chunks
chunks = text_splitter.split_text(text)
- print(f"--- [Agent 4] 文档被切分为 {len(chunks)} 个块。")
+ print(f"--- [Agent 4] Document split into {len(chunks)} chunks.")
if not chunks:
- print(f"--- [Agent 4] 文档内容为空,跳过向量化。")
+ print(f"--- [Agent 4] Document is empty, skipping vectorization.")
return
- # 2. 为每个块创建唯一的ID和元数据
+ # 2. Create a unique ID and metadata for each chunk
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
metadatas = [{"doc_id": doc_id, "category": category, "chunk_number": i} for i in range(len(chunks))]
- # 3. 使用嵌入模型为所有块生成向量
+ # 3. Use an embedding model to generate vectors for all chunks
embeddings = embedding_model.embed_documents(chunks)
- # 4. 将ID、向量、元数据和文本块本身添加到ChromaDB
+ # 4. Add the IDs, vectors, metadata, and text chunks to ChromaDB
vector_store.add(
ids=chunk_ids,
embeddings=embeddings,
@@ -35,4 +35,4 @@ def agent_vectorize_and_store(doc_id: str, text: str, category: str):
metadatas=metadatas
)
- print(f"--- [Agent 4] 文档 {doc_id} 的向量已存入ChromaDB。")
+ print(f"--- [Agent 4] document {doc_id} stored in ChromaDB。")
diff --git a/app/core/callbacks.py b/app/core/callbacks.py
new file mode 100644
index 0000000..e4b878e
--- /dev/null
+++ b/app/core/callbacks.py
@@ -0,0 +1,18 @@
+from langchain_core.callbacks import BaseCallbackHandler
+from langchain_core.outputs import LLMResult
+from typing import Any, Dict
+
+class TokenUsageCallbackHandler(BaseCallbackHandler):
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ token_usage = response.llm_output.get('token_usage', {})
+
+ if token_usage:
+ prompt_tokens = token_usage.get('prompt_tokens', 0)
+ completion_tokens = token_usage.get('completion_tokens', 0)
+ total_tokens = token_usage.get('total_tokens', 0)
+
+ print("--- [Token Usage] ---")
+ print(f" Prompt Tokens: {prompt_tokens}")
+ print(f" Completion Tokens: {completion_tokens}")
+ print(f" Total Tokens: {total_tokens}")
+ print("---------------------")
diff --git a/app/core/llm.py b/app/core/llm.py
index 3e07072..d4f9412 100644
--- a/app/core/llm.py
+++ b/app/core/llm.py
@@ -1,45 +1,33 @@
-# app/core/llm.py
import os
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI, ChatOpenAI
+from .callbacks import TokenUsageCallbackHandler
-# 加载.env文件中的环境变量
load_dotenv()
-
-# 获取配置的LLM供应商
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
-
llm = None
print(f"--- [Core] Initializing LLM with provider: {LLM_PROVIDER} ---")
-if LLM_PROVIDER == "azure":
- # --- Azure OpenAI 配置 ---
- required_vars = [
- "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY",
- "OPENAI_API_VERSION", "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"
- ]
- if not all(os.getenv(var) for var in required_vars):
- raise ValueError("One or more Azure OpenAI environment variables for chat are not set.")
+token_callback = TokenUsageCallbackHandler()
+if LLM_PROVIDER == "azure":
llm = AzureChatOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("OPENAI_API_VERSION"),
azure_deployment=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"),
temperature=0,
+ callbacks=[token_callback]
)
elif LLM_PROVIDER == "openai":
- # --- 标准 OpenAI 配置 ---
- if not os.getenv("OPENAI_API_KEY"):
- raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.")
-
llm = ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
model_name=os.getenv("OPENAI_MODEL_NAME", "gpt-4o"),
temperature=0,
+ callbacks=[token_callback]
)
else:
- raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.")
\ No newline at end of file
+ raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.")
diff --git a/app/core/ocr.py b/app/core/ocr.py
deleted file mode 100644
index b91bca5..0000000
--- a/app/core/ocr.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# app/core/ocr.py
-import pytesseract
-from PIL import Image
-
-
-# 注意: 您需要先在您的系统中安装Google的Tesseract OCR引擎。
-# 详情请参考之前的安装说明。
-
-def extract_text_from_image(image: Image.Image) -> str:
- """
- 使用Tesseract OCR从Pillow Image对象中提取文本。
-
- 参数:
- image: Pillow Image对象。
-
- 返回:
- 从图片中提取出的字符串文本。
- """
- try:
- print("--- [Core OCR] 正在从图片中提取文本用于分类...")
- # lang='chi_sim+eng' 表示同时识别简体中文和英文
- text = pytesseract.image_to_string(image, lang='chi_sim+eng')
- print("--- [Core OCR] 文本提取成功。")
- return text
- except Exception as e:
- print(f"--- [Core OCR] OCR处理失败: {e}")
- raise IOError(f"OCR processing failed: {e}")
\ No newline at end of file
diff --git a/app/core/pdf_processor.py b/app/core/pdf_processor.py
index 0aa3759..1118262 100644
--- a/app/core/pdf_processor.py
+++ b/app/core/pdf_processor.py
@@ -5,39 +5,20 @@ from io import BytesIO
from typing import List
import base64
-
-# 注意: 您需要安装Poppler。
-# - macOS: brew install poppler
-# - Ubuntu/Debian: sudo apt-get install poppler-utils
-# - Windows: 下载Poppler并将其bin目录添加到系统PATH。
-
def convert_pdf_to_images(pdf_bytes: bytes) -> List[Image.Image]:
- """将PDF文件的字节流转换为Pillow Image对象列表。"""
try:
- print("--- [Core PDF] 正在将PDF转换为图片...")
+ print("--- [Core PDF] Converting PDF to images...")
- # --- 新增代码开始 ---
- # 在这里直接指定您电脑上Poppler的bin目录路径
- # 请确保将下面的示例路径替换为您的真实路径
- poppler_path = r"C:\ProgramData\chocolatey\lib\poppler\tools\Library\bin"
- # --- 新增代码结束 ---
-
- # --- 修改的代码开始 ---
- # 在调用时传入poppler_path参数
images = convert_from_bytes(pdf_bytes)
- # --- 修改的代码结束 ---
- print(f"--- [Core PDF] 转换成功,共 {len(images)} 页。")
+ print(f"--- [Core PDF] converted PDF to images,total {len(images)} pages。")
return images
except Exception as e:
- print(f"--- [Core PDF] PDF转换失败: {e}")
- # 增加一个更友好的错误提示
- print("--- [Core PDF] 请确认您已在系统中正确安装Poppler,并在上面的代码中指定了正确的poppler_path。")
+ print(f"--- [Core PDF] PDF conversion failed: {e}")
raise IOError(f"PDF to image conversion failed: {e}")
def image_to_base64_str(image: Image.Image) -> str:
- """将Pillow Image对象转换为Base64编码的字符串。"""
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
\ No newline at end of file
diff --git a/app/core/vector_store.py b/app/core/vector_store.py
index 89fc5f0..f0ecd3a 100644
--- a/app/core/vector_store.py
+++ b/app/core/vector_store.py
@@ -4,10 +4,8 @@ import chromadb
from dotenv import load_dotenv
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
-# 加载.env文件中的环境变量
load_dotenv()
-# 获取配置的LLM供应商
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
embedding_model = None
@@ -15,7 +13,6 @@ embedding_model = None
print(f"--- [Core] Initializing Embeddings with provider: {LLM_PROVIDER} ---")
if LLM_PROVIDER == "azure":
- # --- Azure OpenAI 配置 ---
required_vars = [
"AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY",
"OPENAI_API_VERSION", "AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"
@@ -31,7 +28,6 @@ if LLM_PROVIDER == "azure":
)
elif LLM_PROVIDER == "openai":
- # --- 标准 OpenAI 配置 ---
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.")
@@ -44,7 +40,6 @@ else:
raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.")
-# 初始化ChromaDB客户端 (无变化)
client = chromadb.PersistentClient(path="./chroma_db")
vector_store = client.get_or_create_collection(
name="documents",
diff --git a/app/main.py b/app/main.py
index 55a9de1..8b66e75 100644
--- a/app/main.py
+++ b/app/main.py
@@ -1,20 +1,18 @@
# app/main.py
import uvicorn
from fastapi import FastAPI
-from .routers import documents # 导入我们新的文档路由
+from .routers import documents
app = FastAPI(
- title="混合模式文档处理AI Agent",
- description="一个用于自动分类、提取和处理文档的AI应用框架。",
- version="0.9.0", # 版本升级: 模块化API路由
+ title="Hybrid Mode Document Processing AI Agent",
+ description="An AI application framework for automatic document classification, extraction, and processing.",
+ version="0.9.0",
)
-# 将文档路由包含到主应用中
app.include_router(documents.router)
@app.get("/", tags=["Root"])
async def read_root():
- """一个简单的根端点,用于检查服务是否正在运行。"""
return {"message": "Welcome to the Document Processing AI Agent API!"}
if __name__ == "__main__":
diff --git a/app/routers/documents.py b/app/routers/documents.py
index ad99729..8d3e9ec 100644
--- a/app/routers/documents.py
+++ b/app/routers/documents.py
@@ -10,64 +10,54 @@ from io import BytesIO
from .. import agents
from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
-from ..core.ocr import extract_text_from_image
-# 创建一个APIRouter实例
+# Create an APIRouter instance
router = APIRouter(
- prefix="/documents", # 为这个路由下的所有路径添加前缀
- tags=["Document Processing"], # 在API文档中为这组端点添加标签
+ prefix="/documents",
+ tags=["Document Processing"],
)
-# 模拟一个SQL数据库来存储最终结果
+# Simulate an SQL database to store the final results
db_results: Dict[str, Any] = {}
-
-async def hybrid_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
- """混合处理流水线"""
- ocr_text = await run_in_threadpool(extract_text_from_image, image)
-
- classification_result = await agents.agent_classify_document_from_text(ocr_text)
+async def multimodal_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
+ image_base64 = await run_in_threadpool(image_to_base64_str, image)
+ classification_result = await agents.agent_classify_document_from_image(image_base64)
category = classification_result.category
language = classification_result.language
print(f"Document page {page_num} classified as: {category}, Language: {language}")
extraction_result = None
if category in ["RECEIPT", "INVOICE"]:
- image_base64 = await run_in_threadpool(image_to_base64_str, image)
if category == "RECEIPT":
- # 将语言传递给提取Agent
extraction_result = await agents.agent_extract_receipt_info(image_base64, language)
elif category == "INVOICE":
- # 将语言传递给提取Agent
extraction_result = await agents.agent_extract_invoice_info(image_base64, language)
else:
- print(f"Document classified as '{category}', skipping high-precision extraction.")
+ print(f"Document classified as '{category}', skipping extraction.")
final_result = {
"doc_id": f"{doc_id}_page_{page_num}",
"category": category,
"language": language,
- "ocr_text_for_classification": ocr_text,
"extraction_data": extraction_result.dict() if extraction_result else None,
"status": "Processed"
}
db_results[final_result["doc_id"]] = final_result
return final_result
-
-@router.post("/process", summary="上传并处理单个文档(混合模式)")
+@router.post("/process", summary="upload and process a document")
async def upload_and_process_document(file: UploadFile = File(...)):
- """处理上传的文档文件 (PDF, PNG, JPG)"""
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided.")
doc_id = str(uuid.uuid4())
- print(f"\n接收到新文件: {file.filename} (分配ID: {doc_id})")
+ print(f"\nReceiving document: {file.filename} (allocated ID: {doc_id})")
contents = await file.read()
try:
file_type = mimetypes.guess_type(file.filename)[0]
- print(f"检测到文件类型: {file_type}")
+ print(f"File type: {file_type}")
images: List[Image.Image] = []
if file_type == 'application/pdf':
@@ -80,20 +70,39 @@ async def upload_and_process_document(file: UploadFile = File(...)):
if not images:
raise HTTPException(status_code=400, detail="Could not extract images from document.")
- all_page_results = []
- for i, img in enumerate(images):
- page_result = await hybrid_process_pipeline(doc_id, img, i + 1)
- all_page_results.append(page_result)
+ images_base64 = [await run_in_threadpool(image_to_base64_str, img) for img in images]
- return all_page_results
+ classification_result = await agents.agent_classify_document_from_image(images_base64)
+ category = classification_result.category
+ language = classification_result.language
+ print(f"The document is classified as: {category}, Language: {language}")
+
+ extraction_result = None
+ if category in ["RECEIPT", "INVOICE"]:
+ if category == "RECEIPT":
+ extraction_result = await agents.agent_extract_receipt_info(images_base64, language)
+ elif category == "INVOICE":
+ extraction_result = await agents.agent_extract_invoice_info(images_base64, language)
+ else:
+ print(f"The document is classified as '{category}',skipping extraction。")
+
+ # 3. Return a unified result
+ final_result = {
+ "doc_id": doc_id,
+ "page_count": len(images),
+ "category": category,
+ "language": language,
+ "extraction_data": extraction_result.dict() if extraction_result else None,
+ "status": "Processed"
+ }
+ db_results[doc_id] = final_result
+ return final_result
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
-
-@router.get("/results/{doc_id}", summary="根据ID获取处理结果")
+@router.get("/results/{doc_id}", summary="Get result by doc_id")
async def get_result(doc_id: str):
- """根据文档处理后返回的 doc_id 获取其详细处理结果。"""
if doc_id in db_results:
return db_results[doc_id]
raise HTTPException(status_code=404, detail="Document not found.")
diff --git a/app/schemas.py b/app/schemas.py
index d3b879f..530011c 100644
--- a/app/schemas.py
+++ b/app/schemas.py
@@ -2,26 +2,26 @@
from pydantic import BaseModel, Field
from typing import List, Optional
-# --- 分类结果模型 ---
+# --- Classification Result Model ---
class ClassificationResult(BaseModel):
"""Defines the structured output for the classification agent."""
category: str = Field(description="The category of the document, must be one of ['LETTER', 'INVOICE', 'RECEIPT', 'CONTRACT', 'OTHER']")
language: str = Field(description="The detected primary language of the document as a two-letter code (e.g., 'en', 'zh', 'es').")
-# --- 现有模型 (无变化) ---
+# --- Existing Model (Unchanged) ---
class ReceiptItem(BaseModel):
- name: str = Field(description="购买的项目或服务名称")
- quantity: float = Field(description="项目数量")
- price: float = Field(description="项目单价")
+ name: str = Field(description="The name of the purchased item or service")
+ quantity: float = Field(description="The quantity of the item")
+ price: float = Field(description="The unit price of the item")
class ReceiptInfo(BaseModel):
- merchant_name: Optional[str] = Field(None, description="商户或店铺的名称")
- transaction_date: Optional[str] = Field(None, description="交易日期,格式为 YYYY-MM-DD")
- total_amount: Optional[float] = Field(None, description="收据上的总金额")
- items: Optional[List[ReceiptItem]] = Field(None, description="购买的所有项目列表")
+ merchant_name: Optional[str] = Field(None, description="The name of the merchant or store")
+ transaction_date: Optional[str] = Field(None, description="The transaction date in the YYYY-MM-DD format")
+ total_amount: Optional[float] = Field(None, description="The total amount on the receipt")
+ items: Optional[List[ReceiptItem]] = Field(None, description="The list of all purchased items")
-# --- 新增: 发票行项目模型 ---
+# --- Added: Invoice Line Item Model ---
class LineItem(BaseModel):
"""Defines a single line item from an invoice."""
description: Optional[str] = Field("", description="The description of the product or service.")
@@ -30,7 +30,6 @@ class LineItem(BaseModel):
total_price: Optional[float] = Field(None, description="The total price for this line item (quantity * unit_price).")
-# --- 发票模型 (已更新) ---
class InvoiceInfo(BaseModel):
"""Defines the detailed, structured information to be extracted from an invoice."""
date: Optional[str] = Field("", description="Extract in YYYY-MM-DD format. If unclear, leave as an empty string.")
@@ -48,4 +47,7 @@ class InvoiceInfo(BaseModel):
customer_address_region: Optional[str] = Field("", description="It's the receiver's address region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.")
customer_address_care_of: Optional[str] = Field("", description="It's the receiver's address care of. If not found or unclear, leave as an empty string.")
billo_id: Optional[str] = Field("", description="Extract from customer_address if it exists, following the format 'LLL NNN-A'. If not found or unclear, leave as an empty string.")
+ bank_giro: Optional[str] = Field("", description="BankGiro number, e.g., '123-4567'. If not found, leave as an empty string.")
+ plus_giro: Optional[str] = Field("", description="PlusGiro number, e.g., '123456-7'. If not found, leave as an empty string.")
+ customer_ssn: Optional[str] = Field("", description="Customer's social security number, e.g., 'YYYYMMDD-XXXX'. If not found, leave as an empty string.")
line_items: List[LineItem] = Field([], description="A list of all line items from the invoice.")
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 4c8dfa0..a0d6349 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,11 +1,12 @@
-fastapi
-uvicorn[standard]
-python-dotenv
-langchain
-langchain-openai
-chromadb
+fastapi~=0.115.9
+uvicorn[standard]~=0.35.0
+python-dotenv~=1.1.0
+langchain~=0.3.25
+langchain-openai~=0.3.16
+chromadb~=1.0.16
tiktoken
-pdf2image
+pdf2image~=1.17.0
python-multipart
-pytesseract
-Pillow
+Pillow~=11.3.0
+langchain-core~=0.3.58
+pydantic~=2.11.7
\ No newline at end of file