Files
AmazingDoc/app/agents/classification_agent.py
Yaojia Wang 0a80400720 Init project
2025-08-11 00:07:41 +02:00

43 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# app/agents/classification_agent.py
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from ..core.llm import llm
from ..schemas import ClassificationResult # 导入新的Schema
# 1. 设置PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=ClassificationResult)
# 2. 更新Prompt模板以要求语言并包含格式指令
classification_template = """
You are a professional document analysis assistant. Please perform two tasks on the following text:
1. Determine its category. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
2. Detect the primary language of the text. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
{format_instructions}
[Document Text Start]
{document_text}
[Document Text End]
"""
classification_prompt = PromptTemplate(
template=classification_template,
input_variables=["document_text"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
# 3. 创建新的LangChain链
classification_chain = classification_prompt | llm | parser
async def agent_classify_document_from_text(text: str) -> ClassificationResult:
"""Agent 1: Classify document and detect language from OCR-extracted text."""
print("--- [Agent 1] Calling LLM for classification and language detection...")
if not text.strip():
print("--- [Agent 1] Text content is empty, classifying as 'OTHER'.")
return ClassificationResult(category="OTHER", language="unknown")
# 调用链并返回Pydantic对象
result = await classification_chain.ainvoke({"document_text": text})
return result