diff --git a/backend/tests/integration/test_routing.py b/backend/tests/integration/test_routing.py new file mode 100644 index 0000000..4d8f24a --- /dev/null +++ b/backend/tests/integration/test_routing.py @@ -0,0 +1,339 @@ +"""Integration tests for multi-agent routing flow. + +Tests the full pipeline: intent classification -> supervisor routing -> +agent execution -> response streaming, exercising cross-module integration +with mocked LLM. + +Required by Phase 2 test plan: +- Unit: intent classification accuracy +- Unit: multi-intent sequential execution +- Integration: complete multi-agent routing flow +""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.callbacks import TokenUsageCallbackHandler +from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier +from app.interrupt_manager import InterruptManager +from app.registry import AgentConfig +from app.session_manager import SessionManager +from app.ws_handler import dispatch_message + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class AsyncIterHelper: + def __init__(self, items: list) -> None: + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class FakeWS: + def __init__(self) -> None: + self.sent: list[dict] = [] + + async def send_json(self, data: dict) -> None: + self.sent.append(data) + + +def _chunk(content: str, node: str) -> tuple: + c = MagicMock() + c.content = content + c.tool_calls = [] + return (c, {"langgraph_node": node}) + + +def _tool_chunk(name: str, args: dict, node: str) -> tuple: + c = MagicMock() + c.content = "" + c.tool_calls = [{"name": name, "args": args}] + return (c, {"langgraph_node": node}) + + +def _state(*, interrupt: bool = False, data: dict | None = None): + s = MagicMock() + if interrupt: + obj = MagicMock() + obj.value = data or {} + t = MagicMock() + t.interrupts = (obj,) + s.tasks = (t,) + else: + s.tasks = () + return s + + +AGENTS = ( + AgentConfig(name="order_lookup", description="Looks up orders", permission="read", tools=["get_order_status", "get_tracking_info"]), + AgentConfig(name="order_actions", description="Modifies orders", permission="write", tools=["cancel_order"]), + AgentConfig(name="discount", description="Applies discounts", permission="write", tools=["apply_discount", "generate_coupon"]), + AgentConfig(name="fallback", description="Handles unclear requests", permission="read", tools=["fallback_respond"]), +) + + +def _make_classifier(result: ClassificationResult) -> AsyncMock: + """Create a mock classifier returning the given result.""" + classifier = AsyncMock() + classifier.classify = AsyncMock(return_value=result) + return classifier + + +def _make_graph( + classifier_result: ClassificationResult | None, + chunks: list, + state=None, +) -> MagicMock: + """Build a graph mock with optional intent classifier.""" + graph = MagicMock() + + if classifier_result is not None: + graph.intent_classifier = _make_classifier(classifier_result) + mock_registry = MagicMock() + mock_registry.list_agents = MagicMock(return_value=AGENTS) + graph.agent_registry = mock_registry + else: + graph.intent_classifier = None + graph.agent_registry = None + + graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks))) + graph.aget_state = AsyncMock(return_value=state or _state()) + return graph + + +async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]: + sm = SessionManager() + sm.touch(thread_id) + im = InterruptManager() + cb = TokenUsageCallbackHandler() + ws = FakeWS() + raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content}) + await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + return ws.sent + + +# --------------------------------------------------------------------------- +# Single-intent routing to each agent +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestSingleIntentRouting: + """Verify single-intent messages route to the correct agent.""" + + @pytest.mark.asyncio + async def test_routes_to_order_lookup(self) -> None: + result = ClassificationResult( + intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="status query"),), + ) + graph = _make_graph(result, [ + _tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"), + _chunk("Order 1042 is shipped.", "order_lookup"), + ]) + + msgs = await _dispatch(graph, "What is the status of order 1042?") + + tools = [m for m in msgs if m["type"] == "tool_call"] + assert len(tools) == 1 + assert tools[0]["tool"] == "get_order_status" + assert tools[0]["agent"] == "order_lookup" + + tokens = [m for m in msgs if m["type"] == "token"] + assert any("shipped" in t["content"] for t in tokens) + + @pytest.mark.asyncio + async def test_routes_to_order_actions(self) -> None: + result = ClassificationResult( + intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),), + ) + graph = _make_graph( + result, + [_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")], + state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}), + ) + + msgs = await _dispatch(graph, "Cancel order 1042") + + tools = [m for m in msgs if m["type"] == "tool_call"] + assert tools[0]["tool"] == "cancel_order" + assert tools[0]["agent"] == "order_actions" + + interrupts = [m for m in msgs if m["type"] == "interrupt"] + assert len(interrupts) == 1 + + @pytest.mark.asyncio + async def test_routes_to_discount(self) -> None: + result = ClassificationResult( + intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),), + ) + graph = _make_graph(result, [ + _tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"), + _chunk("Here is your coupon: SAVE15-ABC12345", "discount"), + ]) + + msgs = await _dispatch(graph, "Give me a 15% coupon") + + tools = [m for m in msgs if m["type"] == "tool_call"] + assert tools[0]["tool"] == "generate_coupon" + assert tools[0]["agent"] == "discount" + + @pytest.mark.asyncio + async def test_routes_to_fallback(self) -> None: + result = ClassificationResult( + intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),), + ) + graph = _make_graph(result, [ + _chunk("I can help with order inquiries.", "fallback"), + ]) + + msgs = await _dispatch(graph, "What can you do?") + + tokens = [m for m in msgs if m["type"] == "token"] + assert tokens[0]["agent"] == "fallback" + + +# --------------------------------------------------------------------------- +# Multi-intent routing +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestMultiIntentRouting: + """Verify multi-intent triggers sequential execution hint.""" + + @pytest.mark.asyncio + async def test_two_intents_inject_routing_hint(self) -> None: + result = ClassificationResult( + intents=( + IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"), + IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), + ), + ) + graph = _make_graph(result, [ + _tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"), + _tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"), + ]) + + sm = SessionManager() + sm.touch("t1") + im = InterruptManager() + cb = TokenUsageCallbackHandler() + ws = FakeWS() + + raw = json.dumps({ + "type": "message", + "thread_id": "t1", + "content": "取消订单 1042 并给我一个 10% 折扣", + }) + await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + + # Verify routing hint was injected + call_args = graph.astream.call_args[0][0] + msg_content = call_args["messages"][0].content + assert "[System:" in msg_content + assert "order_actions" in msg_content + assert "discount" in msg_content + + # Both tool calls should appear + tools = [m for m in ws.sent if m["type"] == "tool_call"] + tool_names = {t["tool"] for t in tools} + assert "cancel_order" in tool_names + assert "apply_discount" in tool_names + + @pytest.mark.asyncio + async def test_single_intent_no_routing_hint(self) -> None: + result = ClassificationResult( + intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), + ) + graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")]) + + sm = SessionManager() + sm.touch("t1") + im = InterruptManager() + cb = TokenUsageCallbackHandler() + ws = FakeWS() + + raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) + await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + + msg_content = graph.astream.call_args[0][0]["messages"][0].content + assert "[System:" not in msg_content + + +# --------------------------------------------------------------------------- +# Ambiguity routing +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestAmbiguityRouting: + """Verify ambiguous intents produce clarification, not agent calls.""" + + @pytest.mark.asyncio + async def test_ambiguous_skips_graph_returns_clarification(self) -> None: + result = ClassificationResult( + intents=(), + is_ambiguous=True, + clarification_question="Could you please clarify what you need?", + ) + graph = _make_graph(result, []) + + msgs = await _dispatch(graph, "嗯...") + + clarifications = [m for m in msgs if m["type"] == "clarification"] + assert len(clarifications) == 1 + assert "clarify" in clarifications[0]["message"] + + # Graph should NOT have been called + graph.astream.assert_not_called() + + @pytest.mark.asyncio + async def test_low_confidence_triggers_ambiguity(self) -> None: + """LLMIntentClassifier applies threshold -- low confidence -> ambiguous.""" + raw_result = ClassificationResult( + intents=(IntentTarget(agent_name="fallback", confidence=0.2, reasoning="unclear"),), + is_ambiguous=False, + ) + mock_structured = MagicMock() + mock_structured.ainvoke = AsyncMock(return_value=raw_result) + mock_llm = MagicMock() + mock_llm.with_structured_output = MagicMock(return_value=mock_structured) + + classifier = LLMIntentClassifier(mock_llm) + result = await classifier.classify("hmm", AGENTS) + + assert result.is_ambiguous + assert result.clarification_question is not None + + +# --------------------------------------------------------------------------- +# No classifier fallback +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestNoClassifierFallback: + """Verify system works without intent classifier (falls back to supervisor prompt).""" + + @pytest.mark.asyncio + async def test_no_classifier_routes_via_supervisor(self) -> None: + graph = _make_graph( + classifier_result=None, + chunks=[_chunk("Order 1042 is shipped.", "order_lookup")], + ) + + msgs = await _dispatch(graph, "What is order 1042 status?") + + tokens = [m for m in msgs if m["type"] == "token"] + assert len(tokens) == 1 + completes = [m for m in msgs if m["type"] == "message_complete"] + assert len(completes) == 1