From 6e7b824b644566e725a660cb6dae91ea13942cb1 Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Mon, 30 Mar 2026 21:24:31 +0200 Subject: [PATCH] test: add integration tests for WebSocket message flow 17 integration tests covering: - Happy path: token streaming, tool calls, multi-message sessions - Interrupt flow: approve and reject paths with manager tracking - Session TTL: expiration, sliding window reset, interrupt extension - Validation: invalid JSON, missing fields, size limits - Interrupt TTL: expired interrupt sends retry prompt Fills Phase 1 test gap for integration-level WebSocket coverage. Total: 170 tests, 92.15% coverage. --- backend/tests/integration/test_websocket.py | 347 ++++++++++++++++++++ 1 file changed, 347 insertions(+) create mode 100644 backend/tests/integration/test_websocket.py diff --git a/backend/tests/integration/test_websocket.py b/backend/tests/integration/test_websocket.py new file mode 100644 index 0000000..74222dd --- /dev/null +++ b/backend/tests/integration/test_websocket.py @@ -0,0 +1,347 @@ +"""Integration tests for WebSocket message flow. + +These tests exercise dispatch_message end-to-end with a mocked LangGraph +graph, verifying streaming, interrupt approval/rejection, session TTL, +and interrupt TTL expiration through the full message handling pipeline. +""" + +from __future__ import annotations + +import json +import time +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.callbacks import TokenUsageCallbackHandler +from app.interrupt_manager import InterruptManager +from app.session_manager import SessionManager +from app.ws_handler import dispatch_message + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class AsyncIterHelper: + """Make a list behave as an async iterator.""" + + def __init__(self, items: list) -> None: + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class FakeWS: + """Fake WebSocket that records sent messages.""" + + def __init__(self) -> None: + self.sent: list[dict] = [] + + async def send_json(self, data: dict) -> None: + self.sent.append(data) + + +def _chunk(content: str, node: str = "order_lookup") -> tuple: + c = MagicMock() + c.content = content + c.tool_calls = [] + return (c, {"langgraph_node": node}) + + +def _tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple: + c = MagicMock() + c.content = "" + c.tool_calls = [{"name": name, "args": args}] + return (c, {"langgraph_node": node}) + + +def _state(*, interrupt: bool = False, data: dict | None = None) -> Any: + s = MagicMock() + if interrupt: + obj = MagicMock() + obj.value = data or {"action": "cancel_order", "order_id": "1042"} + t = MagicMock() + t.interrupts = (obj,) + s.tasks = (t,) + else: + s.tasks = () + return s + + +def _graph( + chunks: list | None = None, + st: Any = None, + resume_chunks: list | None = None, +) -> MagicMock: + g = MagicMock() + g.intent_classifier = None + g.agent_registry = None + + if st is None: + st = _state() + + streams = [chunks or [], resume_chunks or []] + idx = {"n": 0} + + def make_stream(*a, **kw): + i = min(idx["n"], len(streams) - 1) + idx["n"] += 1 + return AsyncIterHelper(list(streams[i])) + + g.astream = MagicMock(side_effect=make_stream) + g.aget_state = AsyncMock(return_value=st) + return g + + +def _setup( + graph=None, + session_ttl: int = 1800, + interrupt_ttl: int = 1800, + thread_id: str = "t1", + touch: bool = True, +): + """Create test dependencies. Pre-touches session by default.""" + g = graph or _graph() + sm = SessionManager(session_ttl_seconds=session_ttl) + im = InterruptManager(ttl_seconds=interrupt_ttl) + cb = TokenUsageCallbackHandler() + ws = FakeWS() + if touch: + sm.touch(thread_id) + return g, sm, im, cb, ws + + +async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"): + raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content}) + await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + + +async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True): + raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved}) + await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestWebSocketHappyPath: + @pytest.mark.asyncio + async def test_send_message_receives_tokens_and_complete(self) -> None: + g, sm, im, cb, ws = _setup( + graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")]) + ) + await _send(ws, g, sm, im, cb, content="What is the status of order 1042?") + + tokens = [m for m in ws.sent if m["type"] == "token"] + assert len(tokens) == 2 + assert tokens[0]["content"] == "Order 1042 is " + assert tokens[0]["agent"] == "order_lookup" + assert tokens[1]["content"] == "shipped." + + completes = [m for m in ws.sent if m["type"] == "message_complete"] + assert len(completes) == 1 + assert completes[0]["thread_id"] == "t1" + + @pytest.mark.asyncio + async def test_tool_call_streamed(self) -> None: + g, sm, im, cb, ws = _setup( + graph=_graph(chunks=[ + _tool_chunk("get_order_status", {"order_id": "1042"}), + _chunk("Order shipped."), + ]) + ) + await _send(ws, g, sm, im, cb, content="Check order 1042") + + tools = [m for m in ws.sent if m["type"] == "tool_call"] + assert len(tools) == 1 + assert tools[0]["tool"] == "get_order_status" + assert tools[0]["args"] == {"order_id": "1042"} + + @pytest.mark.asyncio + async def test_multiple_messages_same_session(self) -> None: + g, sm, im, cb, ws = _setup() + for i in range(3): + await _send(ws, g, sm, im, cb, content=f"msg {i}") + + completes = [m for m in ws.sent if m["type"] == "message_complete"] + assert len(completes) == 3 + + +@pytest.mark.integration +class TestWebSocketInterruptApproval: + @pytest.mark.asyncio + async def test_interrupt_then_approve(self) -> None: + st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) + g = _graph(chunks=[], st=st_int, resume_chunks=[_chunk("Order 1042 cancelled.", "order_actions")]) + g_, sm, im, cb, ws = _setup(graph=g) + + # Send message -> triggers interrupt + await _send(ws, g_, sm, im, cb, content="Cancel order 1042") + + interrupts = [m for m in ws.sent if m["type"] == "interrupt"] + assert len(interrupts) == 1 + assert interrupts[0]["action"] == "cancel_order" + assert interrupts[0]["thread_id"] == "t1" + assert im.has_pending("t1") + + # Approve + ws.sent.clear() + await _respond(ws, g_, sm, im, cb, approved=True) + + tokens = [m for m in ws.sent if m["type"] == "token"] + assert len(tokens) == 1 + assert "cancelled" in tokens[0]["content"] + + completes = [m for m in ws.sent if m["type"] == "message_complete"] + assert len(completes) == 1 + assert not im.has_pending("t1") + + @pytest.mark.asyncio + async def test_interrupt_then_reject(self) -> None: + st_int = _state(interrupt=True) + g = _graph(chunks=[], st=st_int, resume_chunks=[_chunk("Order remains active.", "order_actions")]) + g_, sm, im, cb, ws = _setup(graph=g) + + await _send(ws, g_, sm, im, cb, content="Cancel order 1042") + ws.sent.clear() + + await _respond(ws, g_, sm, im, cb, approved=False) + + tokens = [m for m in ws.sent if m["type"] == "token"] + assert "remains active" in tokens[0]["content"] + + +@pytest.mark.integration +class TestWebSocketSessionTTL: + @pytest.mark.asyncio + async def test_expired_session_returns_error(self) -> None: + g, sm, im, cb, ws = _setup(session_ttl=0) + # Session was touched in _setup, but TTL is 0 so it's already expired + await _send(ws, g, sm, im, cb, content="hello") + assert ws.sent[0]["type"] == "error" + assert "expired" in ws.sent[0]["message"].lower() + + @pytest.mark.asyncio + async def test_new_session_not_expired(self) -> None: + g, sm, im, cb, ws = _setup(session_ttl=3600) + await _send(ws, g, sm, im, cb, content="hello") + completes = [m for m in ws.sent if m["type"] == "message_complete"] + assert len(completes) == 1 + + @pytest.mark.asyncio + async def test_sliding_window_resets_on_message(self) -> None: + g, sm, im, cb, ws = _setup(session_ttl=3600) + + await _send(ws, g, sm, im, cb, content="hello") + first_activity = sm.get_state("t1").last_activity + + time.sleep(0.01) + await _send(ws, g, sm, im, cb, content="hello again") + second_activity = sm.get_state("t1").last_activity + + assert second_activity > first_activity + + @pytest.mark.asyncio + async def test_interrupt_extends_session_ttl(self) -> None: + st_int = _state(interrupt=True) + g = _graph(chunks=[], st=st_int) + g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600) + + await _send(ws, g_, sm, im, cb, content="cancel order") + + state = sm.get_state("t1") + assert state is not None + assert state.has_pending_interrupt + assert not sm.is_expired("t1") + + +@pytest.mark.integration +class TestWebSocketValidation: + @pytest.mark.asyncio + async def test_invalid_json(self) -> None: + g, sm, im, cb, ws = _setup() + await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im) + assert ws.sent[0]["type"] == "error" + assert "Invalid JSON" in ws.sent[0]["message"] + + @pytest.mark.asyncio + async def test_missing_thread_id(self) -> None: + g, sm, im, cb, ws = _setup() + raw = json.dumps({"type": "message", "content": "hi"}) + await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + assert ws.sent[0]["type"] == "error" + assert "thread_id" in ws.sent[0]["message"] + + @pytest.mark.asyncio + async def test_invalid_thread_id_format(self) -> None: + g, sm, im, cb, ws = _setup() + raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"}) + await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + assert ws.sent[0]["type"] == "error" + + @pytest.mark.asyncio + async def test_missing_content(self) -> None: + g, sm, im, cb, ws = _setup() + raw = json.dumps({"type": "message", "thread_id": "t1"}) + await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + assert ws.sent[0]["type"] == "error" + + @pytest.mark.asyncio + async def test_unknown_message_type(self) -> None: + g, sm, im, cb, ws = _setup() + raw = json.dumps({"type": "foobar", "thread_id": "t1"}) + await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + assert ws.sent[0]["type"] == "error" + assert "Unknown" in ws.sent[0]["message"] + + @pytest.mark.asyncio + async def test_message_too_large(self) -> None: + g, sm, im, cb, ws = _setup() + await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im) + assert ws.sent[0]["type"] == "error" + assert "too large" in ws.sent[0]["message"].lower() + + @pytest.mark.asyncio + async def test_content_too_long(self) -> None: + g, sm, im, cb, ws = _setup() + raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000}) + await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + assert ws.sent[0]["type"] == "error" + assert "too long" in ws.sent[0]["message"].lower() + + +@pytest.mark.integration +class TestWebSocketInterruptTTL: + @pytest.mark.asyncio + async def test_expired_interrupt_sends_retry_prompt(self) -> None: + st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) + g = _graph(chunks=[], st=st_int) + g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5) + + # Trigger interrupt + await _send(ws, g_, sm, im, cb, content="Cancel order 1042") + + interrupts = [m for m in ws.sent if m["type"] == "interrupt"] + assert len(interrupts) == 1 + + # Expire the interrupt + record = im._interrupts["t1"] + ws.sent.clear() + + with patch("app.interrupt_manager.time") as mock_time: + mock_time.time.return_value = record.created_at + 10 + await _respond(ws, g_, sm, im, cb, approved=True) + + assert ws.sent[0]["type"] == "interrupt_expired" + assert "cancel_order" in ws.sent[0]["message"] + assert ws.sent[0]["thread_id"] == "t1"