diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/AmazingDoc.iml b/.idea/AmazingDoc.iml
new file mode 100644
index 0000000..5904ce9
--- /dev/null
+++ b/.idea/AmazingDoc.iml
@@ -0,0 +1,18 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..1d3ce46
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..f07df23
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/app/__init__.py b/app/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/agents.py b/app/agents.py
new file mode 100644
index 0000000..0e0e95b
--- /dev/null
+++ b/app/agents.py
@@ -0,0 +1,51 @@
+# app/agents.py
+import asyncio
+import random
+from .schemas import ReceiptInfo, InvoiceInfo, ReceiptItem
+
+# --- Agent核心功能 (占位符/模拟实现) ---
+# 在实际应用中,这些函数将被替换为调用LangChain和LLM的真实逻辑。
+
+async def agent_classify_document(text: str) -> str:
+ """Agent 1: 文件分类 (模拟)"""
+ print("--- [Agent 1] 正在进行文档分类...")
+ await asyncio.sleep(0.5) # 模拟网络延迟
+ doc_types = ["信件", "收据", "发票", "合约"]
+ if "发票" in text: return "发票"
+ if "收据" in text or "小票" in text: return "收据"
+ if "合同" in text or "协议" in text: return "合约"
+ return random.choice(doc_types)
+
+async def agent_extract_receipt_info(text: str) -> ReceiptInfo:
+ """Agent 2: 收据信息提取 (模拟)"""
+ print("--- [Agent 2] 正在提取收据信息...")
+ await asyncio.sleep(1) # 模拟LLM处理时间
+ return ReceiptInfo(
+ merchant_name="模拟超市",
+ transaction_date="2025-08-10",
+ total_amount=198.50,
+ items=[ReceiptItem(name="牛奶", quantity=2, price=11.5)]
+ )
+
+async def agent_extract_invoice_info(text: str) -> InvoiceInfo:
+ """Agent 3: 发票信息提取 (模拟)"""
+ print("--- [Agent 3] 正在提取发票信息...")
+ await asyncio.sleep(1) # 模拟LLM处理时间
+ return InvoiceInfo(
+ invoice_number="INV123456789",
+ issue_date="2025-08-09",
+ seller_name="模拟科技有限公司",
+ total_amount_in_figures=12000.00
+ )
+
+async def agent_vectorize_and_store(doc_id: str, text: str, category: str, vector_db: dict):
+ """Agent 4: 向量化并存储 (模拟)"""
+ print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...")
+ await asyncio.sleep(0.5)
+ chunks = [text[i:i+200] for i in range(0, len(text), 200)]
+ vector_db[doc_id] = {
+ "metadata": {"category": category, "chunk_count": len(chunks)},
+ "content_chunks": chunks,
+ "vectors": [random.random() for _ in range(len(chunks) * 128)]
+ }
+ print(f"--- [Agent 4] 文档 {doc_id} 已存入向量数据库。")
\ No newline at end of file
diff --git a/app/agents/__init__.py b/app/agents/__init__.py
new file mode 100644
index 0000000..ee85a62
--- /dev/null
+++ b/app/agents/__init__.py
@@ -0,0 +1,7 @@
+# app/agents/__init__.py
+
+# This file makes it easy to import all agents from the 'agents' package.
+# We are importing the function with its correct name now.
+from .classification_agent import agent_classify_document_from_text
+from .receipt_agent import agent_extract_receipt_info
+from .invoice_agent import agent_extract_invoice_info
\ No newline at end of file
diff --git a/app/agents/classification_agent.py b/app/agents/classification_agent.py
new file mode 100644
index 0000000..d04b862
--- /dev/null
+++ b/app/agents/classification_agent.py
@@ -0,0 +1,42 @@
+# app/agents/classification_agent.py
+from langchain.prompts import PromptTemplate
+from langchain_core.output_parsers import PydanticOutputParser
+from ..core.llm import llm
+from ..schemas import ClassificationResult # 导入新的Schema
+
+# 1. 设置PydanticOutputParser
+parser = PydanticOutputParser(pydantic_object=ClassificationResult)
+
+# 2. 更新Prompt模板以要求语言,并包含格式指令
+classification_template = """
+You are a professional document analysis assistant. Please perform two tasks on the following text:
+1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
+2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
+
+{format_instructions}
+
+[Document Text Start]
+{document_text}
+[Document Text End]
+"""
+
+classification_prompt = PromptTemplate(
+ template=classification_template,
+ input_variables=["document_text"],
+ partial_variables={"format_instructions": parser.get_format_instructions()},
+)
+
+# 3. 创建新的LangChain链
+classification_chain = classification_prompt | llm | parser
+
+
+async def agent_classify_document_from_text(text: str) -> ClassificationResult:
+ """Agent 1: Classify document and detect language from OCR-extracted text."""
+ print("--- [Agent 1] Calling LLM for classification and language detection...")
+ if not text.strip():
+ print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.")
+ return ClassificationResult(category="OTHER", language="unknown")
+
+ # 调用链并返回Pydantic对象
+ result = await classification_chain.ainvoke({"document_text": text})
+ return result
diff --git a/app/agents/invoice_agent.py b/app/agents/invoice_agent.py
new file mode 100644
index 0000000..b6c2570
--- /dev/null
+++ b/app/agents/invoice_agent.py
@@ -0,0 +1,74 @@
+# app/agents/invoice_agent.py
+from langchain_core.messages import HumanMessage
+from langchain_core.output_parsers import PydanticOutputParser
+from langchain.prompts import PromptTemplate
+from ..core.llm import llm
+from ..schemas import InvoiceInfo
+
+parser = PydanticOutputParser(pydantic_object=InvoiceInfo)
+
+# The prompt now includes the detailed rules for each field using snake_case.
+invoice_template = """
+You are an expert data entry clerk AI. Your primary goal is to extract information from an invoice image with the highest possible accuracy.
+The document's primary language is '{language}'.
+
+## Instructions:
+Carefully analyze the invoice image and extract the following fields according to these specific rules. Do not invent information. If a field is not found or is unclear, follow the specific instruction for that field.
+
+- `date`: Extract in YYYY-MM-DD format. If unclear, leave as an empty string.
+- `invoice_number`: If not found or unclear, leave as an empty string.
+- `supplier_number`: This is the organisation number. If not found or unclear, leave as an empty string.
+- `biller_name`: This is the sender's name. If not found or unclear, leave as an empty string.
+- `amount`: Extract the final total amount and format it to a decimal number. If not present, leave as null.
+- `customer_name`: This is the receiver's name. Ensure it is a name and clear any special characters. If not found or unclear, leave as an empty string.
+- `customer_address`: This is the receiver's full address. Put it in one line. If not found or unclear, leave as an empty string.
+- `customer_address_line`: This is only the street address line from the receiver's address. If not found or unclear, leave as an empty string.
+- `customer_address_city`: This is the receiver's city. If not found, try to find any city in the document. If unclear, leave as an empty string.
+- `customer_address_country`: This is the receiver's country. If not found, find the country of the extracted city. If unclear, leave as an empty string.
+- `customer_address_postal_code`: This is the receiver's postal code. If not found or unclear, leave as an empty string.
+- `customer_address_apartment`: This is the receiver's apartment or suite number. If not found or unclear, leave as an empty string.
+- `customer_address_region`: This is the receiver's region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.
+- `customer_address_care_of`: This is the receiver's 'care of' (c/o) line. If not found or unclear, leave as an empty string.
+- `billo_id`: To find this, think step-by-step: 1. Find the customer_address. 2. Scan the address for a pattern of three letters, an optional space, three digits, an optional dash, and one alphanumeric character (e.g., 'ABC 123-X' or 'DEF 456Z'). 3. If found, extract it. If not found or unclear, leave as an empty string.
+- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null.
+
+## Example:
+If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $200.00", the output for that line item should be:
+```json
+{{
+ "description": "Consulting Services",
+ "quantity": 2,
+ "unit_price": 100.00,
+ "total_price": 200.00
+}}
+
+Your Task:
+Now, analyze the provided image and output the full JSON object according to the format below.
+{format_instructions}
+"""
+
+invoice_prompt = PromptTemplate(
+ template=invoice_template,
+ input_variables=["language"],
+ partial_variables={"format_instructions": parser.get_format_instructions()},
+)
+
+async def agent_extract_invoice_info(image_base64: str, language: str) -> InvoiceInfo:
+ """Agent 3: Extracts invoice information from an image, aware of the document's language."""
+ print(f"--- [Agent 3] Calling multimodal LLM to extract invoice info (Language: {language})...")
+
+ prompt_text = await invoice_prompt.aformat(language=language)
+
+ msg = HumanMessage(
+ content=[
+ {"type": "text", "text": prompt_text},
+ {
+ "type": "image_url",
+ "image_url": f"data:image/png;base64,{image_base64}",
+ },
+ ]
+ )
+
+ chain = llm | parser
+ invoice_info = await chain.ainvoke([msg])
+ return invoice_info
diff --git a/app/agents/receipt_agent.py b/app/agents/receipt_agent.py
new file mode 100644
index 0000000..e1719d0
--- /dev/null
+++ b/app/agents/receipt_agent.py
@@ -0,0 +1,46 @@
+# app/agents/receipt_agent.py
+from langchain_core.messages import HumanMessage
+from langchain_core.output_parsers import PydanticOutputParser
+from langchain.prompts import PromptTemplate
+from ..core.llm import llm
+from ..schemas import ReceiptInfo
+
+parser = PydanticOutputParser(pydantic_object=ReceiptInfo)
+
+# 更新Prompt模板以包含语言信息
+receipt_template = """
+You are a highly accurate receipt information extraction robot.
+The document's primary language is '{language}'.
+Please extract all key information from the following receipt image.
+If some information is not present in the image, leave it as null.
+Please strictly follow the JSON format below, without adding any extra explanations or comments.
+
+{format_instructions}
+"""
+
+receipt_prompt = PromptTemplate(
+ template=receipt_template,
+ input_variables=["language"],
+ partial_variables={"format_instructions": parser.get_format_instructions()},
+)
+
+
+async def agent_extract_receipt_info(image_base64: str, language: str) -> ReceiptInfo:
+ """Agent 2: Extracts receipt information from an image, aware of the document's language."""
+ print(f"--- [Agent 2] Calling multimodal LLM to extract receipt info (Language: {language})...")
+
+ prompt_text = await receipt_prompt.aformat(language=language)
+
+ msg = HumanMessage(
+ content=[
+ {"type": "text", "text": prompt_text},
+ {
+ "type": "image_url",
+ "image_url": f"data:image/png;base64,{image_base64}",
+ },
+ ]
+ )
+
+ chain = llm | parser
+ receipt_info = await chain.ainvoke([msg])
+ return receipt_info
diff --git a/app/agents/vectorization_agent.py b/app/agents/vectorization_agent.py
new file mode 100644
index 0000000..0c127be
--- /dev/null
+++ b/app/agents/vectorization_agent.py
@@ -0,0 +1,38 @@
+# app/agents/vectorization_agent.py
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from ..core.vector_store import vector_store, embedding_model
+
+# 初始化文本分割器,用于将长文档切成小块
+text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=500, # 每个块的大小(字符数)
+ chunk_overlap=50, # 块之间的重叠部分
+)
+
+def agent_vectorize_and_store(doc_id: str, text: str, category: str):
+ """Agent 4: 向量化并存储 (真实实现)"""
+ print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...")
+
+ # 1. 将文档文本分割成块
+ chunks = text_splitter.split_text(text)
+ print(f"--- [Agent 4] 文档被切分为 {len(chunks)} 个块。")
+
+ if not chunks:
+ print(f"--- [Agent 4] 文档内容为空,跳过向量化。")
+ return
+
+ # 2. 为每个块创建唯一的ID和元数据
+ chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
+ metadatas = [{"doc_id": doc_id, "category": category, "chunk_number": i} for i in range(len(chunks))]
+
+ # 3. 使用嵌入模型为所有块生成向量
+ embeddings = embedding_model.embed_documents(chunks)
+
+ # 4. 将ID、向量、元数据和文本块本身添加到ChromaDB
+ vector_store.add(
+ ids=chunk_ids,
+ embeddings=embeddings,
+ documents=chunks,
+ metadatas=metadatas
+ )
+
+ print(f"--- [Agent 4] 文档 {doc_id} 的向量已存入ChromaDB。")
diff --git a/app/core/__init__.py b/app/core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/core/llm.py b/app/core/llm.py
new file mode 100644
index 0000000..3e07072
--- /dev/null
+++ b/app/core/llm.py
@@ -0,0 +1,45 @@
+# app/core/llm.py
+import os
+from dotenv import load_dotenv
+from langchain_openai import AzureChatOpenAI, ChatOpenAI
+
+# 加载.env文件中的环境变量
+load_dotenv()
+
+# 获取配置的LLM供应商
+LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
+
+llm = None
+
+print(f"--- [Core] Initializing LLM with provider: {LLM_PROVIDER} ---")
+
+if LLM_PROVIDER == "azure":
+ # --- Azure OpenAI 配置 ---
+ required_vars = [
+ "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY",
+ "OPENAI_API_VERSION", "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"
+ ]
+ if not all(os.getenv(var) for var in required_vars):
+ raise ValueError("One or more Azure OpenAI environment variables for chat are not set.")
+
+ llm = AzureChatOpenAI(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("OPENAI_API_VERSION"),
+ azure_deployment=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"),
+ temperature=0,
+ )
+
+elif LLM_PROVIDER == "openai":
+ # --- 标准 OpenAI 配置 ---
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.")
+
+ llm = ChatOpenAI(
+ api_key=os.getenv("OPENAI_API_KEY"),
+ model_name=os.getenv("OPENAI_MODEL_NAME", "gpt-4o"),
+ temperature=0,
+ )
+
+else:
+ raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.")
\ No newline at end of file
diff --git a/app/core/ocr.py b/app/core/ocr.py
new file mode 100644
index 0000000..b91bca5
--- /dev/null
+++ b/app/core/ocr.py
@@ -0,0 +1,27 @@
+# app/core/ocr.py
+import pytesseract
+from PIL import Image
+
+
+# 注意: 您需要先在您的系统中安装Google的Tesseract OCR引擎。
+# 详情请参考之前的安装说明。
+
+def extract_text_from_image(image: Image.Image) -> str:
+ """
+ 使用Tesseract OCR从Pillow Image对象中提取文本。
+
+ 参数:
+ image: Pillow Image对象。
+
+ 返回:
+ 从图片中提取出的字符串文本。
+ """
+ try:
+ print("--- [Core OCR] 正在从图片中提取文本用于分类...")
+ # lang='chi_sim+eng' 表示同时识别简体中文和英文
+ text = pytesseract.image_to_string(image, lang='chi_sim+eng')
+ print("--- [Core OCR] 文本提取成功。")
+ return text
+ except Exception as e:
+ print(f"--- [Core OCR] OCR处理失败: {e}")
+ raise IOError(f"OCR processing failed: {e}")
\ No newline at end of file
diff --git a/app/core/pdf_processor.py b/app/core/pdf_processor.py
new file mode 100644
index 0000000..0aa3759
--- /dev/null
+++ b/app/core/pdf_processor.py
@@ -0,0 +1,43 @@
+# app/core/pdf_processor.py
+from pdf2image import convert_from_bytes
+from PIL import Image
+from io import BytesIO
+from typing import List
+import base64
+
+
+# 注意: 您需要安装Poppler。
+# - macOS: brew install poppler
+# - Ubuntu/Debian: sudo apt-get install poppler-utils
+# - Windows: 下载Poppler并将其bin目录添加到系统PATH。
+
+def convert_pdf_to_images(pdf_bytes: bytes) -> List[Image.Image]:
+ """将PDF文件的字节流转换为Pillow Image对象列表。"""
+ try:
+ print("--- [Core PDF] 正在将PDF转换为图片...")
+
+ # --- 新增代码开始 ---
+ # 在这里直接指定您电脑上Poppler的bin目录路径
+ # 请确保将下面的示例路径替换为您的真实路径
+ poppler_path = r"C:\ProgramData\chocolatey\lib\poppler\tools\Library\bin"
+ # --- 新增代码结束 ---
+
+ # --- 修改的代码开始 ---
+ # 在调用时传入poppler_path参数
+ images = convert_from_bytes(pdf_bytes)
+ # --- 修改的代码结束 ---
+
+ print(f"--- [Core PDF] 转换成功,共 {len(images)} 页。")
+ return images
+ except Exception as e:
+ print(f"--- [Core PDF] PDF转换失败: {e}")
+ # 增加一个更友好的错误提示
+ print("--- [Core PDF] 请确认您已在系统中正确安装Poppler,并在上面的代码中指定了正确的poppler_path。")
+ raise IOError(f"PDF to image conversion failed: {e}")
+
+
+def image_to_base64_str(image: Image.Image) -> str:
+ """将Pillow Image对象转换为Base64编码的字符串。"""
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
\ No newline at end of file
diff --git a/app/core/vector_store.py b/app/core/vector_store.py
new file mode 100644
index 0000000..89fc5f0
--- /dev/null
+++ b/app/core/vector_store.py
@@ -0,0 +1,52 @@
+# app/core/vector_store.py
+import os
+import chromadb
+from dotenv import load_dotenv
+from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
+
+# 加载.env文件中的环境变量
+load_dotenv()
+
+# 获取配置的LLM供应商
+LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
+
+embedding_model = None
+
+print(f"--- [Core] Initializing Embeddings with provider: {LLM_PROVIDER} ---")
+
+if LLM_PROVIDER == "azure":
+ # --- Azure OpenAI 配置 ---
+ required_vars = [
+ "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY",
+ "OPENAI_API_VERSION", "AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"
+ ]
+ if not all(os.getenv(var) for var in required_vars):
+ raise ValueError("One or more Azure OpenAI environment variables for embeddings are not set.")
+
+ embedding_model = AzureOpenAIEmbeddings(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("OPENAI_API_VERSION"),
+ azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"),
+ )
+
+elif LLM_PROVIDER == "openai":
+ # --- 标准 OpenAI 配置 ---
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.")
+
+ embedding_model = OpenAIEmbeddings(
+ api_key=os.getenv("OPENAI_API_KEY"),
+ model=os.getenv("OPENAI_EMBEDDING_MODEL_NAME", "text-embedding-3-small")
+ )
+
+else:
+ raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.")
+
+
+# 初始化ChromaDB客户端 (无变化)
+client = chromadb.PersistentClient(path="./chroma_db")
+vector_store = client.get_or_create_collection(
+ name="documents",
+ metadata={"hnsw:space": "cosine"}
+)
\ No newline at end of file
diff --git a/app/main.py b/app/main.py
new file mode 100644
index 0000000..55a9de1
--- /dev/null
+++ b/app/main.py
@@ -0,0 +1,21 @@
+# app/main.py
+import uvicorn
+from fastapi import FastAPI
+from .routers import documents # 导入我们新的文档路由
+
+app = FastAPI(
+ title="混合模式文档处理AI Agent",
+ description="一个用于自动分类、提取和处理文档的AI应用框架。",
+ version="0.9.0", # 版本升级: 模块化API路由
+)
+
+# 将文档路由包含到主应用中
+app.include_router(documents.router)
+
+@app.get("/", tags=["Root"])
+async def read_root():
+ """一个简单的根端点,用于检查服务是否正在运行。"""
+ return {"message": "Welcome to the Document Processing AI Agent API!"}
+
+if __name__ == "__main__":
+ uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
diff --git a/app/routers/__init__.py b/app/routers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/routers/documents.py b/app/routers/documents.py
new file mode 100644
index 0000000..ad99729
--- /dev/null
+++ b/app/routers/documents.py
@@ -0,0 +1,99 @@
+# app/routers/documents.py
+import uuid
+import mimetypes
+import base64
+from fastapi import APIRouter, UploadFile, File, HTTPException
+from typing import Dict, Any, List
+from fastapi.concurrency import run_in_threadpool
+from PIL import Image
+from io import BytesIO
+
+from .. import agents
+from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
+from ..core.ocr import extract_text_from_image
+
+# 创建一个APIRouter实例
+router = APIRouter(
+ prefix="/documents", # 为这个路由下的所有路径添加前缀
+ tags=["Document Processing"], # 在API文档中为这组端点添加标签
+)
+
+# 模拟一个SQL数据库来存储最终结果
+db_results: Dict[str, Any] = {}
+
+
+async def hybrid_process_pipeline(doc_id: str, image: Image.Image, page_num: int):
+ """混合处理流水线"""
+ ocr_text = await run_in_threadpool(extract_text_from_image, image)
+
+ classification_result = await agents.agent_classify_document_from_text(ocr_text)
+ category = classification_result.category
+ language = classification_result.language
+ print(f"Document page {page_num} classified as: {category}, Language: {language}")
+
+ extraction_result = None
+ if category in ["RECEIPT", "INVOICE"]:
+ image_base64 = await run_in_threadpool(image_to_base64_str, image)
+ if category == "RECEIPT":
+ # 将语言传递给提取Agent
+ extraction_result = await agents.agent_extract_receipt_info(image_base64, language)
+ elif category == "INVOICE":
+ # 将语言传递给提取Agent
+ extraction_result = await agents.agent_extract_invoice_info(image_base64, language)
+ else:
+ print(f"Document classified as '{category}', skipping high-precision extraction.")
+
+ final_result = {
+ "doc_id": f"{doc_id}_page_{page_num}",
+ "category": category,
+ "language": language,
+ "ocr_text_for_classification": ocr_text,
+ "extraction_data": extraction_result.dict() if extraction_result else None,
+ "status": "Processed"
+ }
+ db_results[final_result["doc_id"]] = final_result
+ return final_result
+
+
+@router.post("/process", summary="上传并处理单个文档(混合模式)")
+async def upload_and_process_document(file: UploadFile = File(...)):
+ """处理上传的文档文件 (PDF, PNG, JPG)"""
+ if not file.filename:
+ raise HTTPException(status_code=400, detail="No file provided.")
+
+ doc_id = str(uuid.uuid4())
+ print(f"\n接收到新文件: {file.filename} (分配ID: {doc_id})")
+ contents = await file.read()
+
+ try:
+ file_type = mimetypes.guess_type(file.filename)[0]
+ print(f"检测到文件类型: {file_type}")
+
+ images: List[Image.Image] = []
+ if file_type == 'application/pdf':
+ images = await run_in_threadpool(convert_pdf_to_images, contents)
+ elif file_type in ['image/png', 'image/jpeg']:
+ images.append(Image.open(BytesIO(contents)))
+ else:
+ raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_type}.")
+
+ if not images:
+ raise HTTPException(status_code=400, detail="Could not extract images from document.")
+
+ all_page_results = []
+ for i, img in enumerate(images):
+ page_result = await hybrid_process_pipeline(doc_id, img, i + 1)
+ all_page_results.append(page_result)
+
+ return all_page_results
+
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
+
+
+@router.get("/results/{doc_id}", summary="根据ID获取处理结果")
+async def get_result(doc_id: str):
+ """根据文档处理后返回的 doc_id 获取其详细处理结果。"""
+ if doc_id in db_results:
+ return db_results[doc_id]
+ raise HTTPException(status_code=404, detail="Document not found.")
diff --git a/app/schemas.py b/app/schemas.py
new file mode 100644
index 0000000..d3b879f
--- /dev/null
+++ b/app/schemas.py
@@ -0,0 +1,51 @@
+# app/schemas.py
+from pydantic import BaseModel, Field
+from typing import List, Optional
+
+# --- 分类结果模型 ---
+class ClassificationResult(BaseModel):
+ """Defines the structured output for the classification agent."""
+ category: str = Field(description="The category of the document, must be one of ['LETTER', 'INVOICE', 'RECEIPT', 'CONTRACT', 'OTHER']")
+ language: str = Field(description="The detected primary language of the document as a two-letter code (e.g., 'en', 'zh', 'es').")
+
+# --- 现有模型 (无变化) ---
+class ReceiptItem(BaseModel):
+ name: str = Field(description="购买的项目或服务名称")
+ quantity: float = Field(description="项目数量")
+ price: float = Field(description="项目单价")
+
+class ReceiptInfo(BaseModel):
+ merchant_name: Optional[str] = Field(None, description="商户或店铺的名称")
+ transaction_date: Optional[str] = Field(None, description="交易日期,格式为 YYYY-MM-DD")
+ total_amount: Optional[float] = Field(None, description="收据上的总金额")
+ items: Optional[List[ReceiptItem]] = Field(None, description="购买的所有项目列表")
+
+
+# --- 新增: 发票行项目模型 ---
+class LineItem(BaseModel):
+ """Defines a single line item from an invoice."""
+ description: Optional[str] = Field("", description="The description of the product or service.")
+ quantity: Optional[float] = Field(None, description="The quantity of the item.")
+ unit_price: Optional[float] = Field(None, description="The price for a single unit of the item.")
+ total_price: Optional[float] = Field(None, description="The total price for this line item (quantity * unit_price).")
+
+
+# --- 发票模型 (已更新) ---
+class InvoiceInfo(BaseModel):
+ """Defines the detailed, structured information to be extracted from an invoice."""
+ date: Optional[str] = Field("", description="Extract in YYYY-MM-DD format. If unclear, leave as an empty string.")
+ invoice_number: Optional[str] = Field("", description="If not found or unclear, leave as an empty string.")
+ supplier_number: Optional[str] = Field("", description="It's the organisation number. If not found or unclear, leave as an empty string.")
+ biller_name: Optional[str] = Field("", description="It's the sender's name. If not found or unclear, leave as an empty string.")
+ amount: Optional[float] = Field(None, description="Extract and format to decimal. If not present, leave as null.")
+ customer_name: Optional[str] = Field("", description="It's the receiver's name. Clean any special chars from the name. If not found or unclear, leave as an empty string.")
+ customer_address: Optional[str] = Field("", description="It's the receiver's address. Put it in one line. If not found or unclear, leave as an empty string.")
+ customer_address_line: Optional[str] = Field("", description="It's the receiver's address line, not the whole address. If not found or unclear, leave as an empty string.")
+ customer_address_city: Optional[str] = Field("", description="It's the receiver's address city. If not found, find a city in the document. If unclear, leave as an empty string.")
+ customer_address_country: Optional[str] = Field("", description="It's the receiver's address country. If not found, find the country of the extracted city. If unclear, leave as an empty string.")
+ customer_address_postal_code: Optional[str] = Field("", description="It's the receiver's address postal code. If not found or unclear, leave as an empty string.")
+ customer_address_apartment: Optional[str] = Field("", description="It's the receiver's address apartment or suite. If not found or unclear, leave as an empty string.")
+ customer_address_region: Optional[str] = Field("", description="It's the receiver's address region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.")
+ customer_address_care_of: Optional[str] = Field("", description="It's the receiver's address care of. If not found or unclear, leave as an empty string.")
+ billo_id: Optional[str] = Field("", description="Extract from customer_address if it exists, following the format 'LLL NNN-A'. If not found or unclear, leave as an empty string.")
+ line_items: List[LineItem] = Field([], description="A list of all line items from the invoice.")
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..4c8dfa0
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+fastapi
+uvicorn[standard]
+python-dotenv
+langchain
+langchain-openai
+chromadb
+tiktoken
+pdf2image
+python-multipart
+pytesseract
+Pillow