"""WebSocket message handling logic -- extracted from main for testability.""" from __future__ import annotations import json import logging import re from typing import TYPE_CHECKING, Any from langchain_core.messages import HumanMessage from langgraph.types import Command if TYPE_CHECKING: from fastapi import WebSocket from langgraph.graph.state import CompiledStateGraph from app.callbacks import TokenUsageCallbackHandler from app.session_manager import SessionManager logger = logging.getLogger(__name__) MAX_MESSAGE_SIZE = 32_768 # 32 KB MAX_CONTENT_LENGTH = 8_000 # characters THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") async def handle_user_message( ws: WebSocket, graph: CompiledStateGraph, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, thread_id: str, content: str, ) -> None: """Process a user message through the graph and stream results back.""" if session_manager.is_expired(thread_id): msg = "Session expired. Please start a new conversation." await _send_json(ws, {"type": "error", "message": msg}) return session_manager.touch(thread_id) config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} input_msg = {"messages": [HumanMessage(content=content)]} try: async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"): msg_chunk, metadata = chunk node = metadata.get("langgraph_node", "") if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls: for tc in msg_chunk.tool_calls: await _send_json( ws, { "type": "tool_call", "agent": node, "tool": tc.get("name", ""), "args": tc.get("args", {}), }, ) elif hasattr(msg_chunk, "content") and msg_chunk.content: await _send_json( ws, { "type": "token", "agent": node, "content": msg_chunk.content, }, ) state = await graph.aget_state(config) if _has_interrupt(state): interrupt_data = _extract_interrupt(state) session_manager.extend_for_interrupt(thread_id) await _send_json( ws, { "type": "interrupt", "thread_id": thread_id, **interrupt_data, }, ) else: await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) except Exception: logger.exception("Error processing message for thread %s", thread_id) err = "An error occurred processing your message." await _send_json(ws, {"type": "error", "message": err}) async def handle_interrupt_response( ws: WebSocket, graph: CompiledStateGraph, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, thread_id: str, approved: bool, ) -> None: """Resume graph execution after interrupt approval/rejection.""" session_manager.resolve_interrupt(thread_id) session_manager.touch(thread_id) config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} try: async for chunk in graph.astream( Command(resume=approved), config=config, stream_mode="messages", ): msg_chunk, metadata = chunk node = metadata.get("langgraph_node", "") if hasattr(msg_chunk, "content") and msg_chunk.content: await _send_json( ws, { "type": "token", "agent": node, "content": msg_chunk.content, }, ) await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) except Exception: logger.exception("Error resuming interrupt for thread %s", thread_id) err = "An error occurred processing your response." await _send_json(ws, {"type": "error", "message": err}) async def dispatch_message( ws: WebSocket, graph: CompiledStateGraph, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, raw_data: str, ) -> None: """Parse and route an incoming WebSocket message.""" if len(raw_data) > MAX_MESSAGE_SIZE: await _send_json(ws, {"type": "error", "message": "Message too large"}) return try: data = json.loads(raw_data) except json.JSONDecodeError: await _send_json(ws, {"type": "error", "message": "Invalid JSON"}) return msg_type = data.get("type") thread_id = data.get("thread_id", "") if not thread_id: await _send_json(ws, {"type": "error", "message": "Missing thread_id"}) return if not THREAD_ID_PATTERN.match(thread_id): await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"}) return if msg_type == "message": content = data.get("content", "") if not content: await _send_json(ws, {"type": "error", "message": "Missing message content"}) return if len(content) > MAX_CONTENT_LENGTH: await _send_json(ws, {"type": "error", "message": "Message content too long"}) return await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content) elif msg_type == "interrupt_response": approved = data.get("approved", False) await handle_interrupt_response( ws, graph, session_manager, callback_handler, thread_id, approved ) else: await _send_json(ws, {"type": "error", "message": "Unknown message type"}) def _has_interrupt(state: Any) -> bool: """Check if the graph state has a pending interrupt.""" tasks = getattr(state, "tasks", ()) return any(getattr(t, "interrupts", ()) for t in tasks) def _extract_interrupt(state: Any) -> dict: """Extract interrupt data from graph state.""" for task in getattr(state, "tasks", ()): for intr in getattr(task, "interrupts", ()): value = intr.value if hasattr(intr, "value") else {} if not isinstance(value, dict): value = {} return { "action": value.get("action", "unknown"), "params": value, } return {"action": "unknown", "params": {}} async def _send_json(ws: WebSocket, data: dict) -> None: """Send a JSON message through the WebSocket.""" await ws.send_json(data)