"""Integration tests for WebSocket message flow. These tests exercise dispatch_message end-to-end with a mocked LangGraph graph, verifying streaming, interrupt approval/rejection, session TTL, and interrupt TTL expiration through the full message handling pipeline. """ from __future__ import annotations import json import time from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.callbacks import TokenUsageCallbackHandler from app.interrupt_manager import InterruptManager from app.session_manager import SessionManager from app.ws_handler import dispatch_message # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class AsyncIterHelper: """Make a list behave as an async iterator.""" def __init__(self, items: list) -> None: self._items = list(items) def __aiter__(self): return self async def __anext__(self): if not self._items: raise StopAsyncIteration return self._items.pop(0) class FakeWS: """Fake WebSocket that records sent messages.""" def __init__(self) -> None: self.sent: list[dict] = [] async def send_json(self, data: dict) -> None: self.sent.append(data) def _chunk(content: str, node: str = "order_lookup") -> tuple: c = MagicMock() c.content = content c.tool_calls = [] return (c, {"langgraph_node": node}) def _tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple: c = MagicMock() c.content = "" c.tool_calls = [{"name": name, "args": args}] return (c, {"langgraph_node": node}) def _state(*, interrupt: bool = False, data: dict | None = None) -> Any: s = MagicMock() if interrupt: obj = MagicMock() obj.value = data or {"action": "cancel_order", "order_id": "1042"} t = MagicMock() t.interrupts = (obj,) s.tasks = (t,) else: s.tasks = () return s def _graph( chunks: list | None = None, st: Any = None, resume_chunks: list | None = None, ) -> MagicMock: g = MagicMock() g.intent_classifier = None g.agent_registry = None if st is None: st = _state() streams = [chunks or [], resume_chunks or []] idx = {"n": 0} def make_stream(*a, **kw): i = min(idx["n"], len(streams) - 1) idx["n"] += 1 return AsyncIterHelper(list(streams[i])) g.astream = MagicMock(side_effect=make_stream) g.aget_state = AsyncMock(return_value=st) return g def _setup( graph=None, session_ttl: int = 1800, interrupt_ttl: int = 1800, thread_id: str = "t1", touch: bool = True, ): """Create test dependencies. Pre-touches session by default.""" g = graph or _graph() sm = SessionManager(session_ttl_seconds=session_ttl) im = InterruptManager(ttl_seconds=interrupt_ttl) cb = TokenUsageCallbackHandler() ws = FakeWS() if touch: sm.touch(thread_id) return g, sm, im, cb, ws async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"): raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content}) await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True): raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved}) await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @pytest.mark.integration class TestWebSocketHappyPath: @pytest.mark.asyncio async def test_send_message_receives_tokens_and_complete(self) -> None: g, sm, im, cb, ws = _setup( graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")]) ) await _send(ws, g, sm, im, cb, content="What is the status of order 1042?") tokens = [m for m in ws.sent if m["type"] == "token"] assert len(tokens) == 2 assert tokens[0]["content"] == "Order 1042 is " assert tokens[0]["agent"] == "order_lookup" assert tokens[1]["content"] == "shipped." completes = [m for m in ws.sent if m["type"] == "message_complete"] assert len(completes) == 1 assert completes[0]["thread_id"] == "t1" @pytest.mark.asyncio async def test_tool_call_streamed(self) -> None: g, sm, im, cb, ws = _setup( graph=_graph(chunks=[ _tool_chunk("get_order_status", {"order_id": "1042"}), _chunk("Order shipped."), ]) ) await _send(ws, g, sm, im, cb, content="Check order 1042") tools = [m for m in ws.sent if m["type"] == "tool_call"] assert len(tools) == 1 assert tools[0]["tool"] == "get_order_status" assert tools[0]["args"] == {"order_id": "1042"} @pytest.mark.asyncio async def test_multiple_messages_same_session(self) -> None: g, sm, im, cb, ws = _setup() for i in range(3): await _send(ws, g, sm, im, cb, content=f"msg {i}") completes = [m for m in ws.sent if m["type"] == "message_complete"] assert len(completes) == 3 @pytest.mark.integration class TestWebSocketInterruptApproval: @pytest.mark.asyncio async def test_interrupt_then_approve(self) -> None: st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) g = _graph(chunks=[], st=st_int, resume_chunks=[_chunk("Order 1042 cancelled.", "order_actions")]) g_, sm, im, cb, ws = _setup(graph=g) # Send message -> triggers interrupt await _send(ws, g_, sm, im, cb, content="Cancel order 1042") interrupts = [m for m in ws.sent if m["type"] == "interrupt"] assert len(interrupts) == 1 assert interrupts[0]["action"] == "cancel_order" assert interrupts[0]["thread_id"] == "t1" assert im.has_pending("t1") # Approve ws.sent.clear() await _respond(ws, g_, sm, im, cb, approved=True) tokens = [m for m in ws.sent if m["type"] == "token"] assert len(tokens) == 1 assert "cancelled" in tokens[0]["content"] completes = [m for m in ws.sent if m["type"] == "message_complete"] assert len(completes) == 1 assert not im.has_pending("t1") @pytest.mark.asyncio async def test_interrupt_then_reject(self) -> None: st_int = _state(interrupt=True) g = _graph(chunks=[], st=st_int, resume_chunks=[_chunk("Order remains active.", "order_actions")]) g_, sm, im, cb, ws = _setup(graph=g) await _send(ws, g_, sm, im, cb, content="Cancel order 1042") ws.sent.clear() await _respond(ws, g_, sm, im, cb, approved=False) tokens = [m for m in ws.sent if m["type"] == "token"] assert "remains active" in tokens[0]["content"] @pytest.mark.integration class TestWebSocketSessionTTL: @pytest.mark.asyncio async def test_expired_session_returns_error(self) -> None: g, sm, im, cb, ws = _setup(session_ttl=0) # Session was touched in _setup, but TTL is 0 so it's already expired await _send(ws, g, sm, im, cb, content="hello") assert ws.sent[0]["type"] == "error" assert "expired" in ws.sent[0]["message"].lower() @pytest.mark.asyncio async def test_new_session_not_expired(self) -> None: g, sm, im, cb, ws = _setup(session_ttl=3600) await _send(ws, g, sm, im, cb, content="hello") completes = [m for m in ws.sent if m["type"] == "message_complete"] assert len(completes) == 1 @pytest.mark.asyncio async def test_sliding_window_resets_on_message(self) -> None: g, sm, im, cb, ws = _setup(session_ttl=3600) await _send(ws, g, sm, im, cb, content="hello") first_activity = sm.get_state("t1").last_activity time.sleep(0.01) await _send(ws, g, sm, im, cb, content="hello again") second_activity = sm.get_state("t1").last_activity assert second_activity > first_activity @pytest.mark.asyncio async def test_interrupt_extends_session_ttl(self) -> None: st_int = _state(interrupt=True) g = _graph(chunks=[], st=st_int) g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600) await _send(ws, g_, sm, im, cb, content="cancel order") state = sm.get_state("t1") assert state is not None assert state.has_pending_interrupt assert not sm.is_expired("t1") @pytest.mark.integration class TestWebSocketValidation: @pytest.mark.asyncio async def test_invalid_json(self) -> None: g, sm, im, cb, ws = _setup() await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im) assert ws.sent[0]["type"] == "error" assert "Invalid JSON" in ws.sent[0]["message"] @pytest.mark.asyncio async def test_missing_thread_id(self) -> None: g, sm, im, cb, ws = _setup() raw = json.dumps({"type": "message", "content": "hi"}) await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) assert ws.sent[0]["type"] == "error" assert "thread_id" in ws.sent[0]["message"] @pytest.mark.asyncio async def test_invalid_thread_id_format(self) -> None: g, sm, im, cb, ws = _setup() raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"}) await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) assert ws.sent[0]["type"] == "error" @pytest.mark.asyncio async def test_missing_content(self) -> None: g, sm, im, cb, ws = _setup() raw = json.dumps({"type": "message", "thread_id": "t1"}) await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) assert ws.sent[0]["type"] == "error" @pytest.mark.asyncio async def test_unknown_message_type(self) -> None: g, sm, im, cb, ws = _setup() raw = json.dumps({"type": "foobar", "thread_id": "t1"}) await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) assert ws.sent[0]["type"] == "error" assert "Unknown" in ws.sent[0]["message"] @pytest.mark.asyncio async def test_message_too_large(self) -> None: g, sm, im, cb, ws = _setup() await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im) assert ws.sent[0]["type"] == "error" assert "too large" in ws.sent[0]["message"].lower() @pytest.mark.asyncio async def test_content_too_long(self) -> None: g, sm, im, cb, ws = _setup() raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000}) await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) assert ws.sent[0]["type"] == "error" assert "too long" in ws.sent[0]["message"].lower() @pytest.mark.integration class TestWebSocketInterruptTTL: @pytest.mark.asyncio async def test_expired_interrupt_sends_retry_prompt(self) -> None: st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) g = _graph(chunks=[], st=st_int) g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5) # Trigger interrupt await _send(ws, g_, sm, im, cb, content="Cancel order 1042") interrupts = [m for m in ws.sent if m["type"] == "interrupt"] assert len(interrupts) == 1 # Expire the interrupt record = im._interrupts["t1"] ws.sent.clear() with patch("app.interrupt_manager.time") as mock_time: mock_time.time.return_value = record.created_at + 10 await _respond(ws, g_, sm, im, cb, approved=True) assert ws.sent[0]["type"] == "interrupt_expired" assert "cancel_order" in ws.sent[0]["message"] assert ws.sent[0]["thread_id"] == "t1"