feat: complete phase 1 -- core framework with chat loop, agents, and React UI

Backend:
- FastAPI WebSocket /ws endpoint with streaming via LangGraph astream
- LangGraph Supervisor connecting 3 mock agents (order_lookup, order_actions, fallback)
- YAML Agent Registry with Pydantic validation and immutable configs
- PostgresSaver checkpoint persistence via langgraph-checkpoint-postgres
- Session TTL with 30-min sliding window and interrupt extension
- LLM provider abstraction (Anthropic/OpenAI/Google)
- Token usage + cost tracking callback handler
- Input validation: message size cap, thread_id format, content length
- Security: no hardcoded defaults, startup API key validation, no input reflection

Frontend:
- React 19 + TypeScript + Vite chat UI
- WebSocket hook with reconnect + exponential backoff
- Streaming token display with agent attribution
- Interrupt approval/reject UI for write operations
- Collapsible tool call viewer

Testing:
- 87 unit tests, 87% coverage (exceeds 80% requirement)
- Ruff lint + format clean

Infrastructure:
- Docker Compose (PostgreSQL 16 + backend)
- pyproject.toml with full dependency management
This commit is contained in:
Yaojia Wang
2026-03-30 00:54:21 +02:00
parent e4f08576a9
commit 33488fd634
51 changed files with 4701 additions and 1 deletions

0
backend/app/__init__.py Normal file
View File

View File

@@ -0,0 +1,30 @@
"""Agent tools registry -- maps tool name strings to actual tool functions."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_core.tools import BaseTool
from app.agents.fallback import fallback_respond
from app.agents.order_actions import cancel_order
from app.agents.order_lookup import get_order_status, get_tracking_info
_TOOL_MAP: dict[str, BaseTool] = {
"get_order_status": get_order_status,
"get_tracking_info": get_tracking_info,
"cancel_order": cancel_order,
"fallback_respond": fallback_respond,
}
def get_tools_by_names(tool_names: list[str]) -> list[BaseTool]:
"""Resolve tool name strings from YAML config to actual tool objects."""
tools = []
for name in tool_names:
if name not in _TOOL_MAP:
available = ", ".join(sorted(_TOOL_MAP.keys()))
raise ValueError(f"Unknown tool '{name}'. Available tools: {available}")
tools.append(_TOOL_MAP[name])
return tools

View File

@@ -0,0 +1,18 @@
"""Fallback agent tools -- handles unmatched intents."""
from __future__ import annotations
from langchain_core.tools import tool
@tool
def fallback_respond(query: str) -> str:
"""Provide a helpful response when the user's intent doesn't match a specific agent."""
return (
"I'm here to help with order inquiries and actions. "
"Here's what I can do:\n"
"- Check order status (e.g., 'What is the status of order 1042?')\n"
"- Get tracking information (e.g., 'Track order 1042')\n"
"- Cancel an order (e.g., 'Cancel order 1042')\n\n"
"Could you please rephrase your request?"
)

View File

@@ -0,0 +1,37 @@
"""Order action tools -- write operations requiring human approval."""
from __future__ import annotations
from langchain_core.tools import tool
from langgraph.types import interrupt
@tool
def cancel_order(order_id: str) -> dict:
"""Cancel an order. Requires human approval before execution."""
response = interrupt(
{
"action": "cancel_order",
"order_id": order_id,
"message": f"Please confirm: cancel order {order_id}?",
}
)
if isinstance(response, bool):
approved = response
elif isinstance(response, dict):
approved = response.get("approved", False)
else:
approved = bool(response)
if approved:
return {
"status": "cancelled",
"order_id": order_id,
"message": f"Order {order_id} has been successfully cancelled.",
}
return {
"status": "kept",
"order_id": order_id,
"message": f"Order {order_id} cancellation was declined. The order remains active.",
}

View File

@@ -0,0 +1,68 @@
"""Order lookup tools -- read-only operations."""
from __future__ import annotations
from types import MappingProxyType
from langchain_core.tools import tool
MOCK_ORDERS: MappingProxyType[str, dict] = MappingProxyType(
{
"1042": {
"order_id": "1042",
"status": "shipped",
"items": ["Wireless Headphones", "USB-C Cable"],
"total": 89.99,
"placed_at": "2026-03-25",
},
"1043": {
"order_id": "1043",
"status": "processing",
"items": ["Laptop Stand"],
"total": 49.99,
"placed_at": "2026-03-28",
},
"1044": {
"order_id": "1044",
"status": "delivered",
"items": ["Mechanical Keyboard", "Mouse Pad"],
"total": 159.99,
"placed_at": "2026-03-20",
},
}
)
MOCK_TRACKING: MappingProxyType[str, dict] = MappingProxyType(
{
"1042": {
"order_id": "1042",
"carrier": "FedEx",
"tracking_number": "FX-9876543210",
"estimated_delivery": "2026-04-01",
"current_location": "Distribution Center, Chicago IL",
},
"1044": {
"order_id": "1044",
"carrier": "UPS",
"tracking_number": "1Z-5678901234",
"estimated_delivery": "2026-03-22",
"current_location": "Delivered",
},
}
)
@tool
def get_order_status(order_id: str) -> dict:
"""Look up the current status of an order by order ID."""
if order_id in MOCK_ORDERS:
return dict(MOCK_ORDERS[order_id])
return {"error": f"Order {order_id} not found", "order_id": order_id}
@tool
def get_tracking_info(order_id: str) -> dict:
"""Get shipping and tracking information for an order."""
if order_id in MOCK_TRACKING:
return dict(MOCK_TRACKING[order_id])
return {"error": f"No tracking information for order {order_id}", "order_id": order_id}

60
backend/app/callbacks.py Normal file
View File

@@ -0,0 +1,60 @@
"""Token usage tracking callback handler."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from langchain_core.callbacks import BaseCallbackHandler
if TYPE_CHECKING:
from langchain_core.outputs import LLMResult
COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
"claude-sonnet-4-6": {"prompt": 0.003, "completion": 0.015},
"claude-haiku-4-5-20251001": {"prompt": 0.0008, "completion": 0.004},
"gpt-4o": {"prompt": 0.0025, "completion": 0.01},
"gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
}
DEFAULT_COST = {"prompt": 0.003, "completion": 0.015}
@dataclass(frozen=True)
class TokenUsage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
total_cost_usd: float
class TokenUsageCallbackHandler(BaseCallbackHandler):
"""Accumulates token usage and cost across LLM invocations."""
def __init__(self, model_name: str = "") -> None:
self._model_name = model_name
self._prompt_tokens = 0
self._completion_tokens = 0
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if response.llm_output and "token_usage" in response.llm_output:
usage = response.llm_output["token_usage"]
self._prompt_tokens += usage.get("prompt_tokens", 0)
self._completion_tokens += usage.get("completion_tokens", 0)
def get_usage(self) -> TokenUsage:
costs = COST_PER_1K_TOKENS.get(self._model_name, DEFAULT_COST)
cost = (
self._prompt_tokens * costs["prompt"] / 1000
+ self._completion_tokens * costs["completion"] / 1000
)
return TokenUsage(
prompt_tokens=self._prompt_tokens,
completion_tokens=self._completion_tokens,
total_tokens=self._prompt_tokens + self._completion_tokens,
total_cost_usd=round(cost, 6),
)
def reset(self) -> None:
self._prompt_tokens = 0
self._completion_tokens = 0

46
backend/app/config.py Normal file
View File

@@ -0,0 +1,46 @@
"""Centralized application configuration via pydantic-settings."""
from __future__ import annotations
from typing import Literal
from pydantic import model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
database_url: str
llm_provider: Literal["anthropic", "openai", "google"] = "anthropic"
llm_model: str = "claude-sonnet-4-6"
session_ttl_minutes: int = 30
interrupt_ttl_minutes: int = 30
ws_host: str = "0.0.0.0"
ws_port: int = 8000
anthropic_api_key: str = ""
openai_api_key: str = ""
google_api_key: str = ""
@model_validator(mode="after")
def validate_provider_key(self) -> Settings:
key_map = {
"anthropic": self.anthropic_api_key,
"openai": self.openai_api_key,
"google": self.google_api_key,
}
key = key_map.get(self.llm_provider, "")
if not key:
raise ValueError(
f"API key for provider '{self.llm_provider}' is required. "
f"Set the corresponding environment variable."
)
return self

61
backend/app/db.py Normal file
View File

@@ -0,0 +1,61 @@
"""Database connection pool and PostgresSaver checkpoint setup."""
from __future__ import annotations
from typing import TYPE_CHECKING
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
if TYPE_CHECKING:
from app.config import Settings
_CONVERSATIONS_DDL = """
CREATE TABLE IF NOT EXISTS conversations (
thread_id TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
total_tokens INTEGER NOT NULL DEFAULT 0,
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
status TEXT NOT NULL DEFAULT 'active'
);
"""
_INTERRUPTS_DDL = """
CREATE TABLE IF NOT EXISTS active_interrupts (
interrupt_id TEXT PRIMARY KEY,
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
action TEXT NOT NULL,
params JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
resolved_at TIMESTAMPTZ,
resolution TEXT
);
"""
async def create_pool(settings: Settings) -> AsyncConnectionPool:
"""Create an async connection pool with the required psycopg settings."""
pool = AsyncConnectionPool(
conninfo=settings.database_url,
kwargs={"autocommit": True, "row_factory": dict_row},
min_size=2,
max_size=10,
)
await pool.open()
return pool
async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
"""Create and initialize the LangGraph checkpointer."""
checkpointer = AsyncPostgresSaver(conn=pool)
await checkpointer.setup()
return checkpointer
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
"""Create application-specific tables (conversations, active_interrupts)."""
async with pool.connection() as conn:
await conn.execute(_CONVERSATIONS_DDL)
await conn.execute(_INTERRUPTS_DDL)

70
backend/app/graph.py Normal file
View File

@@ -0,0 +1,70 @@
"""LangGraph Supervisor construction -- connects registry, agents, LLM, and persistence."""
from __future__ import annotations
from typing import TYPE_CHECKING
from langgraph.prebuilt import create_react_agent
from langgraph_supervisor import create_supervisor
from app.agents import get_tools_by_names
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph.state import CompiledStateGraph
from app.registry import AgentRegistry
SUPERVISOR_PROMPT = (
"You are a customer support supervisor. "
"Route customer requests to the appropriate agent based on their description. "
"For order status and tracking queries, use the order_lookup agent. "
"For order modifications like cancellations, use the order_actions agent. "
"For anything else, use the fallback agent."
)
def build_agent_nodes(
registry: AgentRegistry,
llm: BaseChatModel,
) -> list:
"""Create LangGraph react agent nodes from registry configurations."""
agent_nodes = []
for agent_config in registry.list_agents():
tools = get_tools_by_names(agent_config.tools)
system_prompt = (
f"You are the {agent_config.name} agent. "
f"Personality: {agent_config.personality.tone}. "
f"{agent_config.description} "
f"Permission level: {agent_config.permission}."
)
agent_node = create_react_agent(
model=llm,
tools=tools,
name=agent_config.name,
prompt=system_prompt,
)
agent_nodes.append(agent_node)
return agent_nodes
def build_graph(
registry: AgentRegistry,
llm: BaseChatModel,
checkpointer: AsyncPostgresSaver,
) -> CompiledStateGraph:
"""Build and compile the LangGraph supervisor graph."""
agent_nodes = build_agent_nodes(registry, llm)
workflow = create_supervisor(
agent_nodes,
model=llm,
prompt=SUPERVISOR_PROMPT,
output_mode="full_history",
)
return workflow.compile(checkpointer=checkpointer)

42
backend/app/llm.py Normal file
View File

@@ -0,0 +1,42 @@
"""LLM provider factory with prompt caching support."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
from app.config import Settings
def create_llm(settings: Settings) -> BaseChatModel:
"""Create an LLM instance based on the configured provider."""
provider = settings.llm_provider
model = settings.llm_model
if provider == "anthropic":
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(
model=model,
api_key=settings.anthropic_api_key,
)
if provider == "openai":
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=model,
api_key=settings.openai_api_key,
)
if provider == "google":
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=model,
google_api_key=settings.google_api_key,
)
raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.")

83
backend/app/main.py Normal file
View File

@@ -0,0 +1,83 @@
"""FastAPI application entry point."""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings
from app.db import create_checkpointer, create_pool, setup_app_tables
from app.graph import build_graph
from app.llm import create_llm
from app.registry import AgentRegistry
from app.session_manager import SessionManager
from app.ws_handler import dispatch_message
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
logger = logging.getLogger(__name__)
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
settings = Settings()
pool = await create_pool(settings)
checkpointer = await create_checkpointer(pool)
await setup_app_tables(pool)
registry = AgentRegistry.load(AGENTS_YAML)
llm = create_llm(settings)
graph = build_graph(registry, llm, checkpointer)
session_manager = SessionManager(
session_ttl_seconds=settings.session_ttl_minutes * 60,
)
app.state.graph = graph
app.state.session_manager = session_manager
app.state.settings = settings
app.state.pool = pool
logger.info(
"Smart Support started: %d agents loaded, LLM=%s/%s",
len(registry),
settings.llm_provider,
settings.llm_model,
)
yield
await pool.close()
app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan)
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket) -> None:
await ws.accept()
graph = app.state.graph
session_manager = app.state.session_manager
settings = app.state.settings
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
try:
while True:
raw_data = await ws.receive_text()
await dispatch_message(ws, graph, session_manager, callback_handler, raw_data)
except WebSocketDisconnect:
logger.info("WebSocket client disconnected")
if FRONTEND_DIST.is_dir():
app.mount("/", StaticFiles(directory=str(FRONTEND_DIST), html=True), name="frontend")

104
backend/app/registry.py Normal file
View File

@@ -0,0 +1,104 @@
"""YAML Agent registry loader with validation."""
from __future__ import annotations
from pathlib import Path
from typing import Literal
import yaml
from pydantic import BaseModel, field_validator
class PersonalityConfig(BaseModel, frozen=True):
tone: str = "professional and helpful"
greeting: str = "Hello! How can I help you today?"
escalation_message: str = "Let me connect you with a human agent."
class AgentConfig(BaseModel, frozen=True):
name: str
description: str
permission: Literal["read", "write"]
personality: PersonalityConfig = PersonalityConfig()
tools: list[str]
@field_validator("name")
@classmethod
def name_not_empty(cls, v: str) -> str:
if not v.strip():
raise ValueError("Agent name must not be empty")
return v.strip()
@field_validator("tools")
@classmethod
def tools_not_empty(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("Agent must have at least one tool")
return v
class AgentRegistry:
"""Immutable registry of agent configurations loaded from YAML."""
def __init__(self, agents: tuple[AgentConfig, ...]) -> None:
self._agents = {agent.name: agent for agent in agents}
@classmethod
def load(cls, yaml_path: str | Path) -> AgentRegistry:
"""Load and validate agent configurations from a YAML file."""
path = Path(yaml_path)
if not path.exists():
raise FileNotFoundError(f"Agent config file not found: {path}")
raw_text = path.read_text(encoding="utf-8")
if not raw_text.strip():
raise ValueError(f"Agent config file is empty: {path}")
try:
data = yaml.safe_load(raw_text)
except yaml.YAMLError as exc:
msg = f"Invalid YAML in {path}"
if hasattr(exc, "problem_mark") and exc.problem_mark is not None:
mark = exc.problem_mark
msg += f" at line {mark.line + 1}, column {mark.column + 1}"
raise ValueError(msg) from exc
if not isinstance(data, dict) or "agents" not in data:
raise ValueError(f"Agent config must have a top-level 'agents' key in {path}")
raw_agents = data["agents"]
if not isinstance(raw_agents, list) or not raw_agents:
raise ValueError(f"'agents' must be a non-empty list in {path}")
agents: list[AgentConfig] = []
seen_names: set[str] = set()
for i, raw in enumerate(raw_agents):
if not isinstance(raw, dict):
raise ValueError(f"Agent at index {i} must be a mapping in {path}")
try:
agent = AgentConfig(**raw)
except Exception as exc:
raise ValueError(f"Invalid agent config at index {i} in {path}: {exc}") from exc
if agent.name in seen_names:
raise ValueError(f"Duplicate agent name '{agent.name}' in {path}")
seen_names.add(agent.name)
agents.append(agent)
return cls(agents=tuple(agents))
def get_agent(self, name: str) -> AgentConfig:
if name not in self._agents:
available = ", ".join(sorted(self._agents.keys()))
raise KeyError(f"Agent '{name}' not found. Available: {available}")
return self._agents[name]
def list_agents(self) -> tuple[AgentConfig, ...]:
return tuple(self._agents.values())
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
return tuple(a for a in self._agents.values() if a.permission == permission)
def __len__(self) -> int:
return len(self._agents)

View File

@@ -0,0 +1,78 @@
"""Session TTL management with sliding window and interrupt extension."""
from __future__ import annotations
import time
from dataclasses import dataclass
@dataclass(frozen=True)
class SessionState:
thread_id: str
last_activity: float
has_pending_interrupt: bool
class SessionManager:
"""Manages session TTL with sliding window and interrupt extensions.
- Each message resets the TTL (sliding window).
- A pending interrupt suspends expiration until resolved.
"""
def __init__(self, session_ttl_seconds: int = 1800) -> None:
self._session_ttl = session_ttl_seconds
self._sessions: dict[str, SessionState] = {}
def touch(self, thread_id: str) -> SessionState:
"""Update last activity for a session (resets sliding window)."""
existing = self._sessions.get(thread_id)
new_state = SessionState(
thread_id=thread_id,
last_activity=time.time(),
has_pending_interrupt=existing.has_pending_interrupt if existing else False,
)
self._sessions = {**self._sessions, thread_id: new_state}
return new_state
def is_expired(self, thread_id: str) -> bool:
"""Check if a session has expired."""
state = self._sessions.get(thread_id)
if state is None:
return True
if state.has_pending_interrupt:
return False
elapsed = time.time() - state.last_activity
return elapsed > self._session_ttl
def extend_for_interrupt(self, thread_id: str) -> SessionState:
"""Mark session as having a pending interrupt (suspends TTL)."""
existing = self._sessions.get(thread_id)
if existing is None:
return self.touch(thread_id)
new_state = SessionState(
thread_id=thread_id,
last_activity=existing.last_activity,
has_pending_interrupt=True,
)
self._sessions = {**self._sessions, thread_id: new_state}
return new_state
def resolve_interrupt(self, thread_id: str) -> SessionState:
"""Remove interrupt extension and reset activity timer."""
new_state = SessionState(
thread_id=thread_id,
last_activity=time.time(),
has_pending_interrupt=False,
)
self._sessions = {**self._sessions, thread_id: new_state}
return new_state
def get_state(self, thread_id: str) -> SessionState | None:
return self._sessions.get(thread_id)
def remove(self, thread_id: str) -> None:
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}

204
backend/app/ws_handler.py Normal file
View File

@@ -0,0 +1,204 @@
"""WebSocket message handling logic -- extracted from main for testability."""
from __future__ import annotations
import json
import logging
import re
from typing import TYPE_CHECKING, Any
from langchain_core.messages import HumanMessage
from langgraph.types import Command
if TYPE_CHECKING:
from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.callbacks import TokenUsageCallbackHandler
from app.session_manager import SessionManager
logger = logging.getLogger(__name__)
MAX_MESSAGE_SIZE = 32_768 # 32 KB
MAX_CONTENT_LENGTH = 8_000 # characters
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
async def handle_user_message(
ws: WebSocket,
graph: CompiledStateGraph,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
content: str,
) -> None:
"""Process a user message through the graph and stream results back."""
if session_manager.is_expired(thread_id):
msg = "Session expired. Please start a new conversation."
await _send_json(ws, {"type": "error", "message": msg})
return
session_manager.touch(thread_id)
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
input_msg = {"messages": [HumanMessage(content=content)]}
try:
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
msg_chunk, metadata = chunk
node = metadata.get("langgraph_node", "")
if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls:
for tc in msg_chunk.tool_calls:
await _send_json(
ws,
{
"type": "tool_call",
"agent": node,
"tool": tc.get("name", ""),
"args": tc.get("args", {}),
},
)
elif hasattr(msg_chunk, "content") and msg_chunk.content:
await _send_json(
ws,
{
"type": "token",
"agent": node,
"content": msg_chunk.content,
},
)
state = await graph.aget_state(config)
if _has_interrupt(state):
interrupt_data = _extract_interrupt(state)
session_manager.extend_for_interrupt(thread_id)
await _send_json(
ws,
{
"type": "interrupt",
"thread_id": thread_id,
**interrupt_data,
},
)
else:
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
except Exception:
logger.exception("Error processing message for thread %s", thread_id)
err = "An error occurred processing your message."
await _send_json(ws, {"type": "error", "message": err})
async def handle_interrupt_response(
ws: WebSocket,
graph: CompiledStateGraph,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
approved: bool,
) -> None:
"""Resume graph execution after interrupt approval/rejection."""
session_manager.resolve_interrupt(thread_id)
session_manager.touch(thread_id)
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
try:
async for chunk in graph.astream(
Command(resume=approved),
config=config,
stream_mode="messages",
):
msg_chunk, metadata = chunk
node = metadata.get("langgraph_node", "")
if hasattr(msg_chunk, "content") and msg_chunk.content:
await _send_json(
ws,
{
"type": "token",
"agent": node,
"content": msg_chunk.content,
},
)
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
except Exception:
logger.exception("Error resuming interrupt for thread %s", thread_id)
err = "An error occurred processing your response."
await _send_json(ws, {"type": "error", "message": err})
async def dispatch_message(
ws: WebSocket,
graph: CompiledStateGraph,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
raw_data: str,
) -> None:
"""Parse and route an incoming WebSocket message."""
if len(raw_data) > MAX_MESSAGE_SIZE:
await _send_json(ws, {"type": "error", "message": "Message too large"})
return
try:
data = json.loads(raw_data)
except json.JSONDecodeError:
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
return
msg_type = data.get("type")
thread_id = data.get("thread_id", "")
if not thread_id:
await _send_json(ws, {"type": "error", "message": "Missing thread_id"})
return
if not THREAD_ID_PATTERN.match(thread_id):
await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"})
return
if msg_type == "message":
content = data.get("content", "")
if not content:
await _send_json(ws, {"type": "error", "message": "Missing message content"})
return
if len(content) > MAX_CONTENT_LENGTH:
await _send_json(ws, {"type": "error", "message": "Message content too long"})
return
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
elif msg_type == "interrupt_response":
approved = data.get("approved", False)
await handle_interrupt_response(
ws, graph, session_manager, callback_handler, thread_id, approved
)
else:
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
def _has_interrupt(state: Any) -> bool:
"""Check if the graph state has a pending interrupt."""
tasks = getattr(state, "tasks", ())
return any(getattr(t, "interrupts", ()) for t in tasks)
def _extract_interrupt(state: Any) -> dict:
"""Extract interrupt data from graph state."""
for task in getattr(state, "tasks", ()):
for intr in getattr(task, "interrupts", ()):
value = intr.value if hasattr(intr, "value") else {}
if not isinstance(value, dict):
value = {}
return {
"action": value.get("action", "unknown"),
"params": value,
}
return {"action": "unknown", "params": {}}
async def _send_json(ws: WebSocket, data: dict) -> None:
"""Send a JSON message through the WebSocket."""
await ws.send_json(data)