Files
smart-support/backend/app/ws_handler.py
Yaojia Wang d2b4610df9 fix: address code and security review findings for Phase 5
- Add nginx security headers (X-Frame-Options, X-Content-Type-Options, etc.)
- Fix postgres networking: add to app_network, comment out host port exposure
- Fix rate limit memory leak: add bounded eviction for stale thread entries
- Use immutable update pattern in rate limit check (no .append mutation)
- Extract _VERSION constant to avoid duplicate hardcoded version string
2026-03-31 21:35:13 +02:00

345 lines
12 KiB
Python

"""WebSocket message handling logic -- extracted from main for testability."""
from __future__ import annotations
import json
import logging
import re
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from app.graph import classify_intent
if TYPE_CHECKING:
from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.analytics.event_recorder import AnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler
from app.conversation_tracker import ConversationTrackerProtocol
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
logger = logging.getLogger(__name__)
MAX_MESSAGE_SIZE = 32_768 # 32 KB
MAX_CONTENT_LENGTH = 10_000 # characters
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
# Rate limiting: max 10 messages per 10-second window, per thread
_RATE_LIMIT_MAX = 10
_RATE_LIMIT_WINDOW = 10.0
_MAX_TRACKED_THREADS = 10_000
_thread_timestamps: dict[str, list[float]] = defaultdict(list)
def _evict_stale_threads(cutoff: float) -> None:
"""Remove thread entries with no recent timestamps to prevent memory leak."""
stale = [tid for tid, ts in _thread_timestamps.items() if not ts or ts[-1] < cutoff]
for tid in stale:
del _thread_timestamps[tid]
async def handle_user_message(
ws: WebSocket,
graph: CompiledStateGraph,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
content: str,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Process a user message through the graph and stream results back."""
if session_manager.is_expired(thread_id):
msg = "Session expired. Please start a new conversation."
await _send_json(ws, {"type": "error", "message": msg})
return
session_manager.touch(thread_id)
# Run intent classification if available (for logging/future multi-intent)
classification = await classify_intent(graph, content)
if classification is not None:
logger.info(
"Intent classification for thread %s: ambiguous=%s, intents=%s",
thread_id,
classification.is_ambiguous,
[i.agent_name for i in classification.intents],
)
# If ambiguous, send clarification and return
if classification.is_ambiguous and classification.clarification_question:
await _send_json(
ws,
{
"type": "clarification",
"thread_id": thread_id,
"message": classification.clarification_question,
},
)
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
return
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
# If multi-intent detected, add routing hint to the message
if classification and len(classification.intents) > 1:
agent_names = [i.agent_name for i in classification.intents]
hint = (
f"\n[System: This request involves multiple actions. "
f"Execute in order: {', '.join(agent_names)}]"
)
input_msg = {"messages": [HumanMessage(content=content + hint)]}
else:
input_msg = {"messages": [HumanMessage(content=content)]}
try:
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
msg_chunk, metadata = chunk
node = metadata.get("langgraph_node", "")
if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls:
for tc in msg_chunk.tool_calls:
await _send_json(
ws,
{
"type": "tool_call",
"agent": node,
"tool": tc.get("name", ""),
"args": tc.get("args", {}),
},
)
elif hasattr(msg_chunk, "content") and msg_chunk.content:
await _send_json(
ws,
{
"type": "token",
"agent": node,
"content": msg_chunk.content,
},
)
state = await graph.aget_state(config)
if _has_interrupt(state):
interrupt_data = _extract_interrupt(state)
session_manager.extend_for_interrupt(thread_id)
# Register interrupt with TTL tracking
if interrupt_manager is not None:
interrupt_manager.register(
thread_id=thread_id,
action=interrupt_data.get("action", "unknown"),
params=interrupt_data.get("params", {}),
)
await _send_json(
ws,
{
"type": "interrupt",
"thread_id": thread_id,
**interrupt_data,
},
)
else:
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
except Exception:
logger.exception("Error processing message for thread %s", thread_id)
err = "An error occurred processing your message."
await _send_json(ws, {"type": "error", "message": err})
async def handle_interrupt_response(
ws: WebSocket,
graph: CompiledStateGraph,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
approved: bool,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Resume graph execution after interrupt approval/rejection."""
# Check interrupt TTL before resuming
if interrupt_manager is not None:
status = interrupt_manager.check_status(thread_id)
if status is not None and status.is_expired:
retry_prompt = interrupt_manager.generate_retry_prompt(status.record)
interrupt_manager.resolve(thread_id)
session_manager.resolve_interrupt(thread_id)
await _send_json(ws, retry_prompt)
return
interrupt_manager.resolve(thread_id)
session_manager.resolve_interrupt(thread_id)
session_manager.touch(thread_id)
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
try:
async for chunk in graph.astream(
Command(resume=approved),
config=config,
stream_mode="messages",
):
msg_chunk, metadata = chunk
node = metadata.get("langgraph_node", "")
if hasattr(msg_chunk, "content") and msg_chunk.content:
await _send_json(
ws,
{
"type": "token",
"agent": node,
"content": msg_chunk.content,
},
)
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
except Exception:
logger.exception("Error resuming interrupt for thread %s", thread_id)
err = "An error occurred processing your response."
await _send_json(ws, {"type": "error", "message": err})
async def dispatch_message(
ws: WebSocket,
graph: CompiledStateGraph,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
raw_data: str,
interrupt_manager: InterruptManager | None = None,
analytics_recorder: AnalyticsRecorder | None = None,
conversation_tracker: ConversationTrackerProtocol | None = None,
pool: Any = None,
) -> None:
"""Parse and route an incoming WebSocket message."""
if len(raw_data) > MAX_MESSAGE_SIZE:
await _send_json(ws, {"type": "error", "message": "Message too large"})
return
try:
data = json.loads(raw_data)
except (json.JSONDecodeError, ValueError):
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
return
if not isinstance(data, dict):
await _send_json(ws, {"type": "error", "message": "Invalid JSON: expected object"})
return
msg_type = data.get("type")
thread_id = data.get("thread_id", "")
if not thread_id:
await _send_json(ws, {"type": "error", "message": "Missing thread_id"})
return
if not THREAD_ID_PATTERN.match(thread_id):
await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"})
return
if msg_type == "message":
content = data.get("content", "")
if not content or not content.strip():
await _send_json(ws, {"type": "error", "message": "Missing message content"})
return
if len(content) > MAX_CONTENT_LENGTH:
await _send_json(ws, {"type": "error", "message": "Message content too long"})
return
# Rate limiting check (per-thread, with bounded memory)
now = time.time()
cutoff = now - _RATE_LIMIT_WINDOW
if len(_thread_timestamps) > _MAX_TRACKED_THREADS:
_evict_stale_threads(cutoff)
recent = [t for t in _thread_timestamps[thread_id] if t >= cutoff]
if len(recent) >= _RATE_LIMIT_MAX:
await _send_json(ws, {"type": "error", "message": "Rate limit exceeded"})
return
_thread_timestamps[thread_id] = [*recent, now]
await handle_user_message(
ws, graph, session_manager, callback_handler, thread_id, content,
interrupt_manager=interrupt_manager,
)
await _fire_and_forget_tracking(
thread_id=thread_id,
pool=pool,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
agent_name=None,
tokens=0,
cost=0.0,
)
elif msg_type == "interrupt_response":
approved = data.get("approved", False)
await handle_interrupt_response(
ws, graph, session_manager, callback_handler, thread_id, approved,
interrupt_manager=interrupt_manager,
)
else:
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
async def _fire_and_forget_tracking(
thread_id: str,
pool: Any,
analytics_recorder: Any | None,
conversation_tracker: Any | None,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Fire-and-forget analytics/tracking; failures must NOT break chat."""
try:
if conversation_tracker is not None and pool is not None:
await conversation_tracker.ensure_conversation(pool, thread_id)
await conversation_tracker.record_turn(pool, thread_id, agent_name, tokens, cost)
except Exception:
logger.exception("Conversation tracker error for thread %s (suppressed)", thread_id)
try:
if analytics_recorder is not None:
await analytics_recorder.record(
thread_id=thread_id,
event_type="message",
agent_name=agent_name,
tokens_used=tokens,
cost_usd=cost,
)
except Exception:
logger.exception("Analytics recorder error for thread %s (suppressed)", thread_id)
def _has_interrupt(state: Any) -> bool:
"""Check if the graph state has a pending interrupt."""
tasks = getattr(state, "tasks", ())
return any(getattr(t, "interrupts", ()) for t in tasks)
def _extract_interrupt(state: Any) -> dict:
"""Extract interrupt data from graph state."""
for task in getattr(state, "tasks", ()):
for intr in getattr(task, "interrupts", ()):
value = intr.value if hasattr(intr, "value") else {}
if not isinstance(value, dict):
value = {}
return {
"action": value.get("action", "unknown"),
"params": value,
}
return {"action": "unknown", "params": {}}
async def _send_json(ws: WebSocket, data: dict) -> None:
"""Send a JSON message through the WebSocket."""
await ws.send_json(data)