Init project

This commit is contained in:
Yaojia Wang
2025-08-11 00:07:41 +02:00
parent 840daf2d08
commit 0a80400720
23 changed files with 660 additions and 0 deletions

7
app/agents/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
# app/agents/__init__.py
# This file makes it easy to import all agents from the 'agents' package.
# We are importing the function with its correct name now.
from .classification_agent import agent_classify_document_from_text
from .receipt_agent import agent_extract_receipt_info
from .invoice_agent import agent_extract_invoice_info

View File

@@ -0,0 +1,42 @@
# app/agents/classification_agent.py
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from ..core.llm import llm
from ..schemas import ClassificationResult # 导入新的Schema
# 1. 设置PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=ClassificationResult)
# 2. 更新Prompt模板以要求语言并包含格式指令
classification_template = """
You are a professional document analysis assistant. Please perform two tasks on the following text:
1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
{format_instructions}
[Document Text Start]
{document_text}
[Document Text End]
"""
classification_prompt = PromptTemplate(
template=classification_template,
input_variables=["document_text"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
# 3. 创建新的LangChain链
classification_chain = classification_prompt | llm | parser
async def agent_classify_document_from_text(text: str) -> ClassificationResult:
"""Agent 1: Classify document and detect language from OCR-extracted text."""
print("--- [Agent 1] Calling LLM for classification and language detection...")
if not text.strip():
print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.")
return ClassificationResult(category="OTHER", language="unknown")
# 调用链并返回Pydantic对象
result = await classification_chain.ainvoke({"document_text": text})
return result

View File

@@ -0,0 +1,74 @@
# app/agents/invoice_agent.py
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import InvoiceInfo
parser = PydanticOutputParser(pydantic_object=InvoiceInfo)
# The prompt now includes the detailed rules for each field using snake_case.
invoice_template = """
You are an expert data entry clerk AI. Your primary goal is to extract information from an invoice image with the highest possible accuracy.
The document's primary language is '{language}'.
## Instructions:
Carefully analyze the invoice image and extract the following fields according to these specific rules. Do not invent information. If a field is not found or is unclear, follow the specific instruction for that field.
- `date`: Extract in YYYY-MM-DD format. If unclear, leave as an empty string.
- `invoice_number`: If not found or unclear, leave as an empty string.
- `supplier_number`: This is the organisation number. If not found or unclear, leave as an empty string.
- `biller_name`: This is the sender's name. If not found or unclear, leave as an empty string.
- `amount`: Extract the final total amount and format it to a decimal number. If not present, leave as null.
- `customer_name`: This is the receiver's name. Ensure it is a name and clear any special characters. If not found or unclear, leave as an empty string.
- `customer_address`: This is the receiver's full address. Put it in one line. If not found or unclear, leave as an empty string.
- `customer_address_line`: This is only the street address line from the receiver's address. If not found or unclear, leave as an empty string.
- `customer_address_city`: This is the receiver's city. If not found, try to find any city in the document. If unclear, leave as an empty string.
- `customer_address_country`: This is the receiver's country. If not found, find the country of the extracted city. If unclear, leave as an empty string.
- `customer_address_postal_code`: This is the receiver's postal code. If not found or unclear, leave as an empty string.
- `customer_address_apartment`: This is the receiver's apartment or suite number. If not found or unclear, leave as an empty string.
- `customer_address_region`: This is the receiver's region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.
- `customer_address_care_of`: This is the receiver's 'care of' (c/o) line. If not found or unclear, leave as an empty string.
- `billo_id`: To find this, think step-by-step: 1. Find the customer_address. 2. Scan the address for a pattern of three letters, an optional space, three digits, an optional dash, and one alphanumeric character (e.g., 'ABC 123-X' or 'DEF 456Z'). 3. If found, extract it. If not found or unclear, leave as an empty string.
- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null.
## Example:
If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $200.00", the output for that line item should be:
```json
{{
"description": "Consulting Services",
"quantity": 2,
"unit_price": 100.00,
"total_price": 200.00
}}
Your Task:
Now, analyze the provided image and output the full JSON object according to the format below.
{format_instructions}
"""
invoice_prompt = PromptTemplate(
template=invoice_template,
input_variables=["language"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
async def agent_extract_invoice_info(image_base64: str, language: str) -> InvoiceInfo:
"""Agent 3: Extracts invoice information from an image, aware of the document's language."""
print(f"--- [Agent 3] Calling multimodal LLM to extract invoice info (Language: {language})...")
prompt_text = await invoice_prompt.aformat(language=language)
msg = HumanMessage(
content=[
{"type": "text", "text": prompt_text},
{
"type": "image_url",
"image_url": f"data:image/png;base64,{image_base64}",
},
]
)
chain = llm | parser
invoice_info = await chain.ainvoke([msg])
return invoice_info

View File

@@ -0,0 +1,46 @@
# app/agents/receipt_agent.py
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import ReceiptInfo
parser = PydanticOutputParser(pydantic_object=ReceiptInfo)
# 更新Prompt模板以包含语言信息
receipt_template = """
You are a highly accurate receipt information extraction robot.
The document's primary language is '{language}'.
Please extract all key information from the following receipt image.
If some information is not present in the image, leave it as null.
Please strictly follow the JSON format below, without adding any extra explanations or comments.
{format_instructions}
"""
receipt_prompt = PromptTemplate(
template=receipt_template,
input_variables=["language"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
async def agent_extract_receipt_info(image_base64: str, language: str) -> ReceiptInfo:
"""Agent 2: Extracts receipt information from an image, aware of the document's language."""
print(f"--- [Agent 2] Calling multimodal LLM to extract receipt info (Language: {language})...")
prompt_text = await receipt_prompt.aformat(language=language)
msg = HumanMessage(
content=[
{"type": "text", "text": prompt_text},
{
"type": "image_url",
"image_url": f"data:image/png;base64,{image_base64}",
},
]
)
chain = llm | parser
receipt_info = await chain.ainvoke([msg])
return receipt_info

View File

@@ -0,0 +1,38 @@
# app/agents/vectorization_agent.py
from langchain.text_splitter import RecursiveCharacterTextSplitter
from ..core.vector_store import vector_store, embedding_model
# 初始化文本分割器,用于将长文档切成小块
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, # 每个块的大小(字符数)
chunk_overlap=50, # 块之间的重叠部分
)
def agent_vectorize_and_store(doc_id: str, text: str, category: str):
"""Agent 4: 向量化并存储 (真实实现)"""
print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...")
# 1. 将文档文本分割成块
chunks = text_splitter.split_text(text)
print(f"--- [Agent 4] 文档被切分为 {len(chunks)} 个块。")
if not chunks:
print(f"--- [Agent 4] 文档内容为空,跳过向量化。")
return
# 2. 为每个块创建唯一的ID和元数据
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
metadatas = [{"doc_id": doc_id, "category": category, "chunk_number": i} for i in range(len(chunks))]
# 3. 使用嵌入模型为所有块生成向量
embeddings = embedding_model.embed_documents(chunks)
# 4. 将ID、向量、元数据和文本块本身添加到ChromaDB
vector_store.add(
ids=chunk_ids,
embeddings=embeddings,
documents=chunks,
metadatas=metadatas
)
print(f"--- [Agent 4] 文档 {doc_id} 的向量已存入ChromaDB。")