feat: complete phase 2 -- multi-agent routing, interrupt TTL, escalation, templates
- Intent classification with LLM structured output (single/multi/ambiguous) - Discount agent with apply_discount and generate_coupon tools - Interrupt manager with 30-min TTL auto-expiration and retry prompts - Webhook escalation module with exponential backoff retry (max 3) - Three vertical industry templates (e-commerce, SaaS, fintech) - Template loading in AgentRegistry - Enhanced supervisor prompt with dynamic agent descriptions - 153 tests passing, 90.18% coverage
This commit is contained in:
@@ -10,11 +10,14 @@ from typing import TYPE_CHECKING, Any
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.graph import classify_intent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,6 +34,7 @@ async def handle_user_message(
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
content: str,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Process a user message through the graph and stream results back."""
|
||||
if session_manager.is_expired(thread_id):
|
||||
@@ -39,8 +43,42 @@ async def handle_user_message(
|
||||
return
|
||||
|
||||
session_manager.touch(thread_id)
|
||||
|
||||
# Run intent classification if available (for logging/future multi-intent)
|
||||
classification = await classify_intent(graph, content)
|
||||
if classification is not None:
|
||||
logger.info(
|
||||
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
||||
thread_id,
|
||||
classification.is_ambiguous,
|
||||
[i.agent_name for i in classification.intents],
|
||||
)
|
||||
|
||||
# If ambiguous, send clarification and return
|
||||
if classification.is_ambiguous and classification.clarification_question:
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "clarification",
|
||||
"thread_id": thread_id,
|
||||
"message": classification.clarification_question,
|
||||
},
|
||||
)
|
||||
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||
return
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||
|
||||
# If multi-intent detected, add routing hint to the message
|
||||
if classification and len(classification.intents) > 1:
|
||||
agent_names = [i.agent_name for i in classification.intents]
|
||||
hint = (
|
||||
f"\n[System: This request involves multiple actions. "
|
||||
f"Execute in order: {', '.join(agent_names)}]"
|
||||
)
|
||||
input_msg = {"messages": [HumanMessage(content=content + hint)]}
|
||||
else:
|
||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||
@@ -72,6 +110,15 @@ async def handle_user_message(
|
||||
if _has_interrupt(state):
|
||||
interrupt_data = _extract_interrupt(state)
|
||||
session_manager.extend_for_interrupt(thread_id)
|
||||
|
||||
# Register interrupt with TTL tracking
|
||||
if interrupt_manager is not None:
|
||||
interrupt_manager.register(
|
||||
thread_id=thread_id,
|
||||
action=interrupt_data.get("action", "unknown"),
|
||||
params=interrupt_data.get("params", {}),
|
||||
)
|
||||
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
@@ -96,8 +143,21 @@ async def handle_interrupt_response(
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
approved: bool,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Resume graph execution after interrupt approval/rejection."""
|
||||
# Check interrupt TTL before resuming
|
||||
if interrupt_manager is not None:
|
||||
status = interrupt_manager.check_status(thread_id)
|
||||
if status is not None and status.is_expired:
|
||||
retry_prompt = interrupt_manager.generate_retry_prompt(status.record)
|
||||
interrupt_manager.resolve(thread_id)
|
||||
session_manager.resolve_interrupt(thread_id)
|
||||
await _send_json(ws, retry_prompt)
|
||||
return
|
||||
|
||||
interrupt_manager.resolve(thread_id)
|
||||
|
||||
session_manager.resolve_interrupt(thread_id)
|
||||
session_manager.touch(thread_id)
|
||||
|
||||
@@ -136,6 +196,7 @@ async def dispatch_message(
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
raw_data: str,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Parse and route an incoming WebSocket message."""
|
||||
if len(raw_data) > MAX_MESSAGE_SIZE:
|
||||
@@ -167,12 +228,16 @@ async def dispatch_message(
|
||||
if len(content) > MAX_CONTENT_LENGTH:
|
||||
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
||||
return
|
||||
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
|
||||
await handle_user_message(
|
||||
ws, graph, session_manager, callback_handler, thread_id, content,
|
||||
interrupt_manager=interrupt_manager,
|
||||
)
|
||||
|
||||
elif msg_type == "interrupt_response":
|
||||
approved = data.get("approved", False)
|
||||
await handle_interrupt_response(
|
||||
ws, graph, session_manager, callback_handler, thread_id, approved
|
||||
ws, graph, session_manager, callback_handler, thread_id, approved,
|
||||
interrupt_manager=interrupt_manager,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user