Files
smart-support/backend/tests/integration/test_websocket.py
Yaojia Wang 0e78e5b06b feat: complete phase 5 -- error hardening, frontend, Docker, demo, docs
Backend:
- ConversationTracker: Protocol + PostgresConversationTracker for lifecycle tracking
- Error handler: ErrorCategory enum, classify_error(), with_retry() exponential backoff
- Wire PostgresAnalyticsRecorder + ConversationTracker into ws_handler
- Rate limiting (10 msg/10s per thread), edge case hardening
- Health endpoint GET /api/health, version 0.5.0
- Demo seed data script + sample OpenAPI spec

Frontend (all new):
- React Router with NavBar (Chat / Replay / Dashboard / Review)
- ReplayListPage + ReplayPage with ReplayTimeline component
- DashboardPage with MetricCard, range selector, zero-state
- ReviewPage for OpenAPI classification review
- ErrorBanner for WebSocket disconnect handling
- API client (api.ts) with typed fetch wrappers

Infrastructure:
- Frontend Dockerfile (multi-stage node -> nginx)
- nginx.conf with SPA routing + API/WS proxy
- docker-compose.yml with frontend service + healthchecks
- .env.example files (root + backend)

Documentation:
- README.md with quick start and architecture
- Agent configuration guide
- OpenAPI import guide
- Deployment guide
- Demo script

48 new tests, 449 total passing, 92.87% coverage
2026-03-31 21:20:06 +02:00

349 lines
12 KiB
Python

"""Integration tests for WebSocket message flow.
These tests exercise dispatch_message end-to-end with a mocked LangGraph
graph, verifying streaming, interrupt approval/rejection, session TTL,
and interrupt TTL expiration through the full message handling pipeline.
"""
from __future__ import annotations
import json
import time
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
from app.ws_handler import dispatch_message
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class AsyncIterHelper:
"""Make a list behave as an async iterator."""
def __init__(self, items: list) -> None:
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
class FakeWS:
"""Fake WebSocket that records sent messages."""
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_json(self, data: dict) -> None:
self.sent.append(data)
def _chunk(content: str, node: str = "order_lookup") -> tuple:
c = MagicMock()
c.content = content
c.tool_calls = []
return (c, {"langgraph_node": node})
def _tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
c = MagicMock()
c.content = ""
c.tool_calls = [{"name": name, "args": args}]
return (c, {"langgraph_node": node})
def _state(*, interrupt: bool = False, data: dict | None = None) -> Any:
s = MagicMock()
if interrupt:
obj = MagicMock()
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
t = MagicMock()
t.interrupts = (obj,)
s.tasks = (t,)
else:
s.tasks = ()
return s
def _graph(
chunks: list | None = None,
st: Any = None,
resume_chunks: list | None = None,
) -> MagicMock:
g = MagicMock()
g.intent_classifier = None
g.agent_registry = None
if st is None:
st = _state()
streams = [chunks or [], resume_chunks or []]
idx = {"n": 0}
def make_stream(*a, **kw):
i = min(idx["n"], len(streams) - 1)
idx["n"] += 1
return AsyncIterHelper(list(streams[i]))
g.astream = MagicMock(side_effect=make_stream)
g.aget_state = AsyncMock(return_value=st)
return g
def _setup(
graph=None,
session_ttl: int = 1800,
interrupt_ttl: int = 1800,
thread_id: str = "t1",
touch: bool = True,
):
"""Create test dependencies. Pre-touches session by default."""
g = graph or _graph()
sm = SessionManager(session_ttl_seconds=session_ttl)
im = InterruptManager(ttl_seconds=interrupt_ttl)
cb = TokenUsageCallbackHandler()
ws = FakeWS()
if touch:
sm.touch(thread_id)
return g, sm, im, cb, ws
async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"):
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestWebSocketHappyPath:
@pytest.mark.asyncio
async def test_send_message_receives_tokens_and_complete(self) -> None:
g, sm, im, cb, ws = _setup(
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
)
await _send(ws, g, sm, im, cb, content="What is the status of order 1042?")
tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 2
assert tokens[0]["content"] == "Order 1042 is "
assert tokens[0]["agent"] == "order_lookup"
assert tokens[1]["content"] == "shipped."
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 1
assert completes[0]["thread_id"] == "t1"
@pytest.mark.asyncio
async def test_tool_call_streamed(self) -> None:
g, sm, im, cb, ws = _setup(
graph=_graph(chunks=[
_tool_chunk("get_order_status", {"order_id": "1042"}),
_chunk("Order shipped."),
])
)
await _send(ws, g, sm, im, cb, content="Check order 1042")
tools = [m for m in ws.sent if m["type"] == "tool_call"]
assert len(tools) == 1
assert tools[0]["tool"] == "get_order_status"
assert tools[0]["args"] == {"order_id": "1042"}
@pytest.mark.asyncio
async def test_multiple_messages_same_session(self) -> None:
g, sm, im, cb, ws = _setup()
for i in range(3):
await _send(ws, g, sm, im, cb, content=f"msg {i}")
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 3
@pytest.mark.integration
class TestWebSocketInterruptApproval:
@pytest.mark.asyncio
async def test_interrupt_then_approve(self) -> None:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws = _setup(graph=g)
# Send message -> triggers interrupt
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1
assert interrupts[0]["action"] == "cancel_order"
assert interrupts[0]["thread_id"] == "t1"
assert im.has_pending("t1")
# Approve
ws.sent.clear()
await _respond(ws, g_, sm, im, cb, approved=True)
tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 1
assert "cancelled" in tokens[0]["content"]
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 1
assert not im.has_pending("t1")
@pytest.mark.asyncio
async def test_interrupt_then_reject(self) -> None:
st_int = _state(interrupt=True)
resume = [_chunk("Order remains active.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws = _setup(graph=g)
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
ws.sent.clear()
await _respond(ws, g_, sm, im, cb, approved=False)
tokens = [m for m in ws.sent if m["type"] == "token"]
assert "remains active" in tokens[0]["content"]
@pytest.mark.integration
class TestWebSocketSessionTTL:
@pytest.mark.asyncio
async def test_expired_session_returns_error(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=0)
# Session was touched in _setup, but TTL is 0 so it's already expired
await _send(ws, g, sm, im, cb, content="hello")
assert ws.sent[0]["type"] == "error"
assert "expired" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio
async def test_new_session_not_expired(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=3600)
await _send(ws, g, sm, im, cb, content="hello")
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 1
@pytest.mark.asyncio
async def test_sliding_window_resets_on_message(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=3600)
await _send(ws, g, sm, im, cb, content="hello")
first_activity = sm.get_state("t1").last_activity
time.sleep(0.01)
await _send(ws, g, sm, im, cb, content="hello again")
second_activity = sm.get_state("t1").last_activity
assert second_activity > first_activity
@pytest.mark.asyncio
async def test_interrupt_extends_session_ttl(self) -> None:
st_int = _state(interrupt=True)
g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600)
await _send(ws, g_, sm, im, cb, content="cancel order")
state = sm.get_state("t1")
assert state is not None
assert state.has_pending_interrupt
assert not sm.is_expired("t1")
@pytest.mark.integration
class TestWebSocketValidation:
@pytest.mark.asyncio
async def test_invalid_json(self) -> None:
g, sm, im, cb, ws = _setup()
await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im)
assert ws.sent[0]["type"] == "error"
assert "Invalid JSON" in ws.sent[0]["message"]
@pytest.mark.asyncio
async def test_missing_thread_id(self) -> None:
g, sm, im, cb, ws = _setup()
raw = json.dumps({"type": "message", "content": "hi"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
assert ws.sent[0]["type"] == "error"
assert "thread_id" in ws.sent[0]["message"]
@pytest.mark.asyncio
async def test_invalid_thread_id_format(self) -> None:
g, sm, im, cb, ws = _setup()
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio
async def test_missing_content(self) -> None:
g, sm, im, cb, ws = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio
async def test_unknown_message_type(self) -> None:
g, sm, im, cb, ws = _setup()
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
assert ws.sent[0]["type"] == "error"
assert "Unknown" in ws.sent[0]["message"]
@pytest.mark.asyncio
async def test_message_too_large(self) -> None:
g, sm, im, cb, ws = _setup()
await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im)
assert ws.sent[0]["type"] == "error"
assert "too large" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio
async def test_content_too_long(self) -> None:
g, sm, im, cb, ws = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
assert ws.sent[0]["type"] == "error"
assert "too long" in ws.sent[0]["message"].lower()
@pytest.mark.integration
class TestWebSocketInterruptTTL:
@pytest.mark.asyncio
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5)
# Trigger interrupt
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1
# Expire the interrupt
record = im._interrupts["t1"]
ws.sent.clear()
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 10
await _respond(ws, g_, sm, im, cb, approved=True)
assert ws.sent[0]["type"] == "interrupt_expired"
assert "cancel_order" in ws.sent[0]["message"]
assert ws.sent[0]["thread_id"] == "t1"