refactor: fix architectural issues across frontend and backend
Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
@@ -18,10 +18,12 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.registry import AgentConfig
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -103,36 +105,45 @@ def _make_classifier(result: ClassificationResult) -> AsyncMock:
|
||||
return classifier
|
||||
|
||||
|
||||
def _make_graph(
|
||||
def _make_graph_and_ctx(
|
||||
classifier_result: ClassificationResult | None,
|
||||
chunks: list,
|
||||
state=None,
|
||||
) -> MagicMock:
|
||||
"""Build a graph mock with optional intent classifier."""
|
||||
) -> tuple[MagicMock, GraphContext]:
|
||||
"""Build a graph mock and GraphContext with optional intent classifier."""
|
||||
graph = MagicMock()
|
||||
|
||||
if classifier_result is not None:
|
||||
graph.intent_classifier = _make_classifier(classifier_result)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
||||
graph.agent_registry = mock_registry
|
||||
else:
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
|
||||
graph.aget_state = AsyncMock(return_value=state or _state())
|
||||
return graph
|
||||
|
||||
if classifier_result is not None:
|
||||
classifier = _make_classifier(classifier_result)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=classifier,
|
||||
)
|
||||
else:
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=None,
|
||||
)
|
||||
|
||||
return graph, graph_ctx
|
||||
|
||||
|
||||
async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]:
|
||||
async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
|
||||
sm = SessionManager()
|
||||
sm.touch(thread_id)
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
return ws.sent
|
||||
|
||||
|
||||
@@ -151,12 +162,12 @@ class TestSingleIntentRouting:
|
||||
agent_name="order_lookup", confidence=0.95, reasoning="status query",
|
||||
),),
|
||||
)
|
||||
graph = _make_graph(result, [
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
||||
_chunk("Order 1042 is shipped.", "order_lookup"),
|
||||
])
|
||||
|
||||
msgs = await _dispatch(graph, "What is the status of order 1042?")
|
||||
msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
|
||||
|
||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||
assert len(tools) == 1
|
||||
@@ -171,13 +182,13 @@ class TestSingleIntentRouting:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
|
||||
)
|
||||
graph = _make_graph(
|
||||
graph, graph_ctx = _make_graph_and_ctx(
|
||||
result,
|
||||
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
|
||||
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
|
||||
)
|
||||
|
||||
msgs = await _dispatch(graph, "Cancel order 1042")
|
||||
msgs = await _dispatch(graph_ctx, "Cancel order 1042")
|
||||
|
||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||
assert tools[0]["tool"] == "cancel_order"
|
||||
@@ -191,12 +202,12 @@ class TestSingleIntentRouting:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
|
||||
)
|
||||
graph = _make_graph(result, [
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
|
||||
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
|
||||
])
|
||||
|
||||
msgs = await _dispatch(graph, "Give me a 15% coupon")
|
||||
msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
|
||||
|
||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||
assert tools[0]["tool"] == "generate_coupon"
|
||||
@@ -207,11 +218,11 @@ class TestSingleIntentRouting:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
|
||||
)
|
||||
graph = _make_graph(result, [
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_chunk("I can help with order inquiries.", "fallback"),
|
||||
])
|
||||
|
||||
msgs = await _dispatch(graph, "What can you do?")
|
||||
msgs = await _dispatch(graph_ctx, "What can you do?")
|
||||
|
||||
tokens = [m for m in msgs if m["type"] == "token"]
|
||||
assert tokens[0]["agent"] == "fallback"
|
||||
@@ -233,7 +244,7 @@ class TestMultiIntentRouting:
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
)
|
||||
graph = _make_graph(result, [
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
|
||||
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
|
||||
])
|
||||
@@ -243,13 +254,17 @@ class TestMultiIntentRouting:
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||
})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
# Verify routing hint was injected
|
||||
call_args = graph.astream.call_args[0][0]
|
||||
@@ -269,16 +284,20 @@ class TestMultiIntentRouting:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||
)
|
||||
graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")])
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
msg_content = graph.astream.call_args[0][0]["messages"][0].content
|
||||
assert "[System:" not in msg_content
|
||||
@@ -299,9 +318,9 @@ class TestAmbiguityRouting:
|
||||
is_ambiguous=True,
|
||||
clarification_question="Could you please clarify what you need?",
|
||||
)
|
||||
graph = _make_graph(result, [])
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [])
|
||||
|
||||
msgs = await _dispatch(graph, "嗯...")
|
||||
msgs = await _dispatch(graph_ctx, "嗯...")
|
||||
|
||||
clarifications = [m for m in msgs if m["type"] == "clarification"]
|
||||
assert len(clarifications) == 1
|
||||
@@ -339,12 +358,12 @@ class TestNoClassifierFallback:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_classifier_routes_via_supervisor(self) -> None:
|
||||
graph = _make_graph(
|
||||
graph, graph_ctx = _make_graph_and_ctx(
|
||||
classifier_result=None,
|
||||
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
|
||||
)
|
||||
|
||||
msgs = await _dispatch(graph, "What is order 1042 status?")
|
||||
msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
|
||||
|
||||
tokens = [m for m in msgs if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
|
||||
Reference in New Issue
Block a user