49 lines
2.0 KiB
Python
49 lines
2.0 KiB
Python
# app/agents/classification_agent.py
|
|
from langchain_core.messages import HumanMessage
|
|
from langchain_core.output_parsers import PydanticOutputParser
|
|
from langchain.prompts import PromptTemplate
|
|
from ..core.llm import llm
|
|
from ..schemas import ClassificationResult
|
|
from typing import List
|
|
|
|
parser = PydanticOutputParser(pydantic_object=ClassificationResult)
|
|
|
|
classification_template = """
|
|
You are a professional document analysis assistant. The following images represent pages from a single document. Please perform two tasks based on all pages provided:
|
|
1. Determine the overall category of the document. The category must be one of: ["LETTER", "INVOICE", "RECEIPT", "CONTRACT", "OTHER"].
|
|
2. Detect the primary language of the document. Return the language as a two-letter ISO 639-1 code (e.g., "en" for English, "zh" for Chinese).
|
|
|
|
Please provide a single response for the entire document in the requested JSON format.
|
|
|
|
{format_instructions}
|
|
"""
|
|
|
|
classification_prompt = PromptTemplate(
|
|
template=classification_template,
|
|
input_variables=[],
|
|
partial_variables={"format_instructions": parser.get_format_instructions()},
|
|
)
|
|
|
|
|
|
async def agent_classify_document_from_image(images_base64: List[str]) -> ClassificationResult:
|
|
"""Agent 1: Classifies an entire document (multiple pages) and detects its language from a list of images."""
|
|
print(f"--- [Agent 1] Calling multimodal LLM for classification of a {len(images_base64)}-page document...")
|
|
|
|
prompt_text = await classification_prompt.aformat()
|
|
|
|
# Create a list of content parts, starting with the text prompt
|
|
content_parts = [{"type": "text", "text": prompt_text}]
|
|
|
|
# Add each image to the content list
|
|
for image_base64 in images_base64:
|
|
content_parts.append({
|
|
"type": "image_url",
|
|
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
|
})
|
|
|
|
msg = HumanMessage(content=content_parts)
|
|
|
|
chain = llm | parser
|
|
result = await chain.ainvoke([msg])
|
|
return result
|