refactor: fix architectural issues across frontend and backend

Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
Yaojia Wang
2026-04-06 15:59:14 +02:00
parent b8654aa31f
commit af53111928
29 changed files with 1183 additions and 473 deletions

View File

@@ -8,7 +8,9 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
pytestmark = pytest.mark.unit
@@ -20,7 +22,7 @@ def _make_ws() -> AsyncMock:
return ws
def _make_graph() -> AsyncMock:
def _make_graph() -> MagicMock:
graph = AsyncMock()
class AsyncIterHelper:
@@ -34,23 +36,32 @@ def _make_graph() -> AsyncMock:
state = MagicMock()
state.tasks = ()
graph.aget_state = AsyncMock(return_value=state)
graph.intent_classifier = None
graph.agent_registry = None
return graph
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
graph = _make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
return WebSocketContext(
graph_ctx=graph_ctx,
session_manager=sm or SessionManager(),
callback_handler=TokenUsageCallbackHandler(),
)
@pytest.mark.unit
class TestEmptyMessageHandling:
@pytest.mark.asyncio
async def test_empty_message_content_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -60,13 +71,12 @@ class TestEmptyMessageHandling:
@pytest.mark.asyncio
async def test_whitespace_only_message_treated_as_empty(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -77,14 +87,13 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio
async def test_content_over_10000_chars_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
content = "x" * 10001
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -93,14 +102,13 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
content = "x" * 10000
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0]
# Should be processed, not an error about length
@@ -110,12 +118,10 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio
async def test_raw_message_over_32kb_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
large_msg = "x" * 40_000
await dispatch_message(ws, graph, sm, cb, large_msg)
await dispatch_message(ws, ws_ctx, large_msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -127,11 +133,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio
async def test_invalid_json_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, graph, sm, cb, "not valid json {{")
await dispatch_message(ws, ws_ctx, "not valid json {{")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -140,11 +144,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio
async def test_empty_string_returns_json_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, graph, sm, cb, "")
await dispatch_message(ws, ws_ctx, "")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -152,11 +154,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio
async def test_json_array_not_object_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]')
await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -167,17 +167,15 @@ class TestRateLimiting:
@pytest.mark.asyncio
async def test_rapid_fire_messages_rate_limited(self) -> None:
ws = _make_ws()
_make_graph() # ensure graph factory works, not needed directly
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
rate_limit_triggered = False
for i in range(11):
graph2 = _make_graph() # fresh graph each time
await dispatch_message(ws, graph2, sm, cb, json.dumps({
ws_ctx = _make_ws_ctx(sm=sm)
await dispatch_message(ws, ws_ctx, json.dumps({
"type": "message",
"thread_id": "t1",
"content": f"message {i}",
@@ -193,19 +191,18 @@ class TestRateLimiting:
async def test_different_threads_have_separate_rate_limits(self) -> None:
ws = _make_ws()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
sm.touch("t2")
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
for i in range(5):
graph1 = _make_graph()
graph2 = _make_graph()
await dispatch_message(ws, graph1, sm, cb, json.dumps({
ws_ctx1 = _make_ws_ctx(sm=sm)
ws_ctx2 = _make_ws_ctx(sm=sm)
await dispatch_message(ws, ws_ctx1, json.dumps({
"type": "message", "thread_id": "t1", "content": f"msg {i}",
}))
await dispatch_message(ws, graph2, sm, cb, json.dumps({
await dispatch_message(ws, ws_ctx2, json.dumps({
"type": "message", "thread_id": "t2", "content": f"msg {i}",
}))