"""WebSocket message handling logic -- extracted from main for testability.""" from __future__ import annotations import json import re import time from collections import defaultdict from typing import TYPE_CHECKING from langchain_core.messages import HumanMessage from langgraph.types import Command if TYPE_CHECKING: from fastapi import WebSocket from app.callbacks import TokenUsageCallbackHandler from app.graph_context import GraphContext from app.interrupt_manager import InterruptManager from app.session_manager import SessionManager from app.ws_context import WebSocketContext import structlog logger = structlog.get_logger() MAX_MESSAGE_SIZE = 32_768 # 32 KB MAX_CONTENT_LENGTH = 10_000 # characters THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") # Rate limiting: max 10 messages per 10-second window, per thread _RATE_LIMIT_MAX = 10 _RATE_LIMIT_WINDOW = 10.0 _MAX_TRACKED_THREADS = 10_000 _thread_timestamps: dict[str, list[float]] = defaultdict(list) def _evict_stale_threads(cutoff: float) -> None: """Remove thread entries with no recent timestamps to prevent memory leak.""" stale = [tid for tid, ts in _thread_timestamps.items() if not ts or ts[-1] < cutoff] for tid in stale: del _thread_timestamps[tid] async def handle_user_message( ws: WebSocket, ctx: GraphContext, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, thread_id: str, content: str, interrupt_manager: InterruptManager | None = None, ) -> None: """Process a user message through the graph and stream results back.""" existing = session_manager.get_state(thread_id) if existing is not None and session_manager.is_expired(thread_id): msg = "Session expired. Please start a new conversation." await _send_json(ws, {"type": "error", "message": msg}) return session_manager.touch(thread_id) classification = await ctx.classify_intent(content) if classification is not None: logger.info( "Intent classification for thread %s: ambiguous=%s, intents=%s", thread_id, classification.is_ambiguous, [i.agent_name for i in classification.intents], ) if classification.is_ambiguous and classification.clarification_question: await _send_json( ws, { "type": "clarification", "thread_id": thread_id, "message": classification.clarification_question, }, ) await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) return config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} if classification and len(classification.intents) > 1: agent_names = [i.agent_name for i in classification.intents] hint = ( f"\n[System: This request involves multiple actions. " f"Execute in order: {', '.join(agent_names)}]" ) input_msg = {"messages": [HumanMessage(content=content + hint)]} else: input_msg = {"messages": [HumanMessage(content=content)]} try: async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"): msg_chunk, metadata = chunk node = metadata.get("langgraph_node", "") if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls: for tc in msg_chunk.tool_calls: await _send_json( ws, { "type": "tool_call", "agent": node, "tool": tc.get("name", ""), "args": tc.get("args", {}), }, ) elif hasattr(msg_chunk, "content") and msg_chunk.content: await _send_json( ws, { "type": "token", "agent": node, "content": msg_chunk.content, }, ) state = await ctx.graph.aget_state(config) if _has_interrupt(state): interrupt_data = _extract_interrupt(state) session_manager.extend_for_interrupt(thread_id) if interrupt_manager is not None: interrupt_manager.register( thread_id=thread_id, action=interrupt_data.get("action", "unknown"), params=interrupt_data.get("params", {}), ) await _send_json( ws, { "type": "interrupt", "thread_id": thread_id, **interrupt_data, }, ) else: await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) except Exception: logger.exception("Error processing message for thread %s", thread_id) err = "An error occurred processing your message." await _send_json(ws, {"type": "error", "message": err}) async def handle_interrupt_response( ws: WebSocket, ctx: GraphContext, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, thread_id: str, approved: bool, interrupt_manager: InterruptManager | None = None, ) -> None: """Resume graph execution after interrupt approval/rejection.""" if interrupt_manager is not None: status = interrupt_manager.check_status(thread_id) if status is not None and status.is_expired: retry_prompt = interrupt_manager.generate_retry_prompt(status.record) interrupt_manager.resolve(thread_id) session_manager.resolve_interrupt(thread_id) await _send_json(ws, retry_prompt) return interrupt_manager.resolve(thread_id) session_manager.resolve_interrupt(thread_id) session_manager.touch(thread_id) config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} try: async for chunk in ctx.graph.astream( Command(resume=approved), config=config, stream_mode="messages", ): msg_chunk, metadata = chunk node = metadata.get("langgraph_node", "") if hasattr(msg_chunk, "content") and msg_chunk.content: await _send_json( ws, { "type": "token", "agent": node, "content": msg_chunk.content, }, ) await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) except Exception: logger.exception("Error resuming interrupt for thread %s", thread_id) err = "An error occurred processing your response." await _send_json(ws, {"type": "error", "message": err}) async def dispatch_message( ws: WebSocket, ctx: WebSocketContext, raw_data: str, ) -> None: """Parse and route an incoming WebSocket message.""" if len(raw_data) > MAX_MESSAGE_SIZE: await _send_json(ws, {"type": "error", "message": "Message too large"}) return try: data = json.loads(raw_data) except (json.JSONDecodeError, ValueError): await _send_json(ws, {"type": "error", "message": "Invalid JSON"}) return if not isinstance(data, dict): await _send_json(ws, {"type": "error", "message": "Invalid JSON: expected object"}) return msg_type = data.get("type") thread_id = data.get("thread_id", "") if not thread_id: await _send_json(ws, {"type": "error", "message": "Missing thread_id"}) return if not THREAD_ID_PATTERN.match(thread_id): await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"}) return if msg_type == "message": content = data.get("content", "") if not content or not content.strip(): await _send_json(ws, {"type": "error", "message": "Missing message content"}) return if len(content) > MAX_CONTENT_LENGTH: await _send_json(ws, {"type": "error", "message": "Message content too long"}) return # Rate limiting check (per-thread, with bounded memory) now = time.time() cutoff = now - _RATE_LIMIT_WINDOW if len(_thread_timestamps) > _MAX_TRACKED_THREADS: _evict_stale_threads(cutoff) recent = [t for t in _thread_timestamps[thread_id] if t >= cutoff] if len(recent) >= _RATE_LIMIT_MAX: await _send_json(ws, {"type": "error", "message": "Rate limit exceeded"}) return _thread_timestamps[thread_id] = [*recent, now] await handle_user_message( ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler, thread_id, content, interrupt_manager=ctx.interrupt_manager, ) await _fire_and_forget_tracking( thread_id=thread_id, pool=ctx.pool, analytics_recorder=ctx.analytics_recorder, conversation_tracker=ctx.conversation_tracker, agent_name=None, tokens=0, cost=0.0, ) elif msg_type == "interrupt_response": approved = data.get("approved", False) await handle_interrupt_response( ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler, thread_id, approved, interrupt_manager=ctx.interrupt_manager, ) else: await _send_json(ws, {"type": "error", "message": "Unknown message type"}) async def _fire_and_forget_tracking( thread_id: str, pool: object, analytics_recorder: object | None, conversation_tracker: object | None, agent_name: str | None, tokens: int, cost: float, ) -> None: """Fire-and-forget analytics/tracking; failures must NOT break chat.""" try: if conversation_tracker is not None and pool is not None: await conversation_tracker.ensure_conversation(pool, thread_id) await conversation_tracker.record_turn(pool, thread_id, agent_name, tokens, cost) except Exception: logger.exception("Conversation tracker error for thread %s (suppressed)", thread_id) try: if analytics_recorder is not None: await analytics_recorder.record( thread_id=thread_id, event_type="message", agent_name=agent_name, tokens_used=tokens, cost_usd=cost, ) except Exception: logger.exception("Analytics recorder error for thread %s (suppressed)", thread_id) def _has_interrupt(state: Any) -> bool: """Check if the graph state has a pending interrupt.""" tasks = getattr(state, "tasks", ()) return any(getattr(t, "interrupts", ()) for t in tasks) def _extract_interrupt(state: Any) -> dict: """Extract interrupt data from graph state.""" for task in getattr(state, "tasks", ()): for intr in getattr(task, "interrupts", ()): value = intr.value if hasattr(intr, "value") else {} if not isinstance(value, dict): value = {} return { "action": value.get("action", "unknown"), "params": value, } return {"action": "unknown", "params": {}} async def _send_json(ws: WebSocket, data: dict) -> None: """Send a JSON message through the WebSocket.""" await ws.send_json(data)