From f077c6351ddd2c265fe2bd960f4c67289cb4f35d Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Mon, 11 Aug 2025 14:20:56 +0200 Subject: [PATCH] Init --- .idea/AmazingDoc.iml | 2 +- .idea/misc.xml | 2 +- app/agents.py | 51 ----------------------- app/agents/__init__.py | 5 +-- app/agents/classification_agent.py | 50 ++++++++++++---------- app/agents/invoice_agent.py | 31 +++++++------- app/agents/receipt_agent.py | 29 +++++++------ app/agents/vectorization_agent.py | 24 +++++------ app/core/callbacks.py | 18 ++++++++ app/core/llm.py | 24 +++-------- app/core/ocr.py | 27 ------------ app/core/pdf_processor.py | 25 ++--------- app/core/vector_store.py | 5 --- app/main.py | 10 ++--- app/routers/documents.py | 67 +++++++++++++++++------------- app/schemas.py | 24 ++++++----- requirements.txt | 19 +++++---- 17 files changed, 165 insertions(+), 248 deletions(-) delete mode 100644 app/agents.py create mode 100644 app/core/callbacks.py delete mode 100644 app/core/ocr.py diff --git a/.idea/AmazingDoc.iml b/.idea/AmazingDoc.iml index 5904ce9..67d6524 100644 --- a/.idea/AmazingDoc.iml +++ b/.idea/AmazingDoc.iml @@ -5,7 +5,7 @@ - + diff --git a/.idea/misc.xml b/.idea/misc.xml index 1d3ce46..59ad763 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/app/agents.py b/app/agents.py deleted file mode 100644 index 0e0e95b..0000000 --- a/app/agents.py +++ /dev/null @@ -1,51 +0,0 @@ -# app/agents.py -import asyncio -import random -from .schemas import ReceiptInfo, InvoiceInfo, ReceiptItem - -# --- Agent核心功能 (占位符/模拟实现) --- -# 在实际应用中,这些函数将被替换为调用LangChain和LLM的真实逻辑。 - -async def agent_classify_document(text: str) -> str: - """Agent 1: 文件分类 (模拟)""" - print("--- [Agent 1] 正在进行文档分类...") - await asyncio.sleep(0.5) # 模拟网络延迟 - doc_types = ["信件", "收据", "发票", "合约"] - if "发票" in text: return "发票" - if "收据" in text or "小票" in text: return "收据" - if "合同" in text or "协议" in text: return "合约" - return random.choice(doc_types) - -async def agent_extract_receipt_info(text: str) -> ReceiptInfo: - """Agent 2: 收据信息提取 (模拟)""" - print("--- [Agent 2] 正在提取收据信息...") - await asyncio.sleep(1) # 模拟LLM处理时间 - return ReceiptInfo( - merchant_name="模拟超市", - transaction_date="2025-08-10", - total_amount=198.50, - items=[ReceiptItem(name="牛奶", quantity=2, price=11.5)] - ) - -async def agent_extract_invoice_info(text: str) -> InvoiceInfo: - """Agent 3: 发票信息提取 (模拟)""" - print("--- [Agent 3] 正在提取发票信息...") - await asyncio.sleep(1) # 模拟LLM处理时间 - return InvoiceInfo( - invoice_number="INV123456789", - issue_date="2025-08-09", - seller_name="模拟科技有限公司", - total_amount_in_figures=12000.00 - ) - -async def agent_vectorize_and_store(doc_id: str, text: str, category: str, vector_db: dict): - """Agent 4: 向量化并存储 (模拟)""" - print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...") - await asyncio.sleep(0.5) - chunks = [text[i:i+200] for i in range(0, len(text), 200)] - vector_db[doc_id] = { - "metadata": {"category": category, "chunk_count": len(chunks)}, - "content_chunks": chunks, - "vectors": [random.random() for _ in range(len(chunks) * 128)] - } - print(f"--- [Agent 4] 文档 {doc_id} 已存入向量数据库。") \ No newline at end of file diff --git a/app/agents/__init__.py b/app/agents/__init__.py index ee85a62..49b8c66 100644 --- a/app/agents/__init__.py +++ b/app/agents/__init__.py @@ -1,7 +1,4 @@ # app/agents/__init__.py - -# This file makes it easy to import all agents from the 'agents' package. -# We are importing the function with its correct name now. -from .classification_agent import agent_classify_document_from_text +from .classification_agent import agent_classify_document_from_image from .receipt_agent import agent_extract_receipt_info from .invoice_agent import agent_extract_invoice_info \ No newline at end of file diff --git a/app/agents/classification_agent.py b/app/agents/classification_agent.py index d04b862..42d3ea2 100644 --- a/app/agents/classification_agent.py +++ b/app/agents/classification_agent.py @@ -1,42 +1,48 @@ # app/agents/classification_agent.py -from langchain.prompts import PromptTemplate +from langchain_core.messages import HumanMessage from langchain_core.output_parsers import PydanticOutputParser +from langchain.prompts import PromptTemplate from ..core.llm import llm -from ..schemas import ClassificationResult # 导入新的Schema +from ..schemas import ClassificationResult +from typing import List -# 1. 设置PydanticOutputParser parser = PydanticOutputParser(pydantic_object=ClassificationResult) -# 2. 更新Prompt模板以要求语言,并包含格式指令 classification_template = """ -You are a professional document analysis assistant. Please perform two tasks on the following text: -1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"]. -2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish). +You are a professional document analysis assistant. The following images represent pages from a single document. Please perform two tasks based on all pages provided: +1. Determine the overall category of the document. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"]. +2. Detect the primary language of the document. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese). + +Please provide a single response for the entire document in the requested JSON format. {format_instructions} - -[Document Text Start] -{document_text} -[Document Text End] """ classification_prompt = PromptTemplate( template=classification_template, - input_variables=["document_text"], + input_variables=[], partial_variables={"format_instructions": parser.get_format_instructions()}, ) -# 3. 创建新的LangChain链 -classification_chain = classification_prompt | llm | parser +async def agent_classify_document_from_image(images_base64: List[str]) -> ClassificationResult: + """Agent 1: Classifies an entire document (multiple pages) and detects its language from a list of images.""" + print(f"--- [Agent 1] Calling multimodal LLM for classification of a {len(images_base64)}-page document...") -async def agent_classify_document_from_text(text: str) -> ClassificationResult: - """Agent 1: Classify document and detect language from OCR-extracted text.""" - print("--- [Agent 1] Calling LLM for classification and language detection...") - if not text.strip(): - print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.") - return ClassificationResult(category="OTHER", language="unknown") + prompt_text = await classification_prompt.aformat() - # 调用链并返回Pydantic对象 - result = await classification_chain.ainvoke({"document_text": text}) + # Create a list of content parts, starting with the text prompt + content_parts = [{"type": "text", "text": prompt_text}] + + # Add each image to the content list + for image_base64 in images_base64: + content_parts.append({ + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }) + + msg = HumanMessage(content=content_parts) + + chain = llm | parser + result = await chain.ainvoke([msg]) return result diff --git a/app/agents/invoice_agent.py b/app/agents/invoice_agent.py index b6c2570..ecf3569 100644 --- a/app/agents/invoice_agent.py +++ b/app/agents/invoice_agent.py @@ -4,10 +4,10 @@ from langchain_core.output_parsers import PydanticOutputParser from langchain.prompts import PromptTemplate from ..core.llm import llm from ..schemas import InvoiceInfo +from typing import List parser = PydanticOutputParser(pydantic_object=InvoiceInfo) -# The prompt now includes the detailed rules for each field using snake_case. invoice_template = """ You are an expert data entry clerk AI. Your primary goal is to extract information from an invoice image with the highest possible accuracy. The document's primary language is '{language}'. @@ -30,6 +30,9 @@ Carefully analyze the invoice image and extract the following fields according t - `customer_address_region`: This is the receiver's region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string. - `customer_address_care_of`: This is the receiver's 'care of' (c/o) line. If not found or unclear, leave as an empty string. - `billo_id`: To find this, think step-by-step: 1. Find the customer_address. 2. Scan the address for a pattern of three letters, an optional space, three digits, an optional dash, and one alphanumeric character (e.g., 'ABC 123-X' or 'DEF 456Z'). 3. If found, extract it. If not found or unclear, leave as an empty string. +- `bank_giro`: If found, extract the bank giro number. It often follows patterns like 'ddd-dddd', 'dddd-dddd', or 'dddddddd #41#'. If not found or unclear, leave as an empty string. +- `plus_giro`: If found, extract the plus giro number. It often follows patterns like 'ddddddd-d #16#', 'ddddddd-d', or 'ddd dd dd-d'. If not found or unclear, leave as an empty string. +- `customer_ssn`: If found, extract the customer social security number (personnummer). It follows the pattern 'YYYYMMDD-XXXX' or 'YYMMDD-XXXX'. If not found or unclear, leave as an empty string. - `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null. ## Example: @@ -41,33 +44,31 @@ If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $ "unit_price": 100.00, "total_price": 200.00 }} - +``` Your Task: Now, analyze the provided image and output the full JSON object according to the format below. {format_instructions} """ - invoice_prompt = PromptTemplate( template=invoice_template, input_variables=["language"], - partial_variables={"format_instructions": parser.get_format_instructions()}, + partial_variables={"format_instructions": parser.get_format_instructions()} ) -async def agent_extract_invoice_info(image_base64: str, language: str) -> InvoiceInfo: - """Agent 3: Extracts invoice information from an image, aware of the document's language.""" +async def agent_extract_invoice_info(images_base64: List[str], language: str) -> InvoiceInfo: + """Agent 3: Extracts invoice information from a list of images, aware of the document's language.""" print(f"--- [Agent 3] Calling multimodal LLM to extract invoice info (Language: {language})...") prompt_text = await invoice_prompt.aformat(language=language) - msg = HumanMessage( - content=[ - {"type": "text", "text": prompt_text}, - { - "type": "image_url", - "image_url": f"data:image/png;base64,{image_base64}", - }, - ] - ) + content_parts = [{"type": "text", "text": prompt_text}] + for image_base64 in images_base64: + content_parts.append({ + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }) + + msg = HumanMessage(content=content_parts) chain = llm | parser invoice_info = await chain.ainvoke([msg]) diff --git a/app/agents/receipt_agent.py b/app/agents/receipt_agent.py index e1719d0..4608641 100644 --- a/app/agents/receipt_agent.py +++ b/app/agents/receipt_agent.py @@ -4,15 +4,15 @@ from langchain_core.output_parsers import PydanticOutputParser from langchain.prompts import PromptTemplate from ..core.llm import llm from ..schemas import ReceiptInfo +from typing import List parser = PydanticOutputParser(pydantic_object=ReceiptInfo) -# 更新Prompt模板以包含语言信息 receipt_template = """ You are a highly accurate receipt information extraction robot. The document's primary language is '{language}'. -Please extract all key information from the following receipt image. -If some information is not present in the image, leave it as null. +Please extract all key information from the following receipt images, which belong to a single document. +If some information is not present in the images, leave it as null. Please strictly follow the JSON format below, without adding any extra explanations or comments. {format_instructions} @@ -25,22 +25,21 @@ receipt_prompt = PromptTemplate( ) -async def agent_extract_receipt_info(image_base64: str, language: str) -> ReceiptInfo: - """Agent 2: Extracts receipt information from an image, aware of the document's language.""" +async def agent_extract_receipt_info(images_base64: List[str], language: str) -> ReceiptInfo: + """Agent 2: Extracts receipt information from a list of images, aware of the document's language.""" print(f"--- [Agent 2] Calling multimodal LLM to extract receipt info (Language: {language})...") prompt_text = await receipt_prompt.aformat(language=language) - msg = HumanMessage( - content=[ - {"type": "text", "text": prompt_text}, - { - "type": "image_url", - "image_url": f"data:image/png;base64,{image_base64}", - }, - ] - ) + content_parts = [{"type": "text", "text": prompt_text}] + for image_base64 in images_base64: + content_parts.append({ + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }) + + msg = HumanMessage(content=content_parts) chain = llm | parser receipt_info = await chain.ainvoke([msg]) - return receipt_info + return receipt_info \ No newline at end of file diff --git a/app/agents/vectorization_agent.py b/app/agents/vectorization_agent.py index 0c127be..983e568 100644 --- a/app/agents/vectorization_agent.py +++ b/app/agents/vectorization_agent.py @@ -2,32 +2,32 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from ..core.vector_store import vector_store, embedding_model -# 初始化文本分割器,用于将长文档切成小块 +# Initialize the text splitter to divide long documents into smaller chunks text_splitter = RecursiveCharacterTextSplitter( - chunk_size=500, # 每个块的大小(字符数) - chunk_overlap=50, # 块之间的重叠部分 + chunk_size=500, + chunk_overlap=50, ) def agent_vectorize_and_store(doc_id: str, text: str, category: str): - """Agent 4: 向量化并存储 (真实实现)""" - print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...") + """Agent 4: Vectorization and Storage (Real Implementation)""" + print(f"--- [Agent 4] Vectorizing document (ID: {doc_id})...") - # 1. 将文档文本分割成块 + # 1. Split the document text into chunks chunks = text_splitter.split_text(text) - print(f"--- [Agent 4] 文档被切分为 {len(chunks)} 个块。") + print(f"--- [Agent 4] Document split into {len(chunks)} chunks.") if not chunks: - print(f"--- [Agent 4] 文档内容为空,跳过向量化。") + print(f"--- [Agent 4] Document is empty, skipping vectorization.") return - # 2. 为每个块创建唯一的ID和元数据 + # 2. Create a unique ID and metadata for each chunk chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))] metadatas = [{"doc_id": doc_id, "category": category, "chunk_number": i} for i in range(len(chunks))] - # 3. 使用嵌入模型为所有块生成向量 + # 3. Use an embedding model to generate vectors for all chunks embeddings = embedding_model.embed_documents(chunks) - # 4. 将ID、向量、元数据和文本块本身添加到ChromaDB + # 4. Add the IDs, vectors, metadata, and text chunks to ChromaDB vector_store.add( ids=chunk_ids, embeddings=embeddings, @@ -35,4 +35,4 @@ def agent_vectorize_and_store(doc_id: str, text: str, category: str): metadatas=metadatas ) - print(f"--- [Agent 4] 文档 {doc_id} 的向量已存入ChromaDB。") + print(f"--- [Agent 4] document {doc_id} stored in ChromaDB。") diff --git a/app/core/callbacks.py b/app/core/callbacks.py new file mode 100644 index 0000000..e4b878e --- /dev/null +++ b/app/core/callbacks.py @@ -0,0 +1,18 @@ +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult +from typing import Any, Dict + +class TokenUsageCallbackHandler(BaseCallbackHandler): + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + token_usage = response.llm_output.get('token_usage', {}) + + if token_usage: + prompt_tokens = token_usage.get('prompt_tokens', 0) + completion_tokens = token_usage.get('completion_tokens', 0) + total_tokens = token_usage.get('total_tokens', 0) + + print("--- [Token Usage] ---") + print(f" Prompt Tokens: {prompt_tokens}") + print(f" Completion Tokens: {completion_tokens}") + print(f" Total Tokens: {total_tokens}") + print("---------------------") diff --git a/app/core/llm.py b/app/core/llm.py index 3e07072..d4f9412 100644 --- a/app/core/llm.py +++ b/app/core/llm.py @@ -1,45 +1,33 @@ -# app/core/llm.py import os from dotenv import load_dotenv from langchain_openai import AzureChatOpenAI, ChatOpenAI +from .callbacks import TokenUsageCallbackHandler -# 加载.env文件中的环境变量 load_dotenv() - -# 获取配置的LLM供应商 LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower() - llm = None print(f"--- [Core] Initializing LLM with provider: {LLM_PROVIDER} ---") -if LLM_PROVIDER == "azure": - # --- Azure OpenAI 配置 --- - required_vars = [ - "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", - "OPENAI_API_VERSION", "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME" - ] - if not all(os.getenv(var) for var in required_vars): - raise ValueError("One or more Azure OpenAI environment variables for chat are not set.") +token_callback = TokenUsageCallbackHandler() +if LLM_PROVIDER == "azure": llm = AzureChatOpenAI( azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("OPENAI_API_VERSION"), azure_deployment=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"), temperature=0, + callbacks=[token_callback] ) elif LLM_PROVIDER == "openai": - # --- 标准 OpenAI 配置 --- - if not os.getenv("OPENAI_API_KEY"): - raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.") - llm = ChatOpenAI( api_key=os.getenv("OPENAI_API_KEY"), model_name=os.getenv("OPENAI_MODEL_NAME", "gpt-4o"), temperature=0, + callbacks=[token_callback] ) else: - raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.") \ No newline at end of file + raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.") diff --git a/app/core/ocr.py b/app/core/ocr.py deleted file mode 100644 index b91bca5..0000000 --- a/app/core/ocr.py +++ /dev/null @@ -1,27 +0,0 @@ -# app/core/ocr.py -import pytesseract -from PIL import Image - - -# 注意: 您需要先在您的系统中安装Google的Tesseract OCR引擎。 -# 详情请参考之前的安装说明。 - -def extract_text_from_image(image: Image.Image) -> str: - """ - 使用Tesseract OCR从Pillow Image对象中提取文本。 - - 参数: - image: Pillow Image对象。 - - 返回: - 从图片中提取出的字符串文本。 - """ - try: - print("--- [Core OCR] 正在从图片中提取文本用于分类...") - # lang='chi_sim+eng' 表示同时识别简体中文和英文 - text = pytesseract.image_to_string(image, lang='chi_sim+eng') - print("--- [Core OCR] 文本提取成功。") - return text - except Exception as e: - print(f"--- [Core OCR] OCR处理失败: {e}") - raise IOError(f"OCR processing failed: {e}") \ No newline at end of file diff --git a/app/core/pdf_processor.py b/app/core/pdf_processor.py index 0aa3759..1118262 100644 --- a/app/core/pdf_processor.py +++ b/app/core/pdf_processor.py @@ -5,39 +5,20 @@ from io import BytesIO from typing import List import base64 - -# 注意: 您需要安装Poppler。 -# - macOS: brew install poppler -# - Ubuntu/Debian: sudo apt-get install poppler-utils -# - Windows: 下载Poppler并将其bin目录添加到系统PATH。 - def convert_pdf_to_images(pdf_bytes: bytes) -> List[Image.Image]: - """将PDF文件的字节流转换为Pillow Image对象列表。""" try: - print("--- [Core PDF] 正在将PDF转换为图片...") + print("--- [Core PDF] Converting PDF to images...") - # --- 新增代码开始 --- - # 在这里直接指定您电脑上Poppler的bin目录路径 - # 请确保将下面的示例路径替换为您的真实路径 - poppler_path = r"C:\ProgramData\chocolatey\lib\poppler\tools\Library\bin" - # --- 新增代码结束 --- - - # --- 修改的代码开始 --- - # 在调用时传入poppler_path参数 images = convert_from_bytes(pdf_bytes) - # --- 修改的代码结束 --- - print(f"--- [Core PDF] 转换成功,共 {len(images)} 页。") + print(f"--- [Core PDF] converted PDF to images,total {len(images)} pages。") return images except Exception as e: - print(f"--- [Core PDF] PDF转换失败: {e}") - # 增加一个更友好的错误提示 - print("--- [Core PDF] 请确认您已在系统中正确安装Poppler,并在上面的代码中指定了正确的poppler_path。") + print(f"--- [Core PDF] PDF conversion failed: {e}") raise IOError(f"PDF to image conversion failed: {e}") def image_to_base64_str(image: Image.Image) -> str: - """将Pillow Image对象转换为Base64编码的字符串。""" buffered = BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode('utf-8') \ No newline at end of file diff --git a/app/core/vector_store.py b/app/core/vector_store.py index 89fc5f0..f0ecd3a 100644 --- a/app/core/vector_store.py +++ b/app/core/vector_store.py @@ -4,10 +4,8 @@ import chromadb from dotenv import load_dotenv from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings -# 加载.env文件中的环境变量 load_dotenv() -# 获取配置的LLM供应商 LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower() embedding_model = None @@ -15,7 +13,6 @@ embedding_model = None print(f"--- [Core] Initializing Embeddings with provider: {LLM_PROVIDER} ---") if LLM_PROVIDER == "azure": - # --- Azure OpenAI 配置 --- required_vars = [ "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", "OPENAI_API_VERSION", "AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME" @@ -31,7 +28,6 @@ if LLM_PROVIDER == "azure": ) elif LLM_PROVIDER == "openai": - # --- 标准 OpenAI 配置 --- if not os.getenv("OPENAI_API_KEY"): raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.") @@ -44,7 +40,6 @@ else: raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.") -# 初始化ChromaDB客户端 (无变化) client = chromadb.PersistentClient(path="./chroma_db") vector_store = client.get_or_create_collection( name="documents", diff --git a/app/main.py b/app/main.py index 55a9de1..8b66e75 100644 --- a/app/main.py +++ b/app/main.py @@ -1,20 +1,18 @@ # app/main.py import uvicorn from fastapi import FastAPI -from .routers import documents # 导入我们新的文档路由 +from .routers import documents app = FastAPI( - title="混合模式文档处理AI Agent", - description="一个用于自动分类、提取和处理文档的AI应用框架。", - version="0.9.0", # 版本升级: 模块化API路由 + title="Hybrid Mode Document Processing AI Agent", + description="An AI application framework for automatic document classification, extraction, and processing.", + version="0.9.0", ) -# 将文档路由包含到主应用中 app.include_router(documents.router) @app.get("/", tags=["Root"]) async def read_root(): - """一个简单的根端点,用于检查服务是否正在运行。""" return {"message": "Welcome to the Document Processing AI Agent API!"} if __name__ == "__main__": diff --git a/app/routers/documents.py b/app/routers/documents.py index ad99729..8d3e9ec 100644 --- a/app/routers/documents.py +++ b/app/routers/documents.py @@ -10,64 +10,54 @@ from io import BytesIO from .. import agents from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str -from ..core.ocr import extract_text_from_image -# 创建一个APIRouter实例 +# Create an APIRouter instance router = APIRouter( - prefix="/documents", # 为这个路由下的所有路径添加前缀 - tags=["Document Processing"], # 在API文档中为这组端点添加标签 + prefix="/documents", + tags=["Document Processing"], ) -# 模拟一个SQL数据库来存储最终结果 +# Simulate an SQL database to store the final results db_results: Dict[str, Any] = {} - -async def hybrid_process_pipeline(doc_id: str, image: Image.Image, page_num: int): - """混合处理流水线""" - ocr_text = await run_in_threadpool(extract_text_from_image, image) - - classification_result = await agents.agent_classify_document_from_text(ocr_text) +async def multimodal_process_pipeline(doc_id: str, image: Image.Image, page_num: int): + image_base64 = await run_in_threadpool(image_to_base64_str, image) + classification_result = await agents.agent_classify_document_from_image(image_base64) category = classification_result.category language = classification_result.language print(f"Document page {page_num} classified as: {category}, Language: {language}") extraction_result = None if category in ["RECEIPT", "INVOICE"]: - image_base64 = await run_in_threadpool(image_to_base64_str, image) if category == "RECEIPT": - # 将语言传递给提取Agent extraction_result = await agents.agent_extract_receipt_info(image_base64, language) elif category == "INVOICE": - # 将语言传递给提取Agent extraction_result = await agents.agent_extract_invoice_info(image_base64, language) else: - print(f"Document classified as '{category}', skipping high-precision extraction.") + print(f"Document classified as '{category}', skipping extraction.") final_result = { "doc_id": f"{doc_id}_page_{page_num}", "category": category, "language": language, - "ocr_text_for_classification": ocr_text, "extraction_data": extraction_result.dict() if extraction_result else None, "status": "Processed" } db_results[final_result["doc_id"]] = final_result return final_result - -@router.post("/process", summary="上传并处理单个文档(混合模式)") +@router.post("/process", summary="upload and process a document") async def upload_and_process_document(file: UploadFile = File(...)): - """处理上传的文档文件 (PDF, PNG, JPG)""" if not file.filename: raise HTTPException(status_code=400, detail="No file provided.") doc_id = str(uuid.uuid4()) - print(f"\n接收到新文件: {file.filename} (分配ID: {doc_id})") + print(f"\nReceiving document: {file.filename} (allocated ID: {doc_id})") contents = await file.read() try: file_type = mimetypes.guess_type(file.filename)[0] - print(f"检测到文件类型: {file_type}") + print(f"File type: {file_type}") images: List[Image.Image] = [] if file_type == 'application/pdf': @@ -80,20 +70,39 @@ async def upload_and_process_document(file: UploadFile = File(...)): if not images: raise HTTPException(status_code=400, detail="Could not extract images from document.") - all_page_results = [] - for i, img in enumerate(images): - page_result = await hybrid_process_pipeline(doc_id, img, i + 1) - all_page_results.append(page_result) + images_base64 = [await run_in_threadpool(image_to_base64_str, img) for img in images] - return all_page_results + classification_result = await agents.agent_classify_document_from_image(images_base64) + category = classification_result.category + language = classification_result.language + print(f"The document is classified as: {category}, Language: {language}") + + extraction_result = None + if category in ["RECEIPT", "INVOICE"]: + if category == "RECEIPT": + extraction_result = await agents.agent_extract_receipt_info(images_base64, language) + elif category == "INVOICE": + extraction_result = await agents.agent_extract_invoice_info(images_base64, language) + else: + print(f"The document is classified as '{category}',skipping extraction。") + + # 3. Return a unified result + final_result = { + "doc_id": doc_id, + "page_count": len(images), + "category": category, + "language": language, + "extraction_data": extraction_result.dict() if extraction_result else None, + "status": "Processed" + } + db_results[doc_id] = final_result + return final_result except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") - -@router.get("/results/{doc_id}", summary="根据ID获取处理结果") +@router.get("/results/{doc_id}", summary="Get result by doc_id") async def get_result(doc_id: str): - """根据文档处理后返回的 doc_id 获取其详细处理结果。""" if doc_id in db_results: return db_results[doc_id] raise HTTPException(status_code=404, detail="Document not found.") diff --git a/app/schemas.py b/app/schemas.py index d3b879f..530011c 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -2,26 +2,26 @@ from pydantic import BaseModel, Field from typing import List, Optional -# --- 分类结果模型 --- +# --- Classification Result Model --- class ClassificationResult(BaseModel): """Defines the structured output for the classification agent.""" category: str = Field(description="The category of the document, must be one of ['LETTER', 'INVOICE', 'RECEIPT', 'CONTRACT', 'OTHER']") language: str = Field(description="The detected primary language of the document as a two-letter code (e.g., 'en', 'zh', 'es').") -# --- 现有模型 (无变化) --- +# --- Existing Model (Unchanged) --- class ReceiptItem(BaseModel): - name: str = Field(description="购买的项目或服务名称") - quantity: float = Field(description="项目数量") - price: float = Field(description="项目单价") + name: str = Field(description="The name of the purchased item or service") + quantity: float = Field(description="The quantity of the item") + price: float = Field(description="The unit price of the item") class ReceiptInfo(BaseModel): - merchant_name: Optional[str] = Field(None, description="商户或店铺的名称") - transaction_date: Optional[str] = Field(None, description="交易日期,格式为 YYYY-MM-DD") - total_amount: Optional[float] = Field(None, description="收据上的总金额") - items: Optional[List[ReceiptItem]] = Field(None, description="购买的所有项目列表") + merchant_name: Optional[str] = Field(None, description="The name of the merchant or store") + transaction_date: Optional[str] = Field(None, description="The transaction date in the YYYY-MM-DD format") + total_amount: Optional[float] = Field(None, description="The total amount on the receipt") + items: Optional[List[ReceiptItem]] = Field(None, description="The list of all purchased items") -# --- 新增: 发票行项目模型 --- +# --- Added: Invoice Line Item Model --- class LineItem(BaseModel): """Defines a single line item from an invoice.""" description: Optional[str] = Field("", description="The description of the product or service.") @@ -30,7 +30,6 @@ class LineItem(BaseModel): total_price: Optional[float] = Field(None, description="The total price for this line item (quantity * unit_price).") -# --- 发票模型 (已更新) --- class InvoiceInfo(BaseModel): """Defines the detailed, structured information to be extracted from an invoice.""" date: Optional[str] = Field("", description="Extract in YYYY-MM-DD format. If unclear, leave as an empty string.") @@ -48,4 +47,7 @@ class InvoiceInfo(BaseModel): customer_address_region: Optional[str] = Field("", description="It's the receiver's address region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.") customer_address_care_of: Optional[str] = Field("", description="It's the receiver's address care of. If not found or unclear, leave as an empty string.") billo_id: Optional[str] = Field("", description="Extract from customer_address if it exists, following the format 'LLL NNN-A'. If not found or unclear, leave as an empty string.") + bank_giro: Optional[str] = Field("", description="BankGiro number, e.g., '123-4567'. If not found, leave as an empty string.") + plus_giro: Optional[str] = Field("", description="PlusGiro number, e.g., '123456-7'. If not found, leave as an empty string.") + customer_ssn: Optional[str] = Field("", description="Customer's social security number, e.g., 'YYYYMMDD-XXXX'. If not found, leave as an empty string.") line_items: List[LineItem] = Field([], description="A list of all line items from the invoice.") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4c8dfa0..a0d6349 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,12 @@ -fastapi -uvicorn[standard] -python-dotenv -langchain -langchain-openai -chromadb +fastapi~=0.115.9 +uvicorn[standard]~=0.35.0 +python-dotenv~=1.1.0 +langchain~=0.3.25 +langchain-openai~=0.3.16 +chromadb~=1.0.16 tiktoken -pdf2image +pdf2image~=1.17.0 python-multipart -pytesseract -Pillow +Pillow~=11.3.0 +langchain-core~=0.3.58 +pydantic~=2.11.7 \ No newline at end of file