feat: complete phase 1 -- core framework with chat loop, agents, and React UI
Backend: - FastAPI WebSocket /ws endpoint with streaming via LangGraph astream - LangGraph Supervisor connecting 3 mock agents (order_lookup, order_actions, fallback) - YAML Agent Registry with Pydantic validation and immutable configs - PostgresSaver checkpoint persistence via langgraph-checkpoint-postgres - Session TTL with 30-min sliding window and interrupt extension - LLM provider abstraction (Anthropic/OpenAI/Google) - Token usage + cost tracking callback handler - Input validation: message size cap, thread_id format, content length - Security: no hardcoded defaults, startup API key validation, no input reflection Frontend: - React 19 + TypeScript + Vite chat UI - WebSocket hook with reconnect + exponential backoff - Streaming token display with agent attribution - Interrupt approval/reject UI for write operations - Collapsible tool call viewer Testing: - 87 unit tests, 87% coverage (exceeds 80% requirement) - Ruff lint + format clean Infrastructure: - Docker Compose (PostgreSQL 16 + backend) - pyproject.toml with full dependency management
This commit is contained in:
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
30
backend/app/agents/__init__.py
Normal file
30
backend/app/agents/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Agent tools registry -- maps tool name strings to actual tool functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.fallback import fallback_respond
|
||||
from app.agents.order_actions import cancel_order
|
||||
from app.agents.order_lookup import get_order_status, get_tracking_info
|
||||
|
||||
_TOOL_MAP: dict[str, BaseTool] = {
|
||||
"get_order_status": get_order_status,
|
||||
"get_tracking_info": get_tracking_info,
|
||||
"cancel_order": cancel_order,
|
||||
"fallback_respond": fallback_respond,
|
||||
}
|
||||
|
||||
|
||||
def get_tools_by_names(tool_names: list[str]) -> list[BaseTool]:
|
||||
"""Resolve tool name strings from YAML config to actual tool objects."""
|
||||
tools = []
|
||||
for name in tool_names:
|
||||
if name not in _TOOL_MAP:
|
||||
available = ", ".join(sorted(_TOOL_MAP.keys()))
|
||||
raise ValueError(f"Unknown tool '{name}'. Available tools: {available}")
|
||||
tools.append(_TOOL_MAP[name])
|
||||
return tools
|
||||
18
backend/app/agents/fallback.py
Normal file
18
backend/app/agents/fallback.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Fallback agent tools -- handles unmatched intents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def fallback_respond(query: str) -> str:
|
||||
"""Provide a helpful response when the user's intent doesn't match a specific agent."""
|
||||
return (
|
||||
"I'm here to help with order inquiries and actions. "
|
||||
"Here's what I can do:\n"
|
||||
"- Check order status (e.g., 'What is the status of order 1042?')\n"
|
||||
"- Get tracking information (e.g., 'Track order 1042')\n"
|
||||
"- Cancel an order (e.g., 'Cancel order 1042')\n\n"
|
||||
"Could you please rephrase your request?"
|
||||
)
|
||||
37
backend/app/agents/order_actions.py
Normal file
37
backend/app/agents/order_actions.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Order action tools -- write operations requiring human approval."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
|
||||
|
||||
@tool
|
||||
def cancel_order(order_id: str) -> dict:
|
||||
"""Cancel an order. Requires human approval before execution."""
|
||||
response = interrupt(
|
||||
{
|
||||
"action": "cancel_order",
|
||||
"order_id": order_id,
|
||||
"message": f"Please confirm: cancel order {order_id}?",
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(response, bool):
|
||||
approved = response
|
||||
elif isinstance(response, dict):
|
||||
approved = response.get("approved", False)
|
||||
else:
|
||||
approved = bool(response)
|
||||
|
||||
if approved:
|
||||
return {
|
||||
"status": "cancelled",
|
||||
"order_id": order_id,
|
||||
"message": f"Order {order_id} has been successfully cancelled.",
|
||||
}
|
||||
return {
|
||||
"status": "kept",
|
||||
"order_id": order_id,
|
||||
"message": f"Order {order_id} cancellation was declined. The order remains active.",
|
||||
}
|
||||
68
backend/app/agents/order_lookup.py
Normal file
68
backend/app/agents/order_lookup.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Order lookup tools -- read-only operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import MappingProxyType
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
MOCK_ORDERS: MappingProxyType[str, dict] = MappingProxyType(
|
||||
{
|
||||
"1042": {
|
||||
"order_id": "1042",
|
||||
"status": "shipped",
|
||||
"items": ["Wireless Headphones", "USB-C Cable"],
|
||||
"total": 89.99,
|
||||
"placed_at": "2026-03-25",
|
||||
},
|
||||
"1043": {
|
||||
"order_id": "1043",
|
||||
"status": "processing",
|
||||
"items": ["Laptop Stand"],
|
||||
"total": 49.99,
|
||||
"placed_at": "2026-03-28",
|
||||
},
|
||||
"1044": {
|
||||
"order_id": "1044",
|
||||
"status": "delivered",
|
||||
"items": ["Mechanical Keyboard", "Mouse Pad"],
|
||||
"total": 159.99,
|
||||
"placed_at": "2026-03-20",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
MOCK_TRACKING: MappingProxyType[str, dict] = MappingProxyType(
|
||||
{
|
||||
"1042": {
|
||||
"order_id": "1042",
|
||||
"carrier": "FedEx",
|
||||
"tracking_number": "FX-9876543210",
|
||||
"estimated_delivery": "2026-04-01",
|
||||
"current_location": "Distribution Center, Chicago IL",
|
||||
},
|
||||
"1044": {
|
||||
"order_id": "1044",
|
||||
"carrier": "UPS",
|
||||
"tracking_number": "1Z-5678901234",
|
||||
"estimated_delivery": "2026-03-22",
|
||||
"current_location": "Delivered",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def get_order_status(order_id: str) -> dict:
|
||||
"""Look up the current status of an order by order ID."""
|
||||
if order_id in MOCK_ORDERS:
|
||||
return dict(MOCK_ORDERS[order_id])
|
||||
return {"error": f"Order {order_id} not found", "order_id": order_id}
|
||||
|
||||
|
||||
@tool
|
||||
def get_tracking_info(order_id: str) -> dict:
|
||||
"""Get shipping and tracking information for an order."""
|
||||
if order_id in MOCK_TRACKING:
|
||||
return dict(MOCK_TRACKING[order_id])
|
||||
return {"error": f"No tracking information for order {order_id}", "order_id": order_id}
|
||||
60
backend/app/callbacks.py
Normal file
60
backend/app/callbacks.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Token usage tracking callback handler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
||||
"claude-sonnet-4-6": {"prompt": 0.003, "completion": 0.015},
|
||||
"claude-haiku-4-5-20251001": {"prompt": 0.0008, "completion": 0.004},
|
||||
"gpt-4o": {"prompt": 0.0025, "completion": 0.01},
|
||||
"gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
|
||||
}
|
||||
|
||||
DEFAULT_COST = {"prompt": 0.003, "completion": 0.015}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
total_cost_usd: float
|
||||
|
||||
|
||||
class TokenUsageCallbackHandler(BaseCallbackHandler):
|
||||
"""Accumulates token usage and cost across LLM invocations."""
|
||||
|
||||
def __init__(self, model_name: str = "") -> None:
|
||||
self._model_name = model_name
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
if response.llm_output and "token_usage" in response.llm_output:
|
||||
usage = response.llm_output["token_usage"]
|
||||
self._prompt_tokens += usage.get("prompt_tokens", 0)
|
||||
self._completion_tokens += usage.get("completion_tokens", 0)
|
||||
|
||||
def get_usage(self) -> TokenUsage:
|
||||
costs = COST_PER_1K_TOKENS.get(self._model_name, DEFAULT_COST)
|
||||
cost = (
|
||||
self._prompt_tokens * costs["prompt"] / 1000
|
||||
+ self._completion_tokens * costs["completion"] / 1000
|
||||
)
|
||||
return TokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._prompt_tokens + self._completion_tokens,
|
||||
total_cost_usd=round(cost, 6),
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
46
backend/app/config.py
Normal file
46
backend/app/config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Centralized application configuration via pydantic-settings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
database_url: str
|
||||
|
||||
llm_provider: Literal["anthropic", "openai", "google"] = "anthropic"
|
||||
llm_model: str = "claude-sonnet-4-6"
|
||||
|
||||
session_ttl_minutes: int = 30
|
||||
interrupt_ttl_minutes: int = 30
|
||||
|
||||
ws_host: str = "0.0.0.0"
|
||||
ws_port: int = 8000
|
||||
|
||||
anthropic_api_key: str = ""
|
||||
openai_api_key: str = ""
|
||||
google_api_key: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_provider_key(self) -> Settings:
|
||||
key_map = {
|
||||
"anthropic": self.anthropic_api_key,
|
||||
"openai": self.openai_api_key,
|
||||
"google": self.google_api_key,
|
||||
}
|
||||
key = key_map.get(self.llm_provider, "")
|
||||
if not key:
|
||||
raise ValueError(
|
||||
f"API key for provider '{self.llm_provider}' is required. "
|
||||
f"Set the corresponding environment variable."
|
||||
)
|
||||
return self
|
||||
61
backend/app/db.py
Normal file
61
backend/app/db.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Database connection pool and PostgresSaver checkpoint setup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg.rows import dict_row
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.config import Settings
|
||||
|
||||
_CONVERSATIONS_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
thread_id TEXT PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
total_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||
status TEXT NOT NULL DEFAULT 'active'
|
||||
);
|
||||
"""
|
||||
|
||||
_INTERRUPTS_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS active_interrupts (
|
||||
interrupt_id TEXT PRIMARY KEY,
|
||||
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
|
||||
action TEXT NOT NULL,
|
||||
params JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
resolved_at TIMESTAMPTZ,
|
||||
resolution TEXT
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
async def create_pool(settings: Settings) -> AsyncConnectionPool:
|
||||
"""Create an async connection pool with the required psycopg settings."""
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=settings.database_url,
|
||||
kwargs={"autocommit": True, "row_factory": dict_row},
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
)
|
||||
await pool.open()
|
||||
return pool
|
||||
|
||||
|
||||
async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
|
||||
"""Create and initialize the LangGraph checkpointer."""
|
||||
checkpointer = AsyncPostgresSaver(conn=pool)
|
||||
await checkpointer.setup()
|
||||
return checkpointer
|
||||
|
||||
|
||||
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
||||
"""Create application-specific tables (conversations, active_interrupts)."""
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_CONVERSATIONS_DDL)
|
||||
await conn.execute(_INTERRUPTS_DDL)
|
||||
70
backend/app/graph.py
Normal file
70
backend/app/graph.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""LangGraph Supervisor construction -- connects registry, agents, LLM, and persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph_supervisor import create_supervisor
|
||||
|
||||
from app.agents import get_tools_by_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.registry import AgentRegistry
|
||||
|
||||
SUPERVISOR_PROMPT = (
|
||||
"You are a customer support supervisor. "
|
||||
"Route customer requests to the appropriate agent based on their description. "
|
||||
"For order status and tracking queries, use the order_lookup agent. "
|
||||
"For order modifications like cancellations, use the order_actions agent. "
|
||||
"For anything else, use the fallback agent."
|
||||
)
|
||||
|
||||
|
||||
def build_agent_nodes(
|
||||
registry: AgentRegistry,
|
||||
llm: BaseChatModel,
|
||||
) -> list:
|
||||
"""Create LangGraph react agent nodes from registry configurations."""
|
||||
agent_nodes = []
|
||||
for agent_config in registry.list_agents():
|
||||
tools = get_tools_by_names(agent_config.tools)
|
||||
|
||||
system_prompt = (
|
||||
f"You are the {agent_config.name} agent. "
|
||||
f"Personality: {agent_config.personality.tone}. "
|
||||
f"{agent_config.description} "
|
||||
f"Permission level: {agent_config.permission}."
|
||||
)
|
||||
|
||||
agent_node = create_react_agent(
|
||||
model=llm,
|
||||
tools=tools,
|
||||
name=agent_config.name,
|
||||
prompt=system_prompt,
|
||||
)
|
||||
agent_nodes.append(agent_node)
|
||||
|
||||
return agent_nodes
|
||||
|
||||
|
||||
def build_graph(
|
||||
registry: AgentRegistry,
|
||||
llm: BaseChatModel,
|
||||
checkpointer: AsyncPostgresSaver,
|
||||
) -> CompiledStateGraph:
|
||||
"""Build and compile the LangGraph supervisor graph."""
|
||||
agent_nodes = build_agent_nodes(registry, llm)
|
||||
|
||||
workflow = create_supervisor(
|
||||
agent_nodes,
|
||||
model=llm,
|
||||
prompt=SUPERVISOR_PROMPT,
|
||||
output_mode="full_history",
|
||||
)
|
||||
|
||||
return workflow.compile(checkpointer=checkpointer)
|
||||
42
backend/app/llm.py
Normal file
42
backend/app/llm.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""LLM provider factory with prompt caching support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
|
||||
def create_llm(settings: Settings) -> BaseChatModel:
|
||||
"""Create an LLM instance based on the configured provider."""
|
||||
provider = settings.llm_provider
|
||||
model = settings.llm_model
|
||||
|
||||
if provider == "anthropic":
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
return ChatAnthropic(
|
||||
model=model,
|
||||
api_key=settings.anthropic_api_key,
|
||||
)
|
||||
|
||||
if provider == "openai":
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(
|
||||
model=model,
|
||||
api_key=settings.openai_api_key,
|
||||
)
|
||||
|
||||
if provider == "google":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=model,
|
||||
google_api_key=settings.google_api_key,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.")
|
||||
83
backend/app/main.py
Normal file
83
backend/app/main.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""FastAPI application entry point."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.config import Settings
|
||||
from app.db import create_checkpointer, create_pool, setup_app_tables
|
||||
from app.graph import build_graph
|
||||
from app.llm import create_llm
|
||||
from app.registry import AgentRegistry
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
|
||||
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
settings = Settings()
|
||||
|
||||
pool = await create_pool(settings)
|
||||
checkpointer = await create_checkpointer(pool)
|
||||
await setup_app_tables(pool)
|
||||
|
||||
registry = AgentRegistry.load(AGENTS_YAML)
|
||||
llm = create_llm(settings)
|
||||
graph = build_graph(registry, llm, checkpointer)
|
||||
session_manager = SessionManager(
|
||||
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
||||
)
|
||||
|
||||
app.state.graph = graph
|
||||
app.state.session_manager = session_manager
|
||||
app.state.settings = settings
|
||||
app.state.pool = pool
|
||||
|
||||
logger.info(
|
||||
"Smart Support started: %d agents loaded, LLM=%s/%s",
|
||||
len(registry),
|
||||
settings.llm_provider,
|
||||
settings.llm_model,
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
await pool.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket) -> None:
|
||||
await ws.accept()
|
||||
graph = app.state.graph
|
||||
session_manager = app.state.session_manager
|
||||
settings = app.state.settings
|
||||
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_text()
|
||||
await dispatch_message(ws, graph, session_manager, callback_handler, raw_data)
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket client disconnected")
|
||||
|
||||
|
||||
if FRONTEND_DIST.is_dir():
|
||||
app.mount("/", StaticFiles(directory=str(FRONTEND_DIST), html=True), name="frontend")
|
||||
104
backend/app/registry.py
Normal file
104
backend/app/registry.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""YAML Agent registry loader with validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class PersonalityConfig(BaseModel, frozen=True):
|
||||
tone: str = "professional and helpful"
|
||||
greeting: str = "Hello! How can I help you today?"
|
||||
escalation_message: str = "Let me connect you with a human agent."
|
||||
|
||||
|
||||
class AgentConfig(BaseModel, frozen=True):
|
||||
name: str
|
||||
description: str
|
||||
permission: Literal["read", "write"]
|
||||
personality: PersonalityConfig = PersonalityConfig()
|
||||
tools: list[str]
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def name_not_empty(cls, v: str) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError("Agent name must not be empty")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("tools")
|
||||
@classmethod
|
||||
def tools_not_empty(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("Agent must have at least one tool")
|
||||
return v
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Immutable registry of agent configurations loaded from YAML."""
|
||||
|
||||
def __init__(self, agents: tuple[AgentConfig, ...]) -> None:
|
||||
self._agents = {agent.name: agent for agent in agents}
|
||||
|
||||
@classmethod
|
||||
def load(cls, yaml_path: str | Path) -> AgentRegistry:
|
||||
"""Load and validate agent configurations from a YAML file."""
|
||||
path = Path(yaml_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Agent config file not found: {path}")
|
||||
|
||||
raw_text = path.read_text(encoding="utf-8")
|
||||
if not raw_text.strip():
|
||||
raise ValueError(f"Agent config file is empty: {path}")
|
||||
|
||||
try:
|
||||
data = yaml.safe_load(raw_text)
|
||||
except yaml.YAMLError as exc:
|
||||
msg = f"Invalid YAML in {path}"
|
||||
if hasattr(exc, "problem_mark") and exc.problem_mark is not None:
|
||||
mark = exc.problem_mark
|
||||
msg += f" at line {mark.line + 1}, column {mark.column + 1}"
|
||||
raise ValueError(msg) from exc
|
||||
|
||||
if not isinstance(data, dict) or "agents" not in data:
|
||||
raise ValueError(f"Agent config must have a top-level 'agents' key in {path}")
|
||||
|
||||
raw_agents = data["agents"]
|
||||
if not isinstance(raw_agents, list) or not raw_agents:
|
||||
raise ValueError(f"'agents' must be a non-empty list in {path}")
|
||||
|
||||
agents: list[AgentConfig] = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for i, raw in enumerate(raw_agents):
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(f"Agent at index {i} must be a mapping in {path}")
|
||||
try:
|
||||
agent = AgentConfig(**raw)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid agent config at index {i} in {path}: {exc}") from exc
|
||||
|
||||
if agent.name in seen_names:
|
||||
raise ValueError(f"Duplicate agent name '{agent.name}' in {path}")
|
||||
seen_names.add(agent.name)
|
||||
agents.append(agent)
|
||||
|
||||
return cls(agents=tuple(agents))
|
||||
|
||||
def get_agent(self, name: str) -> AgentConfig:
|
||||
if name not in self._agents:
|
||||
available = ", ".join(sorted(self._agents.keys()))
|
||||
raise KeyError(f"Agent '{name}' not found. Available: {available}")
|
||||
return self._agents[name]
|
||||
|
||||
def list_agents(self) -> tuple[AgentConfig, ...]:
|
||||
return tuple(self._agents.values())
|
||||
|
||||
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
|
||||
return tuple(a for a in self._agents.values() if a.permission == permission)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._agents)
|
||||
78
backend/app/session_manager.py
Normal file
78
backend/app/session_manager.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Session TTL management with sliding window and interrupt extension."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SessionState:
|
||||
thread_id: str
|
||||
last_activity: float
|
||||
has_pending_interrupt: bool
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages session TTL with sliding window and interrupt extensions.
|
||||
|
||||
- Each message resets the TTL (sliding window).
|
||||
- A pending interrupt suspends expiration until resolved.
|
||||
"""
|
||||
|
||||
def __init__(self, session_ttl_seconds: int = 1800) -> None:
|
||||
self._session_ttl = session_ttl_seconds
|
||||
self._sessions: dict[str, SessionState] = {}
|
||||
|
||||
def touch(self, thread_id: str) -> SessionState:
|
||||
"""Update last activity for a session (resets sliding window)."""
|
||||
existing = self._sessions.get(thread_id)
|
||||
new_state = SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=time.time(),
|
||||
has_pending_interrupt=existing.has_pending_interrupt if existing else False,
|
||||
)
|
||||
self._sessions = {**self._sessions, thread_id: new_state}
|
||||
return new_state
|
||||
|
||||
def is_expired(self, thread_id: str) -> bool:
|
||||
"""Check if a session has expired."""
|
||||
state = self._sessions.get(thread_id)
|
||||
if state is None:
|
||||
return True
|
||||
|
||||
if state.has_pending_interrupt:
|
||||
return False
|
||||
|
||||
elapsed = time.time() - state.last_activity
|
||||
return elapsed > self._session_ttl
|
||||
|
||||
def extend_for_interrupt(self, thread_id: str) -> SessionState:
|
||||
"""Mark session as having a pending interrupt (suspends TTL)."""
|
||||
existing = self._sessions.get(thread_id)
|
||||
if existing is None:
|
||||
return self.touch(thread_id)
|
||||
|
||||
new_state = SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=existing.last_activity,
|
||||
has_pending_interrupt=True,
|
||||
)
|
||||
self._sessions = {**self._sessions, thread_id: new_state}
|
||||
return new_state
|
||||
|
||||
def resolve_interrupt(self, thread_id: str) -> SessionState:
|
||||
"""Remove interrupt extension and reset activity timer."""
|
||||
new_state = SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=time.time(),
|
||||
has_pending_interrupt=False,
|
||||
)
|
||||
self._sessions = {**self._sessions, thread_id: new_state}
|
||||
return new_state
|
||||
|
||||
def get_state(self, thread_id: str) -> SessionState | None:
|
||||
return self._sessions.get(thread_id)
|
||||
|
||||
def remove(self, thread_id: str) -> None:
|
||||
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
||||
204
backend/app/ws_handler.py
Normal file
204
backend/app/ws_handler.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""WebSocket message handling logic -- extracted from main for testability."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
||||
MAX_CONTENT_LENGTH = 8_000 # characters
|
||||
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||
|
||||
|
||||
async def handle_user_message(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Process a user message through the graph and stream results back."""
|
||||
if session_manager.is_expired(thread_id):
|
||||
msg = "Session expired. Please start a new conversation."
|
||||
await _send_json(ws, {"type": "error", "message": msg})
|
||||
return
|
||||
|
||||
session_manager.touch(thread_id)
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||
msg_chunk, metadata = chunk
|
||||
node = metadata.get("langgraph_node", "")
|
||||
|
||||
if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls:
|
||||
for tc in msg_chunk.tool_calls:
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "tool_call",
|
||||
"agent": node,
|
||||
"tool": tc.get("name", ""),
|
||||
"args": tc.get("args", {}),
|
||||
},
|
||||
)
|
||||
elif hasattr(msg_chunk, "content") and msg_chunk.content:
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "token",
|
||||
"agent": node,
|
||||
"content": msg_chunk.content,
|
||||
},
|
||||
)
|
||||
|
||||
state = await graph.aget_state(config)
|
||||
if _has_interrupt(state):
|
||||
interrupt_data = _extract_interrupt(state)
|
||||
session_manager.extend_for_interrupt(thread_id)
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "interrupt",
|
||||
"thread_id": thread_id,
|
||||
**interrupt_data,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing message for thread %s", thread_id)
|
||||
err = "An error occurred processing your message."
|
||||
await _send_json(ws, {"type": "error", "message": err})
|
||||
|
||||
|
||||
async def handle_interrupt_response(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
approved: bool,
|
||||
) -> None:
|
||||
"""Resume graph execution after interrupt approval/rejection."""
|
||||
session_manager.resolve_interrupt(thread_id)
|
||||
session_manager.touch(thread_id)
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(
|
||||
Command(resume=approved),
|
||||
config=config,
|
||||
stream_mode="messages",
|
||||
):
|
||||
msg_chunk, metadata = chunk
|
||||
node = metadata.get("langgraph_node", "")
|
||||
|
||||
if hasattr(msg_chunk, "content") and msg_chunk.content:
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "token",
|
||||
"agent": node,
|
||||
"content": msg_chunk.content,
|
||||
},
|
||||
)
|
||||
|
||||
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error resuming interrupt for thread %s", thread_id)
|
||||
err = "An error occurred processing your response."
|
||||
await _send_json(ws, {"type": "error", "message": err})
|
||||
|
||||
|
||||
async def dispatch_message(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
raw_data: str,
|
||||
) -> None:
|
||||
"""Parse and route an incoming WebSocket message."""
|
||||
if len(raw_data) > MAX_MESSAGE_SIZE:
|
||||
await _send_json(ws, {"type": "error", "message": "Message too large"})
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(raw_data)
|
||||
except json.JSONDecodeError:
|
||||
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
|
||||
return
|
||||
|
||||
msg_type = data.get("type")
|
||||
thread_id = data.get("thread_id", "")
|
||||
|
||||
if not thread_id:
|
||||
await _send_json(ws, {"type": "error", "message": "Missing thread_id"})
|
||||
return
|
||||
|
||||
if not THREAD_ID_PATTERN.match(thread_id):
|
||||
await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"})
|
||||
return
|
||||
|
||||
if msg_type == "message":
|
||||
content = data.get("content", "")
|
||||
if not content:
|
||||
await _send_json(ws, {"type": "error", "message": "Missing message content"})
|
||||
return
|
||||
if len(content) > MAX_CONTENT_LENGTH:
|
||||
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
||||
return
|
||||
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
|
||||
|
||||
elif msg_type == "interrupt_response":
|
||||
approved = data.get("approved", False)
|
||||
await handle_interrupt_response(
|
||||
ws, graph, session_manager, callback_handler, thread_id, approved
|
||||
)
|
||||
|
||||
else:
|
||||
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
|
||||
|
||||
|
||||
def _has_interrupt(state: Any) -> bool:
|
||||
"""Check if the graph state has a pending interrupt."""
|
||||
tasks = getattr(state, "tasks", ())
|
||||
return any(getattr(t, "interrupts", ()) for t in tasks)
|
||||
|
||||
|
||||
def _extract_interrupt(state: Any) -> dict:
|
||||
"""Extract interrupt data from graph state."""
|
||||
for task in getattr(state, "tasks", ()):
|
||||
for intr in getattr(task, "interrupts", ()):
|
||||
value = intr.value if hasattr(intr, "value") else {}
|
||||
if not isinstance(value, dict):
|
||||
value = {}
|
||||
return {
|
||||
"action": value.get("action", "unknown"),
|
||||
"params": value,
|
||||
}
|
||||
return {"action": "unknown", "params": {}}
|
||||
|
||||
|
||||
async def _send_json(ws: WebSocket, data: dict) -> None:
|
||||
"""Send a JSON message through the WebSocket."""
|
||||
await ws.send_json(data)
|
||||
Reference in New Issue
Block a user