feat: complete phase 2 -- multi-agent routing, interrupt TTL, escalation, templates
- Intent classification with LLM structured output (single/multi/ambiguous) - Discount agent with apply_discount and generate_coupon tools - Interrupt manager with 30-min TTL auto-expiration and retry prompts - Webhook escalation module with exponential backoff retry (max 3) - Three vertical industry templates (e-commerce, SaaS, fintech) - Template loading in AgentRegistry - Enhanced supervisor prompt with dynamic agent descriptions - 153 tests passing, 90.18% coverage
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.discount import apply_discount, generate_coupon
|
||||
from app.agents.fallback import fallback_respond
|
||||
from app.agents.order_actions import cancel_order
|
||||
from app.agents.order_lookup import get_order_status, get_tracking_info
|
||||
@@ -16,6 +17,8 @@ _TOOL_MAP: dict[str, BaseTool] = {
|
||||
"get_tracking_info": get_tracking_info,
|
||||
"cancel_order": cancel_order,
|
||||
"fallback_respond": fallback_respond,
|
||||
"apply_discount": apply_discount,
|
||||
"generate_coupon": generate_coupon,
|
||||
}
|
||||
|
||||
|
||||
|
||||
79
backend/app/agents/discount.py
Normal file
79
backend/app/agents/discount.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Discount agent tools -- apply discounts and generate coupons."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
|
||||
|
||||
@tool
|
||||
def apply_discount(order_id: str, discount_percent: int) -> dict:
|
||||
"""Apply a discount to an order. Requires human approval before execution."""
|
||||
if discount_percent < 1 or discount_percent > 100:
|
||||
return {
|
||||
"status": "error",
|
||||
"order_id": order_id,
|
||||
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
|
||||
}
|
||||
|
||||
response = interrupt(
|
||||
{
|
||||
"action": "apply_discount",
|
||||
"order_id": order_id,
|
||||
"discount_percent": discount_percent,
|
||||
"message": (
|
||||
f"Please confirm: apply {discount_percent}% discount to order {order_id}?"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(response, bool):
|
||||
approved = response
|
||||
elif isinstance(response, dict):
|
||||
approved = response.get("approved", False)
|
||||
else:
|
||||
approved = bool(response)
|
||||
|
||||
if approved:
|
||||
return {
|
||||
"status": "applied",
|
||||
"order_id": order_id,
|
||||
"discount_percent": discount_percent,
|
||||
"message": (
|
||||
f"{discount_percent}% discount applied to order {order_id}."
|
||||
),
|
||||
}
|
||||
return {
|
||||
"status": "declined",
|
||||
"order_id": order_id,
|
||||
"message": f"Discount for order {order_id} was declined.",
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def generate_coupon(discount_percent: int, expiry_days: int = 30) -> dict:
|
||||
"""Generate a coupon code with the specified discount percentage."""
|
||||
if discount_percent < 1 or discount_percent > 100:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
|
||||
}
|
||||
if expiry_days < 1:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Invalid expiry: {expiry_days} days. Must be at least 1.",
|
||||
}
|
||||
|
||||
coupon_code = f"SAVE{discount_percent}-{uuid.uuid4().hex[:8].upper()}"
|
||||
return {
|
||||
"status": "generated",
|
||||
"coupon_code": coupon_code,
|
||||
"discount_percent": discount_percent,
|
||||
"expiry_days": expiry_days,
|
||||
"message": (
|
||||
f"Coupon {coupon_code} generated: {discount_percent}% off, "
|
||||
f"valid for {expiry_days} days."
|
||||
),
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Fallback agent tools -- handles unmatched intents."""
|
||||
"""Fallback agent tools -- handles unmatched intents and clarification requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -13,6 +13,7 @@ def fallback_respond(query: str) -> str:
|
||||
"Here's what I can do:\n"
|
||||
"- Check order status (e.g., 'What is the status of order 1042?')\n"
|
||||
"- Get tracking information (e.g., 'Track order 1042')\n"
|
||||
"- Cancel an order (e.g., 'Cancel order 1042')\n\n"
|
||||
"- Cancel an order (e.g., 'Cancel order 1042')\n"
|
||||
"- Apply discounts or generate coupons\n\n"
|
||||
"Could you please rephrase your request?"
|
||||
)
|
||||
|
||||
@@ -26,6 +26,12 @@ class Settings(BaseSettings):
|
||||
ws_host: str = "0.0.0.0"
|
||||
ws_port: int = 8000
|
||||
|
||||
webhook_url: str = ""
|
||||
webhook_timeout_seconds: int = 10
|
||||
webhook_max_retries: int = 3
|
||||
|
||||
template_name: str = ""
|
||||
|
||||
anthropic_api_key: str = ""
|
||||
openai_api_key: str = ""
|
||||
google_api_key: str = ""
|
||||
|
||||
140
backend/app/escalation.py
Normal file
140
backend/app/escalation.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Webhook escalation module -- HTTP POST with exponential backoff retry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EscalationPayload(BaseModel, frozen=True):
|
||||
"""Immutable payload sent to the escalation webhook."""
|
||||
|
||||
thread_id: str
|
||||
reason: str
|
||||
conversation_summary: str
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EscalationResult:
|
||||
"""Immutable result of an escalation attempt."""
|
||||
|
||||
success: bool
|
||||
status_code: int | None
|
||||
attempts: int
|
||||
error: str | None
|
||||
|
||||
|
||||
class EscalationService(Protocol):
|
||||
"""Protocol for escalation implementations."""
|
||||
|
||||
async def escalate(self, payload: EscalationPayload) -> EscalationResult: ...
|
||||
|
||||
|
||||
class WebhookEscalator:
|
||||
"""Sends escalation requests via HTTP POST with exponential backoff retry."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
timeout_seconds: int = 10,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
self._url = url
|
||||
self._timeout = timeout_seconds
|
||||
self._max_retries = max_retries
|
||||
|
||||
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
|
||||
"""POST the escalation payload to the configured webhook URL."""
|
||||
if not self._url:
|
||||
return EscalationResult(
|
||||
success=False,
|
||||
status_code=None,
|
||||
attempts=0,
|
||||
error="Webhook URL not configured",
|
||||
)
|
||||
|
||||
last_error: str | None = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
response = await client.post(
|
||||
self._url,
|
||||
json=payload.model_dump(),
|
||||
)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
logger.info(
|
||||
"Escalation succeeded for thread %s (attempt %d)",
|
||||
payload.thread_id,
|
||||
attempt,
|
||||
)
|
||||
return EscalationResult(
|
||||
success=True,
|
||||
status_code=response.status_code,
|
||||
attempts=attempt,
|
||||
error=None,
|
||||
)
|
||||
|
||||
last_error = f"HTTP {response.status_code}"
|
||||
logger.warning(
|
||||
"Escalation attempt %d/%d failed: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
last_error,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
last_error = "Request timed out"
|
||||
logger.warning(
|
||||
"Escalation attempt %d/%d timed out",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
last_error = str(exc)
|
||||
logger.warning(
|
||||
"Escalation attempt %d/%d error: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
last_error,
|
||||
)
|
||||
|
||||
# Exponential backoff: skip delay after last attempt
|
||||
if attempt < self._max_retries:
|
||||
delay = 2**attempt
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.error(
|
||||
"Escalation failed for thread %s after %d attempts: %s",
|
||||
payload.thread_id,
|
||||
self._max_retries,
|
||||
last_error,
|
||||
)
|
||||
return EscalationResult(
|
||||
success=False,
|
||||
status_code=None,
|
||||
attempts=self._max_retries,
|
||||
error=last_error,
|
||||
)
|
||||
|
||||
|
||||
class NoOpEscalator:
|
||||
"""Escalator that does nothing -- used when webhook URL is not configured."""
|
||||
|
||||
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
|
||||
logger.info("Escalation disabled (no webhook URL). Thread: %s", payload.thread_id)
|
||||
return EscalationResult(
|
||||
success=False,
|
||||
status_code=None,
|
||||
attempts=0,
|
||||
error="Escalation disabled",
|
||||
)
|
||||
@@ -2,12 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph_supervisor import create_supervisor
|
||||
|
||||
from app.agents import get_tools_by_names
|
||||
from app.intent import ClassificationResult, IntentClassifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@@ -16,15 +18,31 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.registry import AgentRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUPERVISOR_PROMPT = (
|
||||
"You are a customer support supervisor. "
|
||||
"Route customer requests to the appropriate agent based on their description. "
|
||||
"For order status and tracking queries, use the order_lookup agent. "
|
||||
"For order modifications like cancellations, use the order_actions agent. "
|
||||
"For anything else, use the fallback agent."
|
||||
"Route customer requests to the appropriate agent based on their description.\n\n"
|
||||
"Available agents and their roles:\n"
|
||||
"{agent_descriptions}\n\n"
|
||||
"Routing rules:\n"
|
||||
"- For order status and tracking queries, use the order_lookup agent.\n"
|
||||
"- For order modifications like cancellations, use the order_actions agent.\n"
|
||||
"- For discounts, promotions, or coupon codes, use the discount agent.\n"
|
||||
"- For anything else or when uncertain, use the fallback agent.\n"
|
||||
"- If the user's request involves multiple actions, execute them in order.\n"
|
||||
"- If a previous intent classification is provided, follow it.\n"
|
||||
)
|
||||
|
||||
|
||||
def _format_agent_descriptions(registry: AgentRegistry) -> str:
|
||||
"""Build agent description text for the supervisor prompt."""
|
||||
lines = []
|
||||
for agent in registry.list_agents():
|
||||
lines.append(f"- {agent.name}: {agent.description}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_agent_nodes(
|
||||
registry: AgentRegistry,
|
||||
llm: BaseChatModel,
|
||||
@@ -56,15 +74,48 @@ def build_graph(
|
||||
registry: AgentRegistry,
|
||||
llm: BaseChatModel,
|
||||
checkpointer: AsyncPostgresSaver,
|
||||
intent_classifier: IntentClassifier | None = None,
|
||||
) -> CompiledStateGraph:
|
||||
"""Build and compile the LangGraph supervisor graph."""
|
||||
"""Build and compile the LangGraph supervisor graph.
|
||||
|
||||
If an intent_classifier is provided, the supervisor prompt is enhanced
|
||||
with agent descriptions for better routing. The classifier is stored
|
||||
for use by the routing layer (ws_handler).
|
||||
"""
|
||||
agent_nodes = build_agent_nodes(registry, llm)
|
||||
agent_descriptions = _format_agent_descriptions(registry)
|
||||
|
||||
prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions)
|
||||
|
||||
workflow = create_supervisor(
|
||||
agent_nodes,
|
||||
model=llm,
|
||||
prompt=SUPERVISOR_PROMPT,
|
||||
prompt=prompt,
|
||||
output_mode="full_history",
|
||||
)
|
||||
|
||||
return workflow.compile(checkpointer=checkpointer)
|
||||
graph = workflow.compile(checkpointer=checkpointer)
|
||||
|
||||
# Attach classifier and registry to graph for use by ws_handler
|
||||
graph.intent_classifier = intent_classifier # type: ignore[attr-defined]
|
||||
graph.agent_registry = registry # type: ignore[attr-defined]
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def classify_intent(
|
||||
graph: CompiledStateGraph,
|
||||
message: str,
|
||||
) -> ClassificationResult | None:
|
||||
"""Classify user intent using the graph's attached classifier.
|
||||
|
||||
Returns None if no classifier is configured.
|
||||
"""
|
||||
classifier = getattr(graph, "intent_classifier", None)
|
||||
registry = getattr(graph, "agent_registry", None)
|
||||
|
||||
if classifier is None or registry is None:
|
||||
return None
|
||||
|
||||
agents = registry.list_agents()
|
||||
return await classifier.classify(message, agents)
|
||||
|
||||
118
backend/app/intent.py
Normal file
118
backend/app/intent.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Intent classification using LLM structured output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.registry import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLASSIFICATION_PROMPT = (
|
||||
"You are an intent classifier for a customer support system.\n"
|
||||
"Given a user message, determine which agent(s) should handle it.\n\n"
|
||||
"Available agents:\n{agent_list}\n\n"
|
||||
"Rules:\n"
|
||||
"- If the message clearly maps to one agent, return a single intent.\n"
|
||||
"- If the message contains multiple distinct requests, return multiple intents "
|
||||
"in execution order.\n"
|
||||
"- If the message is vague or doesn't match any agent, set is_ambiguous=True "
|
||||
"and provide a clarification_question.\n"
|
||||
"- Never route to the fallback agent unless truly ambiguous.\n"
|
||||
"- confidence should be between 0.0 and 1.0.\n"
|
||||
)
|
||||
|
||||
AMBIGUITY_THRESHOLD = 0.5
|
||||
|
||||
|
||||
class IntentTarget(BaseModel, frozen=True):
|
||||
"""A single classified intent targeting a specific agent."""
|
||||
|
||||
agent_name: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
|
||||
class ClassificationResult(BaseModel, frozen=True):
|
||||
"""Result of intent classification -- may contain multiple intents."""
|
||||
|
||||
intents: tuple[IntentTarget, ...]
|
||||
is_ambiguous: bool = False
|
||||
clarification_question: str | None = None
|
||||
|
||||
|
||||
class IntentClassifier(Protocol):
|
||||
"""Protocol for intent classification implementations."""
|
||||
|
||||
async def classify(
|
||||
self,
|
||||
message: str,
|
||||
available_agents: tuple[AgentConfig, ...],
|
||||
) -> ClassificationResult: ...
|
||||
|
||||
|
||||
def _build_agent_list(agents: tuple[AgentConfig, ...]) -> str:
|
||||
"""Format agent descriptions for the classification prompt."""
|
||||
lines = []
|
||||
for agent in agents:
|
||||
lines.append(f"- {agent.name}: {agent.description} (permission: {agent.permission})")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class LLMIntentClassifier:
|
||||
"""Classifies user intent using LLM structured output."""
|
||||
|
||||
def __init__(self, llm: BaseChatModel) -> None:
|
||||
self._llm = llm
|
||||
|
||||
async def classify(
|
||||
self,
|
||||
message: str,
|
||||
available_agents: tuple[AgentConfig, ...],
|
||||
) -> ClassificationResult:
|
||||
"""Classify user message into one or more agent intents."""
|
||||
agent_list = _build_agent_list(available_agents)
|
||||
system_prompt = CLASSIFICATION_PROMPT.format(agent_list=agent_list)
|
||||
|
||||
structured_llm = self._llm.with_structured_output(ClassificationResult)
|
||||
|
||||
try:
|
||||
result = await structured_llm.ainvoke(
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Intent classification failed, returning ambiguous")
|
||||
return ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question="I'm not sure I understood. Could you please rephrase?",
|
||||
)
|
||||
|
||||
if not isinstance(result, ClassificationResult):
|
||||
return ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question="I'm not sure I understood. Could you please rephrase?",
|
||||
)
|
||||
|
||||
# Apply ambiguity threshold
|
||||
if result.intents and all(i.confidence < AMBIGUITY_THRESHOLD for i in result.intents):
|
||||
return ClassificationResult(
|
||||
intents=result.intents,
|
||||
is_ambiguous=True,
|
||||
clarification_question=(
|
||||
result.clarification_question
|
||||
or "I'm not sure I understood. Could you please rephrase?"
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
115
backend/app/interrupt_manager.py
Normal file
115
backend/app/interrupt_manager.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InterruptRecord:
|
||||
"""Immutable record of a pending interrupt."""
|
||||
|
||||
interrupt_id: str
|
||||
thread_id: str
|
||||
action: str
|
||||
params: dict
|
||||
created_at: float
|
||||
ttl_seconds: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InterruptStatus:
|
||||
"""Current status of a tracked interrupt."""
|
||||
|
||||
is_expired: bool
|
||||
remaining_seconds: float
|
||||
record: InterruptRecord
|
||||
|
||||
|
||||
class InterruptManager:
|
||||
"""Manages interrupt TTL with auto-expiration.
|
||||
|
||||
Complements SessionManager -- this tracks interrupt-specific TTL
|
||||
while SessionManager handles session-level TTL.
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 1800) -> None:
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._interrupts: dict[str, InterruptRecord] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
thread_id: str,
|
||||
action: str,
|
||||
params: dict,
|
||||
) -> InterruptRecord:
|
||||
"""Register a new pending interrupt with TTL tracking."""
|
||||
record = InterruptRecord(
|
||||
interrupt_id=uuid.uuid4().hex,
|
||||
thread_id=thread_id,
|
||||
action=action,
|
||||
params=dict(params),
|
||||
created_at=time.time(),
|
||||
ttl_seconds=self._ttl_seconds,
|
||||
)
|
||||
self._interrupts = {**self._interrupts, thread_id: record}
|
||||
return record
|
||||
|
||||
def check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||
"""Check the TTL status of a pending interrupt."""
|
||||
record = self._interrupts.get(thread_id)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
elapsed = time.time() - record.created_at
|
||||
remaining = max(0.0, record.ttl_seconds - elapsed)
|
||||
is_expired = elapsed > record.ttl_seconds
|
||||
|
||||
return InterruptStatus(
|
||||
is_expired=is_expired,
|
||||
remaining_seconds=remaining,
|
||||
record=record,
|
||||
)
|
||||
|
||||
def resolve(self, thread_id: str) -> None:
|
||||
"""Remove a resolved interrupt from tracking."""
|
||||
self._interrupts = {
|
||||
k: v for k, v in self._interrupts.items() if k != thread_id
|
||||
}
|
||||
|
||||
def cleanup_expired(self) -> tuple[InterruptRecord, ...]:
|
||||
"""Find and remove all expired interrupts. Returns the expired records."""
|
||||
now = time.time()
|
||||
expired: list[InterruptRecord] = []
|
||||
active: dict[str, InterruptRecord] = {}
|
||||
|
||||
for thread_id, record in self._interrupts.items():
|
||||
if now - record.created_at > record.ttl_seconds:
|
||||
expired.append(record)
|
||||
else:
|
||||
active[thread_id] = record
|
||||
|
||||
self._interrupts = active
|
||||
return tuple(expired)
|
||||
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||
return {
|
||||
"type": "interrupt_expired",
|
||||
"thread_id": expired_record.thread_id,
|
||||
"action": expired_record.action,
|
||||
"message": (
|
||||
f"The approval request for '{expired_record.action}' has expired "
|
||||
f"after {expired_record.ttl_seconds // 60} minutes. "
|
||||
f"Would you like to try again?"
|
||||
),
|
||||
}
|
||||
|
||||
def has_pending(self, thread_id: str) -> bool:
|
||||
"""Check if a thread has a pending (non-expired) interrupt."""
|
||||
status = self.check_status(thread_id)
|
||||
if status is None:
|
||||
return False
|
||||
return not status.is_expired
|
||||
@@ -13,7 +13,10 @@ from fastapi.staticfiles import StaticFiles
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.config import Settings
|
||||
from app.db import create_checkpointer, create_pool, setup_app_tables
|
||||
from app.escalation import NoOpEscalator, WebhookEscalator
|
||||
from app.graph import build_graph
|
||||
from app.intent import LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.llm import create_llm
|
||||
from app.registry import AgentRegistry
|
||||
from app.session_manager import SessionManager
|
||||
@@ -36,23 +39,46 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
checkpointer = await create_checkpointer(pool)
|
||||
await setup_app_tables(pool)
|
||||
|
||||
registry = AgentRegistry.load(AGENTS_YAML)
|
||||
# Load agents from template or default YAML
|
||||
if settings.template_name:
|
||||
registry = AgentRegistry.load_template(settings.template_name)
|
||||
else:
|
||||
registry = AgentRegistry.load(AGENTS_YAML)
|
||||
|
||||
llm = create_llm(settings)
|
||||
graph = build_graph(registry, llm, checkpointer)
|
||||
intent_classifier = LLMIntentClassifier(llm)
|
||||
graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
||||
|
||||
session_manager = SessionManager(
|
||||
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
||||
)
|
||||
interrupt_manager = InterruptManager(
|
||||
ttl_seconds=settings.interrupt_ttl_minutes * 60,
|
||||
)
|
||||
|
||||
# Configure escalation
|
||||
if settings.webhook_url:
|
||||
escalator = WebhookEscalator(
|
||||
url=settings.webhook_url,
|
||||
timeout_seconds=settings.webhook_timeout_seconds,
|
||||
max_retries=settings.webhook_max_retries,
|
||||
)
|
||||
else:
|
||||
escalator = NoOpEscalator()
|
||||
|
||||
app.state.graph = graph
|
||||
app.state.session_manager = session_manager
|
||||
app.state.interrupt_manager = interrupt_manager
|
||||
app.state.escalator = escalator
|
||||
app.state.settings = settings
|
||||
app.state.pool = pool
|
||||
|
||||
logger.info(
|
||||
"Smart Support started: %d agents loaded, LLM=%s/%s",
|
||||
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
|
||||
len(registry),
|
||||
settings.llm_provider,
|
||||
settings.llm_model,
|
||||
settings.template_name or "(default)",
|
||||
)
|
||||
|
||||
yield
|
||||
@@ -60,7 +86,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await pool.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan)
|
||||
app = FastAPI(title="Smart Support", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
@@ -68,13 +94,17 @@ async def websocket_endpoint(ws: WebSocket) -> None:
|
||||
await ws.accept()
|
||||
graph = app.state.graph
|
||||
session_manager = app.state.session_manager
|
||||
interrupt_manager = app.state.interrupt_manager
|
||||
settings = app.state.settings
|
||||
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_text()
|
||||
await dispatch_message(ws, graph, session_manager, callback_handler, raw_data)
|
||||
await dispatch_message(
|
||||
ws, graph, session_manager, callback_handler, raw_data,
|
||||
interrupt_manager=interrupt_manager,
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket client disconnected")
|
||||
|
||||
|
||||
@@ -100,5 +100,41 @@ class AgentRegistry:
|
||||
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
|
||||
return tuple(a for a in self._agents.values() if a.permission == permission)
|
||||
|
||||
@classmethod
|
||||
def load_template(
|
||||
cls,
|
||||
template_name: str,
|
||||
templates_dir: str | Path | None = None,
|
||||
) -> AgentRegistry:
|
||||
"""Load agent configurations from a named template."""
|
||||
if templates_dir is None:
|
||||
templates_dir = Path(__file__).parent.parent / "templates"
|
||||
templates_dir = Path(templates_dir)
|
||||
|
||||
yaml_path = templates_dir / f"{template_name}.yaml"
|
||||
if not yaml_path.exists():
|
||||
available = cls.list_templates(templates_dir)
|
||||
raise FileNotFoundError(
|
||||
f"Template '{template_name}' not found. "
|
||||
f"Available: {', '.join(available) if available else 'none'}"
|
||||
)
|
||||
return cls.load(yaml_path)
|
||||
|
||||
@classmethod
|
||||
def list_templates(
|
||||
cls,
|
||||
templates_dir: str | Path | None = None,
|
||||
) -> tuple[str, ...]:
|
||||
"""List available template names from the templates directory."""
|
||||
if templates_dir is None:
|
||||
templates_dir = Path(__file__).parent.parent / "templates"
|
||||
templates_dir = Path(templates_dir)
|
||||
|
||||
if not templates_dir.is_dir():
|
||||
return ()
|
||||
return tuple(
|
||||
sorted(p.stem for p in templates_dir.glob("*.yaml"))
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._agents)
|
||||
|
||||
@@ -10,11 +10,14 @@ from typing import TYPE_CHECKING, Any
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.graph import classify_intent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,6 +34,7 @@ async def handle_user_message(
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
content: str,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Process a user message through the graph and stream results back."""
|
||||
if session_manager.is_expired(thread_id):
|
||||
@@ -39,8 +43,42 @@ async def handle_user_message(
|
||||
return
|
||||
|
||||
session_manager.touch(thread_id)
|
||||
|
||||
# Run intent classification if available (for logging/future multi-intent)
|
||||
classification = await classify_intent(graph, content)
|
||||
if classification is not None:
|
||||
logger.info(
|
||||
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
||||
thread_id,
|
||||
classification.is_ambiguous,
|
||||
[i.agent_name for i in classification.intents],
|
||||
)
|
||||
|
||||
# If ambiguous, send clarification and return
|
||||
if classification.is_ambiguous and classification.clarification_question:
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "clarification",
|
||||
"thread_id": thread_id,
|
||||
"message": classification.clarification_question,
|
||||
},
|
||||
)
|
||||
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||
return
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||
|
||||
# If multi-intent detected, add routing hint to the message
|
||||
if classification and len(classification.intents) > 1:
|
||||
agent_names = [i.agent_name for i in classification.intents]
|
||||
hint = (
|
||||
f"\n[System: This request involves multiple actions. "
|
||||
f"Execute in order: {', '.join(agent_names)}]"
|
||||
)
|
||||
input_msg = {"messages": [HumanMessage(content=content + hint)]}
|
||||
else:
|
||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||
@@ -72,6 +110,15 @@ async def handle_user_message(
|
||||
if _has_interrupt(state):
|
||||
interrupt_data = _extract_interrupt(state)
|
||||
session_manager.extend_for_interrupt(thread_id)
|
||||
|
||||
# Register interrupt with TTL tracking
|
||||
if interrupt_manager is not None:
|
||||
interrupt_manager.register(
|
||||
thread_id=thread_id,
|
||||
action=interrupt_data.get("action", "unknown"),
|
||||
params=interrupt_data.get("params", {}),
|
||||
)
|
||||
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
@@ -96,8 +143,21 @@ async def handle_interrupt_response(
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
approved: bool,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Resume graph execution after interrupt approval/rejection."""
|
||||
# Check interrupt TTL before resuming
|
||||
if interrupt_manager is not None:
|
||||
status = interrupt_manager.check_status(thread_id)
|
||||
if status is not None and status.is_expired:
|
||||
retry_prompt = interrupt_manager.generate_retry_prompt(status.record)
|
||||
interrupt_manager.resolve(thread_id)
|
||||
session_manager.resolve_interrupt(thread_id)
|
||||
await _send_json(ws, retry_prompt)
|
||||
return
|
||||
|
||||
interrupt_manager.resolve(thread_id)
|
||||
|
||||
session_manager.resolve_interrupt(thread_id)
|
||||
session_manager.touch(thread_id)
|
||||
|
||||
@@ -136,6 +196,7 @@ async def dispatch_message(
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
raw_data: str,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Parse and route an incoming WebSocket message."""
|
||||
if len(raw_data) > MAX_MESSAGE_SIZE:
|
||||
@@ -167,12 +228,16 @@ async def dispatch_message(
|
||||
if len(content) > MAX_CONTENT_LENGTH:
|
||||
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
||||
return
|
||||
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
|
||||
await handle_user_message(
|
||||
ws, graph, session_manager, callback_handler, thread_id, content,
|
||||
interrupt_manager=interrupt_manager,
|
||||
)
|
||||
|
||||
elif msg_type == "interrupt_response":
|
||||
approved = data.get("approved", False)
|
||||
await handle_interrupt_response(
|
||||
ws, graph, session_manager, callback_handler, thread_id, approved
|
||||
ws, graph, session_manager, callback_handler, thread_id, approved,
|
||||
interrupt_manager=interrupt_manager,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user