"""Phase 2 checkpoint acceptance tests. Each test maps to one checkpoint criterion from DEVELOPMENT-PLAN.md: 1. "查询订单 1042" -> routes to order_lookup agent 2. "取消订单 1042 并给我一个 10% 折扣" -> sequential multi-agent execution 3. Ambiguous message -> fallback asks for clarification 4. Interrupt > 30 min TTL -> auto-cancel + retry prompt 5. Agent escalation -> Webhook POST succeeds (or logs after retries) 6. E-commerce template -> 4 pre-configured agents work 7. pytest --cov >= 80% (verified separately) """ from __future__ import annotations import json from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.callbacks import TokenUsageCallbackHandler from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator from app.graph_context import GraphContext from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier from app.interrupt_manager import InterruptManager from app.registry import AgentConfig, AgentRegistry from app.session_manager import SessionManager from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class AsyncIterHelper: def __init__(self, items: list) -> None: self._items = list(items) def __aiter__(self): return self async def __anext__(self): if not self._items: raise StopAsyncIteration return self._items.pop(0) class FakeWS: def __init__(self) -> None: self.sent: list[dict] = [] async def send_json(self, data: dict) -> None: self.sent.append(data) def _chunk(content: str, node: str) -> tuple: c = MagicMock() c.content = content c.tool_calls = [] return (c, {"langgraph_node": node}) def _tool_chunk(name: str, args: dict, node: str) -> tuple: c = MagicMock() c.content = "" c.tool_calls = [{"name": name, "args": args}] return (c, {"langgraph_node": node}) def _state(*, interrupt: bool = False, data: dict | None = None): s = MagicMock() if interrupt: obj = MagicMock() obj.value = data or {"action": "cancel_order", "order_id": "1042"} t = MagicMock() t.interrupts = (obj,) s.tasks = (t,) else: s.tasks = () return s def _agent(name: str, desc: str, perm: str = "read") -> AgentConfig: return AgentConfig(name=name, description=desc, permission=perm, tools=["fallback_respond"]) # --------------------------------------------------------------------------- # Checkpoint 1: "查询订单 1042" -> 路由到订单查询 Agent # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint1OrderQueryRouting: """Verify intent classifier routes order queries to order_lookup.""" @pytest.mark.asyncio async def test_order_query_classified_to_order_lookup(self) -> None: expected = ClassificationResult( intents=( IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="order query"), ), ) mock_structured = MagicMock() mock_structured.ainvoke = AsyncMock(return_value=expected) mock_llm = MagicMock() mock_llm.with_structured_output = MagicMock(return_value=mock_structured) classifier = LLMIntentClassifier(mock_llm) agents = ( _agent("order_lookup", "Looks up order status and tracking"), _agent("order_actions", "Modifies orders", "write"), _agent("discount", "Applies discounts", "write"), _agent("fallback", "Handles unclear requests"), ) result = await classifier.classify("查询订单 1042", agents) assert len(result.intents) == 1 assert result.intents[0].agent_name == "order_lookup" assert result.intents[0].confidence >= 0.9 assert not result.is_ambiguous @pytest.mark.asyncio async def test_order_query_streams_from_order_lookup_agent(self) -> None: """Full dispatch: classify -> route -> stream from order_lookup.""" graph = MagicMock() # Classifier returns order_lookup mock_classifier = AsyncMock() mock_classifier.classify = AsyncMock(return_value=ClassificationResult( intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), )) mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) # Graph streams order_lookup response graph.astream = MagicMock(return_value=AsyncIterHelper([ _tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"), _chunk("Order 1042 is shipped.", "order_lookup"), ])) graph.aget_state = AsyncMock(return_value=_state()) graph_ctx = GraphContext( graph=graph, registry=mock_registry, intent_classifier=mock_classifier, ) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() ws_ctx = WebSocketContext( graph_ctx=graph_ctx, session_manager=sm, callback_handler=cb, interrupt_manager=im, ) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) await dispatch_message(ws, ws_ctx, raw) tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"] assert any(m["tool"] == "get_order_status" for m in tool_msgs) token_msgs = [m for m in ws.sent if m["type"] == "token"] assert any(m["agent"] == "order_lookup" for m in token_msgs) # --------------------------------------------------------------------------- # Checkpoint 2: Multi-intent -> sequential execution # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint2MultiIntentSequential: """Verify multi-intent classified and hint injected for sequential execution.""" @pytest.mark.asyncio async def test_multi_intent_classification(self) -> None: expected = ClassificationResult( intents=( IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"), IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), ), ) mock_structured = MagicMock() mock_structured.ainvoke = AsyncMock(return_value=expected) mock_llm = MagicMock() mock_llm.with_structured_output = MagicMock(return_value=mock_structured) classifier = LLMIntentClassifier(mock_llm) agents = ( _agent("order_actions", "Modifies orders", "write"), _agent("discount", "Applies discounts", "write"), _agent("fallback", "Handles unclear requests"), ) result = await classifier.classify("取消订单 1042 并给我一个 10% 折扣", agents) assert len(result.intents) == 2 assert result.intents[0].agent_name == "order_actions" assert result.intents[1].agent_name == "discount" assert not result.is_ambiguous @pytest.mark.asyncio async def test_multi_intent_injects_routing_hint(self) -> None: """When multi-intent detected, a [System: ...] hint is appended to the message.""" graph = MagicMock() mock_classifier = AsyncMock() mock_classifier.classify = AsyncMock(return_value=ClassificationResult( intents=( IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"), IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), ), )) mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.aget_state = AsyncMock(return_value=_state()) graph_ctx = GraphContext( graph=graph, registry=mock_registry, intent_classifier=mock_classifier, ) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() ws_ctx = WebSocketContext( graph_ctx=graph_ctx, session_manager=sm, callback_handler=cb, interrupt_manager=im, ) raw = json.dumps({ "type": "message", "thread_id": "t1", "content": "取消订单 1042 并给我一个 10% 折扣", }) await dispatch_message(ws, ws_ctx, raw) # Verify the graph was called with the routing hint in the message call_args = graph.astream.call_args input_msg = call_args[0][0] msg_content = input_msg["messages"][0].content assert "[System:" in msg_content assert "order_actions" in msg_content assert "discount" in msg_content # --------------------------------------------------------------------------- # Checkpoint 3: Ambiguous message -> clarification # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint3AmbiguousClarification: """Verify ambiguous messages trigger clarification prompt.""" @pytest.mark.asyncio async def test_ambiguous_intent_returns_clarification(self) -> None: expected = ClassificationResult( intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),), is_ambiguous=False, # low confidence will trigger ambiguity threshold ) mock_structured = MagicMock() mock_structured.ainvoke = AsyncMock(return_value=expected) mock_llm = MagicMock() mock_llm.with_structured_output = MagicMock(return_value=mock_structured) classifier = LLMIntentClassifier(mock_llm) agents = (_agent("order_lookup", "Orders"), _agent("fallback", "Fallback")) result = await classifier.classify("嗯...", agents) assert result.is_ambiguous assert result.clarification_question is not None @pytest.mark.asyncio async def test_ambiguous_sends_clarification_via_websocket(self) -> None: graph = MagicMock() mock_classifier = AsyncMock() mock_classifier.classify = AsyncMock(return_value=ClassificationResult( intents=(), is_ambiguous=True, clarification_question=( "Could you please provide more details about what you need help with?" ), )) mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.aget_state = AsyncMock(return_value=_state()) graph_ctx = GraphContext( graph=graph, registry=mock_registry, intent_classifier=mock_classifier, ) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() ws_ctx = WebSocketContext( graph_ctx=graph_ctx, session_manager=sm, callback_handler=cb, interrupt_manager=im, ) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."}) await dispatch_message(ws, ws_ctx, raw) clarifications = [m for m in ws.sent if m["type"] == "clarification"] assert len(clarifications) == 1 assert "more details" in clarifications[0]["message"] # Should NOT call graph.astream since we returned early graph.astream.assert_not_called() # --------------------------------------------------------------------------- # Checkpoint 4: Interrupt > 30 min -> auto-cancel + retry # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint4InterruptTTLAutoCancel: """Verify interrupt TTL expiration triggers auto-cancel and retry prompt.""" @pytest.mark.asyncio async def test_30min_expired_interrupt_auto_cancels(self) -> None: st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) graph = MagicMock() graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.aget_state = AsyncMock(return_value=st) mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None) sm = SessionManager() sm.touch("t1") im = InterruptManager(ttl_seconds=1800) # 30 minutes cb = TokenUsageCallbackHandler() ws = FakeWS() ws_ctx = WebSocketContext( graph_ctx=graph_ctx, session_manager=sm, callback_handler=cb, interrupt_manager=im, ) # Trigger interrupt raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"}) await dispatch_message(ws, ws_ctx, raw) interrupts = [m for m in ws.sent if m["type"] == "interrupt"] assert len(interrupts) == 1 # Simulate 31 minutes passing record = im._interrupts["t1"] ws.sent.clear() with patch("app.interrupt_manager.time") as mock_time: mock_time.time.return_value = record.created_at + 1860 # 31 min raw = json.dumps({ "type": "interrupt_response", "thread_id": "t1", "approved": True, }) await dispatch_message(ws, ws_ctx, raw) # Should get retry prompt, NOT resume the graph expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"] assert len(expired_msgs) == 1 assert "30 minutes" in expired_msgs[0]["message"] assert expired_msgs[0]["action"] == "cancel_order" assert expired_msgs[0]["thread_id"] == "t1" def test_cleanup_expired_returns_records(self) -> None: im = InterruptManager(ttl_seconds=1800) im.register("t1", "cancel_order", {"order_id": "1042"}) im.register("t2", "apply_discount", {"order_id": "1043"}) with patch("app.interrupt_manager.time") as mock_time: record = im._interrupts["t1"] mock_time.time.return_value = record.created_at + 1860 expired = im.cleanup_expired() assert len(expired) == 2 actions = {r.action for r in expired} assert "cancel_order" in actions assert "apply_discount" in actions # --------------------------------------------------------------------------- # Checkpoint 5: Agent escalation -> Webhook POST # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint5WebhookEscalation: """Verify webhook escalation sends POST and retries on failure.""" @pytest.mark.asyncio async def test_webhook_post_success(self) -> None: mock_response = AsyncMock() mock_response.status_code = 200 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) with patch("app.escalation.httpx.AsyncClient", return_value=mock_client): escalator = WebhookEscalator(url="https://support.example.com/escalate") payload = EscalationPayload( thread_id="t1", reason="Agent cannot resolve customer issue", conversation_summary="Customer asked about refund policy", metadata={"customer_id": "C-123"}, ) result = await escalator.escalate(payload) assert result.success assert result.status_code == 200 assert result.attempts == 1 # Verify POST was called with correct payload call_args = mock_client.post.call_args assert call_args[0][0] == "https://support.example.com/escalate" posted_data = call_args[1]["json"] assert posted_data["thread_id"] == "t1" assert posted_data["reason"] == "Agent cannot resolve customer issue" @pytest.mark.asyncio async def test_webhook_retries_then_logs(self) -> None: fail_response = AsyncMock() fail_response.status_code = 503 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=fail_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) with ( patch("app.escalation.httpx.AsyncClient", return_value=mock_client), patch("app.escalation.asyncio.sleep", new_callable=AsyncMock), ): escalator = WebhookEscalator( url="https://support.example.com/escalate", max_retries=3, ) payload = EscalationPayload( thread_id="t1", reason="Escalation needed", conversation_summary="Summary", ) result = await escalator.escalate(payload) assert not result.success assert result.attempts == 3 assert "503" in result.error @pytest.mark.asyncio async def test_noop_escalator_when_disabled(self) -> None: escalator = NoOpEscalator() payload = EscalationPayload( thread_id="t1", reason="Test", conversation_summary="Test", ) result = await escalator.escalate(payload) assert not result.success assert "disabled" in result.error.lower() # --------------------------------------------------------------------------- # Checkpoint 6: E-commerce template -> pre-configured agents # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint6EcommerceTemplate: """Verify e-commerce template loads with correct agents.""" def test_ecommerce_template_loads_4_agents(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) assert len(registry) == 4 def test_ecommerce_template_has_correct_agents(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agents = registry.list_agents() names = {a.name for a in agents} assert names == {"order_lookup", "order_actions", "discount", "fallback"} def test_ecommerce_order_lookup_is_read(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agent = registry.get_agent("order_lookup") assert agent.permission == "read" assert "get_order_status" in agent.tools assert "get_tracking_info" in agent.tools def test_ecommerce_order_actions_is_write(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agent = registry.get_agent("order_actions") assert agent.permission == "write" assert "cancel_order" in agent.tools def test_ecommerce_discount_is_write(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agent = registry.get_agent("discount") assert agent.permission == "write" assert "apply_discount" in agent.tools assert "generate_coupon" in agent.tools def test_ecommerce_fallback_is_read(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agent = registry.get_agent("fallback") assert agent.permission == "read" def test_all_three_templates_available(self) -> None: templates = AgentRegistry.list_templates(TEMPLATES_DIR) assert "e-commerce" in templates assert "saas" in templates assert "fintech" in templates