test: add Phase 2 checkpoint acceptance tests
18 integration tests validating all 7 Phase 2 checkpoint criteria: 1. Order query routes to order_lookup agent 2. Multi-intent classification with routing hint injection 3. Ambiguous message triggers clarification prompt 4. 30-min interrupt TTL auto-cancel with retry prompt 5. Webhook POST escalation with retry on failure 6. E-commerce template loads 4 correctly configured agents 7. Coverage at 92.60% (188 tests total)
This commit is contained in:
487
backend/tests/integration/test_phase2_checkpoints.py
Normal file
487
backend/tests/integration/test_phase2_checkpoints.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""Phase 2 checkpoint acceptance tests.
|
||||
|
||||
Each test maps to one checkpoint criterion from DEVELOPMENT-PLAN.md:
|
||||
1. "查询订单 1042" -> routes to order_lookup agent
|
||||
2. "取消订单 1042 并给我一个 10% 折扣" -> sequential multi-agent execution
|
||||
3. Ambiguous message -> fallback asks for clarification
|
||||
4. Interrupt > 30 min TTL -> auto-cancel + retry prompt
|
||||
5. Agent escalation -> Webhook POST succeeds (or logs after retries)
|
||||
6. E-commerce template -> 4 pre-configured agents work
|
||||
7. pytest --cov >= 80% (verified separately)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
|
||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.registry import AgentConfig, AgentRegistry
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AsyncIterHelper:
|
||||
def __init__(self, items: list) -> None:
|
||||
self._items = list(items)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._items:
|
||||
raise StopAsyncIteration
|
||||
return self._items.pop(0)
|
||||
|
||||
|
||||
class FakeWS:
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[dict] = []
|
||||
|
||||
async def send_json(self, data: dict) -> None:
|
||||
self.sent.append(data)
|
||||
|
||||
|
||||
def _chunk(content: str, node: str) -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = content
|
||||
c.tool_calls = []
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def _tool_chunk(name: str, args: dict, node: str) -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = ""
|
||||
c.tool_calls = [{"name": name, "args": args}]
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def _state(*, interrupt: bool = False, data: dict | None = None):
|
||||
s = MagicMock()
|
||||
if interrupt:
|
||||
obj = MagicMock()
|
||||
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
||||
t = MagicMock()
|
||||
t.interrupts = (obj,)
|
||||
s.tasks = (t,)
|
||||
else:
|
||||
s.tasks = ()
|
||||
return s
|
||||
|
||||
|
||||
def _agent(name: str, desc: str, perm: str = "read") -> AgentConfig:
|
||||
return AgentConfig(name=name, description=desc, permission=perm, tools=["fallback_respond"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 1: "查询订单 1042" -> 路由到订单查询 Agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint1_OrderQueryRouting:
|
||||
"""Verify intent classifier routes order queries to order_lookup."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_order_query_classified_to_order_lookup(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="order query"),
|
||||
),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (
|
||||
_agent("order_lookup", "Looks up order status and tracking"),
|
||||
_agent("order_actions", "Modifies orders", "write"),
|
||||
_agent("discount", "Applies discounts", "write"),
|
||||
_agent("fallback", "Handles unclear requests"),
|
||||
)
|
||||
|
||||
result = await classifier.classify("查询订单 1042", agents)
|
||||
assert len(result.intents) == 1
|
||||
assert result.intents[0].agent_name == "order_lookup"
|
||||
assert result.intents[0].confidence >= 0.9
|
||||
assert not result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_order_query_streams_from_order_lookup_agent(self) -> None:
|
||||
"""Full dispatch: classify -> route -> stream from order_lookup."""
|
||||
graph = MagicMock()
|
||||
# Classifier returns order_lookup
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||
))
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
|
||||
# Graph streams order_lookup response
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([
|
||||
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
||||
_chunk("Order 1042 is shipped.", "order_lookup"),
|
||||
]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
|
||||
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
|
||||
|
||||
token_msgs = [m for m in ws.sent if m["type"] == "token"]
|
||||
assert any(m["agent"] == "order_lookup" for m in token_msgs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 2: Multi-intent -> sequential execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint2_MultiIntentSequential:
|
||||
"""Verify multi-intent classified and hint injected for sequential execution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_intent_classification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (
|
||||
_agent("order_actions", "Modifies orders", "write"),
|
||||
_agent("discount", "Applies discounts", "write"),
|
||||
_agent("fallback", "Handles unclear requests"),
|
||||
)
|
||||
|
||||
result = await classifier.classify("取消订单 1042 并给我一个 10% 折扣", agents)
|
||||
assert len(result.intents) == 2
|
||||
assert result.intents[0].agent_name == "order_actions"
|
||||
assert result.intents[1].agent_name == "discount"
|
||||
assert not result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_intent_injects_routing_hint(self) -> None:
|
||||
"""When multi-intent detected, a [System: ...] hint is appended to the message."""
|
||||
graph = MagicMock()
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
))
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
|
||||
raw = json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||
})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
|
||||
# Verify the graph was called with the routing hint in the message
|
||||
call_args = graph.astream.call_args
|
||||
input_msg = call_args[0][0]
|
||||
msg_content = input_msg["messages"][0].content
|
||||
assert "[System:" in msg_content
|
||||
assert "order_actions" in msg_content
|
||||
assert "discount" in msg_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 3: Ambiguous message -> clarification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint3_AmbiguousClarification:
|
||||
"""Verify ambiguous messages trigger clarification prompt."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_intent_returns_clarification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
|
||||
is_ambiguous=False, # low confidence will trigger ambiguity threshold
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_agent("order_lookup", "Orders"), _agent("fallback", "Fallback"))
|
||||
|
||||
result = await classifier.classify("嗯...", agents)
|
||||
assert result.is_ambiguous
|
||||
assert result.clarification_question is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_sends_clarification_via_websocket(self) -> None:
|
||||
graph = MagicMock()
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question="Could you please provide more details about what you need help with?",
|
||||
))
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
|
||||
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
|
||||
assert len(clarifications) == 1
|
||||
assert "more details" in clarifications[0]["message"]
|
||||
|
||||
# Should NOT call graph.astream since we returned early
|
||||
graph.astream.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 4: Interrupt > 30 min -> auto-cancel + retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint4_InterruptTTLAutoCancel:
|
||||
"""Verify interrupt TTL expiration triggers auto-cancel and retry prompt."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
|
||||
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||
graph = MagicMock()
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=st)
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager(ttl_seconds=1800) # 30 minutes
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
|
||||
# Trigger interrupt
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
|
||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
|
||||
# Simulate 31 minutes passing
|
||||
record = im._interrupts["t1"]
|
||||
ws.sent.clear()
|
||||
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = record.created_at + 1860 # 31 min
|
||||
|
||||
raw = json.dumps({
|
||||
"type": "interrupt_response",
|
||||
"thread_id": "t1",
|
||||
"approved": True,
|
||||
})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
|
||||
# Should get retry prompt, NOT resume the graph
|
||||
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
|
||||
assert len(expired_msgs) == 1
|
||||
assert "30 minutes" in expired_msgs[0]["message"]
|
||||
assert expired_msgs[0]["action"] == "cancel_order"
|
||||
assert expired_msgs[0]["thread_id"] == "t1"
|
||||
|
||||
def test_cleanup_expired_returns_records(self) -> None:
|
||||
im = InterruptManager(ttl_seconds=1800)
|
||||
im.register("t1", "cancel_order", {"order_id": "1042"})
|
||||
im.register("t2", "apply_discount", {"order_id": "1043"})
|
||||
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
record = im._interrupts["t1"]
|
||||
mock_time.time.return_value = record.created_at + 1860
|
||||
expired = im.cleanup_expired()
|
||||
|
||||
assert len(expired) == 2
|
||||
actions = {r.action for r in expired}
|
||||
assert "cancel_order" in actions
|
||||
assert "apply_discount" in actions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 5: Agent escalation -> Webhook POST
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint5_WebhookEscalation:
|
||||
"""Verify webhook escalation sends POST and retries on failure."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_post_success(self) -> None:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
|
||||
escalator = WebhookEscalator(url="https://support.example.com/escalate")
|
||||
payload = EscalationPayload(
|
||||
thread_id="t1",
|
||||
reason="Agent cannot resolve customer issue",
|
||||
conversation_summary="Customer asked about refund policy",
|
||||
metadata={"customer_id": "C-123"},
|
||||
)
|
||||
result = await escalator.escalate(payload)
|
||||
|
||||
assert result.success
|
||||
assert result.status_code == 200
|
||||
assert result.attempts == 1
|
||||
|
||||
# Verify POST was called with correct payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[0][0] == "https://support.example.com/escalate"
|
||||
posted_data = call_args[1]["json"]
|
||||
assert posted_data["thread_id"] == "t1"
|
||||
assert posted_data["reason"] == "Agent cannot resolve customer issue"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_retries_then_logs(self) -> None:
|
||||
fail_response = AsyncMock()
|
||||
fail_response.status_code = 503
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=fail_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(
|
||||
url="https://support.example.com/escalate",
|
||||
max_retries=3,
|
||||
)
|
||||
payload = EscalationPayload(
|
||||
thread_id="t1",
|
||||
reason="Escalation needed",
|
||||
conversation_summary="Summary",
|
||||
)
|
||||
result = await escalator.escalate(payload)
|
||||
|
||||
assert not result.success
|
||||
assert result.attempts == 3
|
||||
assert "503" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_escalator_when_disabled(self) -> None:
|
||||
escalator = NoOpEscalator()
|
||||
payload = EscalationPayload(
|
||||
thread_id="t1",
|
||||
reason="Test",
|
||||
conversation_summary="Test",
|
||||
)
|
||||
result = await escalator.escalate(payload)
|
||||
assert not result.success
|
||||
assert "disabled" in result.error.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 6: E-commerce template -> pre-configured agents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint6_EcommerceTemplate:
|
||||
"""Verify e-commerce template loads with correct agents."""
|
||||
|
||||
def test_ecommerce_template_loads_4_agents(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
assert len(registry) == 4
|
||||
|
||||
def test_ecommerce_template_has_correct_agents(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agents = registry.list_agents()
|
||||
names = {a.name for a in agents}
|
||||
assert names == {"order_lookup", "order_actions", "discount", "fallback"}
|
||||
|
||||
def test_ecommerce_order_lookup_is_read(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agent = registry.get_agent("order_lookup")
|
||||
assert agent.permission == "read"
|
||||
assert "get_order_status" in agent.tools
|
||||
assert "get_tracking_info" in agent.tools
|
||||
|
||||
def test_ecommerce_order_actions_is_write(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agent = registry.get_agent("order_actions")
|
||||
assert agent.permission == "write"
|
||||
assert "cancel_order" in agent.tools
|
||||
|
||||
def test_ecommerce_discount_is_write(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agent = registry.get_agent("discount")
|
||||
assert agent.permission == "write"
|
||||
assert "apply_discount" in agent.tools
|
||||
assert "generate_coupon" in agent.tools
|
||||
|
||||
def test_ecommerce_fallback_is_read(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agent = registry.get_agent("fallback")
|
||||
assert agent.permission == "read"
|
||||
|
||||
def test_all_three_templates_available(self) -> None:
|
||||
templates = AgentRegistry.list_templates(TEMPLATES_DIR)
|
||||
assert "e-commerce" in templates
|
||||
assert "saas" in templates
|
||||
assert "fintech" in templates
|
||||
Reference in New Issue
Block a user