"""Phase 2 checkpoint acceptance tests. Each test maps to one checkpoint criterion from DEVELOPMENT-PLAN.md: 1. "查询订单 1042" -> routes to order_lookup agent 2. "取消订单 1042 并给我一个 10% 折扣" -> sequential multi-agent execution 3. Ambiguous message -> fallback asks for clarification 4. Interrupt > 30 min TTL -> auto-cancel + retry prompt 5. Agent escalation -> Webhook POST succeeds (or logs after retries) 6. E-commerce template -> 4 pre-configured agents work 7. pytest --cov >= 80% (verified separately) """ from __future__ import annotations import json from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.callbacks import TokenUsageCallbackHandler from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier from app.interrupt_manager import InterruptManager from app.registry import AgentConfig, AgentRegistry from app.session_manager import SessionManager from app.ws_handler import dispatch_message TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class AsyncIterHelper: def __init__(self, items: list) -> None: self._items = list(items) def __aiter__(self): return self async def __anext__(self): if not self._items: raise StopAsyncIteration return self._items.pop(0) class FakeWS: def __init__(self) -> None: self.sent: list[dict] = [] async def send_json(self, data: dict) -> None: self.sent.append(data) def _chunk(content: str, node: str) -> tuple: c = MagicMock() c.content = content c.tool_calls = [] return (c, {"langgraph_node": node}) def _tool_chunk(name: str, args: dict, node: str) -> tuple: c = MagicMock() c.content = "" c.tool_calls = [{"name": name, "args": args}] return (c, {"langgraph_node": node}) def _state(*, interrupt: bool = False, data: dict | None = None): s = MagicMock() if interrupt: obj = MagicMock() obj.value = data or {"action": "cancel_order", "order_id": "1042"} t = MagicMock() t.interrupts = (obj,) s.tasks = (t,) else: s.tasks = () return s def _agent(name: str, desc: str, perm: str = "read") -> AgentConfig: return AgentConfig(name=name, description=desc, permission=perm, tools=["fallback_respond"]) # --------------------------------------------------------------------------- # Checkpoint 1: "查询订单 1042" -> 路由到订单查询 Agent # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint1_OrderQueryRouting: """Verify intent classifier routes order queries to order_lookup.""" @pytest.mark.asyncio async def test_order_query_classified_to_order_lookup(self) -> None: expected = ClassificationResult( intents=( IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="order query"), ), ) mock_structured = MagicMock() mock_structured.ainvoke = AsyncMock(return_value=expected) mock_llm = MagicMock() mock_llm.with_structured_output = MagicMock(return_value=mock_structured) classifier = LLMIntentClassifier(mock_llm) agents = ( _agent("order_lookup", "Looks up order status and tracking"), _agent("order_actions", "Modifies orders", "write"), _agent("discount", "Applies discounts", "write"), _agent("fallback", "Handles unclear requests"), ) result = await classifier.classify("查询订单 1042", agents) assert len(result.intents) == 1 assert result.intents[0].agent_name == "order_lookup" assert result.intents[0].confidence >= 0.9 assert not result.is_ambiguous @pytest.mark.asyncio async def test_order_query_streams_from_order_lookup_agent(self) -> None: """Full dispatch: classify -> route -> stream from order_lookup.""" graph = MagicMock() # Classifier returns order_lookup mock_classifier = AsyncMock() mock_classifier.classify = AsyncMock(return_value=ClassificationResult( intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), )) graph.intent_classifier = mock_classifier mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) graph.agent_registry = mock_registry # Graph streams order_lookup response graph.astream = MagicMock(return_value=AsyncIterHelper([ _tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"), _chunk("Order 1042 is shipped.", "order_lookup"), ])) graph.aget_state = AsyncMock(return_value=_state()) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"] assert any(m["tool"] == "get_order_status" for m in tool_msgs) token_msgs = [m for m in ws.sent if m["type"] == "token"] assert any(m["agent"] == "order_lookup" for m in token_msgs) # --------------------------------------------------------------------------- # Checkpoint 2: Multi-intent -> sequential execution # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint2_MultiIntentSequential: """Verify multi-intent classified and hint injected for sequential execution.""" @pytest.mark.asyncio async def test_multi_intent_classification(self) -> None: expected = ClassificationResult( intents=( IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"), IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), ), ) mock_structured = MagicMock() mock_structured.ainvoke = AsyncMock(return_value=expected) mock_llm = MagicMock() mock_llm.with_structured_output = MagicMock(return_value=mock_structured) classifier = LLMIntentClassifier(mock_llm) agents = ( _agent("order_actions", "Modifies orders", "write"), _agent("discount", "Applies discounts", "write"), _agent("fallback", "Handles unclear requests"), ) result = await classifier.classify("取消订单 1042 并给我一个 10% 折扣", agents) assert len(result.intents) == 2 assert result.intents[0].agent_name == "order_actions" assert result.intents[1].agent_name == "discount" assert not result.is_ambiguous @pytest.mark.asyncio async def test_multi_intent_injects_routing_hint(self) -> None: """When multi-intent detected, a [System: ...] hint is appended to the message.""" graph = MagicMock() mock_classifier = AsyncMock() mock_classifier.classify = AsyncMock(return_value=ClassificationResult( intents=( IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"), IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), ), )) graph.intent_classifier = mock_classifier mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) graph.agent_registry = mock_registry graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.aget_state = AsyncMock(return_value=_state()) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() raw = json.dumps({ "type": "message", "thread_id": "t1", "content": "取消订单 1042 并给我一个 10% 折扣", }) await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) # Verify the graph was called with the routing hint in the message call_args = graph.astream.call_args input_msg = call_args[0][0] msg_content = input_msg["messages"][0].content assert "[System:" in msg_content assert "order_actions" in msg_content assert "discount" in msg_content # --------------------------------------------------------------------------- # Checkpoint 3: Ambiguous message -> clarification # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint3_AmbiguousClarification: """Verify ambiguous messages trigger clarification prompt.""" @pytest.mark.asyncio async def test_ambiguous_intent_returns_clarification(self) -> None: expected = ClassificationResult( intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),), is_ambiguous=False, # low confidence will trigger ambiguity threshold ) mock_structured = MagicMock() mock_structured.ainvoke = AsyncMock(return_value=expected) mock_llm = MagicMock() mock_llm.with_structured_output = MagicMock(return_value=mock_structured) classifier = LLMIntentClassifier(mock_llm) agents = (_agent("order_lookup", "Orders"), _agent("fallback", "Fallback")) result = await classifier.classify("嗯...", agents) assert result.is_ambiguous assert result.clarification_question is not None @pytest.mark.asyncio async def test_ambiguous_sends_clarification_via_websocket(self) -> None: graph = MagicMock() mock_classifier = AsyncMock() mock_classifier.classify = AsyncMock(return_value=ClassificationResult( intents=(), is_ambiguous=True, clarification_question="Could you please provide more details about what you need help with?", )) graph.intent_classifier = mock_classifier mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) graph.agent_registry = mock_registry graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.aget_state = AsyncMock(return_value=_state()) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."}) await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) clarifications = [m for m in ws.sent if m["type"] == "clarification"] assert len(clarifications) == 1 assert "more details" in clarifications[0]["message"] # Should NOT call graph.astream since we returned early graph.astream.assert_not_called() # --------------------------------------------------------------------------- # Checkpoint 4: Interrupt > 30 min -> auto-cancel + retry # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint4_InterruptTTLAutoCancel: """Verify interrupt TTL expiration triggers auto-cancel and retry prompt.""" @pytest.mark.asyncio async def test_30min_expired_interrupt_auto_cancels(self) -> None: st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) graph = MagicMock() graph.intent_classifier = None graph.agent_registry = None graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.aget_state = AsyncMock(return_value=st) sm = SessionManager() sm.touch("t1") im = InterruptManager(ttl_seconds=1800) # 30 minutes cb = TokenUsageCallbackHandler() ws = FakeWS() # Trigger interrupt raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"}) await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) interrupts = [m for m in ws.sent if m["type"] == "interrupt"] assert len(interrupts) == 1 # Simulate 31 minutes passing record = im._interrupts["t1"] ws.sent.clear() with patch("app.interrupt_manager.time") as mock_time: mock_time.time.return_value = record.created_at + 1860 # 31 min raw = json.dumps({ "type": "interrupt_response", "thread_id": "t1", "approved": True, }) await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) # Should get retry prompt, NOT resume the graph expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"] assert len(expired_msgs) == 1 assert "30 minutes" in expired_msgs[0]["message"] assert expired_msgs[0]["action"] == "cancel_order" assert expired_msgs[0]["thread_id"] == "t1" def test_cleanup_expired_returns_records(self) -> None: im = InterruptManager(ttl_seconds=1800) im.register("t1", "cancel_order", {"order_id": "1042"}) im.register("t2", "apply_discount", {"order_id": "1043"}) with patch("app.interrupt_manager.time") as mock_time: record = im._interrupts["t1"] mock_time.time.return_value = record.created_at + 1860 expired = im.cleanup_expired() assert len(expired) == 2 actions = {r.action for r in expired} assert "cancel_order" in actions assert "apply_discount" in actions # --------------------------------------------------------------------------- # Checkpoint 5: Agent escalation -> Webhook POST # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint5_WebhookEscalation: """Verify webhook escalation sends POST and retries on failure.""" @pytest.mark.asyncio async def test_webhook_post_success(self) -> None: mock_response = AsyncMock() mock_response.status_code = 200 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) with patch("app.escalation.httpx.AsyncClient", return_value=mock_client): escalator = WebhookEscalator(url="https://support.example.com/escalate") payload = EscalationPayload( thread_id="t1", reason="Agent cannot resolve customer issue", conversation_summary="Customer asked about refund policy", metadata={"customer_id": "C-123"}, ) result = await escalator.escalate(payload) assert result.success assert result.status_code == 200 assert result.attempts == 1 # Verify POST was called with correct payload call_args = mock_client.post.call_args assert call_args[0][0] == "https://support.example.com/escalate" posted_data = call_args[1]["json"] assert posted_data["thread_id"] == "t1" assert posted_data["reason"] == "Agent cannot resolve customer issue" @pytest.mark.asyncio async def test_webhook_retries_then_logs(self) -> None: fail_response = AsyncMock() fail_response.status_code = 503 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=fail_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) with ( patch("app.escalation.httpx.AsyncClient", return_value=mock_client), patch("app.escalation.asyncio.sleep", new_callable=AsyncMock), ): escalator = WebhookEscalator( url="https://support.example.com/escalate", max_retries=3, ) payload = EscalationPayload( thread_id="t1", reason="Escalation needed", conversation_summary="Summary", ) result = await escalator.escalate(payload) assert not result.success assert result.attempts == 3 assert "503" in result.error @pytest.mark.asyncio async def test_noop_escalator_when_disabled(self) -> None: escalator = NoOpEscalator() payload = EscalationPayload( thread_id="t1", reason="Test", conversation_summary="Test", ) result = await escalator.escalate(payload) assert not result.success assert "disabled" in result.error.lower() # --------------------------------------------------------------------------- # Checkpoint 6: E-commerce template -> pre-configured agents # --------------------------------------------------------------------------- @pytest.mark.integration class TestCheckpoint6_EcommerceTemplate: """Verify e-commerce template loads with correct agents.""" def test_ecommerce_template_loads_4_agents(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) assert len(registry) == 4 def test_ecommerce_template_has_correct_agents(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agents = registry.list_agents() names = {a.name for a in agents} assert names == {"order_lookup", "order_actions", "discount", "fallback"} def test_ecommerce_order_lookup_is_read(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agent = registry.get_agent("order_lookup") assert agent.permission == "read" assert "get_order_status" in agent.tools assert "get_tracking_info" in agent.tools def test_ecommerce_order_actions_is_write(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agent = registry.get_agent("order_actions") assert agent.permission == "write" assert "cancel_order" in agent.tools def test_ecommerce_discount_is_write(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agent = registry.get_agent("discount") assert agent.permission == "write" assert "apply_discount" in agent.tools assert "generate_coupon" in agent.tools def test_ecommerce_fallback_is_read(self) -> None: registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR) agent = registry.get_agent("fallback") assert agent.permission == "read" def test_all_three_templates_available(self) -> None: templates = AgentRegistry.list_templates(TEMPLATES_DIR) assert "e-commerce" in templates assert "saas" in templates assert "fintech" in templates