From 0a804007203ead0c70b54be3527ef57331795890 Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Mon, 11 Aug 2025 00:07:41 +0200 Subject: [PATCH] Init project --- .idea/.gitignore | 8 ++ .idea/AmazingDoc.iml | 18 ++++ .../inspectionProfiles/profiles_settings.xml | 6 ++ .idea/misc.xml | 7 ++ .idea/modules.xml | 8 ++ .idea/vcs.xml | 6 ++ app/__init__.py | 0 app/agents.py | 51 ++++++++++ app/agents/__init__.py | 7 ++ app/agents/classification_agent.py | 42 ++++++++ app/agents/invoice_agent.py | 74 ++++++++++++++ app/agents/receipt_agent.py | 46 +++++++++ app/agents/vectorization_agent.py | 38 +++++++ app/core/__init__.py | 0 app/core/llm.py | 45 +++++++++ app/core/ocr.py | 27 +++++ app/core/pdf_processor.py | 43 ++++++++ app/core/vector_store.py | 52 ++++++++++ app/main.py | 21 ++++ app/routers/__init__.py | 0 app/routers/documents.py | 99 +++++++++++++++++++ app/schemas.py | 51 ++++++++++ requirements.txt | 11 +++ 23 files changed, 660 insertions(+) create mode 100644 .idea/.gitignore create mode 100644 .idea/AmazingDoc.iml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 app/__init__.py create mode 100644 app/agents.py create mode 100644 app/agents/__init__.py create mode 100644 app/agents/classification_agent.py create mode 100644 app/agents/invoice_agent.py create mode 100644 app/agents/receipt_agent.py create mode 100644 app/agents/vectorization_agent.py create mode 100644 app/core/__init__.py create mode 100644 app/core/llm.py create mode 100644 app/core/ocr.py create mode 100644 app/core/pdf_processor.py create mode 100644 app/core/vector_store.py create mode 100644 app/main.py create mode 100644 app/routers/__init__.py create mode 100644 app/routers/documents.py create mode 100644 app/schemas.py create mode 100644 requirements.txt diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/AmazingDoc.iml b/.idea/AmazingDoc.iml new file mode 100644 index 0000000..5904ce9 --- /dev/null +++ b/.idea/AmazingDoc.iml @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..1d3ce46 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..f07df23 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/agents.py b/app/agents.py new file mode 100644 index 0000000..0e0e95b --- /dev/null +++ b/app/agents.py @@ -0,0 +1,51 @@ +# app/agents.py +import asyncio +import random +from .schemas import ReceiptInfo, InvoiceInfo, ReceiptItem + +# --- Agent核心功能 (占位符/模拟实现) --- +# 在实际应用中,这些函数将被替换为调用LangChain和LLM的真实逻辑。 + +async def agent_classify_document(text: str) -> str: + """Agent 1: 文件分类 (模拟)""" + print("--- [Agent 1] 正在进行文档分类...") + await asyncio.sleep(0.5) # 模拟网络延迟 + doc_types = ["信件", "收据", "发票", "合约"] + if "发票" in text: return "发票" + if "收据" in text or "小票" in text: return "收据" + if "合同" in text or "协议" in text: return "合约" + return random.choice(doc_types) + +async def agent_extract_receipt_info(text: str) -> ReceiptInfo: + """Agent 2: 收据信息提取 (模拟)""" + print("--- [Agent 2] 正在提取收据信息...") + await asyncio.sleep(1) # 模拟LLM处理时间 + return ReceiptInfo( + merchant_name="模拟超市", + transaction_date="2025-08-10", + total_amount=198.50, + items=[ReceiptItem(name="牛奶", quantity=2, price=11.5)] + ) + +async def agent_extract_invoice_info(text: str) -> InvoiceInfo: + """Agent 3: 发票信息提取 (模拟)""" + print("--- [Agent 3] 正在提取发票信息...") + await asyncio.sleep(1) # 模拟LLM处理时间 + return InvoiceInfo( + invoice_number="INV123456789", + issue_date="2025-08-09", + seller_name="模拟科技有限公司", + total_amount_in_figures=12000.00 + ) + +async def agent_vectorize_and_store(doc_id: str, text: str, category: str, vector_db: dict): + """Agent 4: 向量化并存储 (模拟)""" + print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...") + await asyncio.sleep(0.5) + chunks = [text[i:i+200] for i in range(0, len(text), 200)] + vector_db[doc_id] = { + "metadata": {"category": category, "chunk_count": len(chunks)}, + "content_chunks": chunks, + "vectors": [random.random() for _ in range(len(chunks) * 128)] + } + print(f"--- [Agent 4] 文档 {doc_id} 已存入向量数据库。") \ No newline at end of file diff --git a/app/agents/__init__.py b/app/agents/__init__.py new file mode 100644 index 0000000..ee85a62 --- /dev/null +++ b/app/agents/__init__.py @@ -0,0 +1,7 @@ +# app/agents/__init__.py + +# This file makes it easy to import all agents from the 'agents' package. +# We are importing the function with its correct name now. +from .classification_agent import agent_classify_document_from_text +from .receipt_agent import agent_extract_receipt_info +from .invoice_agent import agent_extract_invoice_info \ No newline at end of file diff --git a/app/agents/classification_agent.py b/app/agents/classification_agent.py new file mode 100644 index 0000000..d04b862 --- /dev/null +++ b/app/agents/classification_agent.py @@ -0,0 +1,42 @@ +# app/agents/classification_agent.py +from langchain.prompts import PromptTemplate +from langchain_core.output_parsers import PydanticOutputParser +from ..core.llm import llm +from ..schemas import ClassificationResult # 导入新的Schema + +# 1. 设置PydanticOutputParser +parser = PydanticOutputParser(pydantic_object=ClassificationResult) + +# 2. 更新Prompt模板以要求语言,并包含格式指令 +classification_template = """ +You are a professional document analysis assistant. Please perform two tasks on the following text: +1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"]. +2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish). + +{format_instructions} + +[Document Text Start] +{document_text} +[Document Text End] +""" + +classification_prompt = PromptTemplate( + template=classification_template, + input_variables=["document_text"], + partial_variables={"format_instructions": parser.get_format_instructions()}, +) + +# 3. 创建新的LangChain链 +classification_chain = classification_prompt | llm | parser + + +async def agent_classify_document_from_text(text: str) -> ClassificationResult: + """Agent 1: Classify document and detect language from OCR-extracted text.""" + print("--- [Agent 1] Calling LLM for classification and language detection...") + if not text.strip(): + print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.") + return ClassificationResult(category="OTHER", language="unknown") + + # 调用链并返回Pydantic对象 + result = await classification_chain.ainvoke({"document_text": text}) + return result diff --git a/app/agents/invoice_agent.py b/app/agents/invoice_agent.py new file mode 100644 index 0000000..b6c2570 --- /dev/null +++ b/app/agents/invoice_agent.py @@ -0,0 +1,74 @@ +# app/agents/invoice_agent.py +from langchain_core.messages import HumanMessage +from langchain_core.output_parsers import PydanticOutputParser +from langchain.prompts import PromptTemplate +from ..core.llm import llm +from ..schemas import InvoiceInfo + +parser = PydanticOutputParser(pydantic_object=InvoiceInfo) + +# The prompt now includes the detailed rules for each field using snake_case. +invoice_template = """ +You are an expert data entry clerk AI. Your primary goal is to extract information from an invoice image with the highest possible accuracy. +The document's primary language is '{language}'. + +## Instructions: +Carefully analyze the invoice image and extract the following fields according to these specific rules. Do not invent information. If a field is not found or is unclear, follow the specific instruction for that field. + +- `date`: Extract in YYYY-MM-DD format. If unclear, leave as an empty string. +- `invoice_number`: If not found or unclear, leave as an empty string. +- `supplier_number`: This is the organisation number. If not found or unclear, leave as an empty string. +- `biller_name`: This is the sender's name. If not found or unclear, leave as an empty string. +- `amount`: Extract the final total amount and format it to a decimal number. If not present, leave as null. +- `customer_name`: This is the receiver's name. Ensure it is a name and clear any special characters. If not found or unclear, leave as an empty string. +- `customer_address`: This is the receiver's full address. Put it in one line. If not found or unclear, leave as an empty string. +- `customer_address_line`: This is only the street address line from the receiver's address. If not found or unclear, leave as an empty string. +- `customer_address_city`: This is the receiver's city. If not found, try to find any city in the document. If unclear, leave as an empty string. +- `customer_address_country`: This is the receiver's country. If not found, find the country of the extracted city. If unclear, leave as an empty string. +- `customer_address_postal_code`: This is the receiver's postal code. If not found or unclear, leave as an empty string. +- `customer_address_apartment`: This is the receiver's apartment or suite number. If not found or unclear, leave as an empty string. +- `customer_address_region`: This is the receiver's region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string. +- `customer_address_care_of`: This is the receiver's 'care of' (c/o) line. If not found or unclear, leave as an empty string. +- `billo_id`: To find this, think step-by-step: 1. Find the customer_address. 2. Scan the address for a pattern of three letters, an optional space, three digits, an optional dash, and one alphanumeric character (e.g., 'ABC 123-X' or 'DEF 456Z'). 3. If found, extract it. If not found or unclear, leave as an empty string. +- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null. + +## Example: +If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $200.00", the output for that line item should be: +```json +{{ + "description": "Consulting Services", + "quantity": 2, + "unit_price": 100.00, + "total_price": 200.00 +}} + +Your Task: +Now, analyze the provided image and output the full JSON object according to the format below. +{format_instructions} +""" + +invoice_prompt = PromptTemplate( + template=invoice_template, + input_variables=["language"], + partial_variables={"format_instructions": parser.get_format_instructions()}, +) + +async def agent_extract_invoice_info(image_base64: str, language: str) -> InvoiceInfo: + """Agent 3: Extracts invoice information from an image, aware of the document's language.""" + print(f"--- [Agent 3] Calling multimodal LLM to extract invoice info (Language: {language})...") + + prompt_text = await invoice_prompt.aformat(language=language) + + msg = HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + { + "type": "image_url", + "image_url": f"data:image/png;base64,{image_base64}", + }, + ] + ) + + chain = llm | parser + invoice_info = await chain.ainvoke([msg]) + return invoice_info diff --git a/app/agents/receipt_agent.py b/app/agents/receipt_agent.py new file mode 100644 index 0000000..e1719d0 --- /dev/null +++ b/app/agents/receipt_agent.py @@ -0,0 +1,46 @@ +# app/agents/receipt_agent.py +from langchain_core.messages import HumanMessage +from langchain_core.output_parsers import PydanticOutputParser +from langchain.prompts import PromptTemplate +from ..core.llm import llm +from ..schemas import ReceiptInfo + +parser = PydanticOutputParser(pydantic_object=ReceiptInfo) + +# 更新Prompt模板以包含语言信息 +receipt_template = """ +You are a highly accurate receipt information extraction robot. +The document's primary language is '{language}'. +Please extract all key information from the following receipt image. +If some information is not present in the image, leave it as null. +Please strictly follow the JSON format below, without adding any extra explanations or comments. + +{format_instructions} +""" + +receipt_prompt = PromptTemplate( + template=receipt_template, + input_variables=["language"], + partial_variables={"format_instructions": parser.get_format_instructions()}, +) + + +async def agent_extract_receipt_info(image_base64: str, language: str) -> ReceiptInfo: + """Agent 2: Extracts receipt information from an image, aware of the document's language.""" + print(f"--- [Agent 2] Calling multimodal LLM to extract receipt info (Language: {language})...") + + prompt_text = await receipt_prompt.aformat(language=language) + + msg = HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + { + "type": "image_url", + "image_url": f"data:image/png;base64,{image_base64}", + }, + ] + ) + + chain = llm | parser + receipt_info = await chain.ainvoke([msg]) + return receipt_info diff --git a/app/agents/vectorization_agent.py b/app/agents/vectorization_agent.py new file mode 100644 index 0000000..0c127be --- /dev/null +++ b/app/agents/vectorization_agent.py @@ -0,0 +1,38 @@ +# app/agents/vectorization_agent.py +from langchain.text_splitter import RecursiveCharacterTextSplitter +from ..core.vector_store import vector_store, embedding_model + +# 初始化文本分割器,用于将长文档切成小块 +text_splitter = RecursiveCharacterTextSplitter( + chunk_size=500, # 每个块的大小(字符数) + chunk_overlap=50, # 块之间的重叠部分 +) + +def agent_vectorize_and_store(doc_id: str, text: str, category: str): + """Agent 4: 向量化并存储 (真实实现)""" + print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...") + + # 1. 将文档文本分割成块 + chunks = text_splitter.split_text(text) + print(f"--- [Agent 4] 文档被切分为 {len(chunks)} 个块。") + + if not chunks: + print(f"--- [Agent 4] 文档内容为空,跳过向量化。") + return + + # 2. 为每个块创建唯一的ID和元数据 + chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))] + metadatas = [{"doc_id": doc_id, "category": category, "chunk_number": i} for i in range(len(chunks))] + + # 3. 使用嵌入模型为所有块生成向量 + embeddings = embedding_model.embed_documents(chunks) + + # 4. 将ID、向量、元数据和文本块本身添加到ChromaDB + vector_store.add( + ids=chunk_ids, + embeddings=embeddings, + documents=chunks, + metadatas=metadatas + ) + + print(f"--- [Agent 4] 文档 {doc_id} 的向量已存入ChromaDB。") diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/llm.py b/app/core/llm.py new file mode 100644 index 0000000..3e07072 --- /dev/null +++ b/app/core/llm.py @@ -0,0 +1,45 @@ +# app/core/llm.py +import os +from dotenv import load_dotenv +from langchain_openai import AzureChatOpenAI, ChatOpenAI + +# 加载.env文件中的环境变量 +load_dotenv() + +# 获取配置的LLM供应商 +LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower() + +llm = None + +print(f"--- [Core] Initializing LLM with provider: {LLM_PROVIDER} ---") + +if LLM_PROVIDER == "azure": + # --- Azure OpenAI 配置 --- + required_vars = [ + "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", + "OPENAI_API_VERSION", "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME" + ] + if not all(os.getenv(var) for var in required_vars): + raise ValueError("One or more Azure OpenAI environment variables for chat are not set.") + + llm = AzureChatOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("OPENAI_API_VERSION"), + azure_deployment=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"), + temperature=0, + ) + +elif LLM_PROVIDER == "openai": + # --- 标准 OpenAI 配置 --- + if not os.getenv("OPENAI_API_KEY"): + raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.") + + llm = ChatOpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + model_name=os.getenv("OPENAI_MODEL_NAME", "gpt-4o"), + temperature=0, + ) + +else: + raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.") \ No newline at end of file diff --git a/app/core/ocr.py b/app/core/ocr.py new file mode 100644 index 0000000..b91bca5 --- /dev/null +++ b/app/core/ocr.py @@ -0,0 +1,27 @@ +# app/core/ocr.py +import pytesseract +from PIL import Image + + +# 注意: 您需要先在您的系统中安装Google的Tesseract OCR引擎。 +# 详情请参考之前的安装说明。 + +def extract_text_from_image(image: Image.Image) -> str: + """ + 使用Tesseract OCR从Pillow Image对象中提取文本。 + + 参数: + image: Pillow Image对象。 + + 返回: + 从图片中提取出的字符串文本。 + """ + try: + print("--- [Core OCR] 正在从图片中提取文本用于分类...") + # lang='chi_sim+eng' 表示同时识别简体中文和英文 + text = pytesseract.image_to_string(image, lang='chi_sim+eng') + print("--- [Core OCR] 文本提取成功。") + return text + except Exception as e: + print(f"--- [Core OCR] OCR处理失败: {e}") + raise IOError(f"OCR processing failed: {e}") \ No newline at end of file diff --git a/app/core/pdf_processor.py b/app/core/pdf_processor.py new file mode 100644 index 0000000..0aa3759 --- /dev/null +++ b/app/core/pdf_processor.py @@ -0,0 +1,43 @@ +# app/core/pdf_processor.py +from pdf2image import convert_from_bytes +from PIL import Image +from io import BytesIO +from typing import List +import base64 + + +# 注意: 您需要安装Poppler。 +# - macOS: brew install poppler +# - Ubuntu/Debian: sudo apt-get install poppler-utils +# - Windows: 下载Poppler并将其bin目录添加到系统PATH。 + +def convert_pdf_to_images(pdf_bytes: bytes) -> List[Image.Image]: + """将PDF文件的字节流转换为Pillow Image对象列表。""" + try: + print("--- [Core PDF] 正在将PDF转换为图片...") + + # --- 新增代码开始 --- + # 在这里直接指定您电脑上Poppler的bin目录路径 + # 请确保将下面的示例路径替换为您的真实路径 + poppler_path = r"C:\ProgramData\chocolatey\lib\poppler\tools\Library\bin" + # --- 新增代码结束 --- + + # --- 修改的代码开始 --- + # 在调用时传入poppler_path参数 + images = convert_from_bytes(pdf_bytes) + # --- 修改的代码结束 --- + + print(f"--- [Core PDF] 转换成功,共 {len(images)} 页。") + return images + except Exception as e: + print(f"--- [Core PDF] PDF转换失败: {e}") + # 增加一个更友好的错误提示 + print("--- [Core PDF] 请确认您已在系统中正确安装Poppler,并在上面的代码中指定了正确的poppler_path。") + raise IOError(f"PDF to image conversion failed: {e}") + + +def image_to_base64_str(image: Image.Image) -> str: + """将Pillow Image对象转换为Base64编码的字符串。""" + buffered = BytesIO() + image.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode('utf-8') \ No newline at end of file diff --git a/app/core/vector_store.py b/app/core/vector_store.py new file mode 100644 index 0000000..89fc5f0 --- /dev/null +++ b/app/core/vector_store.py @@ -0,0 +1,52 @@ +# app/core/vector_store.py +import os +import chromadb +from dotenv import load_dotenv +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings + +# 加载.env文件中的环境变量 +load_dotenv() + +# 获取配置的LLM供应商 +LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower() + +embedding_model = None + +print(f"--- [Core] Initializing Embeddings with provider: {LLM_PROVIDER} ---") + +if LLM_PROVIDER == "azure": + # --- Azure OpenAI 配置 --- + required_vars = [ + "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", + "OPENAI_API_VERSION", "AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME" + ] + if not all(os.getenv(var) for var in required_vars): + raise ValueError("One or more Azure OpenAI environment variables for embeddings are not set.") + + embedding_model = AzureOpenAIEmbeddings( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("OPENAI_API_VERSION"), + azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"), + ) + +elif LLM_PROVIDER == "openai": + # --- 标准 OpenAI 配置 --- + if not os.getenv("OPENAI_API_KEY"): + raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.") + + embedding_model = OpenAIEmbeddings( + api_key=os.getenv("OPENAI_API_KEY"), + model=os.getenv("OPENAI_EMBEDDING_MODEL_NAME", "text-embedding-3-small") + ) + +else: + raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.") + + +# 初始化ChromaDB客户端 (无变化) +client = chromadb.PersistentClient(path="./chroma_db") +vector_store = client.get_or_create_collection( + name="documents", + metadata={"hnsw:space": "cosine"} +) \ No newline at end of file diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..55a9de1 --- /dev/null +++ b/app/main.py @@ -0,0 +1,21 @@ +# app/main.py +import uvicorn +from fastapi import FastAPI +from .routers import documents # 导入我们新的文档路由 + +app = FastAPI( + title="混合模式文档处理AI Agent", + description="一个用于自动分类、提取和处理文档的AI应用框架。", + version="0.9.0", # 版本升级: 模块化API路由 +) + +# 将文档路由包含到主应用中 +app.include_router(documents.router) + +@app.get("/", tags=["Root"]) +async def read_root(): + """一个简单的根端点,用于检查服务是否正在运行。""" + return {"message": "Welcome to the Document Processing AI Agent API!"} + +if __name__ == "__main__": + uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True) diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/routers/documents.py b/app/routers/documents.py new file mode 100644 index 0000000..ad99729 --- /dev/null +++ b/app/routers/documents.py @@ -0,0 +1,99 @@ +# app/routers/documents.py +import uuid +import mimetypes +import base64 +from fastapi import APIRouter, UploadFile, File, HTTPException +from typing import Dict, Any, List +from fastapi.concurrency import run_in_threadpool +from PIL import Image +from io import BytesIO + +from .. import agents +from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str +from ..core.ocr import extract_text_from_image + +# 创建一个APIRouter实例 +router = APIRouter( + prefix="/documents", # 为这个路由下的所有路径添加前缀 + tags=["Document Processing"], # 在API文档中为这组端点添加标签 +) + +# 模拟一个SQL数据库来存储最终结果 +db_results: Dict[str, Any] = {} + + +async def hybrid_process_pipeline(doc_id: str, image: Image.Image, page_num: int): + """混合处理流水线""" + ocr_text = await run_in_threadpool(extract_text_from_image, image) + + classification_result = await agents.agent_classify_document_from_text(ocr_text) + category = classification_result.category + language = classification_result.language + print(f"Document page {page_num} classified as: {category}, Language: {language}") + + extraction_result = None + if category in ["RECEIPT", "INVOICE"]: + image_base64 = await run_in_threadpool(image_to_base64_str, image) + if category == "RECEIPT": + # 将语言传递给提取Agent + extraction_result = await agents.agent_extract_receipt_info(image_base64, language) + elif category == "INVOICE": + # 将语言传递给提取Agent + extraction_result = await agents.agent_extract_invoice_info(image_base64, language) + else: + print(f"Document classified as '{category}', skipping high-precision extraction.") + + final_result = { + "doc_id": f"{doc_id}_page_{page_num}", + "category": category, + "language": language, + "ocr_text_for_classification": ocr_text, + "extraction_data": extraction_result.dict() if extraction_result else None, + "status": "Processed" + } + db_results[final_result["doc_id"]] = final_result + return final_result + + +@router.post("/process", summary="上传并处理单个文档(混合模式)") +async def upload_and_process_document(file: UploadFile = File(...)): + """处理上传的文档文件 (PDF, PNG, JPG)""" + if not file.filename: + raise HTTPException(status_code=400, detail="No file provided.") + + doc_id = str(uuid.uuid4()) + print(f"\n接收到新文件: {file.filename} (分配ID: {doc_id})") + contents = await file.read() + + try: + file_type = mimetypes.guess_type(file.filename)[0] + print(f"检测到文件类型: {file_type}") + + images: List[Image.Image] = [] + if file_type == 'application/pdf': + images = await run_in_threadpool(convert_pdf_to_images, contents) + elif file_type in ['image/png', 'image/jpeg']: + images.append(Image.open(BytesIO(contents))) + else: + raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_type}.") + + if not images: + raise HTTPException(status_code=400, detail="Could not extract images from document.") + + all_page_results = [] + for i, img in enumerate(images): + page_result = await hybrid_process_pipeline(doc_id, img, i + 1) + all_page_results.append(page_result) + + return all_page_results + + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") + + +@router.get("/results/{doc_id}", summary="根据ID获取处理结果") +async def get_result(doc_id: str): + """根据文档处理后返回的 doc_id 获取其详细处理结果。""" + if doc_id in db_results: + return db_results[doc_id] + raise HTTPException(status_code=404, detail="Document not found.") diff --git a/app/schemas.py b/app/schemas.py new file mode 100644 index 0000000..d3b879f --- /dev/null +++ b/app/schemas.py @@ -0,0 +1,51 @@ +# app/schemas.py +from pydantic import BaseModel, Field +from typing import List, Optional + +# --- 分类结果模型 --- +class ClassificationResult(BaseModel): + """Defines the structured output for the classification agent.""" + category: str = Field(description="The category of the document, must be one of ['LETTER', 'INVOICE', 'RECEIPT', 'CONTRACT', 'OTHER']") + language: str = Field(description="The detected primary language of the document as a two-letter code (e.g., 'en', 'zh', 'es').") + +# --- 现有模型 (无变化) --- +class ReceiptItem(BaseModel): + name: str = Field(description="购买的项目或服务名称") + quantity: float = Field(description="项目数量") + price: float = Field(description="项目单价") + +class ReceiptInfo(BaseModel): + merchant_name: Optional[str] = Field(None, description="商户或店铺的名称") + transaction_date: Optional[str] = Field(None, description="交易日期,格式为 YYYY-MM-DD") + total_amount: Optional[float] = Field(None, description="收据上的总金额") + items: Optional[List[ReceiptItem]] = Field(None, description="购买的所有项目列表") + + +# --- 新增: 发票行项目模型 --- +class LineItem(BaseModel): + """Defines a single line item from an invoice.""" + description: Optional[str] = Field("", description="The description of the product or service.") + quantity: Optional[float] = Field(None, description="The quantity of the item.") + unit_price: Optional[float] = Field(None, description="The price for a single unit of the item.") + total_price: Optional[float] = Field(None, description="The total price for this line item (quantity * unit_price).") + + +# --- 发票模型 (已更新) --- +class InvoiceInfo(BaseModel): + """Defines the detailed, structured information to be extracted from an invoice.""" + date: Optional[str] = Field("", description="Extract in YYYY-MM-DD format. If unclear, leave as an empty string.") + invoice_number: Optional[str] = Field("", description="If not found or unclear, leave as an empty string.") + supplier_number: Optional[str] = Field("", description="It's the organisation number. If not found or unclear, leave as an empty string.") + biller_name: Optional[str] = Field("", description="It's the sender's name. If not found or unclear, leave as an empty string.") + amount: Optional[float] = Field(None, description="Extract and format to decimal. If not present, leave as null.") + customer_name: Optional[str] = Field("", description="It's the receiver's name. Clean any special chars from the name. If not found or unclear, leave as an empty string.") + customer_address: Optional[str] = Field("", description="It's the receiver's address. Put it in one line. If not found or unclear, leave as an empty string.") + customer_address_line: Optional[str] = Field("", description="It's the receiver's address line, not the whole address. If not found or unclear, leave as an empty string.") + customer_address_city: Optional[str] = Field("", description="It's the receiver's address city. If not found, find a city in the document. If unclear, leave as an empty string.") + customer_address_country: Optional[str] = Field("", description="It's the receiver's address country. If not found, find the country of the extracted city. If unclear, leave as an empty string.") + customer_address_postal_code: Optional[str] = Field("", description="It's the receiver's address postal code. If not found or unclear, leave as an empty string.") + customer_address_apartment: Optional[str] = Field("", description="It's the receiver's address apartment or suite. If not found or unclear, leave as an empty string.") + customer_address_region: Optional[str] = Field("", description="It's the receiver's address region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.") + customer_address_care_of: Optional[str] = Field("", description="It's the receiver's address care of. If not found or unclear, leave as an empty string.") + billo_id: Optional[str] = Field("", description="Extract from customer_address if it exists, following the format 'LLL NNN-A'. If not found or unclear, leave as an empty string.") + line_items: List[LineItem] = Field([], description="A list of all line items from the invoice.") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4c8dfa0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +fastapi +uvicorn[standard] +python-dotenv +langchain +langchain-openai +chromadb +tiktoken +pdf2image +python-multipart +pytesseract +Pillow