Backend: - ConversationTracker: Protocol + PostgresConversationTracker for lifecycle tracking - Error handler: ErrorCategory enum, classify_error(), with_retry() exponential backoff - Wire PostgresAnalyticsRecorder + ConversationTracker into ws_handler - Rate limiting (10 msg/10s per thread), edge case hardening - Health endpoint GET /api/health, version 0.5.0 - Demo seed data script + sample OpenAPI spec Frontend (all new): - React Router with NavBar (Chat / Replay / Dashboard / Review) - ReplayListPage + ReplayPage with ReplayTimeline component - DashboardPage with MetricCard, range selector, zero-state - ReviewPage for OpenAPI classification review - ErrorBanner for WebSocket disconnect handling - API client (api.ts) with typed fetch wrappers Infrastructure: - Frontend Dockerfile (multi-stage node -> nginx) - nginx.conf with SPA routing + API/WS proxy - docker-compose.yml with frontend service + healthchecks - .env.example files (root + backend) Documentation: - README.md with quick start and architecture - Agent configuration guide - OpenAPI import guide - Deployment guide - Demo script 48 new tests, 449 total passing, 92.87% coverage
336 lines
12 KiB
Python
336 lines
12 KiB
Python
"""WebSocket message handling logic -- extracted from main for testability."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from langchain_core.messages import HumanMessage
|
|
from langgraph.types import Command
|
|
|
|
from app.graph import classify_intent
|
|
|
|
if TYPE_CHECKING:
|
|
from fastapi import WebSocket
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
|
|
from app.analytics.event_recorder import AnalyticsRecorder
|
|
from app.callbacks import TokenUsageCallbackHandler
|
|
from app.conversation_tracker import ConversationTrackerProtocol
|
|
from app.interrupt_manager import InterruptManager
|
|
from app.session_manager import SessionManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
|
MAX_CONTENT_LENGTH = 10_000 # characters
|
|
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
|
|
|
# Rate limiting: max 10 messages per 10-second window, per thread
|
|
_RATE_LIMIT_MAX = 10
|
|
_RATE_LIMIT_WINDOW = 10.0
|
|
_thread_timestamps: dict[str, list[float]] = defaultdict(list)
|
|
|
|
|
|
async def handle_user_message(
|
|
ws: WebSocket,
|
|
graph: CompiledStateGraph,
|
|
session_manager: SessionManager,
|
|
callback_handler: TokenUsageCallbackHandler,
|
|
thread_id: str,
|
|
content: str,
|
|
interrupt_manager: InterruptManager | None = None,
|
|
) -> None:
|
|
"""Process a user message through the graph and stream results back."""
|
|
if session_manager.is_expired(thread_id):
|
|
msg = "Session expired. Please start a new conversation."
|
|
await _send_json(ws, {"type": "error", "message": msg})
|
|
return
|
|
|
|
session_manager.touch(thread_id)
|
|
|
|
# Run intent classification if available (for logging/future multi-intent)
|
|
classification = await classify_intent(graph, content)
|
|
if classification is not None:
|
|
logger.info(
|
|
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
|
thread_id,
|
|
classification.is_ambiguous,
|
|
[i.agent_name for i in classification.intents],
|
|
)
|
|
|
|
# If ambiguous, send clarification and return
|
|
if classification.is_ambiguous and classification.clarification_question:
|
|
await _send_json(
|
|
ws,
|
|
{
|
|
"type": "clarification",
|
|
"thread_id": thread_id,
|
|
"message": classification.clarification_question,
|
|
},
|
|
)
|
|
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
|
return
|
|
|
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
|
|
|
# If multi-intent detected, add routing hint to the message
|
|
if classification and len(classification.intents) > 1:
|
|
agent_names = [i.agent_name for i in classification.intents]
|
|
hint = (
|
|
f"\n[System: This request involves multiple actions. "
|
|
f"Execute in order: {', '.join(agent_names)}]"
|
|
)
|
|
input_msg = {"messages": [HumanMessage(content=content + hint)]}
|
|
else:
|
|
input_msg = {"messages": [HumanMessage(content=content)]}
|
|
|
|
try:
|
|
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
|
msg_chunk, metadata = chunk
|
|
node = metadata.get("langgraph_node", "")
|
|
|
|
if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls:
|
|
for tc in msg_chunk.tool_calls:
|
|
await _send_json(
|
|
ws,
|
|
{
|
|
"type": "tool_call",
|
|
"agent": node,
|
|
"tool": tc.get("name", ""),
|
|
"args": tc.get("args", {}),
|
|
},
|
|
)
|
|
elif hasattr(msg_chunk, "content") and msg_chunk.content:
|
|
await _send_json(
|
|
ws,
|
|
{
|
|
"type": "token",
|
|
"agent": node,
|
|
"content": msg_chunk.content,
|
|
},
|
|
)
|
|
|
|
state = await graph.aget_state(config)
|
|
if _has_interrupt(state):
|
|
interrupt_data = _extract_interrupt(state)
|
|
session_manager.extend_for_interrupt(thread_id)
|
|
|
|
# Register interrupt with TTL tracking
|
|
if interrupt_manager is not None:
|
|
interrupt_manager.register(
|
|
thread_id=thread_id,
|
|
action=interrupt_data.get("action", "unknown"),
|
|
params=interrupt_data.get("params", {}),
|
|
)
|
|
|
|
await _send_json(
|
|
ws,
|
|
{
|
|
"type": "interrupt",
|
|
"thread_id": thread_id,
|
|
**interrupt_data,
|
|
},
|
|
)
|
|
else:
|
|
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
|
|
|
except Exception:
|
|
logger.exception("Error processing message for thread %s", thread_id)
|
|
err = "An error occurred processing your message."
|
|
await _send_json(ws, {"type": "error", "message": err})
|
|
|
|
|
|
async def handle_interrupt_response(
|
|
ws: WebSocket,
|
|
graph: CompiledStateGraph,
|
|
session_manager: SessionManager,
|
|
callback_handler: TokenUsageCallbackHandler,
|
|
thread_id: str,
|
|
approved: bool,
|
|
interrupt_manager: InterruptManager | None = None,
|
|
) -> None:
|
|
"""Resume graph execution after interrupt approval/rejection."""
|
|
# Check interrupt TTL before resuming
|
|
if interrupt_manager is not None:
|
|
status = interrupt_manager.check_status(thread_id)
|
|
if status is not None and status.is_expired:
|
|
retry_prompt = interrupt_manager.generate_retry_prompt(status.record)
|
|
interrupt_manager.resolve(thread_id)
|
|
session_manager.resolve_interrupt(thread_id)
|
|
await _send_json(ws, retry_prompt)
|
|
return
|
|
|
|
interrupt_manager.resolve(thread_id)
|
|
|
|
session_manager.resolve_interrupt(thread_id)
|
|
session_manager.touch(thread_id)
|
|
|
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
|
|
|
try:
|
|
async for chunk in graph.astream(
|
|
Command(resume=approved),
|
|
config=config,
|
|
stream_mode="messages",
|
|
):
|
|
msg_chunk, metadata = chunk
|
|
node = metadata.get("langgraph_node", "")
|
|
|
|
if hasattr(msg_chunk, "content") and msg_chunk.content:
|
|
await _send_json(
|
|
ws,
|
|
{
|
|
"type": "token",
|
|
"agent": node,
|
|
"content": msg_chunk.content,
|
|
},
|
|
)
|
|
|
|
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
|
|
|
except Exception:
|
|
logger.exception("Error resuming interrupt for thread %s", thread_id)
|
|
err = "An error occurred processing your response."
|
|
await _send_json(ws, {"type": "error", "message": err})
|
|
|
|
|
|
async def dispatch_message(
|
|
ws: WebSocket,
|
|
graph: CompiledStateGraph,
|
|
session_manager: SessionManager,
|
|
callback_handler: TokenUsageCallbackHandler,
|
|
raw_data: str,
|
|
interrupt_manager: InterruptManager | None = None,
|
|
analytics_recorder: AnalyticsRecorder | None = None,
|
|
conversation_tracker: ConversationTrackerProtocol | None = None,
|
|
pool: Any = None,
|
|
) -> None:
|
|
"""Parse and route an incoming WebSocket message."""
|
|
if len(raw_data) > MAX_MESSAGE_SIZE:
|
|
await _send_json(ws, {"type": "error", "message": "Message too large"})
|
|
return
|
|
|
|
try:
|
|
data = json.loads(raw_data)
|
|
except (json.JSONDecodeError, ValueError):
|
|
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
|
|
return
|
|
|
|
if not isinstance(data, dict):
|
|
await _send_json(ws, {"type": "error", "message": "Invalid JSON: expected object"})
|
|
return
|
|
|
|
msg_type = data.get("type")
|
|
thread_id = data.get("thread_id", "")
|
|
|
|
if not thread_id:
|
|
await _send_json(ws, {"type": "error", "message": "Missing thread_id"})
|
|
return
|
|
|
|
if not THREAD_ID_PATTERN.match(thread_id):
|
|
await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"})
|
|
return
|
|
|
|
if msg_type == "message":
|
|
content = data.get("content", "")
|
|
if not content or not content.strip():
|
|
await _send_json(ws, {"type": "error", "message": "Missing message content"})
|
|
return
|
|
if len(content) > MAX_CONTENT_LENGTH:
|
|
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
|
return
|
|
|
|
# Rate limiting check
|
|
now = time.time()
|
|
timestamps = _thread_timestamps[thread_id]
|
|
cutoff = now - _RATE_LIMIT_WINDOW
|
|
_thread_timestamps[thread_id] = [t for t in timestamps if t >= cutoff]
|
|
if len(_thread_timestamps[thread_id]) >= _RATE_LIMIT_MAX:
|
|
await _send_json(ws, {"type": "error", "message": "Rate limit exceeded"})
|
|
return
|
|
_thread_timestamps[thread_id].append(now)
|
|
|
|
await handle_user_message(
|
|
ws, graph, session_manager, callback_handler, thread_id, content,
|
|
interrupt_manager=interrupt_manager,
|
|
)
|
|
await _fire_and_forget_tracking(
|
|
thread_id=thread_id,
|
|
pool=pool,
|
|
analytics_recorder=analytics_recorder,
|
|
conversation_tracker=conversation_tracker,
|
|
agent_name=None,
|
|
tokens=0,
|
|
cost=0.0,
|
|
)
|
|
|
|
elif msg_type == "interrupt_response":
|
|
approved = data.get("approved", False)
|
|
await handle_interrupt_response(
|
|
ws, graph, session_manager, callback_handler, thread_id, approved,
|
|
interrupt_manager=interrupt_manager,
|
|
)
|
|
|
|
else:
|
|
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
|
|
|
|
|
|
async def _fire_and_forget_tracking(
|
|
thread_id: str,
|
|
pool: Any,
|
|
analytics_recorder: Any | None,
|
|
conversation_tracker: Any | None,
|
|
agent_name: str | None,
|
|
tokens: int,
|
|
cost: float,
|
|
) -> None:
|
|
"""Fire-and-forget analytics/tracking; failures must NOT break chat."""
|
|
try:
|
|
if conversation_tracker is not None and pool is not None:
|
|
await conversation_tracker.ensure_conversation(pool, thread_id)
|
|
await conversation_tracker.record_turn(pool, thread_id, agent_name, tokens, cost)
|
|
except Exception:
|
|
logger.exception("Conversation tracker error for thread %s (suppressed)", thread_id)
|
|
|
|
try:
|
|
if analytics_recorder is not None:
|
|
await analytics_recorder.record(
|
|
thread_id=thread_id,
|
|
event_type="message",
|
|
agent_name=agent_name,
|
|
tokens_used=tokens,
|
|
cost_usd=cost,
|
|
)
|
|
except Exception:
|
|
logger.exception("Analytics recorder error for thread %s (suppressed)", thread_id)
|
|
|
|
|
|
def _has_interrupt(state: Any) -> bool:
|
|
"""Check if the graph state has a pending interrupt."""
|
|
tasks = getattr(state, "tasks", ())
|
|
return any(getattr(t, "interrupts", ()) for t in tasks)
|
|
|
|
|
|
def _extract_interrupt(state: Any) -> dict:
|
|
"""Extract interrupt data from graph state."""
|
|
for task in getattr(state, "tasks", ()):
|
|
for intr in getattr(task, "interrupts", ()):
|
|
value = intr.value if hasattr(intr, "value") else {}
|
|
if not isinstance(value, dict):
|
|
value = {}
|
|
return {
|
|
"action": value.get("action", "unknown"),
|
|
"params": value,
|
|
}
|
|
return {"action": "unknown", "params": {}}
|
|
|
|
|
|
async def _send_json(ws: WebSocket, data: dict) -> None:
|
|
"""Send a JSON message through the WebSocket."""
|
|
await ws.send_json(data)
|