feat: complete phase 1 -- core framework with chat loop, agents, and React UI
Backend: - FastAPI WebSocket /ws endpoint with streaming via LangGraph astream - LangGraph Supervisor connecting 3 mock agents (order_lookup, order_actions, fallback) - YAML Agent Registry with Pydantic validation and immutable configs - PostgresSaver checkpoint persistence via langgraph-checkpoint-postgres - Session TTL with 30-min sliding window and interrupt extension - LLM provider abstraction (Anthropic/OpenAI/Google) - Token usage + cost tracking callback handler - Input validation: message size cap, thread_id format, content length - Security: no hardcoded defaults, startup API key validation, no input reflection Frontend: - React 19 + TypeScript + Vite chat UI - WebSocket hook with reconnect + exponential backoff - Streaming token display with agent attribution - Interrupt approval/reject UI for write operations - Collapsible tool call viewer Testing: - 87 unit tests, 87% coverage (exceeds 80% requirement) - Ruff lint + format clean Infrastructure: - Docker Compose (PostgreSQL 16 + backend) - pyproject.toml with full dependency management
This commit is contained in:
204
backend/app/ws_handler.py
Normal file
204
backend/app/ws_handler.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""WebSocket message handling logic -- extracted from main for testability."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
||||
MAX_CONTENT_LENGTH = 8_000 # characters
|
||||
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||
|
||||
|
||||
async def handle_user_message(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Process a user message through the graph and stream results back."""
|
||||
if session_manager.is_expired(thread_id):
|
||||
msg = "Session expired. Please start a new conversation."
|
||||
await _send_json(ws, {"type": "error", "message": msg})
|
||||
return
|
||||
|
||||
session_manager.touch(thread_id)
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||
msg_chunk, metadata = chunk
|
||||
node = metadata.get("langgraph_node", "")
|
||||
|
||||
if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls:
|
||||
for tc in msg_chunk.tool_calls:
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "tool_call",
|
||||
"agent": node,
|
||||
"tool": tc.get("name", ""),
|
||||
"args": tc.get("args", {}),
|
||||
},
|
||||
)
|
||||
elif hasattr(msg_chunk, "content") and msg_chunk.content:
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "token",
|
||||
"agent": node,
|
||||
"content": msg_chunk.content,
|
||||
},
|
||||
)
|
||||
|
||||
state = await graph.aget_state(config)
|
||||
if _has_interrupt(state):
|
||||
interrupt_data = _extract_interrupt(state)
|
||||
session_manager.extend_for_interrupt(thread_id)
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "interrupt",
|
||||
"thread_id": thread_id,
|
||||
**interrupt_data,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing message for thread %s", thread_id)
|
||||
err = "An error occurred processing your message."
|
||||
await _send_json(ws, {"type": "error", "message": err})
|
||||
|
||||
|
||||
async def handle_interrupt_response(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
approved: bool,
|
||||
) -> None:
|
||||
"""Resume graph execution after interrupt approval/rejection."""
|
||||
session_manager.resolve_interrupt(thread_id)
|
||||
session_manager.touch(thread_id)
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(
|
||||
Command(resume=approved),
|
||||
config=config,
|
||||
stream_mode="messages",
|
||||
):
|
||||
msg_chunk, metadata = chunk
|
||||
node = metadata.get("langgraph_node", "")
|
||||
|
||||
if hasattr(msg_chunk, "content") and msg_chunk.content:
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "token",
|
||||
"agent": node,
|
||||
"content": msg_chunk.content,
|
||||
},
|
||||
)
|
||||
|
||||
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error resuming interrupt for thread %s", thread_id)
|
||||
err = "An error occurred processing your response."
|
||||
await _send_json(ws, {"type": "error", "message": err})
|
||||
|
||||
|
||||
async def dispatch_message(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
raw_data: str,
|
||||
) -> None:
|
||||
"""Parse and route an incoming WebSocket message."""
|
||||
if len(raw_data) > MAX_MESSAGE_SIZE:
|
||||
await _send_json(ws, {"type": "error", "message": "Message too large"})
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(raw_data)
|
||||
except json.JSONDecodeError:
|
||||
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
|
||||
return
|
||||
|
||||
msg_type = data.get("type")
|
||||
thread_id = data.get("thread_id", "")
|
||||
|
||||
if not thread_id:
|
||||
await _send_json(ws, {"type": "error", "message": "Missing thread_id"})
|
||||
return
|
||||
|
||||
if not THREAD_ID_PATTERN.match(thread_id):
|
||||
await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"})
|
||||
return
|
||||
|
||||
if msg_type == "message":
|
||||
content = data.get("content", "")
|
||||
if not content:
|
||||
await _send_json(ws, {"type": "error", "message": "Missing message content"})
|
||||
return
|
||||
if len(content) > MAX_CONTENT_LENGTH:
|
||||
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
||||
return
|
||||
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
|
||||
|
||||
elif msg_type == "interrupt_response":
|
||||
approved = data.get("approved", False)
|
||||
await handle_interrupt_response(
|
||||
ws, graph, session_manager, callback_handler, thread_id, approved
|
||||
)
|
||||
|
||||
else:
|
||||
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
|
||||
|
||||
|
||||
def _has_interrupt(state: Any) -> bool:
|
||||
"""Check if the graph state has a pending interrupt."""
|
||||
tasks = getattr(state, "tasks", ())
|
||||
return any(getattr(t, "interrupts", ()) for t in tasks)
|
||||
|
||||
|
||||
def _extract_interrupt(state: Any) -> dict:
|
||||
"""Extract interrupt data from graph state."""
|
||||
for task in getattr(state, "tasks", ()):
|
||||
for intr in getattr(task, "interrupts", ()):
|
||||
value = intr.value if hasattr(intr, "value") else {}
|
||||
if not isinstance(value, dict):
|
||||
value = {}
|
||||
return {
|
||||
"action": value.get("action", "unknown"),
|
||||
"params": value,
|
||||
}
|
||||
return {"action": "unknown", "params": {}}
|
||||
|
||||
|
||||
async def _send_json(ws: WebSocket, data: dict) -> None:
|
||||
"""Send a JSON message through the WebSocket."""
|
||||
await ws.send_json(data)
|
||||
Reference in New Issue
Block a user