"""WebSocket message handling logic -- extracted from main for testability.""" from __future__ import annotations import json import logging import re from typing import TYPE_CHECKING, Any from langchain_core.messages import HumanMessage from langgraph.types import Command from app.graph import classify_intent if TYPE_CHECKING: from fastapi import WebSocket from langgraph.graph.state import CompiledStateGraph from app.callbacks import TokenUsageCallbackHandler from app.interrupt_manager import InterruptManager from app.session_manager import SessionManager logger = logging.getLogger(__name__) MAX_MESSAGE_SIZE = 32_768 # 32 KB MAX_CONTENT_LENGTH = 8_000 # characters THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") async def handle_user_message( ws: WebSocket, graph: CompiledStateGraph, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, thread_id: str, content: str, interrupt_manager: InterruptManager | None = None, ) -> None: """Process a user message through the graph and stream results back.""" if session_manager.is_expired(thread_id): msg = "Session expired. Please start a new conversation." await _send_json(ws, {"type": "error", "message": msg}) return session_manager.touch(thread_id) # Run intent classification if available (for logging/future multi-intent) classification = await classify_intent(graph, content) if classification is not None: logger.info( "Intent classification for thread %s: ambiguous=%s, intents=%s", thread_id, classification.is_ambiguous, [i.agent_name for i in classification.intents], ) # If ambiguous, send clarification and return if classification.is_ambiguous and classification.clarification_question: await _send_json( ws, { "type": "clarification", "thread_id": thread_id, "message": classification.clarification_question, }, ) await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) return config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} # If multi-intent detected, add routing hint to the message if classification and len(classification.intents) > 1: agent_names = [i.agent_name for i in classification.intents] hint = ( f"\n[System: This request involves multiple actions. " f"Execute in order: {', '.join(agent_names)}]" ) input_msg = {"messages": [HumanMessage(content=content + hint)]} else: input_msg = {"messages": [HumanMessage(content=content)]} try: async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"): msg_chunk, metadata = chunk node = metadata.get("langgraph_node", "") if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls: for tc in msg_chunk.tool_calls: await _send_json( ws, { "type": "tool_call", "agent": node, "tool": tc.get("name", ""), "args": tc.get("args", {}), }, ) elif hasattr(msg_chunk, "content") and msg_chunk.content: await _send_json( ws, { "type": "token", "agent": node, "content": msg_chunk.content, }, ) state = await graph.aget_state(config) if _has_interrupt(state): interrupt_data = _extract_interrupt(state) session_manager.extend_for_interrupt(thread_id) # Register interrupt with TTL tracking if interrupt_manager is not None: interrupt_manager.register( thread_id=thread_id, action=interrupt_data.get("action", "unknown"), params=interrupt_data.get("params", {}), ) await _send_json( ws, { "type": "interrupt", "thread_id": thread_id, **interrupt_data, }, ) else: await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) except Exception: logger.exception("Error processing message for thread %s", thread_id) err = "An error occurred processing your message." await _send_json(ws, {"type": "error", "message": err}) async def handle_interrupt_response( ws: WebSocket, graph: CompiledStateGraph, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, thread_id: str, approved: bool, interrupt_manager: InterruptManager | None = None, ) -> None: """Resume graph execution after interrupt approval/rejection.""" # Check interrupt TTL before resuming if interrupt_manager is not None: status = interrupt_manager.check_status(thread_id) if status is not None and status.is_expired: retry_prompt = interrupt_manager.generate_retry_prompt(status.record) interrupt_manager.resolve(thread_id) session_manager.resolve_interrupt(thread_id) await _send_json(ws, retry_prompt) return interrupt_manager.resolve(thread_id) session_manager.resolve_interrupt(thread_id) session_manager.touch(thread_id) config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} try: async for chunk in graph.astream( Command(resume=approved), config=config, stream_mode="messages", ): msg_chunk, metadata = chunk node = metadata.get("langgraph_node", "") if hasattr(msg_chunk, "content") and msg_chunk.content: await _send_json( ws, { "type": "token", "agent": node, "content": msg_chunk.content, }, ) await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) except Exception: logger.exception("Error resuming interrupt for thread %s", thread_id) err = "An error occurred processing your response." await _send_json(ws, {"type": "error", "message": err}) async def dispatch_message( ws: WebSocket, graph: CompiledStateGraph, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, raw_data: str, interrupt_manager: InterruptManager | None = None, ) -> None: """Parse and route an incoming WebSocket message.""" if len(raw_data) > MAX_MESSAGE_SIZE: await _send_json(ws, {"type": "error", "message": "Message too large"}) return try: data = json.loads(raw_data) except json.JSONDecodeError: await _send_json(ws, {"type": "error", "message": "Invalid JSON"}) return msg_type = data.get("type") thread_id = data.get("thread_id", "") if not thread_id: await _send_json(ws, {"type": "error", "message": "Missing thread_id"}) return if not THREAD_ID_PATTERN.match(thread_id): await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"}) return if msg_type == "message": content = data.get("content", "") if not content: await _send_json(ws, {"type": "error", "message": "Missing message content"}) return if len(content) > MAX_CONTENT_LENGTH: await _send_json(ws, {"type": "error", "message": "Message content too long"}) return await handle_user_message( ws, graph, session_manager, callback_handler, thread_id, content, interrupt_manager=interrupt_manager, ) elif msg_type == "interrupt_response": approved = data.get("approved", False) await handle_interrupt_response( ws, graph, session_manager, callback_handler, thread_id, approved, interrupt_manager=interrupt_manager, ) else: await _send_json(ws, {"type": "error", "message": "Unknown message type"}) def _has_interrupt(state: Any) -> bool: """Check if the graph state has a pending interrupt.""" tasks = getattr(state, "tasks", ()) return any(getattr(t, "interrupts", ()) for t in tasks) def _extract_interrupt(state: Any) -> dict: """Extract interrupt data from graph state.""" for task in getattr(state, "tasks", ()): for intr in getattr(task, "interrupts", ()): value = intr.value if hasattr(intr, "value") else {} if not isinstance(value, dict): value = {} return { "action": value.get("action", "unknown"), "params": value, } return {"action": "unknown", "params": {}} async def _send_json(ws: WebSocket, data: dict) -> None: """Send a JSON message through the WebSocket.""" await ws.send_json(data)