Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
361 lines
12 KiB
Python
361 lines
12 KiB
Python
"""Integration tests for WebSocket message flow.
|
|
|
|
These tests exercise dispatch_message end-to-end with a mocked LangGraph
|
|
graph, verifying streaming, interrupt approval/rejection, session TTL,
|
|
and interrupt TTL expiration through the full message handling pipeline.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.callbacks import TokenUsageCallbackHandler
|
|
from app.graph_context import GraphContext
|
|
from app.interrupt_manager import InterruptManager
|
|
from app.session_manager import SessionManager
|
|
from app.ws_context import WebSocketContext
|
|
from app.ws_handler import dispatch_message
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class AsyncIterHelper:
|
|
"""Make a list behave as an async iterator."""
|
|
|
|
def __init__(self, items: list) -> None:
|
|
self._items = list(items)
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if not self._items:
|
|
raise StopAsyncIteration
|
|
return self._items.pop(0)
|
|
|
|
|
|
class FakeWS:
|
|
"""Fake WebSocket that records sent messages."""
|
|
|
|
def __init__(self) -> None:
|
|
self.sent: list[dict] = []
|
|
|
|
async def send_json(self, data: dict) -> None:
|
|
self.sent.append(data)
|
|
|
|
|
|
def _chunk(content: str, node: str = "order_lookup") -> tuple:
|
|
c = MagicMock()
|
|
c.content = content
|
|
c.tool_calls = []
|
|
return (c, {"langgraph_node": node})
|
|
|
|
|
|
def _tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
|
|
c = MagicMock()
|
|
c.content = ""
|
|
c.tool_calls = [{"name": name, "args": args}]
|
|
return (c, {"langgraph_node": node})
|
|
|
|
|
|
def _state(*, interrupt: bool = False, data: dict | None = None) -> Any:
|
|
s = MagicMock()
|
|
if interrupt:
|
|
obj = MagicMock()
|
|
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
|
t = MagicMock()
|
|
t.interrupts = (obj,)
|
|
s.tasks = (t,)
|
|
else:
|
|
s.tasks = ()
|
|
return s
|
|
|
|
|
|
def _graph(
|
|
chunks: list | None = None,
|
|
st: Any = None,
|
|
resume_chunks: list | None = None,
|
|
) -> MagicMock:
|
|
g = MagicMock()
|
|
|
|
if st is None:
|
|
st = _state()
|
|
|
|
streams = [chunks or [], resume_chunks or []]
|
|
idx = {"n": 0}
|
|
|
|
def make_stream(*a, **kw):
|
|
i = min(idx["n"], len(streams) - 1)
|
|
idx["n"] += 1
|
|
return AsyncIterHelper(list(streams[i]))
|
|
|
|
g.astream = MagicMock(side_effect=make_stream)
|
|
g.aget_state = AsyncMock(return_value=st)
|
|
return g
|
|
|
|
|
|
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
|
g = graph or _graph()
|
|
registry = MagicMock()
|
|
registry.list_agents = MagicMock(return_value=())
|
|
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
|
|
|
|
|
def _setup(
|
|
graph=None,
|
|
session_ttl: int = 1800,
|
|
interrupt_ttl: int = 1800,
|
|
thread_id: str = "t1",
|
|
touch: bool = True,
|
|
):
|
|
"""Create test dependencies. Pre-touches session by default."""
|
|
g = graph or _graph()
|
|
graph_ctx = _make_graph_ctx(g)
|
|
sm = SessionManager(session_ttl_seconds=session_ttl)
|
|
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
|
cb = TokenUsageCallbackHandler()
|
|
ws = FakeWS()
|
|
ws_ctx = WebSocketContext(
|
|
graph_ctx=graph_ctx, session_manager=sm,
|
|
callback_handler=cb, interrupt_manager=im,
|
|
)
|
|
if touch:
|
|
sm.touch(thread_id)
|
|
return g, sm, im, cb, ws, ws_ctx
|
|
|
|
|
|
async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
|
|
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
|
|
await dispatch_message(ws, ws_ctx, raw)
|
|
|
|
|
|
async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
|
|
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
|
|
await dispatch_message(ws, ws_ctx, raw)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.integration
|
|
class TestWebSocketHappyPath:
|
|
@pytest.mark.asyncio
|
|
async def test_send_message_receives_tokens_and_complete(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup(
|
|
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
|
|
)
|
|
await _send(ws, ws_ctx, content="What is the status of order 1042?")
|
|
|
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
|
assert len(tokens) == 2
|
|
assert tokens[0]["content"] == "Order 1042 is "
|
|
assert tokens[0]["agent"] == "order_lookup"
|
|
assert tokens[1]["content"] == "shipped."
|
|
|
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
|
assert len(completes) == 1
|
|
assert completes[0]["thread_id"] == "t1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_streamed(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup(
|
|
graph=_graph(chunks=[
|
|
_tool_chunk("get_order_status", {"order_id": "1042"}),
|
|
_chunk("Order shipped."),
|
|
])
|
|
)
|
|
await _send(ws, ws_ctx, content="Check order 1042")
|
|
|
|
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
|
assert len(tools) == 1
|
|
assert tools[0]["tool"] == "get_order_status"
|
|
assert tools[0]["args"] == {"order_id": "1042"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_messages_same_session(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
|
for i in range(3):
|
|
await _send(ws, ws_ctx, content=f"msg {i}")
|
|
|
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
|
assert len(completes) == 3
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestWebSocketInterruptApproval:
|
|
@pytest.mark.asyncio
|
|
async def test_interrupt_then_approve(self) -> None:
|
|
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
|
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
|
|
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
|
|
|
# Send message -> triggers interrupt
|
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
|
|
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
|
assert len(interrupts) == 1
|
|
assert interrupts[0]["action"] == "cancel_order"
|
|
assert interrupts[0]["thread_id"] == "t1"
|
|
assert im.has_pending("t1")
|
|
|
|
# Approve
|
|
ws.sent.clear()
|
|
await _respond(ws, ws_ctx, approved=True)
|
|
|
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
|
assert len(tokens) == 1
|
|
assert "cancelled" in tokens[0]["content"]
|
|
|
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
|
assert len(completes) == 1
|
|
assert not im.has_pending("t1")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_interrupt_then_reject(self) -> None:
|
|
st_int = _state(interrupt=True)
|
|
resume = [_chunk("Order remains active.", "order_actions")]
|
|
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
|
|
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
|
ws.sent.clear()
|
|
|
|
await _respond(ws, ws_ctx, approved=False)
|
|
|
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
|
assert "remains active" in tokens[0]["content"]
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestWebSocketSessionTTL:
|
|
@pytest.mark.asyncio
|
|
async def test_expired_session_returns_error(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
|
|
# Session was touched in _setup, but TTL is 0 so it's already expired
|
|
await _send(ws, ws_ctx, content="hello")
|
|
assert ws.sent[0]["type"] == "error"
|
|
assert "expired" in ws.sent[0]["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_new_session_not_expired(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
|
await _send(ws, ws_ctx, content="hello")
|
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
|
assert len(completes) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sliding_window_resets_on_message(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
|
|
|
await _send(ws, ws_ctx, content="hello")
|
|
first_activity = sm.get_state("t1").last_activity
|
|
|
|
time.sleep(0.01)
|
|
await _send(ws, ws_ctx, content="hello again")
|
|
second_activity = sm.get_state("t1").last_activity
|
|
|
|
assert second_activity > first_activity
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_interrupt_extends_session_ttl(self) -> None:
|
|
st_int = _state(interrupt=True)
|
|
g = _graph(chunks=[], st=st_int)
|
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
|
|
|
|
await _send(ws, ws_ctx, content="cancel order")
|
|
|
|
state = sm.get_state("t1")
|
|
assert state is not None
|
|
assert state.has_pending_interrupt
|
|
assert not sm.is_expired("t1")
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestWebSocketValidation:
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_json(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
|
await dispatch_message(ws, ws_ctx, "not json")
|
|
assert ws.sent[0]["type"] == "error"
|
|
assert "Invalid JSON" in ws.sent[0]["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_thread_id(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
|
raw = json.dumps({"type": "message", "content": "hi"})
|
|
await dispatch_message(ws, ws_ctx, raw)
|
|
assert ws.sent[0]["type"] == "error"
|
|
assert "thread_id" in ws.sent[0]["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_thread_id_format(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
|
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
|
|
await dispatch_message(ws, ws_ctx, raw)
|
|
assert ws.sent[0]["type"] == "error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_content(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
|
raw = json.dumps({"type": "message", "thread_id": "t1"})
|
|
await dispatch_message(ws, ws_ctx, raw)
|
|
assert ws.sent[0]["type"] == "error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unknown_message_type(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
|
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
|
|
await dispatch_message(ws, ws_ctx, raw)
|
|
assert ws.sent[0]["type"] == "error"
|
|
assert "Unknown" in ws.sent[0]["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_too_large(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
|
await dispatch_message(ws, ws_ctx, "x" * 40_000)
|
|
assert ws.sent[0]["type"] == "error"
|
|
assert "too large" in ws.sent[0]["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_content_too_long(self) -> None:
|
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
|
await dispatch_message(ws, ws_ctx, raw)
|
|
assert ws.sent[0]["type"] == "error"
|
|
assert "too long" in ws.sent[0]["message"].lower()
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestWebSocketInterruptTTL:
|
|
@pytest.mark.asyncio
|
|
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
|
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
|
g = _graph(chunks=[], st=st_int)
|
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
|
|
|
|
# Trigger interrupt
|
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
|
|
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
|
assert len(interrupts) == 1
|
|
|
|
# Expire the interrupt
|
|
record = im._interrupts["t1"]
|
|
ws.sent.clear()
|
|
|
|
with patch("app.interrupt_manager.time") as mock_time:
|
|
mock_time.time.return_value = record.created_at + 10
|
|
await _respond(ws, ws_ctx, approved=True)
|
|
|
|
assert ws.sent[0]["type"] == "interrupt_expired"
|
|
assert "cancel_order" in ws.sent[0]["message"]
|
|
assert ws.sent[0]["thread_id"] == "t1"
|