refactor: fix architectural issues across frontend and backend
Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
@@ -20,10 +20,12 @@ import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
|
||||
from app.graph_context import GraphContext
|
||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.registry import AgentConfig, AgentRegistry
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||
@@ -128,10 +130,8 @@ class TestCheckpoint1OrderQueryRouting:
|
||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||
))
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
|
||||
# Graph streams order_lookup response
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([
|
||||
@@ -140,14 +140,21 @@ class TestCheckpoint1OrderQueryRouting:
|
||||
]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
|
||||
@@ -201,25 +208,30 @@ class TestCheckpoint2MultiIntentSequential:
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
))
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||
})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
# Verify the graph was called with the routing hint in the message
|
||||
call_args = graph.astream.call_args
|
||||
@@ -267,21 +279,26 @@ class TestCheckpoint3AmbiguousClarification:
|
||||
"Could you please provide more details about what you need help with?"
|
||||
),
|
||||
))
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
|
||||
assert len(clarifications) == 1
|
||||
@@ -303,20 +320,26 @@ class TestCheckpoint4InterruptTTLAutoCancel:
|
||||
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
|
||||
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||
graph = MagicMock()
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=st)
|
||||
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager(ttl_seconds=1800) # 30 minutes
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
# Trigger interrupt
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
@@ -333,7 +356,7 @@ class TestCheckpoint4InterruptTTLAutoCancel:
|
||||
"thread_id": "t1",
|
||||
"approved": True,
|
||||
})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
# Should get retry prompt, NOT resume the graph
|
||||
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
|
||||
|
||||
Reference in New Issue
Block a user