This commit is contained in:
2025-08-11 14:20:56 +02:00
parent 0a80400720
commit f077c6351d
17 changed files with 165 additions and 248 deletions

View File

@@ -1,7 +1,4 @@
# app/agents/__init__.py
# This file makes it easy to import all agents from the 'agents' package.
# We are importing the function with its correct name now.
from .classification_agent import agent_classify_document_from_text
from .classification_agent import agent_classify_document_from_image
from .receipt_agent import agent_extract_receipt_info
from .invoice_agent import agent_extract_invoice_info

View File

@@ -1,42 +1,48 @@
# app/agents/classification_agent.py
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import ClassificationResult # 导入新的Schema
from ..schemas import ClassificationResult
from typing import List
# 1. 设置PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=ClassificationResult)
# 2. 更新Prompt模板以要求语言并包含格式指令
classification_template = """
You are a professional document analysis assistant. Please perform two tasks on the following text:
1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
You are a professional document analysis assistant. The following images represent pages from a single document. Please perform two tasks based on all pages provided:
1. Determine the overall category of the document. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
2. Detect the primary language of the document. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese).
Please provide a single response for the entire document in the requested JSON format.
{format_instructions}
[Document Text Start]
{document_text}
[Document Text End]
"""
classification_prompt = PromptTemplate(
template=classification_template,
input_variables=["document_text"],
input_variables=[],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
# 3. 创建新的LangChain链
classification_chain = classification_prompt | llm | parser
async def agent_classify_document_from_image(images_base64: List[str]) -> ClassificationResult:
"""Agent 1: Classifies an entire document (multiple pages) and detects its language from a list of images."""
print(f"--- [Agent 1] Calling multimodal LLM for classification of a {len(images_base64)}-page document...")
async def agent_classify_document_from_text(text: str) -> ClassificationResult:
"""Agent 1: Classify document and detect language from OCR-extracted text."""
print("--- [Agent 1] Calling LLM for classification and language detection...")
if not text.strip():
print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.")
return ClassificationResult(category="OTHER", language="unknown")
prompt_text = await classification_prompt.aformat()
# 调用链并返回Pydantic对象
result = await classification_chain.ainvoke({"document_text": text})
# Create a list of content parts, starting with the text prompt
content_parts = [{"type": "text", "text": prompt_text}]
# Add each image to the content list
for image_base64 in images_base64:
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
})
msg = HumanMessage(content=content_parts)
chain = llm | parser
result = await chain.ainvoke([msg])
return result

View File

@@ -4,10 +4,10 @@ from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import InvoiceInfo
from typing import List
parser = PydanticOutputParser(pydantic_object=InvoiceInfo)
# The prompt now includes the detailed rules for each field using snake_case.
invoice_template = """
You are an expert data entry clerk AI. Your primary goal is to extract information from an invoice image with the highest possible accuracy.
The document's primary language is '{language}'.
@@ -30,6 +30,9 @@ Carefully analyze the invoice image and extract the following fields according t
- `customer_address_region`: This is the receiver's region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.
- `customer_address_care_of`: This is the receiver's 'care of' (c/o) line. If not found or unclear, leave as an empty string.
- `billo_id`: To find this, think step-by-step: 1. Find the customer_address. 2. Scan the address for a pattern of three letters, an optional space, three digits, an optional dash, and one alphanumeric character (e.g., 'ABC 123-X' or 'DEF 456Z'). 3. If found, extract it. If not found or unclear, leave as an empty string.
- `bank_giro`: If found, extract the bank giro number. It often follows patterns like 'ddd-dddd', 'dddd-dddd', or 'dddddddd #41#'. If not found or unclear, leave as an empty string.
- `plus_giro`: If found, extract the plus giro number. It often follows patterns like 'ddddddd-d #16#', 'ddddddd-d', or 'ddd dd dd-d'. If not found or unclear, leave as an empty string.
- `customer_ssn`: If found, extract the customer social security number (personnummer). It follows the pattern 'YYYYMMDD-XXXX' or 'YYMMDD-XXXX'. If not found or unclear, leave as an empty string.
- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null.
## Example:
@@ -41,33 +44,31 @@ If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $
"unit_price": 100.00,
"total_price": 200.00
}}
```
Your Task:
Now, analyze the provided image and output the full JSON object according to the format below.
{format_instructions}
"""
invoice_prompt = PromptTemplate(
template=invoice_template,
input_variables=["language"],
partial_variables={"format_instructions": parser.get_format_instructions()},
partial_variables={"format_instructions": parser.get_format_instructions()}
)
async def agent_extract_invoice_info(image_base64: str, language: str) -> InvoiceInfo:
"""Agent 3: Extracts invoice information from an image, aware of the document's language."""
async def agent_extract_invoice_info(images_base64: List[str], language: str) -> InvoiceInfo:
"""Agent 3: Extracts invoice information from a list of images, aware of the document's language."""
print(f"--- [Agent 3] Calling multimodal LLM to extract invoice info (Language: {language})...")
prompt_text = await invoice_prompt.aformat(language=language)
msg = HumanMessage(
content=[
{"type": "text", "text": prompt_text},
{
"type": "image_url",
"image_url": f"data:image/png;base64,{image_base64}",
},
]
)
content_parts = [{"type": "text", "text": prompt_text}]
for image_base64 in images_base64:
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
})
msg = HumanMessage(content=content_parts)
chain = llm | parser
invoice_info = await chain.ainvoke([msg])

View File

@@ -4,15 +4,15 @@ from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import ReceiptInfo
from typing import List
parser = PydanticOutputParser(pydantic_object=ReceiptInfo)
# 更新Prompt模板以包含语言信息
receipt_template = """
You are a highly accurate receipt information extraction robot.
The document's primary language is '{language}'.
Please extract all key information from the following receipt image.
If some information is not present in the image, leave it as null.
Please extract all key information from the following receipt images, which belong to a single document.
If some information is not present in the images, leave it as null.
Please strictly follow the JSON format below, without adding any extra explanations or comments.
{format_instructions}
@@ -25,22 +25,21 @@ receipt_prompt = PromptTemplate(
)
async def agent_extract_receipt_info(image_base64: str, language: str) -> ReceiptInfo:
"""Agent 2: Extracts receipt information from an image, aware of the document's language."""
async def agent_extract_receipt_info(images_base64: List[str], language: str) -> ReceiptInfo:
"""Agent 2: Extracts receipt information from a list of images, aware of the document's language."""
print(f"--- [Agent 2] Calling multimodal LLM to extract receipt info (Language: {language})...")
prompt_text = await receipt_prompt.aformat(language=language)
msg = HumanMessage(
content=[
{"type": "text", "text": prompt_text},
{
"type": "image_url",
"image_url": f"data:image/png;base64,{image_base64}",
},
]
)
content_parts = [{"type": "text", "text": prompt_text}]
for image_base64 in images_base64:
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
})
msg = HumanMessage(content=content_parts)
chain = llm | parser
receipt_info = await chain.ainvoke([msg])
return receipt_info
return receipt_info

View File

@@ -2,32 +2,32 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from ..core.vector_store import vector_store, embedding_model
# 初始化文本分割器,用于将长文档切成小块
# Initialize the text splitter to divide long documents into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, # 每个块的大小(字符数)
chunk_overlap=50, # 块之间的重叠部分
chunk_size=500,
chunk_overlap=50,
)
def agent_vectorize_and_store(doc_id: str, text: str, category: str):
"""Agent 4: 向量化并存储 (真实实现)"""
print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...")
"""Agent 4: Vectorization and Storage (Real Implementation)"""
print(f"--- [Agent 4] Vectorizing document (ID: {doc_id})...")
# 1. 将文档文本分割成块
# 1. Split the document text into chunks
chunks = text_splitter.split_text(text)
print(f"--- [Agent 4] 文档被切分为 {len(chunks)} 个块。")
print(f"--- [Agent 4] Document split into {len(chunks)} chunks.")
if not chunks:
print(f"--- [Agent 4] 文档内容为空,跳过向量化。")
print(f"--- [Agent 4] Document is empty, skipping vectorization.")
return
# 2. 为每个块创建唯一的ID和元数据
# 2. Create a unique ID and metadata for each chunk
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
metadatas = [{"doc_id": doc_id, "category": category, "chunk_number": i} for i in range(len(chunks))]
# 3. 使用嵌入模型为所有块生成向量
# 3. Use an embedding model to generate vectors for all chunks
embeddings = embedding_model.embed_documents(chunks)
# 4. 将ID、向量、元数据和文本块本身添加到ChromaDB
# 4. Add the IDs, vectors, metadata, and text chunks to ChromaDB
vector_store.add(
ids=chunk_ids,
embeddings=embeddings,
@@ -35,4 +35,4 @@ def agent_vectorize_and_store(doc_id: str, text: str, category: str):
metadatas=metadatas
)
print(f"--- [Agent 4] 文档 {doc_id} 的向量已存入ChromaDB。")
print(f"--- [Agent 4] document {doc_id} stored in ChromaDB。")