Files
smart-support/backend/tests/unit/test_ws_handler.py
Yaojia Wang af53111928 refactor: fix architectural issues across frontend and backend
Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
2026-04-06 15:59:14 +02:00

441 lines
15 KiB
Python

"""Tests for app.ws_handler module."""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import (
_extract_interrupt,
_has_interrupt,
dispatch_message,
handle_interrupt_response,
handle_user_message,
)
def _make_ws() -> AsyncMock:
ws = AsyncMock()
ws.send_json = AsyncMock()
return ws
def _make_graph() -> MagicMock:
graph = AsyncMock()
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
state = MagicMock()
state.tasks = ()
graph.aget_state = AsyncMock(return_value=state)
return graph
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
g = graph or _make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
def _make_ws_ctx(
graph_ctx: GraphContext | None = None,
sm: SessionManager | None = None,
cb: TokenUsageCallbackHandler | None = None,
interrupt_manager: InterruptManager | None = None,
analytics_recorder=None,
conversation_tracker=None,
pool=None,
) -> WebSocketContext:
return WebSocketContext(
graph_ctx=graph_ctx or _make_graph_ctx(),
session_manager=sm or SessionManager(),
callback_handler=cb or TokenUsageCallbackHandler(),
interrupt_manager=interrupt_manager,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
pool=pool,
)
class AsyncIterHelper:
"""Helper to make a list behave as an async iterator."""
def __init__(self, items):
self._items = items
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
@pytest.mark.unit
class TestDispatchMessage:
@pytest.mark.asyncio
async def test_invalid_json(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, ws_ctx, "not json")
ws.send_json.assert_awaited_once()
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "Invalid JSON" in call_data["message"]
@pytest.mark.asyncio
async def test_missing_thread_id(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "message", "content": "hello"})
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "thread_id" in call_data["message"]
@pytest.mark.asyncio
async def test_missing_content(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "message", "thread_id": "t1"})
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.asyncio
async def test_unknown_message_type(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "Unknown" in call_data["message"]
@pytest.mark.asyncio
async def test_message_too_large(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
large_msg = "x" * 40_000
await dispatch_message(ws, ws_ctx, large_msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "too large" in call_data["message"].lower()
@pytest.mark.asyncio
async def test_invalid_thread_id_format(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "thread_id" in call_data["message"].lower()
@pytest.mark.asyncio
async def test_content_too_long(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "too long" in call_data["message"].lower()
@pytest.mark.asyncio
async def test_dispatch_with_interrupt_manager(self) -> None:
ws = _make_ws()
sm = SessionManager()
im = InterruptManager()
ws_ctx = _make_ws_ctx(sm=sm, interrupt_manager=im)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.unit
class TestHandleUserMessage:
@pytest.mark.asyncio
async def test_expired_session(self) -> None:
ws = _make_ws()
graph_ctx = _make_graph_ctx()
sm = SessionManager(session_ttl_seconds=0)
cb = TokenUsageCallbackHandler()
# First call creates the session (TTL=0)
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
# Second call finds it expired
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello again")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "expired" in call_data["message"].lower()
@pytest.mark.asyncio
async def test_successful_message(self) -> None:
ws = _make_ws()
graph_ctx = _make_graph_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.asyncio
async def test_graph_error_sends_error_message(self) -> None:
ws = _make_ws()
graph = AsyncMock()
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
graph_ctx = _make_graph_ctx(graph=graph)
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.asyncio
async def test_interrupt_registered_with_manager(self) -> None:
ws = _make_ws()
graph = AsyncMock()
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
# Simulate interrupt in state
interrupt_obj = MagicMock()
interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"}
task = MagicMock()
task.interrupts = (interrupt_obj,)
state = MagicMock()
state.tasks = (task,)
graph.aget_state = AsyncMock(return_value=state)
graph_ctx = _make_graph_ctx(graph=graph)
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager()
sm.touch("t1")
await handle_user_message(
ws, graph_ctx, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
)
# Interrupt should be registered
assert im.has_pending("t1")
# Should have sent interrupt message
calls = [c[0][0] for c in ws.send_json.call_args_list]
interrupt_msgs = [c for c in calls if c.get("type") == "interrupt"]
assert len(interrupt_msgs) == 1
@pytest.mark.asyncio
async def test_ambiguous_intent_sends_clarification(self) -> None:
from app.intent import ClassificationResult
ws = _make_ws()
graph = AsyncMock()
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
state = MagicMock()
state.tasks = ()
graph.aget_state = AsyncMock(return_value=state)
# Set up intent classifier that returns ambiguous
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(
return_value=ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="What do you mean?",
)
)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hmm")
calls = [c[0][0] for c in ws.send_json.call_args_list]
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
assert len(clarification_msgs) == 1
assert clarification_msgs[0]["message"] == "What do you mean?"
@pytest.mark.unit
class TestHandleInterruptResponse:
@pytest.mark.asyncio
async def test_approved_interrupt(self) -> None:
ws = _make_ws()
graph_ctx = _make_graph_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
sm.extend_for_interrupt("t1")
await handle_interrupt_response(ws, graph_ctx, sm, cb, "t1", True)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.asyncio
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
from unittest.mock import patch
ws = _make_ws()
graph_ctx = _make_graph_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=5)
sm.touch("t1")
sm.extend_for_interrupt("t1")
im.register("t1", "cancel_order", {"order_id": "1042"})
# Expire the interrupt
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
await handle_interrupt_response(
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "interrupt_expired"
assert "cancel_order" in call_data["message"]
@pytest.mark.asyncio
async def test_valid_interrupt_resolves(self) -> None:
ws = _make_ws()
graph_ctx = _make_graph_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=1800)
sm.touch("t1")
sm.extend_for_interrupt("t1")
im.register("t1", "cancel_order", {})
await handle_interrupt_response(
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
)
# Interrupt should be resolved
assert not im.has_pending("t1")
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.unit
class TestInterruptHelpers:
def test_has_interrupt_false_for_empty_tasks(self) -> None:
state = MagicMock()
state.tasks = ()
assert not _has_interrupt(state)
def test_has_interrupt_true(self) -> None:
interrupt_obj = MagicMock()
interrupt_obj.value = {"action": "cancel"}
task = MagicMock()
task.interrupts = (interrupt_obj,)
state = MagicMock()
state.tasks = (task,)
assert _has_interrupt(state)
def test_extract_interrupt_data(self) -> None:
interrupt_obj = MagicMock()
interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"}
task = MagicMock()
task.interrupts = (interrupt_obj,)
state = MagicMock()
state.tasks = (task,)
data = _extract_interrupt(state)
assert data["action"] == "cancel_order"
def test_extract_interrupt_empty(self) -> None:
state = MagicMock()
state.tasks = ()
data = _extract_interrupt(state)
assert data["action"] == "unknown"
@pytest.mark.unit
class TestDispatchMessageWithTracking:
@pytest.mark.asyncio
async def test_conversation_tracker_called_on_message(self) -> None:
ws = _make_ws()
sm = SessionManager()
tracker = AsyncMock()
pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(ws, ws_ctx, msg)
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
tracker.record_turn.assert_awaited_once()
@pytest.mark.asyncio
async def test_analytics_recorder_called_on_message(self) -> None:
ws = _make_ws()
sm = SessionManager()
recorder = AsyncMock()
pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, analytics_recorder=recorder, pool=pool)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(ws, ws_ctx, msg)
recorder.record.assert_awaited_once()
@pytest.mark.asyncio
async def test_tracker_failure_does_not_break_chat(self) -> None:
ws = _make_ws()
sm = SessionManager()
tracker = AsyncMock()
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# Should not raise despite tracker failure
await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.asyncio
async def test_no_tracker_no_error(self) -> None:
ws = _make_ws()
sm = SessionManager()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# No tracker or recorder passed -- should work fine
await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"