Init
This commit is contained in:
@@ -1,7 +1,4 @@
|
||||
# app/agents/__init__.py
|
||||
|
||||
# This file makes it easy to import all agents from the 'agents' package.
|
||||
# We are importing the function with its correct name now.
|
||||
from .classification_agent import agent_classify_document_from_text
|
||||
from .classification_agent import agent_classify_document_from_image
|
||||
from .receipt_agent import agent_extract_receipt_info
|
||||
from .invoice_agent import agent_extract_invoice_info
|
||||
@@ -1,42 +1,48 @@
|
||||
# app/agents/classification_agent.py
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain.prompts import PromptTemplate
|
||||
from ..core.llm import llm
|
||||
from ..schemas import ClassificationResult # 导入新的Schema
|
||||
from ..schemas import ClassificationResult
|
||||
from typing import List
|
||||
|
||||
# 1. 设置PydanticOutputParser
|
||||
parser = PydanticOutputParser(pydantic_object=ClassificationResult)
|
||||
|
||||
# 2. 更新Prompt模板以要求语言,并包含格式指令
|
||||
classification_template = """
|
||||
You are a professional document analysis assistant. Please perform two tasks on the following text:
|
||||
1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
|
||||
2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
|
||||
You are a professional document analysis assistant. The following images represent pages from a single document. Please perform two tasks based on all pages provided:
|
||||
1. Determine the overall category of the document. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
|
||||
2. Detect the primary language of the document. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese).
|
||||
|
||||
Please provide a single response for the entire document in the requested JSON format.
|
||||
|
||||
{format_instructions}
|
||||
|
||||
[Document Text Start]
|
||||
{document_text}
|
||||
[Document Text End]
|
||||
"""
|
||||
|
||||
classification_prompt = PromptTemplate(
|
||||
template=classification_template,
|
||||
input_variables=["document_text"],
|
||||
input_variables=[],
|
||||
partial_variables={"format_instructions": parser.get_format_instructions()},
|
||||
)
|
||||
|
||||
# 3. 创建新的LangChain链
|
||||
classification_chain = classification_prompt | llm | parser
|
||||
|
||||
async def agent_classify_document_from_image(images_base64: List[str]) -> ClassificationResult:
|
||||
"""Agent 1: Classifies an entire document (multiple pages) and detects its language from a list of images."""
|
||||
print(f"--- [Agent 1] Calling multimodal LLM for classification of a {len(images_base64)}-page document...")
|
||||
|
||||
async def agent_classify_document_from_text(text: str) -> ClassificationResult:
|
||||
"""Agent 1: Classify document and detect language from OCR-extracted text."""
|
||||
print("--- [Agent 1] Calling LLM for classification and language detection...")
|
||||
if not text.strip():
|
||||
print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.")
|
||||
return ClassificationResult(category="OTHER", language="unknown")
|
||||
prompt_text = await classification_prompt.aformat()
|
||||
|
||||
# 调用链并返回Pydantic对象
|
||||
result = await classification_chain.ainvoke({"document_text": text})
|
||||
# Create a list of content parts, starting with the text prompt
|
||||
content_parts = [{"type": "text", "text": prompt_text}]
|
||||
|
||||
# Add each image to the content list
|
||||
for image_base64 in images_base64:
|
||||
content_parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
|
||||
msg = HumanMessage(content=content_parts)
|
||||
|
||||
chain = llm | parser
|
||||
result = await chain.ainvoke([msg])
|
||||
return result
|
||||
|
||||
@@ -4,10 +4,10 @@ from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain.prompts import PromptTemplate
|
||||
from ..core.llm import llm
|
||||
from ..schemas import InvoiceInfo
|
||||
from typing import List
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=InvoiceInfo)
|
||||
|
||||
# The prompt now includes the detailed rules for each field using snake_case.
|
||||
invoice_template = """
|
||||
You are an expert data entry clerk AI. Your primary goal is to extract information from an invoice image with the highest possible accuracy.
|
||||
The document's primary language is '{language}'.
|
||||
@@ -30,6 +30,9 @@ Carefully analyze the invoice image and extract the following fields according t
|
||||
- `customer_address_region`: This is the receiver's region. If not found, find the region of the extracted city or country. If unclear, leave as an empty string.
|
||||
- `customer_address_care_of`: This is the receiver's 'care of' (c/o) line. If not found or unclear, leave as an empty string.
|
||||
- `billo_id`: To find this, think step-by-step: 1. Find the customer_address. 2. Scan the address for a pattern of three letters, an optional space, three digits, an optional dash, and one alphanumeric character (e.g., 'ABC 123-X' or 'DEF 456Z'). 3. If found, extract it. If not found or unclear, leave as an empty string.
|
||||
- `bank_giro`: If found, extract the bank giro number. It often follows patterns like 'ddd-dddd', 'dddd-dddd', or 'dddddddd #41#'. If not found or unclear, leave as an empty string.
|
||||
- `plus_giro`: If found, extract the plus giro number. It often follows patterns like 'ddddddd-d #16#', 'ddddddd-d', or 'ddd dd dd-d'. If not found or unclear, leave as an empty string.
|
||||
- `customer_ssn`: If found, extract the customer social security number (personnummer). It follows the pattern 'YYYYMMDD-XXXX' or 'YYMMDD-XXXX'. If not found or unclear, leave as an empty string.
|
||||
- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null.
|
||||
|
||||
## Example:
|
||||
@@ -41,33 +44,31 @@ If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $
|
||||
"unit_price": 100.00,
|
||||
"total_price": 200.00
|
||||
}}
|
||||
|
||||
```
|
||||
Your Task:
|
||||
Now, analyze the provided image and output the full JSON object according to the format below.
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
invoice_prompt = PromptTemplate(
|
||||
template=invoice_template,
|
||||
input_variables=["language"],
|
||||
partial_variables={"format_instructions": parser.get_format_instructions()},
|
||||
partial_variables={"format_instructions": parser.get_format_instructions()}
|
||||
)
|
||||
|
||||
async def agent_extract_invoice_info(image_base64: str, language: str) -> InvoiceInfo:
|
||||
"""Agent 3: Extracts invoice information from an image, aware of the document's language."""
|
||||
async def agent_extract_invoice_info(images_base64: List[str], language: str) -> InvoiceInfo:
|
||||
"""Agent 3: Extracts invoice information from a list of images, aware of the document's language."""
|
||||
print(f"--- [Agent 3] Calling multimodal LLM to extract invoice info (Language: {language})...")
|
||||
|
||||
prompt_text = await invoice_prompt.aformat(language=language)
|
||||
|
||||
msg = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": prompt_text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": f"data:image/png;base64,{image_base64}",
|
||||
},
|
||||
]
|
||||
)
|
||||
content_parts = [{"type": "text", "text": prompt_text}]
|
||||
for image_base64 in images_base64:
|
||||
content_parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
|
||||
msg = HumanMessage(content=content_parts)
|
||||
|
||||
chain = llm | parser
|
||||
invoice_info = await chain.ainvoke([msg])
|
||||
|
||||
@@ -4,15 +4,15 @@ from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain.prompts import PromptTemplate
|
||||
from ..core.llm import llm
|
||||
from ..schemas import ReceiptInfo
|
||||
from typing import List
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=ReceiptInfo)
|
||||
|
||||
# 更新Prompt模板以包含语言信息
|
||||
receipt_template = """
|
||||
You are a highly accurate receipt information extraction robot.
|
||||
The document's primary language is '{language}'.
|
||||
Please extract all key information from the following receipt image.
|
||||
If some information is not present in the image, leave it as null.
|
||||
Please extract all key information from the following receipt images, which belong to a single document.
|
||||
If some information is not present in the images, leave it as null.
|
||||
Please strictly follow the JSON format below, without adding any extra explanations or comments.
|
||||
|
||||
{format_instructions}
|
||||
@@ -25,22 +25,21 @@ receipt_prompt = PromptTemplate(
|
||||
)
|
||||
|
||||
|
||||
async def agent_extract_receipt_info(image_base64: str, language: str) -> ReceiptInfo:
|
||||
"""Agent 2: Extracts receipt information from an image, aware of the document's language."""
|
||||
async def agent_extract_receipt_info(images_base64: List[str], language: str) -> ReceiptInfo:
|
||||
"""Agent 2: Extracts receipt information from a list of images, aware of the document's language."""
|
||||
print(f"--- [Agent 2] Calling multimodal LLM to extract receipt info (Language: {language})...")
|
||||
|
||||
prompt_text = await receipt_prompt.aformat(language=language)
|
||||
|
||||
msg = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": prompt_text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": f"data:image/png;base64,{image_base64}",
|
||||
},
|
||||
]
|
||||
)
|
||||
content_parts = [{"type": "text", "text": prompt_text}]
|
||||
for image_base64 in images_base64:
|
||||
content_parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
|
||||
msg = HumanMessage(content=content_parts)
|
||||
|
||||
chain = llm | parser
|
||||
receipt_info = await chain.ainvoke([msg])
|
||||
return receipt_info
|
||||
return receipt_info
|
||||
@@ -2,32 +2,32 @@
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from ..core.vector_store import vector_store, embedding_model
|
||||
|
||||
# 初始化文本分割器,用于将长文档切成小块
|
||||
# Initialize the text splitter to divide long documents into smaller chunks
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500, # 每个块的大小(字符数)
|
||||
chunk_overlap=50, # 块之间的重叠部分
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
)
|
||||
|
||||
def agent_vectorize_and_store(doc_id: str, text: str, category: str):
|
||||
"""Agent 4: 向量化并存储 (真实实现)"""
|
||||
print(f"--- [Agent 4] 正在向量化文档 (ID: {doc_id})...")
|
||||
"""Agent 4: Vectorization and Storage (Real Implementation)"""
|
||||
print(f"--- [Agent 4] Vectorizing document (ID: {doc_id})...")
|
||||
|
||||
# 1. 将文档文本分割成块
|
||||
# 1. Split the document text into chunks
|
||||
chunks = text_splitter.split_text(text)
|
||||
print(f"--- [Agent 4] 文档被切分为 {len(chunks)} 个块。")
|
||||
print(f"--- [Agent 4] Document split into {len(chunks)} chunks.")
|
||||
|
||||
if not chunks:
|
||||
print(f"--- [Agent 4] 文档内容为空,跳过向量化。")
|
||||
print(f"--- [Agent 4] Document is empty, skipping vectorization.")
|
||||
return
|
||||
|
||||
# 2. 为每个块创建唯一的ID和元数据
|
||||
# 2. Create a unique ID and metadata for each chunk
|
||||
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
|
||||
metadatas = [{"doc_id": doc_id, "category": category, "chunk_number": i} for i in range(len(chunks))]
|
||||
|
||||
# 3. 使用嵌入模型为所有块生成向量
|
||||
# 3. Use an embedding model to generate vectors for all chunks
|
||||
embeddings = embedding_model.embed_documents(chunks)
|
||||
|
||||
# 4. 将ID、向量、元数据和文本块本身添加到ChromaDB
|
||||
# 4. Add the IDs, vectors, metadata, and text chunks to ChromaDB
|
||||
vector_store.add(
|
||||
ids=chunk_ids,
|
||||
embeddings=embeddings,
|
||||
@@ -35,4 +35,4 @@ def agent_vectorize_and_store(doc_id: str, text: str, category: str):
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
print(f"--- [Agent 4] 文档 {doc_id} 的向量已存入ChromaDB。")
|
||||
print(f"--- [Agent 4] document {doc_id} stored in ChromaDB。")
|
||||
|
||||
Reference in New Issue
Block a user