This commit is contained in:
2025-08-11 14:20:56 +02:00
parent 0a80400720
commit f077c6351d
17 changed files with 165 additions and 248 deletions

View File

@@ -1,42 +1,48 @@
# app/agents/classification_agent.py
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from ..core.llm import llm
from ..schemas import ClassificationResult # 导入新的Schema
from ..schemas import ClassificationResult
from typing import List
# 1. 设置PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=ClassificationResult)
# 2. 更新Prompt模板以要求语言并包含格式指令
classification_template = """
You are a professional document analysis assistant. Please perform two tasks on the following text:
1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
You are a professional document analysis assistant. The following images represent pages from a single document. Please perform two tasks based on all pages provided:
1. Determine the overall category of the document. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
2. Detect the primary language of the document. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese).
Please provide a single response for the entire document in the requested JSON format.
{format_instructions}
[Document Text Start]
{document_text}
[Document Text End]
"""
classification_prompt = PromptTemplate(
template=classification_template,
input_variables=["document_text"],
input_variables=[],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
# 3. 创建新的LangChain链
classification_chain = classification_prompt | llm | parser
async def agent_classify_document_from_image(images_base64: List[str]) -> ClassificationResult:
"""Agent 1: Classifies an entire document (multiple pages) and detects its language from a list of images."""
print(f"--- [Agent 1] Calling multimodal LLM for classification of a {len(images_base64)}-page document...")
async def agent_classify_document_from_text(text: str) -> ClassificationResult:
"""Agent 1: Classify document and detect language from OCR-extracted text."""
print("--- [Agent 1] Calling LLM for classification and language detection...")
if not text.strip():
print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.")
return ClassificationResult(category="OTHER", language="unknown")
prompt_text = await classification_prompt.aformat()
# 调用链并返回Pydantic对象
result = await classification_chain.ainvoke({"document_text": text})
# Create a list of content parts, starting with the text prompt
content_parts = [{"type": "text", "text": prompt_text}]
# Add each image to the content list
for image_base64 in images_base64:
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
})
msg = HumanMessage(content=content_parts)
chain = llm | parser
result = await chain.ainvoke([msg])
return result