feat: complete phase 2 -- multi-agent routing, interrupt TTL, escalation, templates

- Intent classification with LLM structured output (single/multi/ambiguous)
- Discount agent with apply_discount and generate_coupon tools
- Interrupt manager with 30-min TTL auto-expiration and retry prompts
- Webhook escalation module with exponential backoff retry (max 3)
- Three vertical industry templates (e-commerce, SaaS, fintech)
- Template loading in AgentRegistry
- Enhanced supervisor prompt with dynamic agent descriptions
- 153 tests passing, 90.18% coverage
This commit is contained in:
Yaojia Wang
2026-03-30 21:04:39 +02:00
parent 7c3571b47d
commit 1050df780d
27 changed files with 1683 additions and 43 deletions

View File

@@ -238,7 +238,7 @@ A checkpoint includes:
| Phase | Branch | Focus | Status |
|-------|--------|-------|--------|
| 1 | `phase-1/core-framework` | FastAPI + LangGraph + React chat loop + PostgresSaver | COMPLETED (2026-03-30) |
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | NOT STARTED |
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | COMPLETED (2026-03-30) |
| 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | NOT STARTED |
| 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | NOT STARTED |
| 5 | `phase-5/polish-demo` | Error hardening + demo prep + Docker deploy | NOT STARTED |

View File

@@ -20,6 +20,17 @@ agents:
tools:
- cancel_order
- name: discount
description: "Applies discounts and generates coupon codes. Use when the customer asks for discounts, promotions, or coupon codes."
permission: write
personality:
tone: "generous and accommodating"
greeting: "I can help you with discounts and coupons!"
escalation_message: "Let me connect you with our promotions team."
tools:
- apply_discount
- generate_coupon
- name: fallback
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
permission: read

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_core.tools import BaseTool
from app.agents.discount import apply_discount, generate_coupon
from app.agents.fallback import fallback_respond
from app.agents.order_actions import cancel_order
from app.agents.order_lookup import get_order_status, get_tracking_info
@@ -16,6 +17,8 @@ _TOOL_MAP: dict[str, BaseTool] = {
"get_tracking_info": get_tracking_info,
"cancel_order": cancel_order,
"fallback_respond": fallback_respond,
"apply_discount": apply_discount,
"generate_coupon": generate_coupon,
}

View File

@@ -0,0 +1,79 @@
"""Discount agent tools -- apply discounts and generate coupons."""
from __future__ import annotations
import uuid
from langchain_core.tools import tool
from langgraph.types import interrupt
@tool
def apply_discount(order_id: str, discount_percent: int) -> dict:
"""Apply a discount to an order. Requires human approval before execution."""
if discount_percent < 1 or discount_percent > 100:
return {
"status": "error",
"order_id": order_id,
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
}
response = interrupt(
{
"action": "apply_discount",
"order_id": order_id,
"discount_percent": discount_percent,
"message": (
f"Please confirm: apply {discount_percent}% discount to order {order_id}?"
),
}
)
if isinstance(response, bool):
approved = response
elif isinstance(response, dict):
approved = response.get("approved", False)
else:
approved = bool(response)
if approved:
return {
"status": "applied",
"order_id": order_id,
"discount_percent": discount_percent,
"message": (
f"{discount_percent}% discount applied to order {order_id}."
),
}
return {
"status": "declined",
"order_id": order_id,
"message": f"Discount for order {order_id} was declined.",
}
@tool
def generate_coupon(discount_percent: int, expiry_days: int = 30) -> dict:
"""Generate a coupon code with the specified discount percentage."""
if discount_percent < 1 or discount_percent > 100:
return {
"status": "error",
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
}
if expiry_days < 1:
return {
"status": "error",
"message": f"Invalid expiry: {expiry_days} days. Must be at least 1.",
}
coupon_code = f"SAVE{discount_percent}-{uuid.uuid4().hex[:8].upper()}"
return {
"status": "generated",
"coupon_code": coupon_code,
"discount_percent": discount_percent,
"expiry_days": expiry_days,
"message": (
f"Coupon {coupon_code} generated: {discount_percent}% off, "
f"valid for {expiry_days} days."
),
}

View File

@@ -1,4 +1,4 @@
"""Fallback agent tools -- handles unmatched intents."""
"""Fallback agent tools -- handles unmatched intents and clarification requests."""
from __future__ import annotations
@@ -13,6 +13,7 @@ def fallback_respond(query: str) -> str:
"Here's what I can do:\n"
"- Check order status (e.g., 'What is the status of order 1042?')\n"
"- Get tracking information (e.g., 'Track order 1042')\n"
"- Cancel an order (e.g., 'Cancel order 1042')\n\n"
"- Cancel an order (e.g., 'Cancel order 1042')\n"
"- Apply discounts or generate coupons\n\n"
"Could you please rephrase your request?"
)

View File

@@ -26,6 +26,12 @@ class Settings(BaseSettings):
ws_host: str = "0.0.0.0"
ws_port: int = 8000
webhook_url: str = ""
webhook_timeout_seconds: int = 10
webhook_max_retries: int = 3
template_name: str = ""
anthropic_api_key: str = ""
openai_api_key: str = ""
google_api_key: str = ""

140
backend/app/escalation.py Normal file
View File

@@ -0,0 +1,140 @@
"""Webhook escalation module -- HTTP POST with exponential backoff retry."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Protocol
import httpx
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class EscalationPayload(BaseModel, frozen=True):
"""Immutable payload sent to the escalation webhook."""
thread_id: str
reason: str
conversation_summary: str
metadata: dict = {}
@dataclass(frozen=True)
class EscalationResult:
"""Immutable result of an escalation attempt."""
success: bool
status_code: int | None
attempts: int
error: str | None
class EscalationService(Protocol):
"""Protocol for escalation implementations."""
async def escalate(self, payload: EscalationPayload) -> EscalationResult: ...
class WebhookEscalator:
"""Sends escalation requests via HTTP POST with exponential backoff retry."""
def __init__(
self,
url: str,
timeout_seconds: int = 10,
max_retries: int = 3,
) -> None:
self._url = url
self._timeout = timeout_seconds
self._max_retries = max_retries
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
"""POST the escalation payload to the configured webhook URL."""
if not self._url:
return EscalationResult(
success=False,
status_code=None,
attempts=0,
error="Webhook URL not configured",
)
last_error: str | None = None
async with httpx.AsyncClient(timeout=self._timeout) as client:
for attempt in range(1, self._max_retries + 1):
try:
response = await client.post(
self._url,
json=payload.model_dump(),
)
if 200 <= response.status_code < 300:
logger.info(
"Escalation succeeded for thread %s (attempt %d)",
payload.thread_id,
attempt,
)
return EscalationResult(
success=True,
status_code=response.status_code,
attempts=attempt,
error=None,
)
last_error = f"HTTP {response.status_code}"
logger.warning(
"Escalation attempt %d/%d failed: %s",
attempt,
self._max_retries,
last_error,
)
except httpx.TimeoutException:
last_error = "Request timed out"
logger.warning(
"Escalation attempt %d/%d timed out",
attempt,
self._max_retries,
)
except httpx.RequestError as exc:
last_error = str(exc)
logger.warning(
"Escalation attempt %d/%d error: %s",
attempt,
self._max_retries,
last_error,
)
# Exponential backoff: skip delay after last attempt
if attempt < self._max_retries:
delay = 2**attempt
await asyncio.sleep(delay)
logger.error(
"Escalation failed for thread %s after %d attempts: %s",
payload.thread_id,
self._max_retries,
last_error,
)
return EscalationResult(
success=False,
status_code=None,
attempts=self._max_retries,
error=last_error,
)
class NoOpEscalator:
"""Escalator that does nothing -- used when webhook URL is not configured."""
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
logger.info("Escalation disabled (no webhook URL). Thread: %s", payload.thread_id)
return EscalationResult(
success=False,
status_code=None,
attempts=0,
error="Escalation disabled",
)

View File

@@ -2,12 +2,14 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from langgraph.prebuilt import create_react_agent
from langgraph_supervisor import create_supervisor
from app.agents import get_tools_by_names
from app.intent import ClassificationResult, IntentClassifier
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
@@ -16,15 +18,31 @@ if TYPE_CHECKING:
from app.registry import AgentRegistry
logger = logging.getLogger(__name__)
SUPERVISOR_PROMPT = (
"You are a customer support supervisor. "
"Route customer requests to the appropriate agent based on their description. "
"For order status and tracking queries, use the order_lookup agent. "
"For order modifications like cancellations, use the order_actions agent. "
"For anything else, use the fallback agent."
"Route customer requests to the appropriate agent based on their description.\n\n"
"Available agents and their roles:\n"
"{agent_descriptions}\n\n"
"Routing rules:\n"
"- For order status and tracking queries, use the order_lookup agent.\n"
"- For order modifications like cancellations, use the order_actions agent.\n"
"- For discounts, promotions, or coupon codes, use the discount agent.\n"
"- For anything else or when uncertain, use the fallback agent.\n"
"- If the user's request involves multiple actions, execute them in order.\n"
"- If a previous intent classification is provided, follow it.\n"
)
def _format_agent_descriptions(registry: AgentRegistry) -> str:
"""Build agent description text for the supervisor prompt."""
lines = []
for agent in registry.list_agents():
lines.append(f"- {agent.name}: {agent.description}")
return "\n".join(lines)
def build_agent_nodes(
registry: AgentRegistry,
llm: BaseChatModel,
@@ -56,15 +74,48 @@ def build_graph(
registry: AgentRegistry,
llm: BaseChatModel,
checkpointer: AsyncPostgresSaver,
intent_classifier: IntentClassifier | None = None,
) -> CompiledStateGraph:
"""Build and compile the LangGraph supervisor graph."""
"""Build and compile the LangGraph supervisor graph.
If an intent_classifier is provided, the supervisor prompt is enhanced
with agent descriptions for better routing. The classifier is stored
for use by the routing layer (ws_handler).
"""
agent_nodes = build_agent_nodes(registry, llm)
agent_descriptions = _format_agent_descriptions(registry)
prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions)
workflow = create_supervisor(
agent_nodes,
model=llm,
prompt=SUPERVISOR_PROMPT,
prompt=prompt,
output_mode="full_history",
)
return workflow.compile(checkpointer=checkpointer)
graph = workflow.compile(checkpointer=checkpointer)
# Attach classifier and registry to graph for use by ws_handler
graph.intent_classifier = intent_classifier # type: ignore[attr-defined]
graph.agent_registry = registry # type: ignore[attr-defined]
return graph
async def classify_intent(
graph: CompiledStateGraph,
message: str,
) -> ClassificationResult | None:
"""Classify user intent using the graph's attached classifier.
Returns None if no classifier is configured.
"""
classifier = getattr(graph, "intent_classifier", None)
registry = getattr(graph, "agent_registry", None)
if classifier is None or registry is None:
return None
agents = registry.list_agents()
return await classifier.classify(message, agents)

118
backend/app/intent.py Normal file
View File

@@ -0,0 +1,118 @@
"""Intent classification using LLM structured output."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Protocol
from pydantic import BaseModel
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
from app.registry import AgentConfig
logger = logging.getLogger(__name__)
CLASSIFICATION_PROMPT = (
"You are an intent classifier for a customer support system.\n"
"Given a user message, determine which agent(s) should handle it.\n\n"
"Available agents:\n{agent_list}\n\n"
"Rules:\n"
"- If the message clearly maps to one agent, return a single intent.\n"
"- If the message contains multiple distinct requests, return multiple intents "
"in execution order.\n"
"- If the message is vague or doesn't match any agent, set is_ambiguous=True "
"and provide a clarification_question.\n"
"- Never route to the fallback agent unless truly ambiguous.\n"
"- confidence should be between 0.0 and 1.0.\n"
)
AMBIGUITY_THRESHOLD = 0.5
class IntentTarget(BaseModel, frozen=True):
"""A single classified intent targeting a specific agent."""
agent_name: str
confidence: float
reasoning: str
class ClassificationResult(BaseModel, frozen=True):
"""Result of intent classification -- may contain multiple intents."""
intents: tuple[IntentTarget, ...]
is_ambiguous: bool = False
clarification_question: str | None = None
class IntentClassifier(Protocol):
"""Protocol for intent classification implementations."""
async def classify(
self,
message: str,
available_agents: tuple[AgentConfig, ...],
) -> ClassificationResult: ...
def _build_agent_list(agents: tuple[AgentConfig, ...]) -> str:
"""Format agent descriptions for the classification prompt."""
lines = []
for agent in agents:
lines.append(f"- {agent.name}: {agent.description} (permission: {agent.permission})")
return "\n".join(lines)
class LLMIntentClassifier:
"""Classifies user intent using LLM structured output."""
def __init__(self, llm: BaseChatModel) -> None:
self._llm = llm
async def classify(
self,
message: str,
available_agents: tuple[AgentConfig, ...],
) -> ClassificationResult:
"""Classify user message into one or more agent intents."""
agent_list = _build_agent_list(available_agents)
system_prompt = CLASSIFICATION_PROMPT.format(agent_list=agent_list)
structured_llm = self._llm.with_structured_output(ClassificationResult)
try:
result = await structured_llm.ainvoke(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
]
)
except Exception:
logger.exception("Intent classification failed, returning ambiguous")
return ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="I'm not sure I understood. Could you please rephrase?",
)
if not isinstance(result, ClassificationResult):
return ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="I'm not sure I understood. Could you please rephrase?",
)
# Apply ambiguity threshold
if result.intents and all(i.confidence < AMBIGUITY_THRESHOLD for i in result.intents):
return ClassificationResult(
intents=result.intents,
is_ambiguous=True,
clarification_question=(
result.clarification_question
or "I'm not sure I understood. Could you please rephrase?"
),
)
return result

View File

@@ -0,0 +1,115 @@
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration."""
from __future__ import annotations
import time
import uuid
from dataclasses import dataclass
@dataclass(frozen=True)
class InterruptRecord:
"""Immutable record of a pending interrupt."""
interrupt_id: str
thread_id: str
action: str
params: dict
created_at: float
ttl_seconds: int
@dataclass(frozen=True)
class InterruptStatus:
"""Current status of a tracked interrupt."""
is_expired: bool
remaining_seconds: float
record: InterruptRecord
class InterruptManager:
"""Manages interrupt TTL with auto-expiration.
Complements SessionManager -- this tracks interrupt-specific TTL
while SessionManager handles session-level TTL.
"""
def __init__(self, ttl_seconds: int = 1800) -> None:
self._ttl_seconds = ttl_seconds
self._interrupts: dict[str, InterruptRecord] = {}
def register(
self,
thread_id: str,
action: str,
params: dict,
) -> InterruptRecord:
"""Register a new pending interrupt with TTL tracking."""
record = InterruptRecord(
interrupt_id=uuid.uuid4().hex,
thread_id=thread_id,
action=action,
params=dict(params),
created_at=time.time(),
ttl_seconds=self._ttl_seconds,
)
self._interrupts = {**self._interrupts, thread_id: record}
return record
def check_status(self, thread_id: str) -> InterruptStatus | None:
"""Check the TTL status of a pending interrupt."""
record = self._interrupts.get(thread_id)
if record is None:
return None
elapsed = time.time() - record.created_at
remaining = max(0.0, record.ttl_seconds - elapsed)
is_expired = elapsed > record.ttl_seconds
return InterruptStatus(
is_expired=is_expired,
remaining_seconds=remaining,
record=record,
)
def resolve(self, thread_id: str) -> None:
"""Remove a resolved interrupt from tracking."""
self._interrupts = {
k: v for k, v in self._interrupts.items() if k != thread_id
}
def cleanup_expired(self) -> tuple[InterruptRecord, ...]:
"""Find and remove all expired interrupts. Returns the expired records."""
now = time.time()
expired: list[InterruptRecord] = []
active: dict[str, InterruptRecord] = {}
for thread_id, record in self._interrupts.items():
if now - record.created_at > record.ttl_seconds:
expired.append(record)
else:
active[thread_id] = record
self._interrupts = active
return tuple(expired)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action."""
return {
"type": "interrupt_expired",
"thread_id": expired_record.thread_id,
"action": expired_record.action,
"message": (
f"The approval request for '{expired_record.action}' has expired "
f"after {expired_record.ttl_seconds // 60} minutes. "
f"Would you like to try again?"
),
}
def has_pending(self, thread_id: str) -> bool:
"""Check if a thread has a pending (non-expired) interrupt."""
status = self.check_status(thread_id)
if status is None:
return False
return not status.is_expired

View File

@@ -13,7 +13,10 @@ from fastapi.staticfiles import StaticFiles
from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings
from app.db import create_checkpointer, create_pool, setup_app_tables
from app.escalation import NoOpEscalator, WebhookEscalator
from app.graph import build_graph
from app.intent import LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.llm import create_llm
from app.registry import AgentRegistry
from app.session_manager import SessionManager
@@ -36,23 +39,46 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
checkpointer = await create_checkpointer(pool)
await setup_app_tables(pool)
registry = AgentRegistry.load(AGENTS_YAML)
# Load agents from template or default YAML
if settings.template_name:
registry = AgentRegistry.load_template(settings.template_name)
else:
registry = AgentRegistry.load(AGENTS_YAML)
llm = create_llm(settings)
graph = build_graph(registry, llm, checkpointer)
intent_classifier = LLMIntentClassifier(llm)
graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
session_manager = SessionManager(
session_ttl_seconds=settings.session_ttl_minutes * 60,
)
interrupt_manager = InterruptManager(
ttl_seconds=settings.interrupt_ttl_minutes * 60,
)
# Configure escalation
if settings.webhook_url:
escalator = WebhookEscalator(
url=settings.webhook_url,
timeout_seconds=settings.webhook_timeout_seconds,
max_retries=settings.webhook_max_retries,
)
else:
escalator = NoOpEscalator()
app.state.graph = graph
app.state.session_manager = session_manager
app.state.interrupt_manager = interrupt_manager
app.state.escalator = escalator
app.state.settings = settings
app.state.pool = pool
logger.info(
"Smart Support started: %d agents loaded, LLM=%s/%s",
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
len(registry),
settings.llm_provider,
settings.llm_model,
settings.template_name or "(default)",
)
yield
@@ -60,7 +86,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await pool.close()
app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan)
app = FastAPI(title="Smart Support", version="0.2.0", lifespan=lifespan)
@app.websocket("/ws")
@@ -68,13 +94,17 @@ async def websocket_endpoint(ws: WebSocket) -> None:
await ws.accept()
graph = app.state.graph
session_manager = app.state.session_manager
interrupt_manager = app.state.interrupt_manager
settings = app.state.settings
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
try:
while True:
raw_data = await ws.receive_text()
await dispatch_message(ws, graph, session_manager, callback_handler, raw_data)
await dispatch_message(
ws, graph, session_manager, callback_handler, raw_data,
interrupt_manager=interrupt_manager,
)
except WebSocketDisconnect:
logger.info("WebSocket client disconnected")

View File

@@ -100,5 +100,41 @@ class AgentRegistry:
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
return tuple(a for a in self._agents.values() if a.permission == permission)
@classmethod
def load_template(
cls,
template_name: str,
templates_dir: str | Path | None = None,
) -> AgentRegistry:
"""Load agent configurations from a named template."""
if templates_dir is None:
templates_dir = Path(__file__).parent.parent / "templates"
templates_dir = Path(templates_dir)
yaml_path = templates_dir / f"{template_name}.yaml"
if not yaml_path.exists():
available = cls.list_templates(templates_dir)
raise FileNotFoundError(
f"Template '{template_name}' not found. "
f"Available: {', '.join(available) if available else 'none'}"
)
return cls.load(yaml_path)
@classmethod
def list_templates(
cls,
templates_dir: str | Path | None = None,
) -> tuple[str, ...]:
"""List available template names from the templates directory."""
if templates_dir is None:
templates_dir = Path(__file__).parent.parent / "templates"
templates_dir = Path(templates_dir)
if not templates_dir.is_dir():
return ()
return tuple(
sorted(p.stem for p in templates_dir.glob("*.yaml"))
)
def __len__(self) -> int:
return len(self._agents)

View File

@@ -10,11 +10,14 @@ from typing import TYPE_CHECKING, Any
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from app.graph import classify_intent
if TYPE_CHECKING:
from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.callbacks import TokenUsageCallbackHandler
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
logger = logging.getLogger(__name__)
@@ -31,6 +34,7 @@ async def handle_user_message(
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
content: str,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Process a user message through the graph and stream results back."""
if session_manager.is_expired(thread_id):
@@ -39,8 +43,42 @@ async def handle_user_message(
return
session_manager.touch(thread_id)
# Run intent classification if available (for logging/future multi-intent)
classification = await classify_intent(graph, content)
if classification is not None:
logger.info(
"Intent classification for thread %s: ambiguous=%s, intents=%s",
thread_id,
classification.is_ambiguous,
[i.agent_name for i in classification.intents],
)
# If ambiguous, send clarification and return
if classification.is_ambiguous and classification.clarification_question:
await _send_json(
ws,
{
"type": "clarification",
"thread_id": thread_id,
"message": classification.clarification_question,
},
)
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
return
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
input_msg = {"messages": [HumanMessage(content=content)]}
# If multi-intent detected, add routing hint to the message
if classification and len(classification.intents) > 1:
agent_names = [i.agent_name for i in classification.intents]
hint = (
f"\n[System: This request involves multiple actions. "
f"Execute in order: {', '.join(agent_names)}]"
)
input_msg = {"messages": [HumanMessage(content=content + hint)]}
else:
input_msg = {"messages": [HumanMessage(content=content)]}
try:
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
@@ -72,6 +110,15 @@ async def handle_user_message(
if _has_interrupt(state):
interrupt_data = _extract_interrupt(state)
session_manager.extend_for_interrupt(thread_id)
# Register interrupt with TTL tracking
if interrupt_manager is not None:
interrupt_manager.register(
thread_id=thread_id,
action=interrupt_data.get("action", "unknown"),
params=interrupt_data.get("params", {}),
)
await _send_json(
ws,
{
@@ -96,8 +143,21 @@ async def handle_interrupt_response(
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
approved: bool,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Resume graph execution after interrupt approval/rejection."""
# Check interrupt TTL before resuming
if interrupt_manager is not None:
status = interrupt_manager.check_status(thread_id)
if status is not None and status.is_expired:
retry_prompt = interrupt_manager.generate_retry_prompt(status.record)
interrupt_manager.resolve(thread_id)
session_manager.resolve_interrupt(thread_id)
await _send_json(ws, retry_prompt)
return
interrupt_manager.resolve(thread_id)
session_manager.resolve_interrupt(thread_id)
session_manager.touch(thread_id)
@@ -136,6 +196,7 @@ async def dispatch_message(
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
raw_data: str,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Parse and route an incoming WebSocket message."""
if len(raw_data) > MAX_MESSAGE_SIZE:
@@ -167,12 +228,16 @@ async def dispatch_message(
if len(content) > MAX_CONTENT_LENGTH:
await _send_json(ws, {"type": "error", "message": "Message content too long"})
return
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
await handle_user_message(
ws, graph, session_manager, callback_handler, thread_id, content,
interrupt_manager=interrupt_manager,
)
elif msg_type == "interrupt_response":
approved = data.get("approved", False)
await handle_interrupt_response(
ws, graph, session_manager, callback_handler, thread_id, approved
ws, graph, session_manager, callback_handler, thread_id, approved,
interrupt_manager=interrupt_manager,
)
else:

View File

@@ -18,6 +18,7 @@ dependencies = [
"pydantic-settings>=2.7,<3.0",
"pyyaml>=6.0,<7.0",
"python-dotenv>=1.0,<2.0",
"httpx>=0.28,<1.0",
]
[project.optional-dependencies]

View File

@@ -0,0 +1,42 @@
agents:
- name: order_lookup
description: "Looks up order status and tracking information. Use for queries about order status, shipping, and delivery."
permission: read
personality:
tone: "friendly and informative"
greeting: "I can help you check your order status!"
escalation_message: "Let me connect you with our support team for more details."
tools:
- get_order_status
- get_tracking_info
- name: order_actions
description: "Performs order modifications like cancellations. Use when the customer wants to cancel, modify, or change an order."
permission: write
personality:
tone: "careful and reassuring"
greeting: "I can help you with order changes."
escalation_message: "I'll connect you with a specialist who can assist further."
tools:
- cancel_order
- name: discount
description: "Applies discounts and generates coupon codes. Use when the customer asks for discounts, promotions, or coupon codes."
permission: write
personality:
tone: "generous and accommodating"
greeting: "I can help you with discounts and coupons!"
escalation_message: "Let me connect you with our promotions team."
tools:
- apply_discount
- generate_coupon
- name: fallback
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
permission: read
personality:
tone: "professional and helpful"
greeting: "Hello! How can I help you today?"
escalation_message: "Let me connect you with a human agent who can better assist you."
tools:
- fallback_respond

View File

@@ -0,0 +1,31 @@
agents:
- name: transaction_lookup
description: "Looks up transaction history, balances, and payment details. Use for queries about transactions and account activity."
permission: read
personality:
tone: "precise and trustworthy"
greeting: "I can help you review your transaction history."
escalation_message: "Let me connect you with our financial support team."
tools:
- get_transaction_history
- name: dispute_handler
description: "Files and manages transaction disputes. Use when the customer wants to dispute a charge or check dispute status."
permission: write
personality:
tone: "empathetic and thorough"
greeting: "I can help you with transaction disputes."
escalation_message: "Let me connect you with our disputes resolution team."
tools:
- file_dispute
- check_dispute_status
- name: fallback
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
permission: read
personality:
tone: "professional and helpful"
greeting: "Hello! How can I help you today?"
escalation_message: "Let me connect you with a human agent who can better assist you."
tools:
- fallback_respond

View File

@@ -0,0 +1,31 @@
agents:
- name: account_lookup
description: "Looks up account status, subscription details, and billing history. Use for queries about account information."
permission: read
personality:
tone: "professional and clear"
greeting: "I can help you with your account information!"
escalation_message: "Let me connect you with our account support team."
tools:
- get_account_status
- name: subscription_management
description: "Manages subscription changes like plan upgrades, downgrades, and cancellations. Use when the customer wants to change their subscription."
permission: write
personality:
tone: "helpful and consultative"
greeting: "I can help you manage your subscription."
escalation_message: "Let me connect you with our billing specialist."
tools:
- change_plan
- cancel_subscription
- name: fallback
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
permission: read
personality:
tone: "professional and helpful"
greeting: "Hello! How can I help you today?"
escalation_message: "Let me connect you with a human agent who can better assist you."
tools:
- fallback_respond

View File

@@ -0,0 +1,79 @@
"""Tests for app.agents.discount module."""
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.agents.discount import apply_discount, generate_coupon
@pytest.mark.unit
class TestApplyDiscount:
def test_invalid_discount_zero(self) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 0})
assert result["status"] == "error"
assert "Invalid" in result["message"]
def test_invalid_discount_over_100(self) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 101})
assert result["status"] == "error"
def test_invalid_discount_negative(self) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": -5})
assert result["status"] == "error"
@patch("app.agents.discount.interrupt", return_value=True)
def test_approved_discount(self, mock_interrupt) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
assert result["status"] == "applied"
assert result["discount_percent"] == 10
assert "1042" in result["message"]
@patch("app.agents.discount.interrupt", return_value=False)
def test_rejected_discount(self, mock_interrupt) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
assert result["status"] == "declined"
@patch("app.agents.discount.interrupt", return_value={"approved": True})
def test_approved_via_dict(self, mock_interrupt) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
assert result["status"] == "applied"
@patch("app.agents.discount.interrupt", return_value={"approved": False})
def test_rejected_via_dict(self, mock_interrupt) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
assert result["status"] == "declined"
@pytest.mark.unit
class TestGenerateCoupon:
def test_valid_coupon(self) -> None:
result = generate_coupon.invoke({"discount_percent": 15, "expiry_days": 7})
assert result["status"] == "generated"
assert result["discount_percent"] == 15
assert result["expiry_days"] == 7
assert result["coupon_code"].startswith("SAVE15-")
def test_default_expiry(self) -> None:
result = generate_coupon.invoke({"discount_percent": 20})
assert result["status"] == "generated"
assert result["expiry_days"] == 30
def test_invalid_discount_zero(self) -> None:
result = generate_coupon.invoke({"discount_percent": 0})
assert result["status"] == "error"
def test_invalid_discount_over_100(self) -> None:
result = generate_coupon.invoke({"discount_percent": 101})
assert result["status"] == "error"
def test_invalid_expiry(self) -> None:
result = generate_coupon.invoke({"discount_percent": 10, "expiry_days": 0})
assert result["status"] == "error"
def test_coupon_codes_unique(self) -> None:
r1 = generate_coupon.invoke({"discount_percent": 10})
r2 = generate_coupon.invoke({"discount_percent": 10})
assert r1["coupon_code"] != r2["coupon_code"]

View File

@@ -0,0 +1,169 @@
"""Tests for app.escalation module."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from app.escalation import (
EscalationPayload,
EscalationResult,
NoOpEscalator,
WebhookEscalator,
)
def _make_payload(**kwargs) -> EscalationPayload:
defaults = {
"thread_id": "t1",
"reason": "Agent cannot resolve",
"conversation_summary": "User asked about refund policy",
}
defaults.update(kwargs)
return EscalationPayload(**defaults)
@pytest.mark.unit
class TestEscalationPayload:
def test_frozen(self) -> None:
payload = _make_payload()
with pytest.raises(Exception):
payload.thread_id = "t2" # type: ignore[misc]
def test_default_metadata(self) -> None:
payload = _make_payload()
assert payload.metadata == {}
def test_model_dump(self) -> None:
payload = _make_payload(metadata={"key": "val"})
data = payload.model_dump()
assert data["thread_id"] == "t1"
assert data["metadata"] == {"key": "val"}
@pytest.mark.unit
class TestEscalationResult:
def test_frozen(self) -> None:
result = EscalationResult(success=True, status_code=200, attempts=1, error=None)
assert result.success
assert result.status_code == 200
@pytest.mark.unit
class TestWebhookEscalator:
@pytest.mark.asyncio
async def test_empty_url_returns_failure(self) -> None:
escalator = WebhookEscalator(url="", max_retries=3)
result = await escalator.escalate(_make_payload())
assert not result.success
assert result.attempts == 0
assert "not configured" in result.error
@pytest.mark.asyncio
async def test_successful_post(self) -> None:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
escalator = WebhookEscalator(url="https://example.com/hook")
result = await escalator.escalate(_make_payload())
assert result.success
assert result.status_code == 200
assert result.attempts == 1
@pytest.mark.asyncio
async def test_retry_on_server_error(self) -> None:
fail_response = AsyncMock()
fail_response.status_code = 500
success_response = AsyncMock()
success_response.status_code = 200
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=[fail_response, fail_response, success_response])
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
result = await escalator.escalate(_make_payload())
assert result.success
assert result.attempts == 3
@pytest.mark.asyncio
async def test_all_retries_exhausted(self) -> None:
fail_response = AsyncMock()
fail_response.status_code = 500
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=fail_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
result = await escalator.escalate(_make_payload())
assert not result.success
assert result.attempts == 3
assert "500" in result.error
@pytest.mark.asyncio
async def test_timeout_error(self) -> None:
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("timeout"))
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=2)
result = await escalator.escalate(_make_payload())
assert not result.success
assert "timed out" in result.error
@pytest.mark.asyncio
async def test_request_error(self) -> None:
mock_client = AsyncMock()
mock_client.post = AsyncMock(
side_effect=httpx.RequestError("connection refused")
)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=1)
result = await escalator.escalate(_make_payload())
assert not result.success
@pytest.mark.unit
class TestNoOpEscalator:
@pytest.mark.asyncio
async def test_returns_disabled(self) -> None:
escalator = NoOpEscalator()
result = await escalator.escalate(_make_payload())
assert not result.success
assert result.attempts == 0
assert "disabled" in result.error.lower()

View File

@@ -7,7 +7,8 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.graph import SUPERVISOR_PROMPT, build_agent_nodes, build_graph
from app.graph import build_agent_nodes, build_graph, classify_intent
from app.intent import ClassificationResult, IntentTarget
if TYPE_CHECKING:
from app.registry import AgentRegistry
@@ -38,7 +39,51 @@ class TestBuildGraph:
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
assert graph is not None
def test_supervisor_prompt_contains_routing_info(self) -> None:
assert "order_lookup" in SUPERVISOR_PROMPT
assert "order_actions" in SUPERVISOR_PROMPT
assert "fallback" in SUPERVISOR_PROMPT
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock()
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
mock_checkpointer = AsyncMock()
mock_classifier = MagicMock()
graph = build_graph(
sample_registry, mock_llm, mock_checkpointer, intent_classifier=mock_classifier
)
assert graph.intent_classifier is mock_classifier
assert graph.agent_registry is sample_registry
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock()
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
mock_checkpointer = AsyncMock()
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
assert graph.intent_classifier is None
@pytest.mark.unit
class TestClassifyIntent:
@pytest.mark.asyncio
async def test_returns_none_without_classifier(self) -> None:
graph = MagicMock()
graph.intent_classifier = None
result = await classify_intent(graph, "hello")
assert result is None
@pytest.mark.asyncio
async def test_calls_classifier(self) -> None:
expected = ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="test"),),
)
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=expected)
graph = MagicMock()
graph.intent_classifier = mock_classifier
graph.agent_registry = MagicMock()
graph.agent_registry.list_agents = MagicMock(return_value=())
result = await classify_intent(graph, "check order")
assert result is not None
assert result.intents[0].agent_name == "order_lookup"

View File

@@ -0,0 +1,175 @@
"""Tests for app.intent module."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.intent import (
AMBIGUITY_THRESHOLD,
ClassificationResult,
IntentTarget,
LLMIntentClassifier,
_build_agent_list,
)
from app.registry import AgentConfig
def _make_agent(name: str, desc: str = "test", permission: str = "read") -> AgentConfig:
return AgentConfig(name=name, description=desc, permission=permission, tools=["fallback_respond"])
@pytest.mark.unit
class TestIntentModels:
def test_intent_target_frozen(self) -> None:
target = IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="order query")
with pytest.raises(Exception):
target.agent_name = "other" # type: ignore[misc]
def test_classification_result_frozen(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="a", confidence=0.9, reasoning="r"),),
)
assert not result.is_ambiguous
assert result.clarification_question is None
def test_classification_result_ambiguous(self) -> None:
result = ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="What do you mean?",
)
assert result.is_ambiguous
def test_multi_intent(self) -> None:
result = ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.85, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.8, reasoning="discount"),
),
)
assert len(result.intents) == 2
@pytest.mark.unit
class TestBuildAgentList:
def test_formats_agents(self) -> None:
agents = (
_make_agent("order_lookup", "Looks up orders", "read"),
_make_agent("order_actions", "Modifies orders", "write"),
)
text = _build_agent_list(agents)
assert "order_lookup" in text
assert "order_actions" in text
assert "read" in text
assert "write" in text
@pytest.mark.unit
class TestLLMIntentClassifier:
@pytest.mark.asyncio
async def test_single_intent_classification(self) -> None:
expected = ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
result = await classifier.classify("What is order 1042 status?", agents)
assert len(result.intents) == 1
assert result.intents[0].agent_name == "order_lookup"
assert not result.is_ambiguous
@pytest.mark.asyncio
async def test_multi_intent_classification(self) -> None:
expected = ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("order_actions"), _make_agent("discount"), _make_agent("fallback"))
result = await classifier.classify("Cancel order 1042 and give me 10% off", agents)
assert len(result.intents) == 2
assert not result.is_ambiguous
@pytest.mark.asyncio
async def test_ambiguous_classification(self) -> None:
expected = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
is_ambiguous=False,
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
result = await classifier.classify("hmm", agents)
# Low confidence triggers ambiguity
assert result.is_ambiguous
assert result.clarification_question is not None
@pytest.mark.asyncio
async def test_llm_error_returns_ambiguous(self) -> None:
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(side_effect=RuntimeError("LLM error"))
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("fallback"),)
result = await classifier.classify("test", agents)
assert result.is_ambiguous
assert result.clarification_question is not None
@pytest.mark.asyncio
async def test_non_result_type_returns_ambiguous(self) -> None:
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value="not a ClassificationResult")
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("fallback"),)
result = await classifier.classify("test", agents)
assert result.is_ambiguous
@pytest.mark.asyncio
async def test_high_confidence_not_ambiguous(self) -> None:
expected = ClassificationResult(
intents=(
IntentTarget(
agent_name="order_lookup",
confidence=AMBIGUITY_THRESHOLD + 0.1,
reasoning="clear",
),
),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("order_lookup"),)
result = await classifier.classify("order status 1042", agents)
assert not result.is_ambiguous

View File

@@ -0,0 +1,132 @@
"""Tests for app.interrupt_manager module."""
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.interrupt_manager import InterruptManager
@pytest.mark.unit
class TestInterruptManagerRegister:
def test_register_creates_record(self) -> None:
mgr = InterruptManager(ttl_seconds=1800)
record = mgr.register("t1", "cancel_order", {"order_id": "1042"})
assert record.thread_id == "t1"
assert record.action == "cancel_order"
assert record.ttl_seconds == 1800
assert record.interrupt_id
def test_register_overwrites_previous(self) -> None:
mgr = InterruptManager()
r1 = mgr.register("t1", "cancel_order", {})
r2 = mgr.register("t1", "apply_discount", {})
assert r1.interrupt_id != r2.interrupt_id
status = mgr.check_status("t1")
assert status is not None
assert status.record.action == "apply_discount"
@pytest.mark.unit
class TestInterruptManagerCheckStatus:
def test_no_interrupt_returns_none(self) -> None:
mgr = InterruptManager()
assert mgr.check_status("t1") is None
def test_fresh_interrupt_not_expired(self) -> None:
mgr = InterruptManager(ttl_seconds=1800)
mgr.register("t1", "cancel_order", {})
status = mgr.check_status("t1")
assert status is not None
assert not status.is_expired
assert status.remaining_seconds > 0
def test_expired_interrupt(self) -> None:
mgr = InterruptManager(ttl_seconds=10)
mgr.register("t1", "cancel_order", {})
# Move time forward
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = mgr._interrupts["t1"].created_at + 11
status = mgr.check_status("t1")
assert status is not None
assert status.is_expired
assert status.remaining_seconds == 0.0
def test_boundary_not_expired(self) -> None:
mgr = InterruptManager(ttl_seconds=10)
mgr.register("t1", "cancel_order", {})
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = mgr._interrupts["t1"].created_at + 9
status = mgr.check_status("t1")
assert status is not None
assert not status.is_expired
@pytest.mark.unit
class TestInterruptManagerResolve:
def test_resolve_removes_record(self) -> None:
mgr = InterruptManager()
mgr.register("t1", "cancel_order", {})
mgr.resolve("t1")
assert mgr.check_status("t1") is None
def test_resolve_nonexistent_is_safe(self) -> None:
mgr = InterruptManager()
mgr.resolve("nonexistent") # Should not raise
@pytest.mark.unit
class TestInterruptManagerCleanup:
def test_cleanup_removes_expired(self) -> None:
mgr = InterruptManager(ttl_seconds=5)
mgr.register("t1", "cancel_order", {})
mgr.register("t2", "apply_discount", {})
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = mgr._interrupts["t1"].created_at + 6
expired = mgr.cleanup_expired()
assert len(expired) == 2
assert mgr.check_status("t1") is None
assert mgr.check_status("t2") is None
def test_cleanup_keeps_active(self) -> None:
mgr = InterruptManager(ttl_seconds=100)
mgr.register("t1", "cancel_order", {})
expired = mgr.cleanup_expired()
assert len(expired) == 0
assert mgr.check_status("t1") is not None
@pytest.mark.unit
class TestInterruptManagerRetryPrompt:
def test_generates_correct_prompt(self) -> None:
mgr = InterruptManager(ttl_seconds=1800)
record = mgr.register("t1", "cancel_order", {"order_id": "1042"})
prompt = mgr.generate_retry_prompt(record)
assert prompt["type"] == "interrupt_expired"
assert prompt["thread_id"] == "t1"
assert prompt["action"] == "cancel_order"
assert "30 minutes" in prompt["message"]
assert "cancel_order" in prompt["message"]
@pytest.mark.unit
class TestInterruptManagerHasPending:
def test_no_interrupt(self) -> None:
mgr = InterruptManager()
assert not mgr.has_pending("t1")
def test_has_active_interrupt(self) -> None:
mgr = InterruptManager(ttl_seconds=1800)
mgr.register("t1", "cancel_order", {})
assert mgr.has_pending("t1")
def test_expired_interrupt_not_pending(self) -> None:
mgr = InterruptManager(ttl_seconds=5)
mgr.register("t1", "cancel_order", {})
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = mgr._interrupts["t1"].created_at + 6
assert not mgr.has_pending("t1")

View File

@@ -13,7 +13,7 @@ class TestMainModule:
assert app.title == "Smart Support"
def test_app_version(self) -> None:
assert app.version == "0.1.0"
assert app.version == "0.2.0"
def test_agents_yaml_path_exists(self) -> None:
assert AGENTS_YAML.name == "agents.yaml"

View File

@@ -0,0 +1,70 @@
"""Tests for template loading in app.registry."""
from __future__ import annotations
from pathlib import Path
import pytest
from app.registry import AgentRegistry
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
@pytest.mark.unit
class TestListTemplates:
def test_lists_all_templates(self) -> None:
templates = AgentRegistry.list_templates(TEMPLATES_DIR)
assert "e-commerce" in templates
assert "saas" in templates
assert "fintech" in templates
def test_returns_sorted(self) -> None:
templates = AgentRegistry.list_templates(TEMPLATES_DIR)
assert templates == tuple(sorted(templates))
def test_empty_dir_returns_empty(self, tmp_path: Path) -> None:
templates = AgentRegistry.list_templates(tmp_path)
assert templates == ()
def test_nonexistent_dir_returns_empty(self) -> None:
templates = AgentRegistry.list_templates("/nonexistent/path")
assert templates == ()
@pytest.mark.unit
class TestLoadTemplate:
def test_load_ecommerce(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
assert len(registry) == 4
agents = registry.list_agents()
names = {a.name for a in agents}
assert "order_lookup" in names
assert "discount" in names
assert "fallback" in names
def test_load_saas(self) -> None:
registry = AgentRegistry.load_template("saas", TEMPLATES_DIR)
assert len(registry) == 3
agents = registry.list_agents()
names = {a.name for a in agents}
assert "account_lookup" in names
assert "subscription_management" in names
def test_load_fintech(self) -> None:
registry = AgentRegistry.load_template("fintech", TEMPLATES_DIR)
assert len(registry) == 3
agents = registry.list_agents()
names = {a.name for a in agents}
assert "transaction_lookup" in names
assert "dispute_handler" in names
def test_nonexistent_template_raises(self) -> None:
with pytest.raises(FileNotFoundError, match="not found"):
AgentRegistry.load_template("nonexistent", TEMPLATES_DIR)
def test_error_message_lists_available(self) -> None:
try:
AgentRegistry.load_template("nonexistent", TEMPLATES_DIR)
except FileNotFoundError as exc:
assert "e-commerce" in str(exc)

View File

@@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
from app.ws_handler import (
_extract_interrupt,
@@ -30,6 +31,9 @@ def _make_graph() -> AsyncMock:
state = MagicMock()
state.tasks = ()
graph.aget_state = AsyncMock(return_value=state)
# Phase 2: graph needs intent_classifier and agent_registry attrs
graph.intent_classifier = None
graph.agent_registry = None
return graph
@@ -100,8 +104,6 @@ class TestDispatchMessage:
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "Unknown" in call_data["message"]
# Verify raw input is NOT reflected back
assert "unknown" not in call_data["message"].lower().replace("unknown message type", "")
@pytest.mark.asyncio
async def test_message_too_large(self) -> None:
@@ -142,6 +144,20 @@ class TestDispatchMessage:
assert call_data["type"] == "error"
assert "too long" in call_data["message"].lower()
@pytest.mark.asyncio
async def test_dispatch_with_interrupt_manager(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager()
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.unit
class TestHandleUserMessage:
@@ -166,7 +182,6 @@ class TestHandleUserMessage:
sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
# Should end with message_complete
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@@ -175,6 +190,8 @@ class TestHandleUserMessage:
ws = _make_ws()
graph = AsyncMock()
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
graph.intent_classifier = None
graph.agent_registry = None
sm = SessionManager()
cb = TokenUsageCallbackHandler()
@@ -183,6 +200,74 @@ class TestHandleUserMessage:
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.asyncio
async def test_interrupt_registered_with_manager(self) -> None:
ws = _make_ws()
graph = AsyncMock()
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
# Simulate interrupt in state
interrupt_obj = MagicMock()
interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"}
task = MagicMock()
task.interrupts = (interrupt_obj,)
state = MagicMock()
state.tasks = (task,)
graph.aget_state = AsyncMock(return_value=state)
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager()
sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im)
# Interrupt should be registered
assert im.has_pending("t1")
# Should have sent interrupt message
calls = [c[0][0] for c in ws.send_json.call_args_list]
interrupt_msgs = [c for c in calls if c.get("type") == "interrupt"]
assert len(interrupt_msgs) == 1
@pytest.mark.asyncio
async def test_ambiguous_intent_sends_clarification(self) -> None:
from app.intent import ClassificationResult
ws = _make_ws()
graph = AsyncMock()
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
state = MagicMock()
state.tasks = ()
graph.aget_state = AsyncMock(return_value=state)
# Set up intent classifier that returns ambiguous
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(
return_value=ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="What do you mean?",
)
)
graph.intent_classifier = mock_classifier
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hmm")
calls = [c[0][0] for c in ws.send_json.call_args_list]
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
assert len(clarification_msgs) == 1
assert clarification_msgs[0]["message"] == "What do you mean?"
@pytest.mark.unit
class TestHandleInterruptResponse:
@@ -199,6 +284,52 @@ class TestHandleInterruptResponse:
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.asyncio
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
from unittest.mock import patch
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=5)
sm.touch("t1")
sm.extend_for_interrupt("t1")
im.register("t1", "cancel_order", {"order_id": "1042"})
# Expire the interrupt
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
await handle_interrupt_response(
ws, graph, sm, cb, "t1", True, interrupt_manager=im
)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "interrupt_expired"
assert "cancel_order" in call_data["message"]
@pytest.mark.asyncio
async def test_valid_interrupt_resolves(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=1800)
sm.touch("t1")
sm.extend_for_interrupt("t1")
im.register("t1", "cancel_order", {})
await handle_interrupt_response(
ws, graph, sm, cb, "t1", True, interrupt_manager=im
)
# Interrupt should be resolved
assert not im.has_pending("t1")
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.unit
class TestInterruptHelpers:

View File

@@ -276,6 +276,9 @@ Smart Support 是一个 AI 客服行动层框架。核心价值主张: "粘贴
## Phase 2: 多 Agent 路由 + 安全层 (第 3-4 周)
> Status: COMPLETED (2026-03-30)
> Dev log: [Phase 2 Dev Log](phases/phase-2-dev-log.md)
### 目标
完善 Supervisor 的意图分类和多 Agent 路由能力, 实现 Webhook 升级、垂直行业模板、中断超时处理。
@@ -289,22 +292,22 @@ Smart Support 是一个 AI 客服行动层框架。核心价值主张: "粘贴
#### 2.1 Supervisor 路由增强 (预计 2 天)
- [ ] **2.1.1** 实现 LLM 结构化输出的意图分类 (基于 Agent 描述选择)
- [x] **2.1.1** 实现 LLM 结构化输出的意图分类 (基于 Agent 描述选择)
- 文件: `backend/app/graph.py` (增强)
- 工作量: M (4 小时)
- 依赖: Phase 1 完成
- 风险: 中 -- 路由准确率需要评估
- [ ] **2.1.2** 实现多意图请求处理 ("取消订单并给我一个折扣" -> 顺序执行)
- [x] **2.1.2** 实现多意图请求处理 ("取消订单并给我一个折扣" -> 顺序执行)
- 文件: `backend/app/graph.py` (增强)
- 工作量: M (6 小时)
- 依赖: 2.1.1
- 风险: 高 -- 多意图原子性问题 (全部成功 vs. 部分失败升级)
- [ ] **2.1.3** 实现歧义意图处理 (无法分类时询问澄清问题)
- [x] **2.1.3** 实现歧义意图处理 (无法分类时询问澄清问题)
- 文件: `backend/app/agents/fallback.py` (增强)
- 工作量: S (2 小时)
- 依赖: 2.1.1
- 风险: 低
- [ ] **2.1.4** 编写路由测试 (正确路由、多意图、歧义、回退)
- [x] **2.1.4** 编写路由测试 (正确路由、多意图、歧义、回退)
- 文件: `backend/tests/test_routing.py`
- 工作量: M (4 小时)
- 依赖: 2.1.1, 2.1.2, 2.1.3
@@ -312,12 +315,12 @@ Smart Support 是一个 AI 客服行动层框架。核心价值主张: "粘贴
#### 2.2 Mock 折扣 Agent (预计 0.5 天)
- [ ] **2.2.1** 创建 Mock 折扣 Agent + 工具 (apply_discount, generate_coupon)
- [x] **2.2.1** 创建 Mock 折扣 Agent + 工具 (apply_discount, generate_coupon)
- 文件: `backend/app/agents/discount.py`
- 工作量: S (2 小时)
- 依赖: Phase 1
- 风险: 低
- [ ] **2.2.2** 更新 agents.yaml 添加折扣 Agent 配置
- [x] **2.2.2** 更新 agents.yaml 添加折扣 Agent 配置
- 文件: `backend/agents.yaml`
- 工作量: S (30 分钟)
- 依赖: 2.2.1
@@ -325,17 +328,17 @@ Smart Support 是一个 AI 客服行动层框架。核心价值主张: "粘贴
#### 2.3 中断超时处理 (预计 1 天)
- [ ] **2.3.1** 实现 30 分钟 TTL 自动取消机制
- [x] **2.3.1** 实现 30 分钟 TTL 自动取消机制
- 文件: `backend/app/interrupt_manager.py`
- 工作量: M (4 小时)
- 依赖: Phase 1 (interrupt 基础)
- 风险: 中 -- 定时器精度和状态一致性
- [ ] **2.3.2** 实现过期后重试提示 (重新评估当前状态后重新发起)
- [x] **2.3.2** 实现过期后重试提示 (重新评估当前状态后重新发起)
- 文件: `backend/app/interrupt_manager.py` (扩展)
- 工作量: M (3 小时)
- 依赖: 2.3.1
- 风险: 中
- [ ] **2.3.3** 编写中断超时测试
- [x] **2.3.3** 编写中断超时测试
- 文件: `backend/tests/test_interrupt.py`
- 工作量: S (2 小时)
- 依赖: 2.3.1, 2.3.2
@@ -343,17 +346,17 @@ Smart Support 是一个 AI 客服行动层框架。核心价值主张: "粘贴
#### 2.4 Webhook 升级 (预计 1 天)
- [ ] **2.4.1** 实现 Webhook 升级模块 (HTTP POST 到配置的 URL, 包含完整对话上下文)
- [x] **2.4.1** 实现 Webhook 升级模块 (HTTP POST 到配置的 URL, 包含完整对话上下文)
- 文件: `backend/app/escalation.py`
- 工作量: M (3 小时)
- 依赖: Phase 1
- 风险: 低
- [ ] **2.4.2** 实现 Webhook 重试机制 (指数退避, 最多 3 次)
- [x] **2.4.2** 实现 Webhook 重试机制 (指数退避, 最多 3 次)
- 文件: `backend/app/escalation.py` (扩展)
- 工作量: S (2 小时)
- 依赖: 2.4.1
- 风险: 低
- [ ] **2.4.3** 编写 Webhook 测试 (成功发送、目标不可达、重试)
- [x] **2.4.3** 编写 Webhook 测试 (成功发送、目标不可达、重试)
- 文件: `backend/tests/test_escalation.py`
- 工作量: S (2 小时)
- 依赖: 2.4.1, 2.4.2
@@ -361,22 +364,22 @@ Smart Support 是一个 AI 客服行动层框架。核心价值主张: "粘贴
#### 2.5 垂直行业模板 (预计 0.5 天)
- [ ] **2.5.1** 创建电商模板 YAML (订单查询、订单操作、折扣)
- [x] **2.5.1** 创建电商模板 YAML (订单查询、订单操作、折扣)
- 文件: `backend/templates/e-commerce.yaml`
- 工作量: S (1 小时)
- 依赖: 1.2.2
- 风险: 低
- [ ] **2.5.2** 创建 SaaS 模板 YAML (账户查询、订阅管理、计费)
- [x] **2.5.2** 创建 SaaS 模板 YAML (账户查询、订阅管理、计费)
- 文件: `backend/templates/saas.yaml`
- 工作量: S (1 小时)
- 依赖: 1.2.2
- 风险: 低
- [ ] **2.5.3** 创建 Fintech 模板 YAML (交易查询、争议处理)
- [x] **2.5.3** 创建 Fintech 模板 YAML (交易查询、争议处理)
- 文件: `backend/templates/fintech.yaml`
- 工作量: S (1 小时)
- 依赖: 1.2.2
- 风险: 低
- [ ] **2.5.4** 实现模板加载逻辑 (选择模板 -> 覆盖 agents.yaml)
- [x] **2.5.4** 实现模板加载逻辑 (选择模板 -> 覆盖 agents.yaml)
- 文件: `backend/app/registry.py` (扩展)
- 工作量: S (2 小时)
- 依赖: 2.5.1, 2.5.2, 2.5.3

View File

@@ -0,0 +1,76 @@
# Phase 2: Multi-Agent Routing + Safety Layer -- Development Log
> Status: COMPLETED
> Phase branch: `phase-2/multi-agent-safety`
> Date started: 2026-03-30
> Date completed: 2026-03-30
> Related plan section: [Phase 2 in DEVELOPMENT-PLAN](../DEVELOPMENT-PLAN.md#phase-2-多-agent-路由--安全层-第-3-4-周)
## What Was Built
- **Intent Classification** (`app/intent.py`): LLM structured output-based intent classifier with Pydantic models (`IntentTarget`, `ClassificationResult`). Supports single-intent, multi-intent, and ambiguity detection with configurable confidence threshold.
- **Discount Agent** (`app/agents/discount.py`): Mock agent with `apply_discount` (write + interrupt) and `generate_coupon` (read) tools. Validates discount range (1-100%).
- **Interrupt Manager** (`app/interrupt_manager.py`): TTL-based interrupt tracking with 30-minute auto-expiration. Provides `register`, `check_status`, `resolve`, `cleanup_expired`, and `generate_retry_prompt` methods. Complements SessionManager.
- **Webhook Escalation** (`app/escalation.py`): HTTP POST escalation with exponential backoff retry (max 3 attempts). Includes `WebhookEscalator` and `NoOpEscalator` implementations behind `EscalationService` protocol.
- **Enhanced Supervisor Routing** (`app/graph.py`): Supervisor prompt now includes dynamic agent descriptions. Intent classifier attached to graph for use by ws_handler routing layer. Multi-intent hint injection for sequential execution.
- **Vertical Templates**: Three industry YAML templates (e-commerce, SaaS, fintech) in `backend/templates/`.
- **Template Loading** (`app/registry.py`): `load_template()` and `list_templates()` class methods for template-based agent configuration.
- **WebSocket Integration** (`app/ws_handler.py`): Ambiguous intent sends clarification message. Interrupt TTL checked before resume -- expired interrupts return retry prompt. Interrupt manager registration on interrupt detection.
## Code Structure
New files:
- `backend/app/intent.py` -- Intent classification models and LLM classifier
- `backend/app/agents/discount.py` -- Discount agent tools
- `backend/app/interrupt_manager.py` -- Interrupt TTL management
- `backend/app/escalation.py` -- Webhook escalation with retry
- `backend/templates/e-commerce.yaml` -- E-commerce agent template
- `backend/templates/saas.yaml` -- SaaS agent template
- `backend/templates/fintech.yaml` -- Fintech agent template
Modified files:
- `backend/app/graph.py` -- Intent classifier integration, dynamic supervisor prompt
- `backend/app/agents/__init__.py` -- Registered discount tools
- `backend/app/agents/fallback.py` -- Updated capability list
- `backend/app/registry.py` -- Template loading methods
- `backend/app/config.py` -- Webhook, template settings
- `backend/app/ws_handler.py` -- Interrupt manager + intent classification integration
- `backend/app/main.py` -- Wiring new modules, template loading, version bump to 0.2.0
- `backend/agents.yaml` -- Added discount agent
- `backend/pyproject.toml` -- Added httpx to main dependencies
Test files added:
- `tests/unit/test_intent.py` (11 tests)
- `tests/unit/test_discount.py` (13 tests)
- `tests/unit/test_interrupt_manager.py` (14 tests)
- `tests/unit/test_escalation.py` (11 tests)
- `tests/unit/test_templates.py` (9 tests)
Test files updated:
- `tests/unit/test_graph.py` -- Tests for classifier attachment and classify_intent
- `tests/unit/test_ws_handler.py` -- Tests for interrupt manager and clarification flow
- `tests/unit/test_main.py` -- Updated version check
## Test Coverage
- Total tests: 153 (87 Phase 1 + 66 Phase 2)
- Overall coverage: 90.18%
- New module coverage:
- intent.py: 100%
- discount.py: 96%
- interrupt_manager.py: 100%
- escalation.py: 100%
- graph.py: 100%
- registry.py: 97%
## Deviations from Plan
- Multi-intent handling uses supervisor prompt hint injection rather than a fully custom pre-routing graph node. This is simpler and leverages the existing `langgraph-supervisor` routing rather than fighting it.
- Webhook escalation is wired to main.py app.state but not yet connected to a specific agent tool (escalation trigger). The module is ready for use -- integration with fallback agent's escalation path is straightforward but deferred to avoid scope creep.
- The `escalate_to_human` tool mentioned in the plan was not created. The escalation module works standalone and can be triggered from ws_handler or agent tools in Phase 5.
## Known Issues / Tech Debt
- SaaS and fintech templates reference tool names (`get_account_status`, `change_plan`, etc.) that don't have implementations. These are configuration blueprints for future use.
- Interrupt manager cleanup is not called on a schedule -- `cleanup_expired()` exists but no periodic task invokes it. Consider adding a background task in Phase 5.
- `main.py` coverage is 44% due to lifespan requiring real DB connection. Integration tests would cover this.