refactor: fix architectural issues across frontend and backend
Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
@@ -8,8 +8,10 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import (
|
||||
_extract_interrupt,
|
||||
_has_interrupt,
|
||||
@@ -25,18 +27,42 @@ def _make_ws() -> AsyncMock:
|
||||
return ws
|
||||
|
||||
|
||||
def _make_graph() -> AsyncMock:
|
||||
def _make_graph() -> MagicMock:
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
# Phase 2: graph needs intent_classifier and agent_registry attrs
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
return graph
|
||||
|
||||
|
||||
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||
g = graph or _make_graph()
|
||||
registry = MagicMock()
|
||||
registry.list_agents = MagicMock(return_value=())
|
||||
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||
|
||||
|
||||
def _make_ws_ctx(
|
||||
graph_ctx: GraphContext | None = None,
|
||||
sm: SessionManager | None = None,
|
||||
cb: TokenUsageCallbackHandler | None = None,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
analytics_recorder=None,
|
||||
conversation_tracker=None,
|
||||
pool=None,
|
||||
) -> WebSocketContext:
|
||||
return WebSocketContext(
|
||||
graph_ctx=graph_ctx or _make_graph_ctx(),
|
||||
session_manager=sm or SessionManager(),
|
||||
callback_handler=cb or TokenUsageCallbackHandler(),
|
||||
interrupt_manager=interrupt_manager,
|
||||
analytics_recorder=analytics_recorder,
|
||||
conversation_tracker=conversation_tracker,
|
||||
pool=pool,
|
||||
)
|
||||
|
||||
|
||||
class AsyncIterHelper:
|
||||
"""Helper to make a list behave as an async iterator."""
|
||||
|
||||
@@ -57,11 +83,9 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, "not json")
|
||||
await dispatch_message(ws, ws_ctx, "not json")
|
||||
ws.send_json.assert_awaited_once()
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -70,12 +94,10 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_thread_id(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "message", "content": "hello"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "thread_id" in call_data["message"]
|
||||
@@ -83,24 +105,20 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_content(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_message_type(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "Unknown" in call_data["message"]
|
||||
@@ -108,12 +126,10 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_too_large(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
large_msg = "x" * 40_000
|
||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
||||
await dispatch_message(ws, ws_ctx, large_msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too large" in call_data["message"].lower()
|
||||
@@ -121,12 +137,10 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_thread_id_format(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "thread_id" in call_data["message"].lower()
|
||||
@@ -134,12 +148,10 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_too_long(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too long" in call_data["message"].lower()
|
||||
@@ -147,14 +159,13 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_with_interrupt_manager(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager()
|
||||
ws_ctx = _make_ws_ctx(sm=sm, interrupt_manager=im)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@@ -164,14 +175,14 @@ class TestHandleUserMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager(session_ttl_seconds=0)
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
# First call creates the session (TTL=0)
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||
# Second call finds it expired
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello again")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello again")
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "expired" in call_data["message"].lower()
|
||||
@@ -179,12 +190,12 @@ class TestHandleUserMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@@ -193,13 +204,12 @@ class TestHandleUserMessage:
|
||||
ws = _make_ws()
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
graph_ctx = _make_graph_ctx(graph=graph)
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
@@ -207,8 +217,6 @@ class TestHandleUserMessage:
|
||||
async def test_interrupt_registered_with_manager(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = AsyncMock()
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
|
||||
# Simulate interrupt in state
|
||||
@@ -220,13 +228,14 @@ class TestHandleUserMessage:
|
||||
state.tasks = (task,)
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
|
||||
graph_ctx = _make_graph_ctx(graph=graph)
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(
|
||||
ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
|
||||
ws, graph_ctx, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
|
||||
)
|
||||
|
||||
# Interrupt should be registered
|
||||
@@ -257,16 +266,17 @@ class TestHandleUserMessage:
|
||||
clarification_question="What do you mean?",
|
||||
)
|
||||
)
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hmm")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hmm")
|
||||
|
||||
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
||||
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
|
||||
@@ -279,13 +289,13 @@ class TestHandleInterruptResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_approved_interrupt(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
sm.extend_for_interrupt("t1")
|
||||
await handle_interrupt_response(ws, graph, sm, cb, "t1", True)
|
||||
await handle_interrupt_response(ws, graph_ctx, sm, cb, "t1", True)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@@ -294,7 +304,7 @@ class TestHandleInterruptResponse:
|
||||
from unittest.mock import patch
|
||||
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager(ttl_seconds=5)
|
||||
@@ -307,7 +317,7 @@ class TestHandleInterruptResponse:
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
|
||||
await handle_interrupt_response(
|
||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
||||
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
|
||||
)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
@@ -317,7 +327,7 @@ class TestHandleInterruptResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_interrupt_resolves(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager(ttl_seconds=1800)
|
||||
@@ -327,7 +337,7 @@ class TestHandleInterruptResponse:
|
||||
im.register("t1", "cancel_order", {})
|
||||
|
||||
await handle_interrupt_response(
|
||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
||||
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
|
||||
)
|
||||
|
||||
# Interrupt should be resolved
|
||||
@@ -374,19 +384,14 @@ class TestDispatchMessageWithTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_tracker_called_on_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
tracker = AsyncMock()
|
||||
pool = MagicMock()
|
||||
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
await dispatch_message(
|
||||
ws, graph, sm, cb, msg,
|
||||
conversation_tracker=tracker,
|
||||
pool=pool,
|
||||
)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
|
||||
tracker.record_turn.assert_awaited_once()
|
||||
@@ -394,53 +399,42 @@ class TestDispatchMessageWithTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_analytics_recorder_called_on_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
recorder = AsyncMock()
|
||||
pool = MagicMock()
|
||||
ws_ctx = _make_ws_ctx(sm=sm, analytics_recorder=recorder, pool=pool)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
await dispatch_message(
|
||||
ws, graph, sm, cb, msg,
|
||||
analytics_recorder=recorder,
|
||||
pool=pool,
|
||||
)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
recorder.record.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracker_failure_does_not_break_chat(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
tracker = AsyncMock()
|
||||
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
|
||||
pool = MagicMock()
|
||||
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
# Should not raise despite tracker failure
|
||||
await dispatch_message(
|
||||
ws, graph, sm, cb, msg,
|
||||
conversation_tracker=tracker,
|
||||
pool=pool,
|
||||
)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tracker_no_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
# No tracker or recorder passed -- should work fine
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
Reference in New Issue
Block a user