Files
smart-support/backend/tests/integration/test_phase2_checkpoints.py
Yaojia Wang af53111928 refactor: fix architectural issues across frontend and backend
Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
2026-04-06 15:59:14 +02:00

513 lines
20 KiB
Python

"""Phase 2 checkpoint acceptance tests.
Each test maps to one checkpoint criterion from DEVELOPMENT-PLAN.md:
1. "查询订单 1042" -> routes to order_lookup agent
2. "取消订单 1042 并给我一个 10% 折扣" -> sequential multi-agent execution
3. Ambiguous message -> fallback asks for clarification
4. Interrupt > 30 min TTL -> auto-cancel + retry prompt
5. Agent escalation -> Webhook POST succeeds (or logs after retries)
6. E-commerce template -> 4 pre-configured agents work
7. pytest --cov >= 80% (verified separately)
"""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig, AgentRegistry
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class AsyncIterHelper:
def __init__(self, items: list) -> None:
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
class FakeWS:
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_json(self, data: dict) -> None:
self.sent.append(data)
def _chunk(content: str, node: str) -> tuple:
c = MagicMock()
c.content = content
c.tool_calls = []
return (c, {"langgraph_node": node})
def _tool_chunk(name: str, args: dict, node: str) -> tuple:
c = MagicMock()
c.content = ""
c.tool_calls = [{"name": name, "args": args}]
return (c, {"langgraph_node": node})
def _state(*, interrupt: bool = False, data: dict | None = None):
s = MagicMock()
if interrupt:
obj = MagicMock()
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
t = MagicMock()
t.interrupts = (obj,)
s.tasks = (t,)
else:
s.tasks = ()
return s
def _agent(name: str, desc: str, perm: str = "read") -> AgentConfig:
return AgentConfig(name=name, description=desc, permission=perm, tools=["fallback_respond"])
# ---------------------------------------------------------------------------
# Checkpoint 1: "查询订单 1042" -> 路由到订单查询 Agent
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint1OrderQueryRouting:
"""Verify intent classifier routes order queries to order_lookup."""
@pytest.mark.asyncio
async def test_order_query_classified_to_order_lookup(self) -> None:
expected = ClassificationResult(
intents=(
IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="order query"),
),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (
_agent("order_lookup", "Looks up order status and tracking"),
_agent("order_actions", "Modifies orders", "write"),
_agent("discount", "Applies discounts", "write"),
_agent("fallback", "Handles unclear requests"),
)
result = await classifier.classify("查询订单 1042", agents)
assert len(result.intents) == 1
assert result.intents[0].agent_name == "order_lookup"
assert result.intents[0].confidence >= 0.9
assert not result.is_ambiguous
@pytest.mark.asyncio
async def test_order_query_streams_from_order_lookup_agent(self) -> None:
"""Full dispatch: classify -> route -> stream from order_lookup."""
graph = MagicMock()
# Classifier returns order_lookup
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
))
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
# Graph streams order_lookup response
graph.astream = MagicMock(return_value=AsyncIterHelper([
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
_chunk("Order 1042 is shipped.", "order_lookup"),
]))
graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, ws_ctx, raw)
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
token_msgs = [m for m in ws.sent if m["type"] == "token"]
assert any(m["agent"] == "order_lookup" for m in token_msgs)
# ---------------------------------------------------------------------------
# Checkpoint 2: Multi-intent -> sequential execution
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint2MultiIntentSequential:
"""Verify multi-intent classified and hint injected for sequential execution."""
@pytest.mark.asyncio
async def test_multi_intent_classification(self) -> None:
expected = ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (
_agent("order_actions", "Modifies orders", "write"),
_agent("discount", "Applies discounts", "write"),
_agent("fallback", "Handles unclear requests"),
)
result = await classifier.classify("取消订单 1042 并给我一个 10% 折扣", agents)
assert len(result.intents) == 2
assert result.intents[0].agent_name == "order_actions"
assert result.intents[1].agent_name == "discount"
assert not result.is_ambiguous
@pytest.mark.asyncio
async def test_multi_intent_injects_routing_hint(self) -> None:
"""When multi-intent detected, a [System: ...] hint is appended to the message."""
graph = MagicMock()
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
))
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({
"type": "message",
"thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣",
})
await dispatch_message(ws, ws_ctx, raw)
# Verify the graph was called with the routing hint in the message
call_args = graph.astream.call_args
input_msg = call_args[0][0]
msg_content = input_msg["messages"][0].content
assert "[System:" in msg_content
assert "order_actions" in msg_content
assert "discount" in msg_content
# ---------------------------------------------------------------------------
# Checkpoint 3: Ambiguous message -> clarification
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint3AmbiguousClarification:
"""Verify ambiguous messages trigger clarification prompt."""
@pytest.mark.asyncio
async def test_ambiguous_intent_returns_clarification(self) -> None:
expected = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
is_ambiguous=False, # low confidence will trigger ambiguity threshold
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_agent("order_lookup", "Orders"), _agent("fallback", "Fallback"))
result = await classifier.classify("嗯...", agents)
assert result.is_ambiguous
assert result.clarification_question is not None
@pytest.mark.asyncio
async def test_ambiguous_sends_clarification_via_websocket(self) -> None:
graph = MagicMock()
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question=(
"Could you please provide more details about what you need help with?"
),
))
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
await dispatch_message(ws, ws_ctx, raw)
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
assert len(clarifications) == 1
assert "more details" in clarifications[0]["message"]
# Should NOT call graph.astream since we returned early
graph.astream.assert_not_called()
# ---------------------------------------------------------------------------
# Checkpoint 4: Interrupt > 30 min -> auto-cancel + retry
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint4InterruptTTLAutoCancel:
"""Verify interrupt TTL expiration triggers auto-cancel and retry prompt."""
@pytest.mark.asyncio
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
graph = MagicMock()
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=st)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager(ttl_seconds=1800) # 30 minutes
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
# Trigger interrupt
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
await dispatch_message(ws, ws_ctx, raw)
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1
# Simulate 31 minutes passing
record = im._interrupts["t1"]
ws.sent.clear()
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 1860 # 31 min
raw = json.dumps({
"type": "interrupt_response",
"thread_id": "t1",
"approved": True,
})
await dispatch_message(ws, ws_ctx, raw)
# Should get retry prompt, NOT resume the graph
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
assert len(expired_msgs) == 1
assert "30 minutes" in expired_msgs[0]["message"]
assert expired_msgs[0]["action"] == "cancel_order"
assert expired_msgs[0]["thread_id"] == "t1"
def test_cleanup_expired_returns_records(self) -> None:
im = InterruptManager(ttl_seconds=1800)
im.register("t1", "cancel_order", {"order_id": "1042"})
im.register("t2", "apply_discount", {"order_id": "1043"})
with patch("app.interrupt_manager.time") as mock_time:
record = im._interrupts["t1"]
mock_time.time.return_value = record.created_at + 1860
expired = im.cleanup_expired()
assert len(expired) == 2
actions = {r.action for r in expired}
assert "cancel_order" in actions
assert "apply_discount" in actions
# ---------------------------------------------------------------------------
# Checkpoint 5: Agent escalation -> Webhook POST
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint5WebhookEscalation:
"""Verify webhook escalation sends POST and retries on failure."""
@pytest.mark.asyncio
async def test_webhook_post_success(self) -> None:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
escalator = WebhookEscalator(url="https://support.example.com/escalate")
payload = EscalationPayload(
thread_id="t1",
reason="Agent cannot resolve customer issue",
conversation_summary="Customer asked about refund policy",
metadata={"customer_id": "C-123"},
)
result = await escalator.escalate(payload)
assert result.success
assert result.status_code == 200
assert result.attempts == 1
# Verify POST was called with correct payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "https://support.example.com/escalate"
posted_data = call_args[1]["json"]
assert posted_data["thread_id"] == "t1"
assert posted_data["reason"] == "Agent cannot resolve customer issue"
@pytest.mark.asyncio
async def test_webhook_retries_then_logs(self) -> None:
fail_response = AsyncMock()
fail_response.status_code = 503
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=fail_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(
url="https://support.example.com/escalate",
max_retries=3,
)
payload = EscalationPayload(
thread_id="t1",
reason="Escalation needed",
conversation_summary="Summary",
)
result = await escalator.escalate(payload)
assert not result.success
assert result.attempts == 3
assert "503" in result.error
@pytest.mark.asyncio
async def test_noop_escalator_when_disabled(self) -> None:
escalator = NoOpEscalator()
payload = EscalationPayload(
thread_id="t1",
reason="Test",
conversation_summary="Test",
)
result = await escalator.escalate(payload)
assert not result.success
assert "disabled" in result.error.lower()
# ---------------------------------------------------------------------------
# Checkpoint 6: E-commerce template -> pre-configured agents
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint6EcommerceTemplate:
"""Verify e-commerce template loads with correct agents."""
def test_ecommerce_template_loads_4_agents(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
assert len(registry) == 4
def test_ecommerce_template_has_correct_agents(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agents = registry.list_agents()
names = {a.name for a in agents}
assert names == {"order_lookup", "order_actions", "discount", "fallback"}
def test_ecommerce_order_lookup_is_read(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agent = registry.get_agent("order_lookup")
assert agent.permission == "read"
assert "get_order_status" in agent.tools
assert "get_tracking_info" in agent.tools
def test_ecommerce_order_actions_is_write(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agent = registry.get_agent("order_actions")
assert agent.permission == "write"
assert "cancel_order" in agent.tools
def test_ecommerce_discount_is_write(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agent = registry.get_agent("discount")
assert agent.permission == "write"
assert "apply_discount" in agent.tools
assert "generate_coupon" in agent.tools
def test_ecommerce_fallback_is_read(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agent = registry.get_agent("fallback")
assert agent.permission == "read"
def test_all_three_templates_available(self) -> None:
templates = AgentRegistry.list_templates(TEMPLATES_DIR)
assert "e-commerce" in templates
assert "saas" in templates
assert "fintech" in templates