refactor: fix architectural issues across frontend and backend

Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
Yaojia Wang
2026-04-06 15:59:14 +02:00
parent b8654aa31f
commit af53111928
29 changed files with 1183 additions and 473 deletions

View File

@@ -20,10 +20,12 @@ import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig, AgentRegistry
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
@@ -128,10 +130,8 @@ class TestCheckpoint1OrderQueryRouting:
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
))
graph.intent_classifier = mock_classifier
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
# Graph streams order_lookup response
graph.astream = MagicMock(return_value=AsyncIterHelper([
@@ -140,14 +140,21 @@ class TestCheckpoint1OrderQueryRouting:
]))
graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
@@ -201,25 +208,30 @@ class TestCheckpoint2MultiIntentSequential:
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
))
graph.intent_classifier = mock_classifier
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({
"type": "message",
"thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣",
})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
# Verify the graph was called with the routing hint in the message
call_args = graph.astream.call_args
@@ -267,21 +279,26 @@ class TestCheckpoint3AmbiguousClarification:
"Could you please provide more details about what you need help with?"
),
))
graph.intent_classifier = mock_classifier
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
assert len(clarifications) == 1
@@ -303,20 +320,26 @@ class TestCheckpoint4InterruptTTLAutoCancel:
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
graph = MagicMock()
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=st)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager(ttl_seconds=1800) # 30 minutes
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
# Trigger interrupt
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1
@@ -333,7 +356,7 @@ class TestCheckpoint4InterruptTTLAutoCancel:
"thread_id": "t1",
"approved": True,
})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
# Should get retry prompt, NOT resume the graph
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]

View File

@@ -18,10 +18,12 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
# ---------------------------------------------------------------------------
@@ -103,36 +105,45 @@ def _make_classifier(result: ClassificationResult) -> AsyncMock:
return classifier
def _make_graph(
def _make_graph_and_ctx(
classifier_result: ClassificationResult | None,
chunks: list,
state=None,
) -> MagicMock:
"""Build a graph mock with optional intent classifier."""
) -> tuple[MagicMock, GraphContext]:
"""Build a graph mock and GraphContext with optional intent classifier."""
graph = MagicMock()
if classifier_result is not None:
graph.intent_classifier = _make_classifier(classifier_result)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=AGENTS)
graph.agent_registry = mock_registry
else:
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
graph.aget_state = AsyncMock(return_value=state or _state())
return graph
if classifier_result is not None:
classifier = _make_classifier(classifier_result)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=AGENTS)
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=classifier,
)
else:
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=None,
)
return graph, graph_ctx
async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]:
async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
sm = SessionManager()
sm.touch(thread_id)
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
return ws.sent
@@ -151,12 +162,12 @@ class TestSingleIntentRouting:
agent_name="order_lookup", confidence=0.95, reasoning="status query",
),),
)
graph = _make_graph(result, [
graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
_chunk("Order 1042 is shipped.", "order_lookup"),
])
msgs = await _dispatch(graph, "What is the status of order 1042?")
msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
tools = [m for m in msgs if m["type"] == "tool_call"]
assert len(tools) == 1
@@ -171,13 +182,13 @@ class TestSingleIntentRouting:
result = ClassificationResult(
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
)
graph = _make_graph(
graph, graph_ctx = _make_graph_and_ctx(
result,
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
)
msgs = await _dispatch(graph, "Cancel order 1042")
msgs = await _dispatch(graph_ctx, "Cancel order 1042")
tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "cancel_order"
@@ -191,12 +202,12 @@ class TestSingleIntentRouting:
result = ClassificationResult(
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
)
graph = _make_graph(result, [
graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
])
msgs = await _dispatch(graph, "Give me a 15% coupon")
msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "generate_coupon"
@@ -207,11 +218,11 @@ class TestSingleIntentRouting:
result = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
)
graph = _make_graph(result, [
graph, graph_ctx = _make_graph_and_ctx(result, [
_chunk("I can help with order inquiries.", "fallback"),
])
msgs = await _dispatch(graph, "What can you do?")
msgs = await _dispatch(graph_ctx, "What can you do?")
tokens = [m for m in msgs if m["type"] == "token"]
assert tokens[0]["agent"] == "fallback"
@@ -233,7 +244,7 @@ class TestMultiIntentRouting:
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
)
graph = _make_graph(result, [
graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
])
@@ -243,13 +254,17 @@ class TestMultiIntentRouting:
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({
"type": "message",
"thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣",
})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
# Verify routing hint was injected
call_args = graph.astream.call_args[0][0]
@@ -269,16 +284,20 @@ class TestMultiIntentRouting:
result = ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
)
graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")])
graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
msg_content = graph.astream.call_args[0][0]["messages"][0].content
assert "[System:" not in msg_content
@@ -299,9 +318,9 @@ class TestAmbiguityRouting:
is_ambiguous=True,
clarification_question="Could you please clarify what you need?",
)
graph = _make_graph(result, [])
graph, graph_ctx = _make_graph_and_ctx(result, [])
msgs = await _dispatch(graph, "嗯...")
msgs = await _dispatch(graph_ctx, "嗯...")
clarifications = [m for m in msgs if m["type"] == "clarification"]
assert len(clarifications) == 1
@@ -339,12 +358,12 @@ class TestNoClassifierFallback:
@pytest.mark.asyncio
async def test_no_classifier_routes_via_supervisor(self) -> None:
graph = _make_graph(
graph, graph_ctx = _make_graph_and_ctx(
classifier_result=None,
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
)
msgs = await _dispatch(graph, "What is order 1042 status?")
msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
tokens = [m for m in msgs if m["type"] == "token"]
assert len(tokens) == 1

View File

@@ -15,8 +15,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
# ---------------------------------------------------------------------------
@@ -81,8 +83,6 @@ def _graph(
resume_chunks: list | None = None,
) -> MagicMock:
g = MagicMock()
g.intent_classifier = None
g.agent_registry = None
if st is None:
st = _state()
@@ -100,6 +100,13 @@ def _graph(
return g
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
g = graph or _graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
def _setup(
graph=None,
session_ttl: int = 1800,
@@ -109,23 +116,28 @@ def _setup(
):
"""Create test dependencies. Pre-touches session by default."""
g = graph or _graph()
graph_ctx = _make_graph_ctx(g)
sm = SessionManager(session_ttl_seconds=session_ttl)
im = InterruptManager(ttl_seconds=interrupt_ttl)
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
if touch:
sm.touch(thread_id)
return g, sm, im, cb, ws
return g, sm, im, cb, ws, ws_ctx
async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"):
async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
# ---------------------------------------------------------------------------
@@ -136,10 +148,10 @@ async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
class TestWebSocketHappyPath:
@pytest.mark.asyncio
async def test_send_message_receives_tokens_and_complete(self) -> None:
g, sm, im, cb, ws = _setup(
g, sm, im, cb, ws, ws_ctx = _setup(
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
)
await _send(ws, g, sm, im, cb, content="What is the status of order 1042?")
await _send(ws, ws_ctx, content="What is the status of order 1042?")
tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 2
@@ -153,13 +165,13 @@ class TestWebSocketHappyPath:
@pytest.mark.asyncio
async def test_tool_call_streamed(self) -> None:
g, sm, im, cb, ws = _setup(
g, sm, im, cb, ws, ws_ctx = _setup(
graph=_graph(chunks=[
_tool_chunk("get_order_status", {"order_id": "1042"}),
_chunk("Order shipped."),
])
)
await _send(ws, g, sm, im, cb, content="Check order 1042")
await _send(ws, ws_ctx, content="Check order 1042")
tools = [m for m in ws.sent if m["type"] == "tool_call"]
assert len(tools) == 1
@@ -168,9 +180,9 @@ class TestWebSocketHappyPath:
@pytest.mark.asyncio
async def test_multiple_messages_same_session(self) -> None:
g, sm, im, cb, ws = _setup()
g, sm, im, cb, ws, ws_ctx = _setup()
for i in range(3):
await _send(ws, g, sm, im, cb, content=f"msg {i}")
await _send(ws, ws_ctx, content=f"msg {i}")
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 3
@@ -183,10 +195,10 @@ class TestWebSocketInterruptApproval:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws = _setup(graph=g)
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
# Send message -> triggers interrupt
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
await _send(ws, ws_ctx, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1
@@ -196,7 +208,7 @@ class TestWebSocketInterruptApproval:
# Approve
ws.sent.clear()
await _respond(ws, g_, sm, im, cb, approved=True)
await _respond(ws, ws_ctx, approved=True)
tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 1
@@ -211,12 +223,12 @@ class TestWebSocketInterruptApproval:
st_int = _state(interrupt=True)
resume = [_chunk("Order remains active.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws = _setup(graph=g)
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
await _send(ws, ws_ctx, content="Cancel order 1042")
ws.sent.clear()
await _respond(ws, g_, sm, im, cb, approved=False)
await _respond(ws, ws_ctx, approved=False)
tokens = [m for m in ws.sent if m["type"] == "token"]
assert "remains active" in tokens[0]["content"]
@@ -226,28 +238,28 @@ class TestWebSocketInterruptApproval:
class TestWebSocketSessionTTL:
@pytest.mark.asyncio
async def test_expired_session_returns_error(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=0)
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
# Session was touched in _setup, but TTL is 0 so it's already expired
await _send(ws, g, sm, im, cb, content="hello")
await _send(ws, ws_ctx, content="hello")
assert ws.sent[0]["type"] == "error"
assert "expired" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio
async def test_new_session_not_expired(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=3600)
await _send(ws, g, sm, im, cb, content="hello")
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
await _send(ws, ws_ctx, content="hello")
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 1
@pytest.mark.asyncio
async def test_sliding_window_resets_on_message(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=3600)
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
await _send(ws, g, sm, im, cb, content="hello")
await _send(ws, ws_ctx, content="hello")
first_activity = sm.get_state("t1").last_activity
time.sleep(0.01)
await _send(ws, g, sm, im, cb, content="hello again")
await _send(ws, ws_ctx, content="hello again")
second_activity = sm.get_state("t1").last_activity
assert second_activity > first_activity
@@ -256,9 +268,9 @@ class TestWebSocketSessionTTL:
async def test_interrupt_extends_session_ttl(self) -> None:
st_int = _state(interrupt=True)
g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600)
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
await _send(ws, g_, sm, im, cb, content="cancel order")
await _send(ws, ws_ctx, content="cancel order")
state = sm.get_state("t1")
assert state is not None
@@ -270,53 +282,53 @@ class TestWebSocketSessionTTL:
class TestWebSocketValidation:
@pytest.mark.asyncio
async def test_invalid_json(self) -> None:
g, sm, im, cb, ws = _setup()
await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im)
g, sm, im, cb, ws, ws_ctx = _setup()
await dispatch_message(ws, ws_ctx, "not json")
assert ws.sent[0]["type"] == "error"
assert "Invalid JSON" in ws.sent[0]["message"]
@pytest.mark.asyncio
async def test_missing_thread_id(self) -> None:
g, sm, im, cb, ws = _setup()
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "content": "hi"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
assert "thread_id" in ws.sent[0]["message"]
@pytest.mark.asyncio
async def test_invalid_thread_id_format(self) -> None:
g, sm, im, cb, ws = _setup()
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio
async def test_missing_content(self) -> None:
g, sm, im, cb, ws = _setup()
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio
async def test_unknown_message_type(self) -> None:
g, sm, im, cb, ws = _setup()
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
assert "Unknown" in ws.sent[0]["message"]
@pytest.mark.asyncio
async def test_message_too_large(self) -> None:
g, sm, im, cb, ws = _setup()
await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im)
g, sm, im, cb, ws, ws_ctx = _setup()
await dispatch_message(ws, ws_ctx, "x" * 40_000)
assert ws.sent[0]["type"] == "error"
assert "too large" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio
async def test_content_too_long(self) -> None:
g, sm, im, cb, ws = _setup()
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
assert "too long" in ws.sent[0]["message"].lower()
@@ -327,10 +339,10 @@ class TestWebSocketInterruptTTL:
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5)
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
# Trigger interrupt
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
await _send(ws, ws_ctx, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1
@@ -341,7 +353,7 @@ class TestWebSocketInterruptTTL:
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 10
await _respond(ws, g_, sm, im, cb, approved=True)
await _respond(ws, ws_ctx, approved=True)
assert ws.sent[0]["type"] == "interrupt_expired"
assert "cancel_order" in ws.sent[0]["message"]