feat: complete phase 2 -- multi-agent routing, interrupt TTL, escalation, templates

- Intent classification with LLM structured output (single/multi/ambiguous)
- Discount agent with apply_discount and generate_coupon tools
- Interrupt manager with 30-min TTL auto-expiration and retry prompts
- Webhook escalation module with exponential backoff retry (max 3)
- Three vertical industry templates (e-commerce, SaaS, fintech)
- Template loading in AgentRegistry
- Enhanced supervisor prompt with dynamic agent descriptions
- 153 tests passing, 90.18% coverage
This commit is contained in:
Yaojia Wang
2026-03-30 21:04:39 +02:00
parent 7c3571b47d
commit 1050df780d
27 changed files with 1683 additions and 43 deletions

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_core.tools import BaseTool
from app.agents.discount import apply_discount, generate_coupon
from app.agents.fallback import fallback_respond
from app.agents.order_actions import cancel_order
from app.agents.order_lookup import get_order_status, get_tracking_info
@@ -16,6 +17,8 @@ _TOOL_MAP: dict[str, BaseTool] = {
"get_tracking_info": get_tracking_info,
"cancel_order": cancel_order,
"fallback_respond": fallback_respond,
"apply_discount": apply_discount,
"generate_coupon": generate_coupon,
}

View File

@@ -0,0 +1,79 @@
"""Discount agent tools -- apply discounts and generate coupons."""
from __future__ import annotations
import uuid
from langchain_core.tools import tool
from langgraph.types import interrupt
@tool
def apply_discount(order_id: str, discount_percent: int) -> dict:
"""Apply a discount to an order. Requires human approval before execution."""
if discount_percent < 1 or discount_percent > 100:
return {
"status": "error",
"order_id": order_id,
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
}
response = interrupt(
{
"action": "apply_discount",
"order_id": order_id,
"discount_percent": discount_percent,
"message": (
f"Please confirm: apply {discount_percent}% discount to order {order_id}?"
),
}
)
if isinstance(response, bool):
approved = response
elif isinstance(response, dict):
approved = response.get("approved", False)
else:
approved = bool(response)
if approved:
return {
"status": "applied",
"order_id": order_id,
"discount_percent": discount_percent,
"message": (
f"{discount_percent}% discount applied to order {order_id}."
),
}
return {
"status": "declined",
"order_id": order_id,
"message": f"Discount for order {order_id} was declined.",
}
@tool
def generate_coupon(discount_percent: int, expiry_days: int = 30) -> dict:
"""Generate a coupon code with the specified discount percentage."""
if discount_percent < 1 or discount_percent > 100:
return {
"status": "error",
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
}
if expiry_days < 1:
return {
"status": "error",
"message": f"Invalid expiry: {expiry_days} days. Must be at least 1.",
}
coupon_code = f"SAVE{discount_percent}-{uuid.uuid4().hex[:8].upper()}"
return {
"status": "generated",
"coupon_code": coupon_code,
"discount_percent": discount_percent,
"expiry_days": expiry_days,
"message": (
f"Coupon {coupon_code} generated: {discount_percent}% off, "
f"valid for {expiry_days} days."
),
}

View File

@@ -1,4 +1,4 @@
"""Fallback agent tools -- handles unmatched intents."""
"""Fallback agent tools -- handles unmatched intents and clarification requests."""
from __future__ import annotations
@@ -13,6 +13,7 @@ def fallback_respond(query: str) -> str:
"Here's what I can do:\n"
"- Check order status (e.g., 'What is the status of order 1042?')\n"
"- Get tracking information (e.g., 'Track order 1042')\n"
"- Cancel an order (e.g., 'Cancel order 1042')\n\n"
"- Cancel an order (e.g., 'Cancel order 1042')\n"
"- Apply discounts or generate coupons\n\n"
"Could you please rephrase your request?"
)

View File

@@ -26,6 +26,12 @@ class Settings(BaseSettings):
ws_host: str = "0.0.0.0"
ws_port: int = 8000
webhook_url: str = ""
webhook_timeout_seconds: int = 10
webhook_max_retries: int = 3
template_name: str = ""
anthropic_api_key: str = ""
openai_api_key: str = ""
google_api_key: str = ""

140
backend/app/escalation.py Normal file
View File

@@ -0,0 +1,140 @@
"""Webhook escalation module -- HTTP POST with exponential backoff retry."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Protocol
import httpx
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class EscalationPayload(BaseModel, frozen=True):
"""Immutable payload sent to the escalation webhook."""
thread_id: str
reason: str
conversation_summary: str
metadata: dict = {}
@dataclass(frozen=True)
class EscalationResult:
"""Immutable result of an escalation attempt."""
success: bool
status_code: int | None
attempts: int
error: str | None
class EscalationService(Protocol):
"""Protocol for escalation implementations."""
async def escalate(self, payload: EscalationPayload) -> EscalationResult: ...
class WebhookEscalator:
"""Sends escalation requests via HTTP POST with exponential backoff retry."""
def __init__(
self,
url: str,
timeout_seconds: int = 10,
max_retries: int = 3,
) -> None:
self._url = url
self._timeout = timeout_seconds
self._max_retries = max_retries
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
"""POST the escalation payload to the configured webhook URL."""
if not self._url:
return EscalationResult(
success=False,
status_code=None,
attempts=0,
error="Webhook URL not configured",
)
last_error: str | None = None
async with httpx.AsyncClient(timeout=self._timeout) as client:
for attempt in range(1, self._max_retries + 1):
try:
response = await client.post(
self._url,
json=payload.model_dump(),
)
if 200 <= response.status_code < 300:
logger.info(
"Escalation succeeded for thread %s (attempt %d)",
payload.thread_id,
attempt,
)
return EscalationResult(
success=True,
status_code=response.status_code,
attempts=attempt,
error=None,
)
last_error = f"HTTP {response.status_code}"
logger.warning(
"Escalation attempt %d/%d failed: %s",
attempt,
self._max_retries,
last_error,
)
except httpx.TimeoutException:
last_error = "Request timed out"
logger.warning(
"Escalation attempt %d/%d timed out",
attempt,
self._max_retries,
)
except httpx.RequestError as exc:
last_error = str(exc)
logger.warning(
"Escalation attempt %d/%d error: %s",
attempt,
self._max_retries,
last_error,
)
# Exponential backoff: skip delay after last attempt
if attempt < self._max_retries:
delay = 2**attempt
await asyncio.sleep(delay)
logger.error(
"Escalation failed for thread %s after %d attempts: %s",
payload.thread_id,
self._max_retries,
last_error,
)
return EscalationResult(
success=False,
status_code=None,
attempts=self._max_retries,
error=last_error,
)
class NoOpEscalator:
"""Escalator that does nothing -- used when webhook URL is not configured."""
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
logger.info("Escalation disabled (no webhook URL). Thread: %s", payload.thread_id)
return EscalationResult(
success=False,
status_code=None,
attempts=0,
error="Escalation disabled",
)

View File

@@ -2,12 +2,14 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from langgraph.prebuilt import create_react_agent
from langgraph_supervisor import create_supervisor
from app.agents import get_tools_by_names
from app.intent import ClassificationResult, IntentClassifier
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
@@ -16,15 +18,31 @@ if TYPE_CHECKING:
from app.registry import AgentRegistry
logger = logging.getLogger(__name__)
SUPERVISOR_PROMPT = (
"You are a customer support supervisor. "
"Route customer requests to the appropriate agent based on their description. "
"For order status and tracking queries, use the order_lookup agent. "
"For order modifications like cancellations, use the order_actions agent. "
"For anything else, use the fallback agent."
"Route customer requests to the appropriate agent based on their description.\n\n"
"Available agents and their roles:\n"
"{agent_descriptions}\n\n"
"Routing rules:\n"
"- For order status and tracking queries, use the order_lookup agent.\n"
"- For order modifications like cancellations, use the order_actions agent.\n"
"- For discounts, promotions, or coupon codes, use the discount agent.\n"
"- For anything else or when uncertain, use the fallback agent.\n"
"- If the user's request involves multiple actions, execute them in order.\n"
"- If a previous intent classification is provided, follow it.\n"
)
def _format_agent_descriptions(registry: AgentRegistry) -> str:
"""Build agent description text for the supervisor prompt."""
lines = []
for agent in registry.list_agents():
lines.append(f"- {agent.name}: {agent.description}")
return "\n".join(lines)
def build_agent_nodes(
registry: AgentRegistry,
llm: BaseChatModel,
@@ -56,15 +74,48 @@ def build_graph(
registry: AgentRegistry,
llm: BaseChatModel,
checkpointer: AsyncPostgresSaver,
intent_classifier: IntentClassifier | None = None,
) -> CompiledStateGraph:
"""Build and compile the LangGraph supervisor graph."""
"""Build and compile the LangGraph supervisor graph.
If an intent_classifier is provided, the supervisor prompt is enhanced
with agent descriptions for better routing. The classifier is stored
for use by the routing layer (ws_handler).
"""
agent_nodes = build_agent_nodes(registry, llm)
agent_descriptions = _format_agent_descriptions(registry)
prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions)
workflow = create_supervisor(
agent_nodes,
model=llm,
prompt=SUPERVISOR_PROMPT,
prompt=prompt,
output_mode="full_history",
)
return workflow.compile(checkpointer=checkpointer)
graph = workflow.compile(checkpointer=checkpointer)
# Attach classifier and registry to graph for use by ws_handler
graph.intent_classifier = intent_classifier # type: ignore[attr-defined]
graph.agent_registry = registry # type: ignore[attr-defined]
return graph
async def classify_intent(
graph: CompiledStateGraph,
message: str,
) -> ClassificationResult | None:
"""Classify user intent using the graph's attached classifier.
Returns None if no classifier is configured.
"""
classifier = getattr(graph, "intent_classifier", None)
registry = getattr(graph, "agent_registry", None)
if classifier is None or registry is None:
return None
agents = registry.list_agents()
return await classifier.classify(message, agents)

118
backend/app/intent.py Normal file
View File

@@ -0,0 +1,118 @@
"""Intent classification using LLM structured output."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Protocol
from pydantic import BaseModel
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
from app.registry import AgentConfig
logger = logging.getLogger(__name__)
CLASSIFICATION_PROMPT = (
"You are an intent classifier for a customer support system.\n"
"Given a user message, determine which agent(s) should handle it.\n\n"
"Available agents:\n{agent_list}\n\n"
"Rules:\n"
"- If the message clearly maps to one agent, return a single intent.\n"
"- If the message contains multiple distinct requests, return multiple intents "
"in execution order.\n"
"- If the message is vague or doesn't match any agent, set is_ambiguous=True "
"and provide a clarification_question.\n"
"- Never route to the fallback agent unless truly ambiguous.\n"
"- confidence should be between 0.0 and 1.0.\n"
)
AMBIGUITY_THRESHOLD = 0.5
class IntentTarget(BaseModel, frozen=True):
"""A single classified intent targeting a specific agent."""
agent_name: str
confidence: float
reasoning: str
class ClassificationResult(BaseModel, frozen=True):
"""Result of intent classification -- may contain multiple intents."""
intents: tuple[IntentTarget, ...]
is_ambiguous: bool = False
clarification_question: str | None = None
class IntentClassifier(Protocol):
"""Protocol for intent classification implementations."""
async def classify(
self,
message: str,
available_agents: tuple[AgentConfig, ...],
) -> ClassificationResult: ...
def _build_agent_list(agents: tuple[AgentConfig, ...]) -> str:
"""Format agent descriptions for the classification prompt."""
lines = []
for agent in agents:
lines.append(f"- {agent.name}: {agent.description} (permission: {agent.permission})")
return "\n".join(lines)
class LLMIntentClassifier:
"""Classifies user intent using LLM structured output."""
def __init__(self, llm: BaseChatModel) -> None:
self._llm = llm
async def classify(
self,
message: str,
available_agents: tuple[AgentConfig, ...],
) -> ClassificationResult:
"""Classify user message into one or more agent intents."""
agent_list = _build_agent_list(available_agents)
system_prompt = CLASSIFICATION_PROMPT.format(agent_list=agent_list)
structured_llm = self._llm.with_structured_output(ClassificationResult)
try:
result = await structured_llm.ainvoke(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
]
)
except Exception:
logger.exception("Intent classification failed, returning ambiguous")
return ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="I'm not sure I understood. Could you please rephrase?",
)
if not isinstance(result, ClassificationResult):
return ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="I'm not sure I understood. Could you please rephrase?",
)
# Apply ambiguity threshold
if result.intents and all(i.confidence < AMBIGUITY_THRESHOLD for i in result.intents):
return ClassificationResult(
intents=result.intents,
is_ambiguous=True,
clarification_question=(
result.clarification_question
or "I'm not sure I understood. Could you please rephrase?"
),
)
return result

View File

@@ -0,0 +1,115 @@
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration."""
from __future__ import annotations
import time
import uuid
from dataclasses import dataclass
@dataclass(frozen=True)
class InterruptRecord:
"""Immutable record of a pending interrupt."""
interrupt_id: str
thread_id: str
action: str
params: dict
created_at: float
ttl_seconds: int
@dataclass(frozen=True)
class InterruptStatus:
"""Current status of a tracked interrupt."""
is_expired: bool
remaining_seconds: float
record: InterruptRecord
class InterruptManager:
"""Manages interrupt TTL with auto-expiration.
Complements SessionManager -- this tracks interrupt-specific TTL
while SessionManager handles session-level TTL.
"""
def __init__(self, ttl_seconds: int = 1800) -> None:
self._ttl_seconds = ttl_seconds
self._interrupts: dict[str, InterruptRecord] = {}
def register(
self,
thread_id: str,
action: str,
params: dict,
) -> InterruptRecord:
"""Register a new pending interrupt with TTL tracking."""
record = InterruptRecord(
interrupt_id=uuid.uuid4().hex,
thread_id=thread_id,
action=action,
params=dict(params),
created_at=time.time(),
ttl_seconds=self._ttl_seconds,
)
self._interrupts = {**self._interrupts, thread_id: record}
return record
def check_status(self, thread_id: str) -> InterruptStatus | None:
"""Check the TTL status of a pending interrupt."""
record = self._interrupts.get(thread_id)
if record is None:
return None
elapsed = time.time() - record.created_at
remaining = max(0.0, record.ttl_seconds - elapsed)
is_expired = elapsed > record.ttl_seconds
return InterruptStatus(
is_expired=is_expired,
remaining_seconds=remaining,
record=record,
)
def resolve(self, thread_id: str) -> None:
"""Remove a resolved interrupt from tracking."""
self._interrupts = {
k: v for k, v in self._interrupts.items() if k != thread_id
}
def cleanup_expired(self) -> tuple[InterruptRecord, ...]:
"""Find and remove all expired interrupts. Returns the expired records."""
now = time.time()
expired: list[InterruptRecord] = []
active: dict[str, InterruptRecord] = {}
for thread_id, record in self._interrupts.items():
if now - record.created_at > record.ttl_seconds:
expired.append(record)
else:
active[thread_id] = record
self._interrupts = active
return tuple(expired)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action."""
return {
"type": "interrupt_expired",
"thread_id": expired_record.thread_id,
"action": expired_record.action,
"message": (
f"The approval request for '{expired_record.action}' has expired "
f"after {expired_record.ttl_seconds // 60} minutes. "
f"Would you like to try again?"
),
}
def has_pending(self, thread_id: str) -> bool:
"""Check if a thread has a pending (non-expired) interrupt."""
status = self.check_status(thread_id)
if status is None:
return False
return not status.is_expired

View File

@@ -13,7 +13,10 @@ from fastapi.staticfiles import StaticFiles
from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings
from app.db import create_checkpointer, create_pool, setup_app_tables
from app.escalation import NoOpEscalator, WebhookEscalator
from app.graph import build_graph
from app.intent import LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.llm import create_llm
from app.registry import AgentRegistry
from app.session_manager import SessionManager
@@ -36,23 +39,46 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
checkpointer = await create_checkpointer(pool)
await setup_app_tables(pool)
registry = AgentRegistry.load(AGENTS_YAML)
# Load agents from template or default YAML
if settings.template_name:
registry = AgentRegistry.load_template(settings.template_name)
else:
registry = AgentRegistry.load(AGENTS_YAML)
llm = create_llm(settings)
graph = build_graph(registry, llm, checkpointer)
intent_classifier = LLMIntentClassifier(llm)
graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
session_manager = SessionManager(
session_ttl_seconds=settings.session_ttl_minutes * 60,
)
interrupt_manager = InterruptManager(
ttl_seconds=settings.interrupt_ttl_minutes * 60,
)
# Configure escalation
if settings.webhook_url:
escalator = WebhookEscalator(
url=settings.webhook_url,
timeout_seconds=settings.webhook_timeout_seconds,
max_retries=settings.webhook_max_retries,
)
else:
escalator = NoOpEscalator()
app.state.graph = graph
app.state.session_manager = session_manager
app.state.interrupt_manager = interrupt_manager
app.state.escalator = escalator
app.state.settings = settings
app.state.pool = pool
logger.info(
"Smart Support started: %d agents loaded, LLM=%s/%s",
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
len(registry),
settings.llm_provider,
settings.llm_model,
settings.template_name or "(default)",
)
yield
@@ -60,7 +86,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await pool.close()
app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan)
app = FastAPI(title="Smart Support", version="0.2.0", lifespan=lifespan)
@app.websocket("/ws")
@@ -68,13 +94,17 @@ async def websocket_endpoint(ws: WebSocket) -> None:
await ws.accept()
graph = app.state.graph
session_manager = app.state.session_manager
interrupt_manager = app.state.interrupt_manager
settings = app.state.settings
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
try:
while True:
raw_data = await ws.receive_text()
await dispatch_message(ws, graph, session_manager, callback_handler, raw_data)
await dispatch_message(
ws, graph, session_manager, callback_handler, raw_data,
interrupt_manager=interrupt_manager,
)
except WebSocketDisconnect:
logger.info("WebSocket client disconnected")

View File

@@ -100,5 +100,41 @@ class AgentRegistry:
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
return tuple(a for a in self._agents.values() if a.permission == permission)
@classmethod
def load_template(
cls,
template_name: str,
templates_dir: str | Path | None = None,
) -> AgentRegistry:
"""Load agent configurations from a named template."""
if templates_dir is None:
templates_dir = Path(__file__).parent.parent / "templates"
templates_dir = Path(templates_dir)
yaml_path = templates_dir / f"{template_name}.yaml"
if not yaml_path.exists():
available = cls.list_templates(templates_dir)
raise FileNotFoundError(
f"Template '{template_name}' not found. "
f"Available: {', '.join(available) if available else 'none'}"
)
return cls.load(yaml_path)
@classmethod
def list_templates(
cls,
templates_dir: str | Path | None = None,
) -> tuple[str, ...]:
"""List available template names from the templates directory."""
if templates_dir is None:
templates_dir = Path(__file__).parent.parent / "templates"
templates_dir = Path(templates_dir)
if not templates_dir.is_dir():
return ()
return tuple(
sorted(p.stem for p in templates_dir.glob("*.yaml"))
)
def __len__(self) -> int:
return len(self._agents)

View File

@@ -10,11 +10,14 @@ from typing import TYPE_CHECKING, Any
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from app.graph import classify_intent
if TYPE_CHECKING:
from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.callbacks import TokenUsageCallbackHandler
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
logger = logging.getLogger(__name__)
@@ -31,6 +34,7 @@ async def handle_user_message(
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
content: str,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Process a user message through the graph and stream results back."""
if session_manager.is_expired(thread_id):
@@ -39,8 +43,42 @@ async def handle_user_message(
return
session_manager.touch(thread_id)
# Run intent classification if available (for logging/future multi-intent)
classification = await classify_intent(graph, content)
if classification is not None:
logger.info(
"Intent classification for thread %s: ambiguous=%s, intents=%s",
thread_id,
classification.is_ambiguous,
[i.agent_name for i in classification.intents],
)
# If ambiguous, send clarification and return
if classification.is_ambiguous and classification.clarification_question:
await _send_json(
ws,
{
"type": "clarification",
"thread_id": thread_id,
"message": classification.clarification_question,
},
)
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
return
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
input_msg = {"messages": [HumanMessage(content=content)]}
# If multi-intent detected, add routing hint to the message
if classification and len(classification.intents) > 1:
agent_names = [i.agent_name for i in classification.intents]
hint = (
f"\n[System: This request involves multiple actions. "
f"Execute in order: {', '.join(agent_names)}]"
)
input_msg = {"messages": [HumanMessage(content=content + hint)]}
else:
input_msg = {"messages": [HumanMessage(content=content)]}
try:
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
@@ -72,6 +110,15 @@ async def handle_user_message(
if _has_interrupt(state):
interrupt_data = _extract_interrupt(state)
session_manager.extend_for_interrupt(thread_id)
# Register interrupt with TTL tracking
if interrupt_manager is not None:
interrupt_manager.register(
thread_id=thread_id,
action=interrupt_data.get("action", "unknown"),
params=interrupt_data.get("params", {}),
)
await _send_json(
ws,
{
@@ -96,8 +143,21 @@ async def handle_interrupt_response(
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
approved: bool,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Resume graph execution after interrupt approval/rejection."""
# Check interrupt TTL before resuming
if interrupt_manager is not None:
status = interrupt_manager.check_status(thread_id)
if status is not None and status.is_expired:
retry_prompt = interrupt_manager.generate_retry_prompt(status.record)
interrupt_manager.resolve(thread_id)
session_manager.resolve_interrupt(thread_id)
await _send_json(ws, retry_prompt)
return
interrupt_manager.resolve(thread_id)
session_manager.resolve_interrupt(thread_id)
session_manager.touch(thread_id)
@@ -136,6 +196,7 @@ async def dispatch_message(
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
raw_data: str,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Parse and route an incoming WebSocket message."""
if len(raw_data) > MAX_MESSAGE_SIZE:
@@ -167,12 +228,16 @@ async def dispatch_message(
if len(content) > MAX_CONTENT_LENGTH:
await _send_json(ws, {"type": "error", "message": "Message content too long"})
return
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
await handle_user_message(
ws, graph, session_manager, callback_handler, thread_id, content,
interrupt_manager=interrupt_manager,
)
elif msg_type == "interrupt_response":
approved = data.get("approved", False)
await handle_interrupt_response(
ws, graph, session_manager, callback_handler, thread_id, approved
ws, graph, session_manager, callback_handler, thread_id, approved,
interrupt_manager=interrupt_manager,
)
else: