"""LangGraph Supervisor construction -- connects registry, agents, LLM, and persistence.""" from __future__ import annotations import logging from typing import TYPE_CHECKING from langgraph.prebuilt import create_react_agent from langgraph_supervisor import create_supervisor from app.agents import get_tools_by_names from app.intent import ClassificationResult, IntentClassifier if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.graph.state import CompiledStateGraph from app.registry import AgentRegistry logger = logging.getLogger(__name__) SUPERVISOR_PROMPT = ( "You are a customer support supervisor. " "Route customer requests to the appropriate agent based on their description.\n\n" "Available agents and their roles:\n" "{agent_descriptions}\n\n" "Routing rules:\n" "- For order status and tracking queries, use the order_lookup agent.\n" "- For order modifications like cancellations, use the order_actions agent.\n" "- For discounts, promotions, or coupon codes, use the discount agent.\n" "- For anything else or when uncertain, use the fallback agent.\n" "- If the user's request involves multiple actions, execute them in order.\n" "- If a previous intent classification is provided, follow it.\n" ) def _format_agent_descriptions(registry: AgentRegistry) -> str: """Build agent description text for the supervisor prompt.""" lines = [] for agent in registry.list_agents(): lines.append(f"- {agent.name}: {agent.description}") return "\n".join(lines) def build_agent_nodes( registry: AgentRegistry, llm: BaseChatModel, ) -> list: """Create LangGraph react agent nodes from registry configurations.""" agent_nodes = [] for agent_config in registry.list_agents(): tools = get_tools_by_names(agent_config.tools) system_prompt = ( f"You are the {agent_config.name} agent. " f"Personality: {agent_config.personality.tone}. " f"{agent_config.description} " f"Permission level: {agent_config.permission}." ) agent_node = create_react_agent( model=llm, tools=tools, name=agent_config.name, prompt=system_prompt, ) agent_nodes.append(agent_node) return agent_nodes def build_graph( registry: AgentRegistry, llm: BaseChatModel, checkpointer: AsyncPostgresSaver, intent_classifier: IntentClassifier | None = None, ) -> CompiledStateGraph: """Build and compile the LangGraph supervisor graph. If an intent_classifier is provided, the supervisor prompt is enhanced with agent descriptions for better routing. The classifier is stored for use by the routing layer (ws_handler). """ agent_nodes = build_agent_nodes(registry, llm) agent_descriptions = _format_agent_descriptions(registry) prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions) workflow = create_supervisor( agent_nodes, model=llm, prompt=prompt, output_mode="full_history", ) graph = workflow.compile(checkpointer=checkpointer) # Attach classifier and registry to graph for use by ws_handler graph.intent_classifier = intent_classifier # type: ignore[attr-defined] graph.agent_registry = registry # type: ignore[attr-defined] return graph async def classify_intent( graph: CompiledStateGraph, message: str, ) -> ClassificationResult | None: """Classify user intent using the graph's attached classifier. Returns None if no classifier is configured. """ classifier = getattr(graph, "intent_classifier", None) registry = getattr(graph, "agent_registry", None) if classifier is None or registry is None: return None agents = registry.list_agents() return await classifier.classify(message, agents)