"""Integration tests for multi-agent routing flow. Tests the full pipeline: intent classification -> supervisor routing -> agent execution -> response streaming, exercising cross-module integration with mocked LLM. Required by Phase 2 test plan: - Unit: intent classification accuracy - Unit: multi-intent sequential execution - Integration: complete multi-agent routing flow """ from __future__ import annotations import json from unittest.mock import AsyncMock, MagicMock import pytest from app.callbacks import TokenUsageCallbackHandler from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier from app.interrupt_manager import InterruptManager from app.registry import AgentConfig from app.session_manager import SessionManager from app.ws_handler import dispatch_message # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class AsyncIterHelper: def __init__(self, items: list) -> None: self._items = list(items) def __aiter__(self): return self async def __anext__(self): if not self._items: raise StopAsyncIteration return self._items.pop(0) class FakeWS: def __init__(self) -> None: self.sent: list[dict] = [] async def send_json(self, data: dict) -> None: self.sent.append(data) def _chunk(content: str, node: str) -> tuple: c = MagicMock() c.content = content c.tool_calls = [] return (c, {"langgraph_node": node}) def _tool_chunk(name: str, args: dict, node: str) -> tuple: c = MagicMock() c.content = "" c.tool_calls = [{"name": name, "args": args}] return (c, {"langgraph_node": node}) def _state(*, interrupt: bool = False, data: dict | None = None): s = MagicMock() if interrupt: obj = MagicMock() obj.value = data or {} t = MagicMock() t.interrupts = (obj,) s.tasks = (t,) else: s.tasks = () return s AGENTS = ( AgentConfig(name="order_lookup", description="Looks up orders", permission="read", tools=["get_order_status", "get_tracking_info"]), AgentConfig(name="order_actions", description="Modifies orders", permission="write", tools=["cancel_order"]), AgentConfig(name="discount", description="Applies discounts", permission="write", tools=["apply_discount", "generate_coupon"]), AgentConfig(name="fallback", description="Handles unclear requests", permission="read", tools=["fallback_respond"]), ) def _make_classifier(result: ClassificationResult) -> AsyncMock: """Create a mock classifier returning the given result.""" classifier = AsyncMock() classifier.classify = AsyncMock(return_value=result) return classifier def _make_graph( classifier_result: ClassificationResult | None, chunks: list, state=None, ) -> MagicMock: """Build a graph mock with optional intent classifier.""" graph = MagicMock() if classifier_result is not None: graph.intent_classifier = _make_classifier(classifier_result) mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=AGENTS) graph.agent_registry = mock_registry else: graph.intent_classifier = None graph.agent_registry = None graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks))) graph.aget_state = AsyncMock(return_value=state or _state()) return graph async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]: sm = SessionManager() sm.touch(thread_id) im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content}) await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) return ws.sent # --------------------------------------------------------------------------- # Single-intent routing to each agent # --------------------------------------------------------------------------- @pytest.mark.integration class TestSingleIntentRouting: """Verify single-intent messages route to the correct agent.""" @pytest.mark.asyncio async def test_routes_to_order_lookup(self) -> None: result = ClassificationResult( intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="status query"),), ) graph = _make_graph(result, [ _tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"), _chunk("Order 1042 is shipped.", "order_lookup"), ]) msgs = await _dispatch(graph, "What is the status of order 1042?") tools = [m for m in msgs if m["type"] == "tool_call"] assert len(tools) == 1 assert tools[0]["tool"] == "get_order_status" assert tools[0]["agent"] == "order_lookup" tokens = [m for m in msgs if m["type"] == "token"] assert any("shipped" in t["content"] for t in tokens) @pytest.mark.asyncio async def test_routes_to_order_actions(self) -> None: result = ClassificationResult( intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),), ) graph = _make_graph( result, [_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")], state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}), ) msgs = await _dispatch(graph, "Cancel order 1042") tools = [m for m in msgs if m["type"] == "tool_call"] assert tools[0]["tool"] == "cancel_order" assert tools[0]["agent"] == "order_actions" interrupts = [m for m in msgs if m["type"] == "interrupt"] assert len(interrupts) == 1 @pytest.mark.asyncio async def test_routes_to_discount(self) -> None: result = ClassificationResult( intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),), ) graph = _make_graph(result, [ _tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"), _chunk("Here is your coupon: SAVE15-ABC12345", "discount"), ]) msgs = await _dispatch(graph, "Give me a 15% coupon") tools = [m for m in msgs if m["type"] == "tool_call"] assert tools[0]["tool"] == "generate_coupon" assert tools[0]["agent"] == "discount" @pytest.mark.asyncio async def test_routes_to_fallback(self) -> None: result = ClassificationResult( intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),), ) graph = _make_graph(result, [ _chunk("I can help with order inquiries.", "fallback"), ]) msgs = await _dispatch(graph, "What can you do?") tokens = [m for m in msgs if m["type"] == "token"] assert tokens[0]["agent"] == "fallback" # --------------------------------------------------------------------------- # Multi-intent routing # --------------------------------------------------------------------------- @pytest.mark.integration class TestMultiIntentRouting: """Verify multi-intent triggers sequential execution hint.""" @pytest.mark.asyncio async def test_two_intents_inject_routing_hint(self) -> None: result = ClassificationResult( intents=( IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"), IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), ), ) graph = _make_graph(result, [ _tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"), _tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"), ]) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() raw = json.dumps({ "type": "message", "thread_id": "t1", "content": "取消订单 1042 并给我一个 10% 折扣", }) await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) # Verify routing hint was injected call_args = graph.astream.call_args[0][0] msg_content = call_args["messages"][0].content assert "[System:" in msg_content assert "order_actions" in msg_content assert "discount" in msg_content # Both tool calls should appear tools = [m for m in ws.sent if m["type"] == "tool_call"] tool_names = {t["tool"] for t in tools} assert "cancel_order" in tool_names assert "apply_discount" in tool_names @pytest.mark.asyncio async def test_single_intent_no_routing_hint(self) -> None: result = ClassificationResult( intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), ) graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")]) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) msg_content = graph.astream.call_args[0][0]["messages"][0].content assert "[System:" not in msg_content # --------------------------------------------------------------------------- # Ambiguity routing # --------------------------------------------------------------------------- @pytest.mark.integration class TestAmbiguityRouting: """Verify ambiguous intents produce clarification, not agent calls.""" @pytest.mark.asyncio async def test_ambiguous_skips_graph_returns_clarification(self) -> None: result = ClassificationResult( intents=(), is_ambiguous=True, clarification_question="Could you please clarify what you need?", ) graph = _make_graph(result, []) msgs = await _dispatch(graph, "嗯...") clarifications = [m for m in msgs if m["type"] == "clarification"] assert len(clarifications) == 1 assert "clarify" in clarifications[0]["message"] # Graph should NOT have been called graph.astream.assert_not_called() @pytest.mark.asyncio async def test_low_confidence_triggers_ambiguity(self) -> None: """LLMIntentClassifier applies threshold -- low confidence -> ambiguous.""" raw_result = ClassificationResult( intents=(IntentTarget(agent_name="fallback", confidence=0.2, reasoning="unclear"),), is_ambiguous=False, ) mock_structured = MagicMock() mock_structured.ainvoke = AsyncMock(return_value=raw_result) mock_llm = MagicMock() mock_llm.with_structured_output = MagicMock(return_value=mock_structured) classifier = LLMIntentClassifier(mock_llm) result = await classifier.classify("hmm", AGENTS) assert result.is_ambiguous assert result.clarification_question is not None # --------------------------------------------------------------------------- # No classifier fallback # --------------------------------------------------------------------------- @pytest.mark.integration class TestNoClassifierFallback: """Verify system works without intent classifier (falls back to supervisor prompt).""" @pytest.mark.asyncio async def test_no_classifier_routes_via_supervisor(self) -> None: graph = _make_graph( classifier_result=None, chunks=[_chunk("Order 1042 is shipped.", "order_lookup")], ) msgs = await _dispatch(graph, "What is order 1042 status?") tokens = [m for m in msgs if m["type"] == "token"] assert len(tokens) == 1 completes = [m for m in msgs if m["type"] == "message_complete"] assert len(completes) == 1