Files
smart-support/backend/tests/integration/test_routing.py
Yaojia Wang af53111928 refactor: fix architectural issues across frontend and backend
Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
2026-04-06 15:59:14 +02:00

372 lines
13 KiB
Python

"""Integration tests for multi-agent routing flow.
Tests the full pipeline: intent classification -> supervisor routing ->
agent execution -> response streaming, exercising cross-module integration
with mocked LLM.
Required by Phase 2 test plan:
- Unit: intent classification accuracy
- Unit: multi-intent sequential execution
- Integration: complete multi-agent routing flow
"""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class AsyncIterHelper:
def __init__(self, items: list) -> None:
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
class FakeWS:
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_json(self, data: dict) -> None:
self.sent.append(data)
def _chunk(content: str, node: str) -> tuple:
c = MagicMock()
c.content = content
c.tool_calls = []
return (c, {"langgraph_node": node})
def _tool_chunk(name: str, args: dict, node: str) -> tuple:
c = MagicMock()
c.content = ""
c.tool_calls = [{"name": name, "args": args}]
return (c, {"langgraph_node": node})
def _state(*, interrupt: bool = False, data: dict | None = None):
s = MagicMock()
if interrupt:
obj = MagicMock()
obj.value = data or {}
t = MagicMock()
t.interrupts = (obj,)
s.tasks = (t,)
else:
s.tasks = ()
return s
AGENTS = (
AgentConfig(
name="order_lookup", description="Looks up orders",
permission="read", tools=["get_order_status", "get_tracking_info"],
),
AgentConfig(
name="order_actions", description="Modifies orders",
permission="write", tools=["cancel_order"],
),
AgentConfig(
name="discount", description="Applies discounts",
permission="write", tools=["apply_discount", "generate_coupon"],
),
AgentConfig(
name="fallback", description="Handles unclear requests",
permission="read", tools=["fallback_respond"],
),
)
def _make_classifier(result: ClassificationResult) -> AsyncMock:
"""Create a mock classifier returning the given result."""
classifier = AsyncMock()
classifier.classify = AsyncMock(return_value=result)
return classifier
def _make_graph_and_ctx(
classifier_result: ClassificationResult | None,
chunks: list,
state=None,
) -> tuple[MagicMock, GraphContext]:
"""Build a graph mock and GraphContext with optional intent classifier."""
graph = MagicMock()
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
graph.aget_state = AsyncMock(return_value=state or _state())
if classifier_result is not None:
classifier = _make_classifier(classifier_result)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=AGENTS)
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=classifier,
)
else:
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=None,
)
return graph, graph_ctx
async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
sm = SessionManager()
sm.touch(thread_id)
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
await dispatch_message(ws, ws_ctx, raw)
return ws.sent
# ---------------------------------------------------------------------------
# Single-intent routing to each agent
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestSingleIntentRouting:
"""Verify single-intent messages route to the correct agent."""
@pytest.mark.asyncio
async def test_routes_to_order_lookup(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(
agent_name="order_lookup", confidence=0.95, reasoning="status query",
),),
)
graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
_chunk("Order 1042 is shipped.", "order_lookup"),
])
msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
tools = [m for m in msgs if m["type"] == "tool_call"]
assert len(tools) == 1
assert tools[0]["tool"] == "get_order_status"
assert tools[0]["agent"] == "order_lookup"
tokens = [m for m in msgs if m["type"] == "token"]
assert any("shipped" in t["content"] for t in tokens)
@pytest.mark.asyncio
async def test_routes_to_order_actions(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
)
graph, graph_ctx = _make_graph_and_ctx(
result,
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
)
msgs = await _dispatch(graph_ctx, "Cancel order 1042")
tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "cancel_order"
assert tools[0]["agent"] == "order_actions"
interrupts = [m for m in msgs if m["type"] == "interrupt"]
assert len(interrupts) == 1
@pytest.mark.asyncio
async def test_routes_to_discount(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
)
graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
])
msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "generate_coupon"
assert tools[0]["agent"] == "discount"
@pytest.mark.asyncio
async def test_routes_to_fallback(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
)
graph, graph_ctx = _make_graph_and_ctx(result, [
_chunk("I can help with order inquiries.", "fallback"),
])
msgs = await _dispatch(graph_ctx, "What can you do?")
tokens = [m for m in msgs if m["type"] == "token"]
assert tokens[0]["agent"] == "fallback"
# ---------------------------------------------------------------------------
# Multi-intent routing
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestMultiIntentRouting:
"""Verify multi-intent triggers sequential execution hint."""
@pytest.mark.asyncio
async def test_two_intents_inject_routing_hint(self) -> None:
result = ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
)
graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
])
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({
"type": "message",
"thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣",
})
await dispatch_message(ws, ws_ctx, raw)
# Verify routing hint was injected
call_args = graph.astream.call_args[0][0]
msg_content = call_args["messages"][0].content
assert "[System:" in msg_content
assert "order_actions" in msg_content
assert "discount" in msg_content
# Both tool calls should appear
tools = [m for m in ws.sent if m["type"] == "tool_call"]
tool_names = {t["tool"] for t in tools}
assert "cancel_order" in tool_names
assert "apply_discount" in tool_names
@pytest.mark.asyncio
async def test_single_intent_no_routing_hint(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
)
graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, ws_ctx, raw)
msg_content = graph.astream.call_args[0][0]["messages"][0].content
assert "[System:" not in msg_content
# ---------------------------------------------------------------------------
# Ambiguity routing
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestAmbiguityRouting:
"""Verify ambiguous intents produce clarification, not agent calls."""
@pytest.mark.asyncio
async def test_ambiguous_skips_graph_returns_clarification(self) -> None:
result = ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="Could you please clarify what you need?",
)
graph, graph_ctx = _make_graph_and_ctx(result, [])
msgs = await _dispatch(graph_ctx, "嗯...")
clarifications = [m for m in msgs if m["type"] == "clarification"]
assert len(clarifications) == 1
assert "clarify" in clarifications[0]["message"]
# Graph should NOT have been called
graph.astream.assert_not_called()
@pytest.mark.asyncio
async def test_low_confidence_triggers_ambiguity(self) -> None:
"""LLMIntentClassifier applies threshold -- low confidence -> ambiguous."""
raw_result = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.2, reasoning="unclear"),),
is_ambiguous=False,
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=raw_result)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
result = await classifier.classify("hmm", AGENTS)
assert result.is_ambiguous
assert result.clarification_question is not None
# ---------------------------------------------------------------------------
# No classifier fallback
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestNoClassifierFallback:
"""Verify system works without intent classifier (falls back to supervisor prompt)."""
@pytest.mark.asyncio
async def test_no_classifier_routes_via_supervisor(self) -> None:
graph, graph_ctx = _make_graph_and_ctx(
classifier_result=None,
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
)
msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
tokens = [m for m in msgs if m["type"] == "token"]
assert len(tokens) == 1
completes = [m for m in msgs if m["type"] == "message_complete"]
assert len(completes) == 1