feat: complete phase 1 -- core framework with chat loop, agents, and React UI
Backend: - FastAPI WebSocket /ws endpoint with streaming via LangGraph astream - LangGraph Supervisor connecting 3 mock agents (order_lookup, order_actions, fallback) - YAML Agent Registry with Pydantic validation and immutable configs - PostgresSaver checkpoint persistence via langgraph-checkpoint-postgres - Session TTL with 30-min sliding window and interrupt extension - LLM provider abstraction (Anthropic/OpenAI/Google) - Token usage + cost tracking callback handler - Input validation: message size cap, thread_id format, content length - Security: no hardcoded defaults, startup API key validation, no input reflection Frontend: - React 19 + TypeScript + Vite chat UI - WebSocket hook with reconnect + exponential backoff - Streaming token display with agent attribution - Interrupt approval/reject UI for write operations - Collapsible tool call viewer Testing: - 87 unit tests, 87% coverage (exceeds 80% requirement) - Ruff lint + format clean Infrastructure: - Docker Compose (PostgreSQL 16 + backend) - pyproject.toml with full dependency management
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -7,10 +7,15 @@ build/
|
|||||||
.venv/
|
.venv/
|
||||||
venv/
|
venv/
|
||||||
.env
|
.env
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
# Node
|
# Node
|
||||||
node_modules/
|
node_modules/
|
||||||
.next/
|
.next/
|
||||||
|
frontend/dist/
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ A checkpoint includes:
|
|||||||
|
|
||||||
| Phase | Branch | Focus | Status |
|
| Phase | Branch | Focus | Status |
|
||||||
|-------|--------|-------|--------|
|
|-------|--------|-------|--------|
|
||||||
| 1 | `phase-1/core-framework` | FastAPI + LangGraph + React chat loop + PostgresSaver | NOT STARTED |
|
| 1 | `phase-1/core-framework` | FastAPI + LangGraph + React chat loop + PostgresSaver | IN PROGRESS |
|
||||||
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | NOT STARTED |
|
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | NOT STARTED |
|
||||||
| 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | NOT STARTED |
|
| 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | NOT STARTED |
|
||||||
| 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | NOT STARTED |
|
| 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | NOT STARTED |
|
||||||
|
|||||||
19
backend/.env.example
Normal file
19
backend/.env.example
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Database
|
||||||
|
DATABASE_URL=postgresql://smart_support:dev_password@localhost:5432/smart_support
|
||||||
|
|
||||||
|
# LLM Provider: anthropic | openai | google
|
||||||
|
LLM_PROVIDER=anthropic
|
||||||
|
LLM_MODEL=claude-sonnet-4-6
|
||||||
|
|
||||||
|
# API Keys (set the one matching your LLM_PROVIDER)
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
GOOGLE_API_KEY=
|
||||||
|
|
||||||
|
# Session
|
||||||
|
SESSION_TTL_MINUTES=30
|
||||||
|
INTERRUPT_TTL_MINUTES=30
|
||||||
|
|
||||||
|
# Server
|
||||||
|
WS_HOST=0.0.0.0
|
||||||
|
WS_PORT=8000
|
||||||
12
backend/Dockerfile
Normal file
12
backend/Dockerfile
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pyproject.toml .
|
||||||
|
RUN pip install --no-cache-dir -e .
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
31
backend/agents.yaml
Normal file
31
backend/agents.yaml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
agents:
|
||||||
|
- name: order_lookup
|
||||||
|
description: "Looks up order status and tracking information. Use for queries about order status, shipping, and delivery."
|
||||||
|
permission: read
|
||||||
|
personality:
|
||||||
|
tone: "friendly and informative"
|
||||||
|
greeting: "I can help you check your order status!"
|
||||||
|
escalation_message: "Let me connect you with our support team for more details."
|
||||||
|
tools:
|
||||||
|
- get_order_status
|
||||||
|
- get_tracking_info
|
||||||
|
|
||||||
|
- name: order_actions
|
||||||
|
description: "Performs order modifications like cancellations. Use when the customer wants to cancel, modify, or change an order."
|
||||||
|
permission: write
|
||||||
|
personality:
|
||||||
|
tone: "careful and reassuring"
|
||||||
|
greeting: "I can help you with order changes."
|
||||||
|
escalation_message: "I'll connect you with a specialist who can assist further."
|
||||||
|
tools:
|
||||||
|
- cancel_order
|
||||||
|
|
||||||
|
- name: fallback
|
||||||
|
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
||||||
|
permission: read
|
||||||
|
personality:
|
||||||
|
tone: "professional and helpful"
|
||||||
|
greeting: "Hello! How can I help you today?"
|
||||||
|
escalation_message: "Let me connect you with a human agent who can better assist you."
|
||||||
|
tools:
|
||||||
|
- fallback_respond
|
||||||
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
30
backend/app/agents/__init__.py
Normal file
30
backend/app/agents/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Agent tools registry -- maps tool name strings to actual tool functions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.fallback import fallback_respond
|
||||||
|
from app.agents.order_actions import cancel_order
|
||||||
|
from app.agents.order_lookup import get_order_status, get_tracking_info
|
||||||
|
|
||||||
|
_TOOL_MAP: dict[str, BaseTool] = {
|
||||||
|
"get_order_status": get_order_status,
|
||||||
|
"get_tracking_info": get_tracking_info,
|
||||||
|
"cancel_order": cancel_order,
|
||||||
|
"fallback_respond": fallback_respond,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tools_by_names(tool_names: list[str]) -> list[BaseTool]:
|
||||||
|
"""Resolve tool name strings from YAML config to actual tool objects."""
|
||||||
|
tools = []
|
||||||
|
for name in tool_names:
|
||||||
|
if name not in _TOOL_MAP:
|
||||||
|
available = ", ".join(sorted(_TOOL_MAP.keys()))
|
||||||
|
raise ValueError(f"Unknown tool '{name}'. Available tools: {available}")
|
||||||
|
tools.append(_TOOL_MAP[name])
|
||||||
|
return tools
|
||||||
18
backend/app/agents/fallback.py
Normal file
18
backend/app/agents/fallback.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""Fallback agent tools -- handles unmatched intents."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def fallback_respond(query: str) -> str:
|
||||||
|
"""Provide a helpful response when the user's intent doesn't match a specific agent."""
|
||||||
|
return (
|
||||||
|
"I'm here to help with order inquiries and actions. "
|
||||||
|
"Here's what I can do:\n"
|
||||||
|
"- Check order status (e.g., 'What is the status of order 1042?')\n"
|
||||||
|
"- Get tracking information (e.g., 'Track order 1042')\n"
|
||||||
|
"- Cancel an order (e.g., 'Cancel order 1042')\n\n"
|
||||||
|
"Could you please rephrase your request?"
|
||||||
|
)
|
||||||
37
backend/app/agents/order_actions.py
Normal file
37
backend/app/agents/order_actions.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Order action tools -- write operations requiring human approval."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langgraph.types import interrupt
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def cancel_order(order_id: str) -> dict:
|
||||||
|
"""Cancel an order. Requires human approval before execution."""
|
||||||
|
response = interrupt(
|
||||||
|
{
|
||||||
|
"action": "cancel_order",
|
||||||
|
"order_id": order_id,
|
||||||
|
"message": f"Please confirm: cancel order {order_id}?",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, bool):
|
||||||
|
approved = response
|
||||||
|
elif isinstance(response, dict):
|
||||||
|
approved = response.get("approved", False)
|
||||||
|
else:
|
||||||
|
approved = bool(response)
|
||||||
|
|
||||||
|
if approved:
|
||||||
|
return {
|
||||||
|
"status": "cancelled",
|
||||||
|
"order_id": order_id,
|
||||||
|
"message": f"Order {order_id} has been successfully cancelled.",
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"status": "kept",
|
||||||
|
"order_id": order_id,
|
||||||
|
"message": f"Order {order_id} cancellation was declined. The order remains active.",
|
||||||
|
}
|
||||||
68
backend/app/agents/order_lookup.py
Normal file
68
backend/app/agents/order_lookup.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""Order lookup tools -- read-only operations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
MOCK_ORDERS: MappingProxyType[str, dict] = MappingProxyType(
|
||||||
|
{
|
||||||
|
"1042": {
|
||||||
|
"order_id": "1042",
|
||||||
|
"status": "shipped",
|
||||||
|
"items": ["Wireless Headphones", "USB-C Cable"],
|
||||||
|
"total": 89.99,
|
||||||
|
"placed_at": "2026-03-25",
|
||||||
|
},
|
||||||
|
"1043": {
|
||||||
|
"order_id": "1043",
|
||||||
|
"status": "processing",
|
||||||
|
"items": ["Laptop Stand"],
|
||||||
|
"total": 49.99,
|
||||||
|
"placed_at": "2026-03-28",
|
||||||
|
},
|
||||||
|
"1044": {
|
||||||
|
"order_id": "1044",
|
||||||
|
"status": "delivered",
|
||||||
|
"items": ["Mechanical Keyboard", "Mouse Pad"],
|
||||||
|
"total": 159.99,
|
||||||
|
"placed_at": "2026-03-20",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
MOCK_TRACKING: MappingProxyType[str, dict] = MappingProxyType(
|
||||||
|
{
|
||||||
|
"1042": {
|
||||||
|
"order_id": "1042",
|
||||||
|
"carrier": "FedEx",
|
||||||
|
"tracking_number": "FX-9876543210",
|
||||||
|
"estimated_delivery": "2026-04-01",
|
||||||
|
"current_location": "Distribution Center, Chicago IL",
|
||||||
|
},
|
||||||
|
"1044": {
|
||||||
|
"order_id": "1044",
|
||||||
|
"carrier": "UPS",
|
||||||
|
"tracking_number": "1Z-5678901234",
|
||||||
|
"estimated_delivery": "2026-03-22",
|
||||||
|
"current_location": "Delivered",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_order_status(order_id: str) -> dict:
|
||||||
|
"""Look up the current status of an order by order ID."""
|
||||||
|
if order_id in MOCK_ORDERS:
|
||||||
|
return dict(MOCK_ORDERS[order_id])
|
||||||
|
return {"error": f"Order {order_id} not found", "order_id": order_id}
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_tracking_info(order_id: str) -> dict:
|
||||||
|
"""Get shipping and tracking information for an order."""
|
||||||
|
if order_id in MOCK_TRACKING:
|
||||||
|
return dict(MOCK_TRACKING[order_id])
|
||||||
|
return {"error": f"No tracking information for order {order_id}", "order_id": order_id}
|
||||||
60
backend/app/callbacks.py
Normal file
60
backend/app/callbacks.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""Token usage tracking callback handler."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
||||||
|
"claude-sonnet-4-6": {"prompt": 0.003, "completion": 0.015},
|
||||||
|
"claude-haiku-4-5-20251001": {"prompt": 0.0008, "completion": 0.004},
|
||||||
|
"gpt-4o": {"prompt": 0.0025, "completion": 0.01},
|
||||||
|
"gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_COST = {"prompt": 0.003, "completion": 0.015}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TokenUsage:
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
total_cost_usd: float
|
||||||
|
|
||||||
|
|
||||||
|
class TokenUsageCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Accumulates token usage and cost across LLM invocations."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "") -> None:
|
||||||
|
self._model_name = model_name
|
||||||
|
self._prompt_tokens = 0
|
||||||
|
self._completion_tokens = 0
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
if response.llm_output and "token_usage" in response.llm_output:
|
||||||
|
usage = response.llm_output["token_usage"]
|
||||||
|
self._prompt_tokens += usage.get("prompt_tokens", 0)
|
||||||
|
self._completion_tokens += usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
def get_usage(self) -> TokenUsage:
|
||||||
|
costs = COST_PER_1K_TOKENS.get(self._model_name, DEFAULT_COST)
|
||||||
|
cost = (
|
||||||
|
self._prompt_tokens * costs["prompt"] / 1000
|
||||||
|
+ self._completion_tokens * costs["completion"] / 1000
|
||||||
|
)
|
||||||
|
return TokenUsage(
|
||||||
|
prompt_tokens=self._prompt_tokens,
|
||||||
|
completion_tokens=self._completion_tokens,
|
||||||
|
total_tokens=self._prompt_tokens + self._completion_tokens,
|
||||||
|
total_cost_usd=round(cost, 6),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._prompt_tokens = 0
|
||||||
|
self._completion_tokens = 0
|
||||||
46
backend/app/config.py
Normal file
46
backend/app/config.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Centralized application configuration via pydantic-settings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import model_validator
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
database_url: str
|
||||||
|
|
||||||
|
llm_provider: Literal["anthropic", "openai", "google"] = "anthropic"
|
||||||
|
llm_model: str = "claude-sonnet-4-6"
|
||||||
|
|
||||||
|
session_ttl_minutes: int = 30
|
||||||
|
interrupt_ttl_minutes: int = 30
|
||||||
|
|
||||||
|
ws_host: str = "0.0.0.0"
|
||||||
|
ws_port: int = 8000
|
||||||
|
|
||||||
|
anthropic_api_key: str = ""
|
||||||
|
openai_api_key: str = ""
|
||||||
|
google_api_key: str = ""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_provider_key(self) -> Settings:
|
||||||
|
key_map = {
|
||||||
|
"anthropic": self.anthropic_api_key,
|
||||||
|
"openai": self.openai_api_key,
|
||||||
|
"google": self.google_api_key,
|
||||||
|
}
|
||||||
|
key = key_map.get(self.llm_provider, "")
|
||||||
|
if not key:
|
||||||
|
raise ValueError(
|
||||||
|
f"API key for provider '{self.llm_provider}' is required. "
|
||||||
|
f"Set the corresponding environment variable."
|
||||||
|
)
|
||||||
|
return self
|
||||||
61
backend/app/db.py
Normal file
61
backend/app/db.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""Database connection pool and PostgresSaver checkpoint setup."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from psycopg.rows import dict_row
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
_CONVERSATIONS_DDL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
total_tokens INTEGER NOT NULL DEFAULT 0,
|
||||||
|
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||||
|
status TEXT NOT NULL DEFAULT 'active'
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
_INTERRUPTS_DDL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS active_interrupts (
|
||||||
|
interrupt_id TEXT PRIMARY KEY,
|
||||||
|
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
params JSONB NOT NULL DEFAULT '{}',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
resolved_at TIMESTAMPTZ,
|
||||||
|
resolution TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def create_pool(settings: Settings) -> AsyncConnectionPool:
|
||||||
|
"""Create an async connection pool with the required psycopg settings."""
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
conninfo=settings.database_url,
|
||||||
|
kwargs={"autocommit": True, "row_factory": dict_row},
|
||||||
|
min_size=2,
|
||||||
|
max_size=10,
|
||||||
|
)
|
||||||
|
await pool.open()
|
||||||
|
return pool
|
||||||
|
|
||||||
|
|
||||||
|
async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
|
||||||
|
"""Create and initialize the LangGraph checkpointer."""
|
||||||
|
checkpointer = AsyncPostgresSaver(conn=pool)
|
||||||
|
await checkpointer.setup()
|
||||||
|
return checkpointer
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
||||||
|
"""Create application-specific tables (conversations, active_interrupts)."""
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
await conn.execute(_CONVERSATIONS_DDL)
|
||||||
|
await conn.execute(_INTERRUPTS_DDL)
|
||||||
70
backend/app/graph.py
Normal file
70
backend/app/graph.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""LangGraph Supervisor construction -- connects registry, agents, LLM, and persistence."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
from langgraph_supervisor import create_supervisor
|
||||||
|
|
||||||
|
from app.agents import get_tools_by_names
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
|
from app.registry import AgentRegistry
|
||||||
|
|
||||||
|
SUPERVISOR_PROMPT = (
|
||||||
|
"You are a customer support supervisor. "
|
||||||
|
"Route customer requests to the appropriate agent based on their description. "
|
||||||
|
"For order status and tracking queries, use the order_lookup agent. "
|
||||||
|
"For order modifications like cancellations, use the order_actions agent. "
|
||||||
|
"For anything else, use the fallback agent."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_agent_nodes(
|
||||||
|
registry: AgentRegistry,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
) -> list:
|
||||||
|
"""Create LangGraph react agent nodes from registry configurations."""
|
||||||
|
agent_nodes = []
|
||||||
|
for agent_config in registry.list_agents():
|
||||||
|
tools = get_tools_by_names(agent_config.tools)
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
f"You are the {agent_config.name} agent. "
|
||||||
|
f"Personality: {agent_config.personality.tone}. "
|
||||||
|
f"{agent_config.description} "
|
||||||
|
f"Permission level: {agent_config.permission}."
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_node = create_react_agent(
|
||||||
|
model=llm,
|
||||||
|
tools=tools,
|
||||||
|
name=agent_config.name,
|
||||||
|
prompt=system_prompt,
|
||||||
|
)
|
||||||
|
agent_nodes.append(agent_node)
|
||||||
|
|
||||||
|
return agent_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def build_graph(
|
||||||
|
registry: AgentRegistry,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
checkpointer: AsyncPostgresSaver,
|
||||||
|
) -> CompiledStateGraph:
|
||||||
|
"""Build and compile the LangGraph supervisor graph."""
|
||||||
|
agent_nodes = build_agent_nodes(registry, llm)
|
||||||
|
|
||||||
|
workflow = create_supervisor(
|
||||||
|
agent_nodes,
|
||||||
|
model=llm,
|
||||||
|
prompt=SUPERVISOR_PROMPT,
|
||||||
|
output_mode="full_history",
|
||||||
|
)
|
||||||
|
|
||||||
|
return workflow.compile(checkpointer=checkpointer)
|
||||||
42
backend/app/llm.py
Normal file
42
backend/app/llm.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""LLM provider factory with prompt caching support."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm(settings: Settings) -> BaseChatModel:
|
||||||
|
"""Create an LLM instance based on the configured provider."""
|
||||||
|
provider = settings.llm_provider
|
||||||
|
model = settings.llm_model
|
||||||
|
|
||||||
|
if provider == "anthropic":
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
|
return ChatAnthropic(
|
||||||
|
model=model,
|
||||||
|
api_key=settings.anthropic_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider == "openai":
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
api_key=settings.openai_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider == "google":
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
|
return ChatGoogleGenerativeAI(
|
||||||
|
model=model,
|
||||||
|
google_api_key=settings.google_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.")
|
||||||
83
backend/app/main.py
Normal file
83
backend/app/main.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""FastAPI application entry point."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.config import Settings
|
||||||
|
from app.db import create_checkpointer, create_pool, setup_app_tables
|
||||||
|
from app.graph import build_graph
|
||||||
|
from app.llm import create_llm
|
||||||
|
from app.registry import AgentRegistry
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
|
||||||
|
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
pool = await create_pool(settings)
|
||||||
|
checkpointer = await create_checkpointer(pool)
|
||||||
|
await setup_app_tables(pool)
|
||||||
|
|
||||||
|
registry = AgentRegistry.load(AGENTS_YAML)
|
||||||
|
llm = create_llm(settings)
|
||||||
|
graph = build_graph(registry, llm, checkpointer)
|
||||||
|
session_manager = SessionManager(
|
||||||
|
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.state.graph = graph
|
||||||
|
app.state.session_manager = session_manager
|
||||||
|
app.state.settings = settings
|
||||||
|
app.state.pool = pool
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Smart Support started: %d agents loaded, LLM=%s/%s",
|
||||||
|
len(registry),
|
||||||
|
settings.llm_provider,
|
||||||
|
settings.llm_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
await pool.close()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(ws: WebSocket) -> None:
|
||||||
|
await ws.accept()
|
||||||
|
graph = app.state.graph
|
||||||
|
session_manager = app.state.session_manager
|
||||||
|
settings = app.state.settings
|
||||||
|
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw_data = await ws.receive_text()
|
||||||
|
await dispatch_message(ws, graph, session_manager, callback_handler, raw_data)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info("WebSocket client disconnected")
|
||||||
|
|
||||||
|
|
||||||
|
if FRONTEND_DIST.is_dir():
|
||||||
|
app.mount("/", StaticFiles(directory=str(FRONTEND_DIST), html=True), name="frontend")
|
||||||
104
backend/app/registry.py
Normal file
104
backend/app/registry.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""YAML Agent registry loader with validation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalityConfig(BaseModel, frozen=True):
|
||||||
|
tone: str = "professional and helpful"
|
||||||
|
greeting: str = "Hello! How can I help you today?"
|
||||||
|
escalation_message: str = "Let me connect you with a human agent."
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel, frozen=True):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
permission: Literal["read", "write"]
|
||||||
|
personality: PersonalityConfig = PersonalityConfig()
|
||||||
|
tools: list[str]
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def name_not_empty(cls, v: str) -> str:
|
||||||
|
if not v.strip():
|
||||||
|
raise ValueError("Agent name must not be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@field_validator("tools")
|
||||||
|
@classmethod
|
||||||
|
def tools_not_empty(cls, v: list[str]) -> list[str]:
|
||||||
|
if not v:
|
||||||
|
raise ValueError("Agent must have at least one tool")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
"""Immutable registry of agent configurations loaded from YAML."""
|
||||||
|
|
||||||
|
def __init__(self, agents: tuple[AgentConfig, ...]) -> None:
|
||||||
|
self._agents = {agent.name: agent for agent in agents}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, yaml_path: str | Path) -> AgentRegistry:
|
||||||
|
"""Load and validate agent configurations from a YAML file."""
|
||||||
|
path = Path(yaml_path)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Agent config file not found: {path}")
|
||||||
|
|
||||||
|
raw_text = path.read_text(encoding="utf-8")
|
||||||
|
if not raw_text.strip():
|
||||||
|
raise ValueError(f"Agent config file is empty: {path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = yaml.safe_load(raw_text)
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
msg = f"Invalid YAML in {path}"
|
||||||
|
if hasattr(exc, "problem_mark") and exc.problem_mark is not None:
|
||||||
|
mark = exc.problem_mark
|
||||||
|
msg += f" at line {mark.line + 1}, column {mark.column + 1}"
|
||||||
|
raise ValueError(msg) from exc
|
||||||
|
|
||||||
|
if not isinstance(data, dict) or "agents" not in data:
|
||||||
|
raise ValueError(f"Agent config must have a top-level 'agents' key in {path}")
|
||||||
|
|
||||||
|
raw_agents = data["agents"]
|
||||||
|
if not isinstance(raw_agents, list) or not raw_agents:
|
||||||
|
raise ValueError(f"'agents' must be a non-empty list in {path}")
|
||||||
|
|
||||||
|
agents: list[AgentConfig] = []
|
||||||
|
seen_names: set[str] = set()
|
||||||
|
|
||||||
|
for i, raw in enumerate(raw_agents):
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise ValueError(f"Agent at index {i} must be a mapping in {path}")
|
||||||
|
try:
|
||||||
|
agent = AgentConfig(**raw)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Invalid agent config at index {i} in {path}: {exc}") from exc
|
||||||
|
|
||||||
|
if agent.name in seen_names:
|
||||||
|
raise ValueError(f"Duplicate agent name '{agent.name}' in {path}")
|
||||||
|
seen_names.add(agent.name)
|
||||||
|
agents.append(agent)
|
||||||
|
|
||||||
|
return cls(agents=tuple(agents))
|
||||||
|
|
||||||
|
def get_agent(self, name: str) -> AgentConfig:
|
||||||
|
if name not in self._agents:
|
||||||
|
available = ", ".join(sorted(self._agents.keys()))
|
||||||
|
raise KeyError(f"Agent '{name}' not found. Available: {available}")
|
||||||
|
return self._agents[name]
|
||||||
|
|
||||||
|
def list_agents(self) -> tuple[AgentConfig, ...]:
|
||||||
|
return tuple(self._agents.values())
|
||||||
|
|
||||||
|
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
|
||||||
|
return tuple(a for a in self._agents.values() if a.permission == permission)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._agents)
|
||||||
78
backend/app/session_manager.py
Normal file
78
backend/app/session_manager.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""Session TTL management with sliding window and interrupt extension."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SessionState:
|
||||||
|
thread_id: str
|
||||||
|
last_activity: float
|
||||||
|
has_pending_interrupt: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""Manages session TTL with sliding window and interrupt extensions.
|
||||||
|
|
||||||
|
- Each message resets the TTL (sliding window).
|
||||||
|
- A pending interrupt suspends expiration until resolved.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_ttl_seconds: int = 1800) -> None:
|
||||||
|
self._session_ttl = session_ttl_seconds
|
||||||
|
self._sessions: dict[str, SessionState] = {}
|
||||||
|
|
||||||
|
def touch(self, thread_id: str) -> SessionState:
|
||||||
|
"""Update last activity for a session (resets sliding window)."""
|
||||||
|
existing = self._sessions.get(thread_id)
|
||||||
|
new_state = SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=time.time(),
|
||||||
|
has_pending_interrupt=existing.has_pending_interrupt if existing else False,
|
||||||
|
)
|
||||||
|
self._sessions = {**self._sessions, thread_id: new_state}
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
def is_expired(self, thread_id: str) -> bool:
|
||||||
|
"""Check if a session has expired."""
|
||||||
|
state = self._sessions.get(thread_id)
|
||||||
|
if state is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if state.has_pending_interrupt:
|
||||||
|
return False
|
||||||
|
|
||||||
|
elapsed = time.time() - state.last_activity
|
||||||
|
return elapsed > self._session_ttl
|
||||||
|
|
||||||
|
def extend_for_interrupt(self, thread_id: str) -> SessionState:
|
||||||
|
"""Mark session as having a pending interrupt (suspends TTL)."""
|
||||||
|
existing = self._sessions.get(thread_id)
|
||||||
|
if existing is None:
|
||||||
|
return self.touch(thread_id)
|
||||||
|
|
||||||
|
new_state = SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=existing.last_activity,
|
||||||
|
has_pending_interrupt=True,
|
||||||
|
)
|
||||||
|
self._sessions = {**self._sessions, thread_id: new_state}
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
def resolve_interrupt(self, thread_id: str) -> SessionState:
|
||||||
|
"""Remove interrupt extension and reset activity timer."""
|
||||||
|
new_state = SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=time.time(),
|
||||||
|
has_pending_interrupt=False,
|
||||||
|
)
|
||||||
|
self._sessions = {**self._sessions, thread_id: new_state}
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
def get_state(self, thread_id: str) -> SessionState | None:
|
||||||
|
return self._sessions.get(thread_id)
|
||||||
|
|
||||||
|
def remove(self, thread_id: str) -> None:
|
||||||
|
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
||||||
204
backend/app/ws_handler.py
Normal file
204
backend/app/ws_handler.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""WebSocket message handling logic -- extracted from main for testability."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastapi import WebSocket
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
||||||
|
MAX_CONTENT_LENGTH = 8_000 # characters
|
||||||
|
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_user_message(
|
||||||
|
ws: WebSocket,
|
||||||
|
graph: CompiledStateGraph,
|
||||||
|
session_manager: SessionManager,
|
||||||
|
callback_handler: TokenUsageCallbackHandler,
|
||||||
|
thread_id: str,
|
||||||
|
content: str,
|
||||||
|
) -> None:
|
||||||
|
"""Process a user message through the graph and stream results back."""
|
||||||
|
if session_manager.is_expired(thread_id):
|
||||||
|
msg = "Session expired. Please start a new conversation."
|
||||||
|
await _send_json(ws, {"type": "error", "message": msg})
|
||||||
|
return
|
||||||
|
|
||||||
|
session_manager.touch(thread_id)
|
||||||
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||||
|
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||||
|
msg_chunk, metadata = chunk
|
||||||
|
node = metadata.get("langgraph_node", "")
|
||||||
|
|
||||||
|
if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls:
|
||||||
|
for tc in msg_chunk.tool_calls:
|
||||||
|
await _send_json(
|
||||||
|
ws,
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"agent": node,
|
||||||
|
"tool": tc.get("name", ""),
|
||||||
|
"args": tc.get("args", {}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif hasattr(msg_chunk, "content") and msg_chunk.content:
|
||||||
|
await _send_json(
|
||||||
|
ws,
|
||||||
|
{
|
||||||
|
"type": "token",
|
||||||
|
"agent": node,
|
||||||
|
"content": msg_chunk.content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
state = await graph.aget_state(config)
|
||||||
|
if _has_interrupt(state):
|
||||||
|
interrupt_data = _extract_interrupt(state)
|
||||||
|
session_manager.extend_for_interrupt(thread_id)
|
||||||
|
await _send_json(
|
||||||
|
ws,
|
||||||
|
{
|
||||||
|
"type": "interrupt",
|
||||||
|
"thread_id": thread_id,
|
||||||
|
**interrupt_data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing message for thread %s", thread_id)
|
||||||
|
err = "An error occurred processing your message."
|
||||||
|
await _send_json(ws, {"type": "error", "message": err})
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_interrupt_response(
|
||||||
|
ws: WebSocket,
|
||||||
|
graph: CompiledStateGraph,
|
||||||
|
session_manager: SessionManager,
|
||||||
|
callback_handler: TokenUsageCallbackHandler,
|
||||||
|
thread_id: str,
|
||||||
|
approved: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Resume graph execution after interrupt approval/rejection."""
|
||||||
|
session_manager.resolve_interrupt(thread_id)
|
||||||
|
session_manager.touch(thread_id)
|
||||||
|
|
||||||
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in graph.astream(
|
||||||
|
Command(resume=approved),
|
||||||
|
config=config,
|
||||||
|
stream_mode="messages",
|
||||||
|
):
|
||||||
|
msg_chunk, metadata = chunk
|
||||||
|
node = metadata.get("langgraph_node", "")
|
||||||
|
|
||||||
|
if hasattr(msg_chunk, "content") and msg_chunk.content:
|
||||||
|
await _send_json(
|
||||||
|
ws,
|
||||||
|
{
|
||||||
|
"type": "token",
|
||||||
|
"agent": node,
|
||||||
|
"content": msg_chunk.content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error resuming interrupt for thread %s", thread_id)
|
||||||
|
err = "An error occurred processing your response."
|
||||||
|
await _send_json(ws, {"type": "error", "message": err})
|
||||||
|
|
||||||
|
|
||||||
|
async def dispatch_message(
|
||||||
|
ws: WebSocket,
|
||||||
|
graph: CompiledStateGraph,
|
||||||
|
session_manager: SessionManager,
|
||||||
|
callback_handler: TokenUsageCallbackHandler,
|
||||||
|
raw_data: str,
|
||||||
|
) -> None:
|
||||||
|
"""Parse and route an incoming WebSocket message."""
|
||||||
|
if len(raw_data) > MAX_MESSAGE_SIZE:
|
||||||
|
await _send_json(ws, {"type": "error", "message": "Message too large"})
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(raw_data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
|
||||||
|
return
|
||||||
|
|
||||||
|
msg_type = data.get("type")
|
||||||
|
thread_id = data.get("thread_id", "")
|
||||||
|
|
||||||
|
if not thread_id:
|
||||||
|
await _send_json(ws, {"type": "error", "message": "Missing thread_id"})
|
||||||
|
return
|
||||||
|
|
||||||
|
if not THREAD_ID_PATTERN.match(thread_id):
|
||||||
|
await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"})
|
||||||
|
return
|
||||||
|
|
||||||
|
if msg_type == "message":
|
||||||
|
content = data.get("content", "")
|
||||||
|
if not content:
|
||||||
|
await _send_json(ws, {"type": "error", "message": "Missing message content"})
|
||||||
|
return
|
||||||
|
if len(content) > MAX_CONTENT_LENGTH:
|
||||||
|
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
||||||
|
return
|
||||||
|
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
|
||||||
|
|
||||||
|
elif msg_type == "interrupt_response":
|
||||||
|
approved = data.get("approved", False)
|
||||||
|
await handle_interrupt_response(
|
||||||
|
ws, graph, session_manager, callback_handler, thread_id, approved
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
|
||||||
|
|
||||||
|
|
||||||
|
def _has_interrupt(state: Any) -> bool:
|
||||||
|
"""Check if the graph state has a pending interrupt."""
|
||||||
|
tasks = getattr(state, "tasks", ())
|
||||||
|
return any(getattr(t, "interrupts", ()) for t in tasks)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_interrupt(state: Any) -> dict:
|
||||||
|
"""Extract interrupt data from graph state."""
|
||||||
|
for task in getattr(state, "tasks", ()):
|
||||||
|
for intr in getattr(task, "interrupts", ()):
|
||||||
|
value = intr.value if hasattr(intr, "value") else {}
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
value = {}
|
||||||
|
return {
|
||||||
|
"action": value.get("action", "unknown"),
|
||||||
|
"params": value,
|
||||||
|
}
|
||||||
|
return {"action": "unknown", "params": {}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_json(ws: WebSocket, data: dict) -> None:
|
||||||
|
"""Send a JSON message through the WebSocket."""
|
||||||
|
await ws.send_json(data)
|
||||||
64
backend/pyproject.toml
Normal file
64
backend/pyproject.toml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
[project]
|
||||||
|
name = "smart-support"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "AI customer support action-layer framework"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.115,<1.0",
|
||||||
|
"uvicorn[standard]>=0.34,<1.0",
|
||||||
|
"langgraph>=0.4,<1.0",
|
||||||
|
"langgraph-supervisor>=0.0.12,<1.0",
|
||||||
|
"langgraph-checkpoint-postgres>=3.0,<4.0",
|
||||||
|
"langchain-core>=0.3,<1.0",
|
||||||
|
"langchain-anthropic>=0.3,<2.0",
|
||||||
|
"langchain-openai>=0.3,<1.0",
|
||||||
|
"langchain-google-genai>=2.1,<3.0",
|
||||||
|
"psycopg[binary,pool]>=3.2,<4.0",
|
||||||
|
"pydantic>=2.10,<3.0",
|
||||||
|
"pydantic-settings>=2.7,<3.0",
|
||||||
|
"pyyaml>=6.0,<7.0",
|
||||||
|
"python-dotenv>=1.0,<2.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.3,<9.0",
|
||||||
|
"pytest-asyncio>=0.25,<1.0",
|
||||||
|
"pytest-cov>=6.0,<7.0",
|
||||||
|
"httpx>=0.28,<1.0",
|
||||||
|
"ruff>=0.9,<1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=75.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
markers = [
|
||||||
|
"unit: per-module isolated tests",
|
||||||
|
"integration: cross-module with real PostgreSQL",
|
||||||
|
"e2e: full-stack user flow tests",
|
||||||
|
]
|
||||||
|
addopts = "--strict-markers"
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
source = ["app"]
|
||||||
|
|
||||||
|
[tool.coverage.report]
|
||||||
|
fail_under = 80
|
||||||
|
show_missing = true
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py311"
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I", "N", "W", "UP", "B", "A", "SIM", "TCH"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"tests/**" = ["N806", "B017"]
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
69
backend/tests/conftest.py
Normal file
69
backend/tests/conftest.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Shared test fixtures and marker registration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
from app.registry import AgentRegistry
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_settings() -> Settings:
|
||||||
|
return Settings(
|
||||||
|
database_url="postgresql://test:test@localhost:5432/test_db",
|
||||||
|
llm_provider="anthropic",
|
||||||
|
llm_model="claude-sonnet-4-6",
|
||||||
|
anthropic_api_key="test-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_yaml_path(tmp_path: Path) -> Path:
|
||||||
|
data = {
|
||||||
|
"agents": [
|
||||||
|
{
|
||||||
|
"name": "test_reader",
|
||||||
|
"description": "A test read agent",
|
||||||
|
"permission": "read",
|
||||||
|
"tools": ["get_order_status"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "test_writer",
|
||||||
|
"description": "A test write agent",
|
||||||
|
"permission": "write",
|
||||||
|
"personality": {
|
||||||
|
"tone": "formal",
|
||||||
|
"greeting": "Greetings.",
|
||||||
|
"escalation_message": "Escalating now.",
|
||||||
|
},
|
||||||
|
"tools": ["cancel_order"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "test_fallback",
|
||||||
|
"description": "A fallback agent",
|
||||||
|
"permission": "read",
|
||||||
|
"tools": ["fallback_respond"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
path = tmp_path / "test_agents.yaml"
|
||||||
|
path.write_text(yaml.dump(data), encoding="utf-8")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_registry(sample_yaml_path: Path) -> AgentRegistry:
|
||||||
|
return AgentRegistry.load(sample_yaml_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_manager() -> SessionManager:
|
||||||
|
return SessionManager(session_ttl_seconds=60)
|
||||||
0
backend/tests/e2e/__init__.py
Normal file
0
backend/tests/e2e/__init__.py
Normal file
0
backend/tests/integration/__init__.py
Normal file
0
backend/tests/integration/__init__.py
Normal file
0
backend/tests/unit/__init__.py
Normal file
0
backend/tests/unit/__init__.py
Normal file
82
backend/tests/unit/test_agents.py
Normal file
82
backend/tests/unit/test_agents.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Tests for agent tools."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.fallback import fallback_respond
|
||||||
|
from app.agents.order_lookup import get_order_status, get_tracking_info
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestOrderLookup:
|
||||||
|
def test_get_order_status_existing(self) -> None:
|
||||||
|
result = get_order_status.invoke({"order_id": "1042"})
|
||||||
|
assert result["order_id"] == "1042"
|
||||||
|
assert result["status"] == "shipped"
|
||||||
|
|
||||||
|
def test_get_order_status_not_found(self) -> None:
|
||||||
|
result = get_order_status.invoke({"order_id": "9999"})
|
||||||
|
assert "error" in result
|
||||||
|
assert "9999" in result["error"]
|
||||||
|
|
||||||
|
def test_get_tracking_info_existing(self) -> None:
|
||||||
|
result = get_tracking_info.invoke({"order_id": "1042"})
|
||||||
|
assert result["carrier"] == "FedEx"
|
||||||
|
assert result["tracking_number"] == "FX-9876543210"
|
||||||
|
|
||||||
|
def test_get_tracking_info_not_found(self) -> None:
|
||||||
|
result = get_tracking_info.invoke({"order_id": "1043"})
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
def test_all_mock_orders_have_required_fields(self) -> None:
|
||||||
|
from app.agents.order_lookup import MOCK_ORDERS
|
||||||
|
|
||||||
|
for oid, order in MOCK_ORDERS.items():
|
||||||
|
assert "order_id" in order
|
||||||
|
assert "status" in order
|
||||||
|
assert order["order_id"] == oid
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestFallback:
|
||||||
|
def test_fallback_respond_returns_help(self) -> None:
|
||||||
|
result = fallback_respond.invoke({"query": "random question"})
|
||||||
|
assert "order" in result.lower()
|
||||||
|
assert "help" in result.lower() or "can do" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestOrderActions:
|
||||||
|
def test_cancel_order_approved(self) -> None:
|
||||||
|
with patch("app.agents.order_actions.interrupt", return_value=True):
|
||||||
|
from app.agents.order_actions import cancel_order
|
||||||
|
|
||||||
|
result = cancel_order.invoke({"order_id": "1042"})
|
||||||
|
assert result["status"] == "cancelled"
|
||||||
|
assert "1042" in result["message"]
|
||||||
|
|
||||||
|
def test_cancel_order_rejected(self) -> None:
|
||||||
|
with patch("app.agents.order_actions.interrupt", return_value=False):
|
||||||
|
from app.agents.order_actions import cancel_order
|
||||||
|
|
||||||
|
result = cancel_order.invoke({"order_id": "1042"})
|
||||||
|
assert result["status"] == "kept"
|
||||||
|
assert "declined" in result["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestToolBridge:
|
||||||
|
def test_get_tools_by_names(self) -> None:
|
||||||
|
from app.agents import get_tools_by_names
|
||||||
|
|
||||||
|
tools = get_tools_by_names(["get_order_status", "cancel_order"])
|
||||||
|
assert len(tools) == 2
|
||||||
|
|
||||||
|
def test_unknown_tool_raises(self) -> None:
|
||||||
|
from app.agents import get_tools_by_names
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown tool"):
|
||||||
|
get_tools_by_names(["nonexistent_tool"])
|
||||||
102
backend/tests/unit/test_callbacks.py
Normal file
102
backend/tests/unit/test_callbacks.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""Tests for app.callbacks module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestTokenUsageCallbackHandler:
|
||||||
|
def test_initial_state(self) -> None:
|
||||||
|
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||||
|
usage = handler.get_usage()
|
||||||
|
assert usage.prompt_tokens == 0
|
||||||
|
assert usage.completion_tokens == 0
|
||||||
|
assert usage.total_tokens == 0
|
||||||
|
assert usage.total_cost_usd == 0.0
|
||||||
|
|
||||||
|
def test_accumulates_tokens(self) -> None:
|
||||||
|
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}}
|
||||||
|
|
||||||
|
handler.on_llm_end(FakeResult())
|
||||||
|
usage = handler.get_usage()
|
||||||
|
assert usage.prompt_tokens == 100
|
||||||
|
assert usage.completion_tokens == 50
|
||||||
|
assert usage.total_tokens == 150
|
||||||
|
|
||||||
|
def test_accumulates_across_calls(self) -> None:
|
||||||
|
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}}
|
||||||
|
|
||||||
|
handler.on_llm_end(FakeResult())
|
||||||
|
handler.on_llm_end(FakeResult())
|
||||||
|
usage = handler.get_usage()
|
||||||
|
assert usage.prompt_tokens == 200
|
||||||
|
assert usage.completion_tokens == 100
|
||||||
|
assert usage.total_tokens == 300
|
||||||
|
|
||||||
|
def test_cost_calculation(self) -> None:
|
||||||
|
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
llm_output = {"token_usage": {"prompt_tokens": 1000, "completion_tokens": 1000}}
|
||||||
|
|
||||||
|
handler.on_llm_end(FakeResult())
|
||||||
|
usage = handler.get_usage()
|
||||||
|
# claude-sonnet-4-6: prompt $0.003/1K, completion $0.015/1K
|
||||||
|
expected_cost = 1000 * 0.003 / 1000 + 1000 * 0.015 / 1000
|
||||||
|
assert usage.total_cost_usd == pytest.approx(expected_cost)
|
||||||
|
|
||||||
|
def test_reset(self) -> None:
|
||||||
|
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}}
|
||||||
|
|
||||||
|
handler.on_llm_end(FakeResult())
|
||||||
|
handler.reset()
|
||||||
|
usage = handler.get_usage()
|
||||||
|
assert usage.total_tokens == 0
|
||||||
|
|
||||||
|
def test_usage_is_immutable(self) -> None:
|
||||||
|
handler = TokenUsageCallbackHandler()
|
||||||
|
usage = handler.get_usage()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
usage.prompt_tokens = 999
|
||||||
|
|
||||||
|
def test_unknown_model_uses_default_cost(self) -> None:
|
||||||
|
handler = TokenUsageCallbackHandler(model_name="unknown-model")
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
llm_output = {"token_usage": {"prompt_tokens": 1000, "completion_tokens": 1000}}
|
||||||
|
|
||||||
|
handler.on_llm_end(FakeResult())
|
||||||
|
usage = handler.get_usage()
|
||||||
|
assert usage.total_cost_usd > 0
|
||||||
|
|
||||||
|
def test_handles_missing_token_usage(self) -> None:
|
||||||
|
handler = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
llm_output = {}
|
||||||
|
|
||||||
|
handler.on_llm_end(FakeResult())
|
||||||
|
usage = handler.get_usage()
|
||||||
|
assert usage.total_tokens == 0
|
||||||
|
|
||||||
|
def test_handles_none_llm_output(self) -> None:
|
||||||
|
handler = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
llm_output = None
|
||||||
|
|
||||||
|
handler.on_llm_end(FakeResult())
|
||||||
|
usage = handler.get_usage()
|
||||||
|
assert usage.total_tokens == 0
|
||||||
60
backend/tests/unit/test_config.py
Normal file
60
backend/tests/unit/test_config.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""Tests for app.config module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSettings:
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
settings = Settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
anthropic_api_key="key",
|
||||||
|
)
|
||||||
|
assert settings.llm_provider == "anthropic"
|
||||||
|
assert settings.llm_model == "claude-sonnet-4-6"
|
||||||
|
assert settings.session_ttl_minutes == 30
|
||||||
|
assert settings.interrupt_ttl_minutes == 30
|
||||||
|
|
||||||
|
def test_custom_values(self) -> None:
|
||||||
|
settings = Settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="openai",
|
||||||
|
llm_model="gpt-4o",
|
||||||
|
session_ttl_minutes=15,
|
||||||
|
openai_api_key="sk-test",
|
||||||
|
)
|
||||||
|
assert settings.llm_provider == "openai"
|
||||||
|
assert settings.llm_model == "gpt-4o"
|
||||||
|
assert settings.session_ttl_minutes == 15
|
||||||
|
|
||||||
|
def test_invalid_provider_rejected(self) -> None:
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
Settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_missing_database_url_rejected(self) -> None:
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
Settings(anthropic_api_key="key")
|
||||||
|
|
||||||
|
def test_empty_api_key_for_provider_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="API key"):
|
||||||
|
Settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="anthropic",
|
||||||
|
anthropic_api_key="",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_wrong_provider_key_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="API key"):
|
||||||
|
Settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="openai",
|
||||||
|
anthropic_api_key="key",
|
||||||
|
openai_api_key="",
|
||||||
|
)
|
||||||
64
backend/tests/unit/test_db.py
Normal file
64
backend/tests/unit/test_db.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""Tests for app.db module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
from app.db import _CONVERSATIONS_DDL, _INTERRUPTS_DDL
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestDbModule:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_pool_sets_correct_params(self) -> None:
|
||||||
|
settings = Settings(
|
||||||
|
database_url="postgresql://user:pass@localhost:5432/testdb",
|
||||||
|
anthropic_api_key="key",
|
||||||
|
)
|
||||||
|
with patch("app.db.AsyncConnectionPool") as MockPool:
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
MockPool.return_value = mock_pool
|
||||||
|
|
||||||
|
from app.db import create_pool
|
||||||
|
|
||||||
|
await create_pool(settings)
|
||||||
|
MockPool.assert_called_once()
|
||||||
|
call_kwargs = MockPool.call_args
|
||||||
|
assert "postgresql://user:pass@localhost:5432/testdb" in str(call_kwargs)
|
||||||
|
mock_pool.open.assert_awaited_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpointer_calls_setup(self) -> None:
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
with patch("app.db.AsyncPostgresSaver") as MockSaver:
|
||||||
|
mock_saver = AsyncMock()
|
||||||
|
MockSaver.return_value = mock_saver
|
||||||
|
|
||||||
|
from app.db import create_checkpointer
|
||||||
|
|
||||||
|
await create_checkpointer(mock_pool)
|
||||||
|
MockSaver.assert_called_once_with(conn=mock_pool)
|
||||||
|
mock_saver.setup.assert_awaited_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_app_tables_executes_ddl(self) -> None:
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection.return_value = mock_ctx
|
||||||
|
|
||||||
|
from app.db import setup_app_tables
|
||||||
|
|
||||||
|
await setup_app_tables(mock_pool)
|
||||||
|
assert mock_conn.execute.await_count == 2
|
||||||
|
|
||||||
|
def test_ddl_statements_valid(self) -> None:
|
||||||
|
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
||||||
|
assert "CREATE TABLE IF NOT EXISTS active_interrupts" in _INTERRUPTS_DDL
|
||||||
|
assert "thread_id" in _CONVERSATIONS_DDL
|
||||||
|
assert "interrupt_id" in _INTERRUPTS_DDL
|
||||||
44
backend/tests/unit/test_graph.py
Normal file
44
backend/tests/unit/test_graph.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Tests for app.graph module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.graph import SUPERVISOR_PROMPT, build_agent_nodes, build_graph
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.registry import AgentRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestBuildAgentNodes:
|
||||||
|
def test_creates_correct_number_of_nodes(self, sample_registry: AgentRegistry) -> None:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
nodes = build_agent_nodes(sample_registry, mock_llm)
|
||||||
|
assert len(nodes) == 3
|
||||||
|
|
||||||
|
def test_nodes_are_runnable(self, sample_registry: AgentRegistry) -> None:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
nodes = build_agent_nodes(sample_registry, mock_llm)
|
||||||
|
for node in nodes:
|
||||||
|
assert hasattr(node, "invoke") or hasattr(node, "ainvoke")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestBuildGraph:
|
||||||
|
def test_graph_compiles_with_mock_checkpointer(self, sample_registry: AgentRegistry) -> None:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||||
|
mock_checkpointer = AsyncMock()
|
||||||
|
|
||||||
|
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
|
||||||
|
assert graph is not None
|
||||||
|
|
||||||
|
def test_supervisor_prompt_contains_routing_info(self) -> None:
|
||||||
|
assert "order_lookup" in SUPERVISOR_PROMPT
|
||||||
|
assert "order_actions" in SUPERVISOR_PROMPT
|
||||||
|
assert "fallback" in SUPERVISOR_PROMPT
|
||||||
41
backend/tests/unit/test_llm.py
Normal file
41
backend/tests/unit/test_llm.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""Tests for app.llm module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
from app.llm import create_llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestCreateLlm:
|
||||||
|
def test_anthropic_provider(self) -> None:
|
||||||
|
settings = Settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="anthropic",
|
||||||
|
llm_model="claude-sonnet-4-6",
|
||||||
|
anthropic_api_key="test-key",
|
||||||
|
)
|
||||||
|
llm = create_llm(settings)
|
||||||
|
assert type(llm).__name__ == "ChatAnthropic"
|
||||||
|
|
||||||
|
def test_openai_provider(self) -> None:
|
||||||
|
settings = Settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="openai",
|
||||||
|
llm_model="gpt-4o",
|
||||||
|
openai_api_key="sk-test",
|
||||||
|
)
|
||||||
|
llm = create_llm(settings)
|
||||||
|
assert type(llm).__name__ == "ChatOpenAI"
|
||||||
|
|
||||||
|
def test_google_provider(self) -> None:
|
||||||
|
settings = Settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="google",
|
||||||
|
llm_model="gemini-pro",
|
||||||
|
google_api_key="test-key",
|
||||||
|
)
|
||||||
|
llm = create_llm(settings)
|
||||||
|
assert type(llm).__name__ == "ChatGoogleGenerativeAI"
|
||||||
27
backend/tests/unit/test_main.py
Normal file
27
backend/tests/unit/test_main.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""Tests for app.main module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.main import AGENTS_YAML, FRONTEND_DIST, app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestMainModule:
|
||||||
|
def test_app_title(self) -> None:
|
||||||
|
assert app.title == "Smart Support"
|
||||||
|
|
||||||
|
def test_app_version(self) -> None:
|
||||||
|
assert app.version == "0.1.0"
|
||||||
|
|
||||||
|
def test_agents_yaml_path_exists(self) -> None:
|
||||||
|
assert AGENTS_YAML.name == "agents.yaml"
|
||||||
|
|
||||||
|
def test_frontend_dist_path(self) -> None:
|
||||||
|
assert "frontend" in str(FRONTEND_DIST)
|
||||||
|
assert "dist" in str(FRONTEND_DIST)
|
||||||
|
|
||||||
|
def test_websocket_route_registered(self) -> None:
|
||||||
|
routes = [r.path for r in app.routes if hasattr(r, "path")]
|
||||||
|
assert "/ws" in routes
|
||||||
147
backend/tests/unit/test_registry.py
Normal file
147
backend/tests/unit/test_registry.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""Tests for app.registry module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.registry import AgentConfig, AgentRegistry, PersonalityConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestPersonalityConfig:
|
||||||
|
def test_defaults(self) -> None:
|
||||||
|
p = PersonalityConfig()
|
||||||
|
assert p.tone == "professional and helpful"
|
||||||
|
assert "Hello" in p.greeting
|
||||||
|
assert "human agent" in p.escalation_message
|
||||||
|
|
||||||
|
def test_custom_values(self) -> None:
|
||||||
|
p = PersonalityConfig(tone="casual", greeting="Hey!", escalation_message="Hold on.")
|
||||||
|
assert p.tone == "casual"
|
||||||
|
|
||||||
|
def test_immutable(self) -> None:
|
||||||
|
p = PersonalityConfig()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
p.tone = "new tone"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestAgentConfig:
|
||||||
|
def test_valid_config(self) -> None:
|
||||||
|
ac = AgentConfig(
|
||||||
|
name="test",
|
||||||
|
description="A test agent",
|
||||||
|
permission="read",
|
||||||
|
tools=["tool1"],
|
||||||
|
)
|
||||||
|
assert ac.name == "test"
|
||||||
|
assert ac.permission == "read"
|
||||||
|
|
||||||
|
def test_empty_name_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="must not be empty"):
|
||||||
|
AgentConfig(name=" ", description="d", permission="read", tools=["t"])
|
||||||
|
|
||||||
|
def test_empty_tools_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="at least one tool"):
|
||||||
|
AgentConfig(name="x", description="d", permission="read", tools=[])
|
||||||
|
|
||||||
|
def test_invalid_permission_rejected(self) -> None:
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
AgentConfig(name="x", description="d", permission="admin", tools=["t"])
|
||||||
|
|
||||||
|
def test_name_stripped(self) -> None:
|
||||||
|
ac = AgentConfig(name=" test ", description="d", permission="read", tools=["t"])
|
||||||
|
assert ac.name == "test"
|
||||||
|
|
||||||
|
def test_immutable(self) -> None:
|
||||||
|
ac = AgentConfig(name="test", description="d", permission="read", tools=["t"])
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
ac.name = "new"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestAgentRegistry:
|
||||||
|
def test_load_valid_yaml(self, sample_yaml_path: Path) -> None:
|
||||||
|
registry = AgentRegistry.load(sample_yaml_path)
|
||||||
|
assert len(registry) == 3
|
||||||
|
assert registry.get_agent("test_reader").permission == "read"
|
||||||
|
assert registry.get_agent("test_writer").permission == "write"
|
||||||
|
|
||||||
|
def test_list_agents(self, sample_registry: AgentRegistry) -> None:
|
||||||
|
agents = sample_registry.list_agents()
|
||||||
|
assert len(agents) == 3
|
||||||
|
names = {a.name for a in agents}
|
||||||
|
assert names == {"test_reader", "test_writer", "test_fallback"}
|
||||||
|
|
||||||
|
def test_get_agents_by_permission(self, sample_registry: AgentRegistry) -> None:
|
||||||
|
readers = sample_registry.get_agents_by_permission("read")
|
||||||
|
assert len(readers) == 2
|
||||||
|
writers = sample_registry.get_agents_by_permission("write")
|
||||||
|
assert len(writers) == 1
|
||||||
|
|
||||||
|
def test_get_nonexistent_agent(self, sample_registry: AgentRegistry) -> None:
|
||||||
|
with pytest.raises(KeyError, match="not found"):
|
||||||
|
sample_registry.get_agent("nonexistent")
|
||||||
|
|
||||||
|
def test_personality_defaults_applied(self, sample_registry: AgentRegistry) -> None:
|
||||||
|
agent = sample_registry.get_agent("test_reader")
|
||||||
|
assert agent.personality.tone == "professional and helpful"
|
||||||
|
|
||||||
|
def test_personality_custom_applied(self, sample_registry: AgentRegistry) -> None:
|
||||||
|
agent = sample_registry.get_agent("test_writer")
|
||||||
|
assert agent.personality.tone == "formal"
|
||||||
|
assert agent.personality.greeting == "Greetings."
|
||||||
|
|
||||||
|
def test_file_not_found(self) -> None:
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
AgentRegistry.load("/nonexistent/path.yaml")
|
||||||
|
|
||||||
|
def test_empty_file(self, tmp_path: Path) -> None:
|
||||||
|
path = tmp_path / "empty.yaml"
|
||||||
|
path.write_text("", encoding="utf-8")
|
||||||
|
with pytest.raises(ValueError, match="empty"):
|
||||||
|
AgentRegistry.load(path)
|
||||||
|
|
||||||
|
def test_invalid_yaml_syntax(self, tmp_path: Path) -> None:
|
||||||
|
path = tmp_path / "bad.yaml"
|
||||||
|
path.write_text("agents:\n - name: [invalid\n", encoding="utf-8")
|
||||||
|
with pytest.raises(ValueError, match="Invalid YAML"):
|
||||||
|
AgentRegistry.load(path)
|
||||||
|
|
||||||
|
def test_missing_agents_key(self, tmp_path: Path) -> None:
|
||||||
|
path = tmp_path / "no_agents.yaml"
|
||||||
|
path.write_text(yaml.dump({"items": []}), encoding="utf-8")
|
||||||
|
with pytest.raises(ValueError, match="agents"):
|
||||||
|
AgentRegistry.load(path)
|
||||||
|
|
||||||
|
def test_duplicate_agent_names(self, tmp_path: Path) -> None:
|
||||||
|
data = {
|
||||||
|
"agents": [
|
||||||
|
{"name": "dup", "description": "a", "permission": "read", "tools": ["t1"]},
|
||||||
|
{"name": "dup", "description": "b", "permission": "read", "tools": ["t2"]},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
path = tmp_path / "dups.yaml"
|
||||||
|
path.write_text(yaml.dump(data), encoding="utf-8")
|
||||||
|
with pytest.raises(ValueError, match="Duplicate"):
|
||||||
|
AgentRegistry.load(path)
|
||||||
|
|
||||||
|
def test_missing_required_fields(self, tmp_path: Path) -> None:
|
||||||
|
data = {"agents": [{"name": "x"}]}
|
||||||
|
path = tmp_path / "missing.yaml"
|
||||||
|
path.write_text(yaml.dump(data), encoding="utf-8")
|
||||||
|
with pytest.raises(ValueError, match="Invalid agent config"):
|
||||||
|
AgentRegistry.load(path)
|
||||||
|
|
||||||
|
def test_empty_agents_list(self, tmp_path: Path) -> None:
|
||||||
|
data = {"agents": []}
|
||||||
|
path = tmp_path / "empty_list.yaml"
|
||||||
|
path.write_text(yaml.dump(data), encoding="utf-8")
|
||||||
|
with pytest.raises(ValueError, match="non-empty"):
|
||||||
|
AgentRegistry.load(path)
|
||||||
70
backend/tests/unit/test_session_manager.py
Normal file
70
backend/tests/unit/test_session_manager.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""Tests for app.session_manager module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSessionManager:
|
||||||
|
def test_new_session_not_expired(self, session_manager: SessionManager) -> None:
|
||||||
|
session_manager.touch("thread-1")
|
||||||
|
assert not session_manager.is_expired("thread-1")
|
||||||
|
|
||||||
|
def test_unknown_session_is_expired(self, session_manager: SessionManager) -> None:
|
||||||
|
assert session_manager.is_expired("unknown")
|
||||||
|
|
||||||
|
def test_session_expires_after_ttl(self) -> None:
|
||||||
|
mgr = SessionManager(session_ttl_seconds=1)
|
||||||
|
mgr.touch("t1")
|
||||||
|
with patch("app.session_manager.time") as mock_time:
|
||||||
|
mock_time.time.return_value = time.time() + 2
|
||||||
|
assert mgr.is_expired("t1")
|
||||||
|
|
||||||
|
def test_touch_resets_ttl(self) -> None:
|
||||||
|
mgr = SessionManager(session_ttl_seconds=5)
|
||||||
|
mgr.touch("t1")
|
||||||
|
initial_state = mgr.get_state("t1")
|
||||||
|
# Touch again after some time
|
||||||
|
with patch("app.session_manager.time") as mock_time:
|
||||||
|
mock_time.time.return_value = time.time() + 3
|
||||||
|
mgr.touch("t1")
|
||||||
|
new_state = mgr.get_state("t1")
|
||||||
|
assert new_state.last_activity > initial_state.last_activity
|
||||||
|
|
||||||
|
def test_interrupt_suspends_expiration(self) -> None:
|
||||||
|
mgr = SessionManager(session_ttl_seconds=1)
|
||||||
|
mgr.touch("t1")
|
||||||
|
mgr.extend_for_interrupt("t1")
|
||||||
|
with patch("app.session_manager.time") as mock_time:
|
||||||
|
mock_time.time.return_value = time.time() + 100
|
||||||
|
assert not mgr.is_expired("t1")
|
||||||
|
|
||||||
|
def test_resolve_interrupt_resumes_ttl(self) -> None:
|
||||||
|
mgr = SessionManager(session_ttl_seconds=1)
|
||||||
|
mgr.touch("t1")
|
||||||
|
mgr.extend_for_interrupt("t1")
|
||||||
|
mgr.resolve_interrupt("t1")
|
||||||
|
state = mgr.get_state("t1")
|
||||||
|
assert not state.has_pending_interrupt
|
||||||
|
|
||||||
|
def test_extend_for_nonexistent_creates_session(self) -> None:
|
||||||
|
mgr = SessionManager()
|
||||||
|
mgr.extend_for_interrupt("new-thread")
|
||||||
|
state = mgr.get_state("new-thread")
|
||||||
|
assert state is not None
|
||||||
|
|
||||||
|
def test_remove_session(self, session_manager: SessionManager) -> None:
|
||||||
|
session_manager.touch("t1")
|
||||||
|
session_manager.remove("t1")
|
||||||
|
assert session_manager.get_state("t1") is None
|
||||||
|
|
||||||
|
def test_session_state_is_immutable(self, session_manager: SessionManager) -> None:
|
||||||
|
state = session_manager.touch("t1")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
state.thread_id = "new"
|
||||||
233
backend/tests/unit/test_ws_handler.py
Normal file
233
backend/tests/unit/test_ws_handler.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""Tests for app.ws_handler module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_handler import (
|
||||||
|
_extract_interrupt,
|
||||||
|
_has_interrupt,
|
||||||
|
dispatch_message,
|
||||||
|
handle_interrupt_response,
|
||||||
|
handle_user_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws() -> AsyncMock:
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.send_json = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph() -> AsyncMock:
|
||||||
|
graph = AsyncMock()
|
||||||
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
|
state = MagicMock()
|
||||||
|
state.tasks = ()
|
||||||
|
graph.aget_state = AsyncMock(return_value=state)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncIterHelper:
|
||||||
|
"""Helper to make a list behave as an async iterator."""
|
||||||
|
|
||||||
|
def __init__(self, items):
|
||||||
|
self._items = items
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if not self._items:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return self._items.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestDispatchMessage:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_json(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
await dispatch_message(ws, graph, sm, cb, "not json")
|
||||||
|
ws.send_json.assert_awaited_once()
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "Invalid JSON" in call_data["message"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_thread_id(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
msg = json.dumps({"type": "message", "content": "hello"})
|
||||||
|
await dispatch_message(ws, graph, sm, cb, msg)
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "thread_id" in call_data["message"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_content(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
||||||
|
await dispatch_message(ws, graph, sm, cb, msg)
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_message_type(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
||||||
|
await dispatch_message(ws, graph, sm, cb, msg)
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "Unknown" in call_data["message"]
|
||||||
|
# Verify raw input is NOT reflected back
|
||||||
|
assert "unknown" not in call_data["message"].lower().replace("unknown message type", "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_too_large(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
large_msg = "x" * 40_000
|
||||||
|
await dispatch_message(ws, graph, sm, cb, large_msg)
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "too large" in call_data["message"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_thread_id_format(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
||||||
|
await dispatch_message(ws, graph, sm, cb, msg)
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "thread_id" in call_data["message"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_content_too_long(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000})
|
||||||
|
await dispatch_message(ws, graph, sm, cb, msg)
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "too long" in call_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestHandleUserMessage:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_expired_session(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager(session_ttl_seconds=0)
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "expired" in call_data["message"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_message(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
sm.touch("t1")
|
||||||
|
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||||
|
# Should end with message_complete
|
||||||
|
last_call = ws.send_json.call_args[0][0]
|
||||||
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_error_sends_error_message(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = AsyncMock()
|
||||||
|
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
sm.touch("t1")
|
||||||
|
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestHandleInterruptResponse:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_approved_interrupt(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
graph = _make_graph()
|
||||||
|
sm = SessionManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
|
sm.touch("t1")
|
||||||
|
sm.extend_for_interrupt("t1")
|
||||||
|
await handle_interrupt_response(ws, graph, sm, cb, "t1", True)
|
||||||
|
last_call = ws.send_json.call_args[0][0]
|
||||||
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestInterruptHelpers:
|
||||||
|
def test_has_interrupt_false_for_empty_tasks(self) -> None:
|
||||||
|
state = MagicMock()
|
||||||
|
state.tasks = ()
|
||||||
|
assert not _has_interrupt(state)
|
||||||
|
|
||||||
|
def test_has_interrupt_true(self) -> None:
|
||||||
|
interrupt_obj = MagicMock()
|
||||||
|
interrupt_obj.value = {"action": "cancel"}
|
||||||
|
task = MagicMock()
|
||||||
|
task.interrupts = (interrupt_obj,)
|
||||||
|
state = MagicMock()
|
||||||
|
state.tasks = (task,)
|
||||||
|
assert _has_interrupt(state)
|
||||||
|
|
||||||
|
def test_extract_interrupt_data(self) -> None:
|
||||||
|
interrupt_obj = MagicMock()
|
||||||
|
interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"}
|
||||||
|
task = MagicMock()
|
||||||
|
task.interrupts = (interrupt_obj,)
|
||||||
|
state = MagicMock()
|
||||||
|
state.tasks = (task,)
|
||||||
|
data = _extract_interrupt(state)
|
||||||
|
assert data["action"] == "cancel_order"
|
||||||
|
|
||||||
|
def test_extract_interrupt_empty(self) -> None:
|
||||||
|
state = MagicMock()
|
||||||
|
state.tasks = ()
|
||||||
|
data = _extract_interrupt(state)
|
||||||
|
assert data["action"] == "unknown"
|
||||||
39
docker-compose.yml
Normal file
39
docker-compose.yml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:16
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: smart_support
|
||||||
|
POSTGRES_USER: smart_support
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-dev_password}
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
volumes:
|
||||||
|
- pgdata:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U smart_support -d smart_support"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
backend:
|
||||||
|
build:
|
||||||
|
context: ./backend
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql://smart_support:${POSTGRES_PASSWORD:-dev_password}@postgres:5432/smart_support
|
||||||
|
LLM_PROVIDER: ${LLM_PROVIDER:-anthropic}
|
||||||
|
LLM_MODEL: ${LLM_MODEL:-claude-sonnet-4-6}
|
||||||
|
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-}
|
||||||
|
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
||||||
|
GOOGLE_API_KEY: ${GOOGLE_API_KEY:-}
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
volumes:
|
||||||
|
- ./backend:/app
|
||||||
|
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
pgdata:
|
||||||
88
docs/phases/phase-1-dev-log.md
Normal file
88
docs/phases/phase-1-dev-log.md
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# Phase 1: Core Framework -- Development Log
|
||||||
|
|
||||||
|
> Status: IN PROGRESS
|
||||||
|
> Phase branch: `phase-1/core-framework`
|
||||||
|
> Date started: 2026-03-30
|
||||||
|
> Date completed: --
|
||||||
|
> Related plan section: [Phase 1 in DEVELOPMENT-PLAN](../DEVELOPMENT-PLAN.md#phase-1-核心框架-第-1-3-周)
|
||||||
|
|
||||||
|
## What Was Built
|
||||||
|
|
||||||
|
- FastAPI WebSocket backend with `/ws` endpoint for real-time chat
|
||||||
|
- LangGraph Supervisor (via `langgraph-supervisor`) connecting 3 agents
|
||||||
|
- YAML-based Agent Registry with Pydantic validation
|
||||||
|
- 3 Mock Agents: order_lookup (read), order_actions (write + interrupt), fallback
|
||||||
|
- PostgresSaver checkpoint persistence via `langgraph-checkpoint-postgres`
|
||||||
|
- Session TTL management with 30-minute sliding window and interrupt extension
|
||||||
|
- LLM provider abstraction (Anthropic/OpenAI/Google) with prompt caching support
|
||||||
|
- Token usage tracking callback handler
|
||||||
|
- React Chat UI with streaming display, interrupt confirmation, and agent action viewer
|
||||||
|
- Docker Compose configuration (PostgreSQL 16 + backend)
|
||||||
|
|
||||||
|
## Code Structure
|
||||||
|
|
||||||
|
### New files
|
||||||
|
|
||||||
|
Backend (`backend/app/`):
|
||||||
|
- `config.py` -- pydantic-settings centralized configuration
|
||||||
|
- `db.py` -- Async PostgreSQL pool + AsyncPostgresSaver setup
|
||||||
|
- `llm.py` -- LLM provider factory (ChatAnthropic/ChatOpenAI/ChatGoogleGenerativeAI)
|
||||||
|
- `callbacks.py` -- Token usage + cost tracking callback handler
|
||||||
|
- `registry.py` -- YAML agent registry with validation + immutable config models
|
||||||
|
- `session_manager.py` -- Session TTL with sliding window + interrupt extension
|
||||||
|
- `graph.py` -- LangGraph Supervisor construction from registry
|
||||||
|
- `ws_handler.py` -- WebSocket message dispatch + streaming logic
|
||||||
|
- `main.py` -- FastAPI app entry with lifespan + WebSocket endpoint
|
||||||
|
- `agents/__init__.py` -- Tool name-to-function bridge
|
||||||
|
- `agents/order_lookup.py` -- Mock order status/tracking tools
|
||||||
|
- `agents/order_actions.py` -- Mock cancel_order with interrupt()
|
||||||
|
- `agents/fallback.py` -- Fallback response tool
|
||||||
|
|
||||||
|
Frontend (`frontend/src/`):
|
||||||
|
- `types.ts` -- WebSocket message protocol TypeScript types
|
||||||
|
- `hooks/useWebSocket.ts` -- WebSocket connection + reconnect + message dispatch
|
||||||
|
- `components/ChatMessages.tsx` -- Streaming message display
|
||||||
|
- `components/ChatInput.tsx` -- Message input
|
||||||
|
- `components/InterruptPrompt.tsx` -- Approve/reject interrupt UI
|
||||||
|
- `components/AgentAction.tsx` -- Tool call inline display
|
||||||
|
- `pages/ChatPage.tsx` -- Main chat page composing all components
|
||||||
|
|
||||||
|
Infrastructure:
|
||||||
|
- `backend/pyproject.toml` -- Dependencies + pytest + ruff config
|
||||||
|
- `backend/agents.yaml` -- Agent registry YAML config
|
||||||
|
- `backend/Dockerfile` -- Backend container
|
||||||
|
- `docker-compose.yml` -- PostgreSQL 16 + backend services
|
||||||
|
- `.gitignore` -- Updated for Python + Node artifacts
|
||||||
|
|
||||||
|
Tests (`backend/tests/unit/`):
|
||||||
|
- `test_config.py` -- Settings validation tests
|
||||||
|
- `test_registry.py` -- 17 tests for registry loading/validation
|
||||||
|
- `test_agents.py` -- 10 tests for tool functions + tool bridge
|
||||||
|
- `test_llm.py` -- 3 tests for LLM provider factory
|
||||||
|
- `test_callbacks.py` -- 9 tests for token usage tracking
|
||||||
|
- `test_session_manager.py` -- 9 tests for session TTL logic
|
||||||
|
- `test_graph.py` -- 4 tests for supervisor construction
|
||||||
|
- `test_db.py` -- 5 tests for database setup
|
||||||
|
- `test_ws_handler.py` -- 12 tests for WebSocket message handling
|
||||||
|
- `test_main.py` -- 5 tests for app configuration
|
||||||
|
|
||||||
|
## Test Coverage
|
||||||
|
|
||||||
|
- Unit test count: 82
|
||||||
|
- Integration test count: 0 (requires running PostgreSQL)
|
||||||
|
- E2E test count: 0 (manual verification in plan)
|
||||||
|
- Overall coverage: 88%
|
||||||
|
|
||||||
|
## Deviations from Plan
|
||||||
|
|
||||||
|
- Used `astream(stream_mode="messages")` instead of `astream_events()` per langgraph best practices
|
||||||
|
- Separated WebSocket handler logic into `ws_handler.py` for testability (not in original plan)
|
||||||
|
- Session manager uses in-memory storage instead of DB-backed (sufficient for Phase 1 single-instance)
|
||||||
|
|
||||||
|
## Known Issues / Tech Debt
|
||||||
|
|
||||||
|
- Session manager not DB-backed (loses state on restart) -- acceptable for Phase 1 single-instance
|
||||||
|
- WebSocket reconnect does not re-send pending interrupt state from server
|
||||||
|
- No rate limiting on WebSocket endpoint (Phase 2)
|
||||||
|
- No authentication (Phase 2)
|
||||||
|
- `main.py` coverage at 47% -- lifespan function not unit-testable without full DB
|
||||||
12
frontend/index.html
Normal file
12
frontend/index.html
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>Smart Support</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.tsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
1820
frontend/package-lock.json
generated
Normal file
1820
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
22
frontend/package.json
Normal file
22
frontend/package.json
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"name": "smart-support-frontend",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.1.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "tsc -b && vite build",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"react": "^19.0.0",
|
||||||
|
"react-dom": "^19.0.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/react": "^19.0.0",
|
||||||
|
"@types/react-dom": "^19.0.0",
|
||||||
|
"@vitejs/plugin-react": "^4.3.0",
|
||||||
|
"typescript": "~5.7.0",
|
||||||
|
"vite": "^6.2.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
5
frontend/src/App.tsx
Normal file
5
frontend/src/App.tsx
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import { ChatPage } from "./pages/ChatPage";
|
||||||
|
|
||||||
|
export default function App() {
|
||||||
|
return <ChatPage />;
|
||||||
|
}
|
||||||
77
frontend/src/components/AgentAction.tsx
Normal file
77
frontend/src/components/AgentAction.tsx
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import { useState } from "react";
|
||||||
|
import type { ToolAction } from "../types";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
action: ToolAction;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function AgentAction({ action }: Props) {
|
||||||
|
const [expanded, setExpanded] = useState(false);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={styles.container}>
|
||||||
|
<div style={styles.header} onClick={() => setExpanded(!expanded)}>
|
||||||
|
<span style={styles.icon}>{expanded ? "v" : ">"}</span>
|
||||||
|
<span style={styles.agent}>{action.agent}</span>
|
||||||
|
<span style={styles.tool}>{action.tool}</span>
|
||||||
|
</div>
|
||||||
|
{expanded && (
|
||||||
|
<div style={styles.details}>
|
||||||
|
<div style={styles.section}>
|
||||||
|
<strong>Args:</strong>
|
||||||
|
<pre style={styles.code}>{JSON.stringify(action.args, null, 2)}</pre>
|
||||||
|
</div>
|
||||||
|
{action.result !== undefined && (
|
||||||
|
<div style={styles.section}>
|
||||||
|
<strong>Result:</strong>
|
||||||
|
<pre style={styles.code}>{JSON.stringify(action.result, null, 2)}</pre>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const styles: Record<string, React.CSSProperties> = {
|
||||||
|
container: {
|
||||||
|
margin: "4px 16px",
|
||||||
|
padding: "6px 10px",
|
||||||
|
background: "#f5f5f5",
|
||||||
|
borderRadius: "6px",
|
||||||
|
fontSize: "12px",
|
||||||
|
color: "#666",
|
||||||
|
},
|
||||||
|
header: {
|
||||||
|
display: "flex",
|
||||||
|
alignItems: "center",
|
||||||
|
gap: "6px",
|
||||||
|
cursor: "pointer",
|
||||||
|
},
|
||||||
|
icon: {
|
||||||
|
fontFamily: "monospace",
|
||||||
|
width: "12px",
|
||||||
|
},
|
||||||
|
agent: {
|
||||||
|
fontWeight: 600,
|
||||||
|
},
|
||||||
|
tool: {
|
||||||
|
color: "#0066cc",
|
||||||
|
fontFamily: "monospace",
|
||||||
|
},
|
||||||
|
details: {
|
||||||
|
marginTop: "6px",
|
||||||
|
paddingLeft: "18px",
|
||||||
|
},
|
||||||
|
section: {
|
||||||
|
marginBottom: "4px",
|
||||||
|
},
|
||||||
|
code: {
|
||||||
|
background: "#e8e8e8",
|
||||||
|
padding: "4px 8px",
|
||||||
|
borderRadius: "4px",
|
||||||
|
fontSize: "11px",
|
||||||
|
overflowX: "auto",
|
||||||
|
margin: "4px 0",
|
||||||
|
},
|
||||||
|
};
|
||||||
68
frontend/src/components/ChatInput.tsx
Normal file
68
frontend/src/components/ChatInput.tsx
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import { useState } from "react";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
onSend: (content: string) => void;
|
||||||
|
disabled: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ChatInput({ onSend, disabled }: Props) {
|
||||||
|
const [value, setValue] = useState("");
|
||||||
|
|
||||||
|
const handleSubmit = () => {
|
||||||
|
const trimmed = value.trim();
|
||||||
|
if (!trimmed || disabled) return;
|
||||||
|
onSend(trimmed);
|
||||||
|
setValue("");
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||||
|
if (e.key === "Enter" && !e.shiftKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
handleSubmit();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={styles.container}>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={value}
|
||||||
|
onChange={(e) => setValue(e.target.value)}
|
||||||
|
onKeyDown={handleKeyDown}
|
||||||
|
placeholder={disabled ? "Waiting for response..." : "Type a message..."}
|
||||||
|
disabled={disabled}
|
||||||
|
style={styles.input}
|
||||||
|
/>
|
||||||
|
<button onClick={handleSubmit} disabled={disabled || !value.trim()} style={styles.button}>
|
||||||
|
Send
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const styles: Record<string, React.CSSProperties> = {
|
||||||
|
container: {
|
||||||
|
display: "flex",
|
||||||
|
gap: "8px",
|
||||||
|
padding: "12px 16px",
|
||||||
|
borderTop: "1px solid #e0e0e0",
|
||||||
|
background: "white",
|
||||||
|
},
|
||||||
|
input: {
|
||||||
|
flex: 1,
|
||||||
|
padding: "10px 14px",
|
||||||
|
border: "1px solid #ccc",
|
||||||
|
borderRadius: "8px",
|
||||||
|
fontSize: "14px",
|
||||||
|
outline: "none",
|
||||||
|
},
|
||||||
|
button: {
|
||||||
|
padding: "10px 20px",
|
||||||
|
background: "#0066cc",
|
||||||
|
color: "white",
|
||||||
|
border: "none",
|
||||||
|
borderRadius: "8px",
|
||||||
|
fontSize: "14px",
|
||||||
|
cursor: "pointer",
|
||||||
|
},
|
||||||
|
};
|
||||||
82
frontend/src/components/ChatMessages.tsx
Normal file
82
frontend/src/components/ChatMessages.tsx
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import { useEffect, useRef } from "react";
|
||||||
|
import type { ChatMessage } from "../types";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
messages: ChatMessage[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ChatMessages({ messages }: Props) {
|
||||||
|
const bottomRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
bottomRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||||
|
}, [messages]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={styles.container}>
|
||||||
|
{messages.map((msg) => (
|
||||||
|
<div
|
||||||
|
key={msg.id}
|
||||||
|
style={{
|
||||||
|
...styles.message,
|
||||||
|
...(msg.sender === "user" ? styles.userMessage : styles.agentMessage),
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div style={styles.header}>
|
||||||
|
<span style={styles.sender}>
|
||||||
|
{msg.sender === "user" ? "You" : msg.agent || "Agent"}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div style={styles.content}>
|
||||||
|
{msg.content}
|
||||||
|
{msg.isStreaming && <span style={styles.cursor}>|</span>}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
<div ref={bottomRef} />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const styles: Record<string, React.CSSProperties> = {
|
||||||
|
container: {
|
||||||
|
flex: 1,
|
||||||
|
overflowY: "auto",
|
||||||
|
padding: "16px",
|
||||||
|
display: "flex",
|
||||||
|
flexDirection: "column",
|
||||||
|
gap: "12px",
|
||||||
|
},
|
||||||
|
message: {
|
||||||
|
maxWidth: "80%",
|
||||||
|
padding: "10px 14px",
|
||||||
|
borderRadius: "12px",
|
||||||
|
lineHeight: 1.5,
|
||||||
|
},
|
||||||
|
userMessage: {
|
||||||
|
alignSelf: "flex-end",
|
||||||
|
background: "#0066cc",
|
||||||
|
color: "white",
|
||||||
|
},
|
||||||
|
agentMessage: {
|
||||||
|
alignSelf: "flex-start",
|
||||||
|
background: "#f0f0f0",
|
||||||
|
color: "#333",
|
||||||
|
},
|
||||||
|
header: {
|
||||||
|
marginBottom: "4px",
|
||||||
|
},
|
||||||
|
sender: {
|
||||||
|
fontSize: "12px",
|
||||||
|
fontWeight: 600,
|
||||||
|
opacity: 0.8,
|
||||||
|
},
|
||||||
|
content: {
|
||||||
|
fontSize: "14px",
|
||||||
|
whiteSpace: "pre-wrap",
|
||||||
|
},
|
||||||
|
cursor: {
|
||||||
|
animation: "blink 1s infinite",
|
||||||
|
opacity: 0.7,
|
||||||
|
},
|
||||||
|
};
|
||||||
81
frontend/src/components/InterruptPrompt.tsx
Normal file
81
frontend/src/components/InterruptPrompt.tsx
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import type { InterruptMessage } from "../types";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
interrupt: InterruptMessage;
|
||||||
|
onRespond: (approved: boolean) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function InterruptPrompt({ interrupt, onRespond }: Props) {
|
||||||
|
return (
|
||||||
|
<div style={styles.container}>
|
||||||
|
<div style={styles.header}>Action Requires Approval</div>
|
||||||
|
<div style={styles.action}>
|
||||||
|
<strong>Action:</strong> {interrupt.action}
|
||||||
|
</div>
|
||||||
|
{"message" in interrupt.params && interrupt.params.message != null && (
|
||||||
|
<div style={styles.detail}>{String(interrupt.params.message)}</div>
|
||||||
|
)}
|
||||||
|
{"order_id" in interrupt.params && interrupt.params.order_id != null && (
|
||||||
|
<div style={styles.detail}>
|
||||||
|
<strong>Order:</strong> {String(interrupt.params.order_id)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div style={styles.buttons}>
|
||||||
|
<button onClick={() => onRespond(true)} style={styles.approveBtn}>
|
||||||
|
Approve
|
||||||
|
</button>
|
||||||
|
<button onClick={() => onRespond(false)} style={styles.rejectBtn}>
|
||||||
|
Reject
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const styles: Record<string, React.CSSProperties> = {
|
||||||
|
container: {
|
||||||
|
margin: "12px 16px",
|
||||||
|
padding: "16px",
|
||||||
|
border: "2px solid #ff9800",
|
||||||
|
borderRadius: "12px",
|
||||||
|
background: "#fff8e1",
|
||||||
|
},
|
||||||
|
header: {
|
||||||
|
fontWeight: 700,
|
||||||
|
fontSize: "14px",
|
||||||
|
color: "#e65100",
|
||||||
|
marginBottom: "8px",
|
||||||
|
},
|
||||||
|
action: {
|
||||||
|
fontSize: "14px",
|
||||||
|
marginBottom: "4px",
|
||||||
|
},
|
||||||
|
detail: {
|
||||||
|
fontSize: "13px",
|
||||||
|
color: "#555",
|
||||||
|
marginBottom: "4px",
|
||||||
|
},
|
||||||
|
buttons: {
|
||||||
|
display: "flex",
|
||||||
|
gap: "8px",
|
||||||
|
marginTop: "12px",
|
||||||
|
},
|
||||||
|
approveBtn: {
|
||||||
|
padding: "8px 20px",
|
||||||
|
background: "#4caf50",
|
||||||
|
color: "white",
|
||||||
|
border: "none",
|
||||||
|
borderRadius: "6px",
|
||||||
|
cursor: "pointer",
|
||||||
|
fontWeight: 600,
|
||||||
|
},
|
||||||
|
rejectBtn: {
|
||||||
|
padding: "8px 20px",
|
||||||
|
background: "#f44336",
|
||||||
|
color: "white",
|
||||||
|
border: "none",
|
||||||
|
borderRadius: "6px",
|
||||||
|
cursor: "pointer",
|
||||||
|
fontWeight: 600,
|
||||||
|
},
|
||||||
|
};
|
||||||
104
frontend/src/hooks/useWebSocket.ts
Normal file
104
frontend/src/hooks/useWebSocket.ts
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
import type {
|
||||||
|
ClientMessage,
|
||||||
|
ConnectionStatus,
|
||||||
|
InterruptResponse,
|
||||||
|
SendMessage,
|
||||||
|
ServerMessage,
|
||||||
|
} from "../types";
|
||||||
|
|
||||||
|
const WS_URL = `${window.location.protocol === "https:" ? "wss:" : "ws:"}//${window.location.host}/ws`;
|
||||||
|
const MAX_RETRIES = 5;
|
||||||
|
const BASE_DELAY_MS = 1000;
|
||||||
|
|
||||||
|
function getOrCreateThreadId(): string {
|
||||||
|
const key = "smart_support_thread_id";
|
||||||
|
let id = sessionStorage.getItem(key);
|
||||||
|
if (!id) {
|
||||||
|
id = crypto.randomUUID();
|
||||||
|
sessionStorage.setItem(key, id);
|
||||||
|
}
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useWebSocket(onMessage: (msg: ServerMessage) => void) {
|
||||||
|
const [status, setStatus] = useState<ConnectionStatus>("disconnected");
|
||||||
|
const [threadId] = useState(getOrCreateThreadId);
|
||||||
|
const wsRef = useRef<WebSocket | null>(null);
|
||||||
|
const retriesRef = useRef(0);
|
||||||
|
const onMessageRef = useRef(onMessage);
|
||||||
|
onMessageRef.current = onMessage;
|
||||||
|
|
||||||
|
const connect = useCallback(() => {
|
||||||
|
if (wsRef.current?.readyState === WebSocket.OPEN) return;
|
||||||
|
|
||||||
|
setStatus("connecting");
|
||||||
|
const ws = new WebSocket(WS_URL);
|
||||||
|
|
||||||
|
ws.onopen = () => {
|
||||||
|
setStatus("connected");
|
||||||
|
retriesRef.current = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(event.data) as ServerMessage;
|
||||||
|
onMessageRef.current(data);
|
||||||
|
} catch {
|
||||||
|
// ignore non-JSON messages
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = () => {
|
||||||
|
setStatus("disconnected");
|
||||||
|
wsRef.current = null;
|
||||||
|
|
||||||
|
if (retriesRef.current < MAX_RETRIES) {
|
||||||
|
const delay = BASE_DELAY_MS * Math.pow(2, retriesRef.current);
|
||||||
|
retriesRef.current += 1;
|
||||||
|
setTimeout(connect, delay);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = () => {
|
||||||
|
ws.close();
|
||||||
|
};
|
||||||
|
|
||||||
|
wsRef.current = ws;
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
connect();
|
||||||
|
return () => {
|
||||||
|
wsRef.current?.close();
|
||||||
|
};
|
||||||
|
}, [connect]);
|
||||||
|
|
||||||
|
const send = useCallback((msg: ClientMessage) => {
|
||||||
|
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||||
|
wsRef.current.send(JSON.stringify(msg));
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const sendMessage = useCallback(
|
||||||
|
(content: string) => {
|
||||||
|
const msg: SendMessage = { type: "message", thread_id: threadId, content };
|
||||||
|
send(msg);
|
||||||
|
},
|
||||||
|
[send, threadId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const sendInterruptResponse = useCallback(
|
||||||
|
(approved: boolean) => {
|
||||||
|
const msg: InterruptResponse = {
|
||||||
|
type: "interrupt_response",
|
||||||
|
thread_id: threadId,
|
||||||
|
approved,
|
||||||
|
};
|
||||||
|
send(msg);
|
||||||
|
},
|
||||||
|
[send, threadId]
|
||||||
|
);
|
||||||
|
|
||||||
|
return { status, threadId, sendMessage, sendInterruptResponse };
|
||||||
|
}
|
||||||
9
frontend/src/main.tsx
Normal file
9
frontend/src/main.tsx
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import { StrictMode } from "react";
|
||||||
|
import { createRoot } from "react-dom/client";
|
||||||
|
import App from "./App";
|
||||||
|
|
||||||
|
createRoot(document.getElementById("root")!).render(
|
||||||
|
<StrictMode>
|
||||||
|
<App />
|
||||||
|
</StrictMode>
|
||||||
|
);
|
||||||
200
frontend/src/pages/ChatPage.tsx
Normal file
200
frontend/src/pages/ChatPage.tsx
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
import { useCallback, useState } from "react";
|
||||||
|
import { AgentAction } from "../components/AgentAction";
|
||||||
|
import { ChatInput } from "../components/ChatInput";
|
||||||
|
import { ChatMessages } from "../components/ChatMessages";
|
||||||
|
import { InterruptPrompt } from "../components/InterruptPrompt";
|
||||||
|
import { useWebSocket } from "../hooks/useWebSocket";
|
||||||
|
import type {
|
||||||
|
ChatMessage,
|
||||||
|
ConnectionStatus,
|
||||||
|
InterruptMessage,
|
||||||
|
ServerMessage,
|
||||||
|
ToolAction,
|
||||||
|
} from "../types";
|
||||||
|
|
||||||
|
let msgCounter = 0;
|
||||||
|
function nextId(): string {
|
||||||
|
msgCounter += 1;
|
||||||
|
return `msg-${msgCounter}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ChatPage() {
|
||||||
|
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||||
|
const [toolActions, setToolActions] = useState<ToolAction[]>([]);
|
||||||
|
const [currentInterrupt, setCurrentInterrupt] = useState<InterruptMessage | null>(null);
|
||||||
|
const [isWaiting, setIsWaiting] = useState(false);
|
||||||
|
|
||||||
|
const handleServerMessage = useCallback((msg: ServerMessage) => {
|
||||||
|
switch (msg.type) {
|
||||||
|
case "token": {
|
||||||
|
setMessages((prev) => {
|
||||||
|
const last = prev[prev.length - 1];
|
||||||
|
if (last && last.sender === "agent" && last.isStreaming) {
|
||||||
|
return [
|
||||||
|
...prev.slice(0, -1),
|
||||||
|
{ ...last, content: last.content + msg.content },
|
||||||
|
];
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: nextId(),
|
||||||
|
sender: "agent",
|
||||||
|
agent: msg.agent,
|
||||||
|
content: msg.content,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
isStreaming: true,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case "tool_call": {
|
||||||
|
setToolActions((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: nextId(),
|
||||||
|
agent: msg.agent,
|
||||||
|
tool: msg.tool,
|
||||||
|
args: msg.args,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case "interrupt": {
|
||||||
|
setCurrentInterrupt(msg);
|
||||||
|
setIsWaiting(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case "message_complete": {
|
||||||
|
setMessages((prev) => {
|
||||||
|
const last = prev[prev.length - 1];
|
||||||
|
if (last && last.isStreaming) {
|
||||||
|
return [...prev.slice(0, -1), { ...last, isStreaming: false }];
|
||||||
|
}
|
||||||
|
return prev;
|
||||||
|
});
|
||||||
|
setIsWaiting(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case "error": {
|
||||||
|
setMessages((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: nextId(),
|
||||||
|
sender: "agent",
|
||||||
|
agent: "System",
|
||||||
|
content: `Error: ${msg.message}`,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
setIsWaiting(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const { status, sendMessage, sendInterruptResponse } =
|
||||||
|
useWebSocket(handleServerMessage);
|
||||||
|
|
||||||
|
const handleSend = useCallback(
|
||||||
|
(content: string) => {
|
||||||
|
setMessages((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: nextId(),
|
||||||
|
sender: "user",
|
||||||
|
content,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
setIsWaiting(true);
|
||||||
|
sendMessage(content);
|
||||||
|
},
|
||||||
|
[sendMessage]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleInterruptResponse = useCallback(
|
||||||
|
(approved: boolean) => {
|
||||||
|
sendInterruptResponse(approved);
|
||||||
|
setCurrentInterrupt(null);
|
||||||
|
setIsWaiting(true);
|
||||||
|
},
|
||||||
|
[sendInterruptResponse]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={styles.page}>
|
||||||
|
<div style={styles.header}>
|
||||||
|
<h1 style={styles.title}>Smart Support</h1>
|
||||||
|
<StatusIndicator status={status} />
|
||||||
|
</div>
|
||||||
|
<ChatMessages messages={messages} />
|
||||||
|
{toolActions.length > 0 && (
|
||||||
|
<div style={styles.actionsBar}>
|
||||||
|
{toolActions.slice(-3).map((action) => (
|
||||||
|
<AgentAction key={action.id} action={action} />
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{currentInterrupt && (
|
||||||
|
<InterruptPrompt
|
||||||
|
interrupt={currentInterrupt}
|
||||||
|
onRespond={handleInterruptResponse}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<ChatInput onSend={handleSend} disabled={isWaiting || status !== "connected"} />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function StatusIndicator({ status }: { status: ConnectionStatus }) {
|
||||||
|
const colors: Record<ConnectionStatus, string> = {
|
||||||
|
connected: "#4caf50",
|
||||||
|
connecting: "#ff9800",
|
||||||
|
disconnected: "#f44336",
|
||||||
|
};
|
||||||
|
return (
|
||||||
|
<div style={{ display: "flex", alignItems: "center", gap: "6px" }}>
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
width: "8px",
|
||||||
|
height: "8px",
|
||||||
|
borderRadius: "50%",
|
||||||
|
background: colors[status],
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<span style={{ fontSize: "12px", color: "#666" }}>{status}</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const styles: Record<string, React.CSSProperties> = {
|
||||||
|
page: {
|
||||||
|
height: "100vh",
|
||||||
|
display: "flex",
|
||||||
|
flexDirection: "column",
|
||||||
|
background: "white",
|
||||||
|
maxWidth: "800px",
|
||||||
|
margin: "0 auto",
|
||||||
|
boxShadow: "0 0 20px rgba(0,0,0,0.1)",
|
||||||
|
},
|
||||||
|
header: {
|
||||||
|
display: "flex",
|
||||||
|
justifyContent: "space-between",
|
||||||
|
alignItems: "center",
|
||||||
|
padding: "12px 16px",
|
||||||
|
borderBottom: "1px solid #e0e0e0",
|
||||||
|
},
|
||||||
|
title: {
|
||||||
|
fontSize: "18px",
|
||||||
|
fontWeight: 700,
|
||||||
|
margin: 0,
|
||||||
|
color: "#333",
|
||||||
|
},
|
||||||
|
actionsBar: {
|
||||||
|
borderTop: "1px solid #eee",
|
||||||
|
paddingTop: "4px",
|
||||||
|
},
|
||||||
|
};
|
||||||
86
frontend/src/types.ts
Normal file
86
frontend/src/types.ts
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
/** WebSocket message protocol types matching ARCHITECTURE.md Section 6.1 */
|
||||||
|
|
||||||
|
// -- Server -> Client messages --
|
||||||
|
|
||||||
|
export interface TokenMessage {
|
||||||
|
type: "token";
|
||||||
|
agent: string;
|
||||||
|
content: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface InterruptMessage {
|
||||||
|
type: "interrupt";
|
||||||
|
thread_id: string;
|
||||||
|
action: string;
|
||||||
|
params: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ToolCallMessage {
|
||||||
|
type: "tool_call";
|
||||||
|
agent: string;
|
||||||
|
tool: string;
|
||||||
|
args: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ToolResultMessage {
|
||||||
|
type: "tool_result";
|
||||||
|
agent: string;
|
||||||
|
tool: string;
|
||||||
|
result: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MessageCompleteMessage {
|
||||||
|
type: "message_complete";
|
||||||
|
thread_id: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ErrorMessage {
|
||||||
|
type: "error";
|
||||||
|
message: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ServerMessage =
|
||||||
|
| TokenMessage
|
||||||
|
| InterruptMessage
|
||||||
|
| ToolCallMessage
|
||||||
|
| ToolResultMessage
|
||||||
|
| MessageCompleteMessage
|
||||||
|
| ErrorMessage;
|
||||||
|
|
||||||
|
// -- Client -> Server messages --
|
||||||
|
|
||||||
|
export interface SendMessage {
|
||||||
|
type: "message";
|
||||||
|
thread_id: string;
|
||||||
|
content: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface InterruptResponse {
|
||||||
|
type: "interrupt_response";
|
||||||
|
thread_id: string;
|
||||||
|
approved: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ClientMessage = SendMessage | InterruptResponse;
|
||||||
|
|
||||||
|
// -- UI state --
|
||||||
|
|
||||||
|
export interface ChatMessage {
|
||||||
|
id: string;
|
||||||
|
sender: "user" | "agent";
|
||||||
|
agent?: string;
|
||||||
|
content: string;
|
||||||
|
timestamp: number;
|
||||||
|
isStreaming?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ToolAction {
|
||||||
|
id: string;
|
||||||
|
agent: string;
|
||||||
|
tool: string;
|
||||||
|
args: Record<string, unknown>;
|
||||||
|
result?: unknown;
|
||||||
|
timestamp: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ConnectionStatus = "connecting" | "connected" | "disconnected";
|
||||||
21
frontend/tsconfig.json
Normal file
21
frontend/tsconfig.json
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "ES2020",
|
||||||
|
"useDefineForClassFields": true,
|
||||||
|
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"isolatedModules": true,
|
||||||
|
"moduleDetection": "force",
|
||||||
|
"noEmit": true,
|
||||||
|
"jsx": "react-jsx",
|
||||||
|
"strict": true,
|
||||||
|
"noUnusedLocals": true,
|
||||||
|
"noUnusedParameters": true,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
"noUncheckedSideEffectImports": true
|
||||||
|
},
|
||||||
|
"include": ["src"]
|
||||||
|
}
|
||||||
15
frontend/vite.config.ts
Normal file
15
frontend/vite.config.ts
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import react from "@vitejs/plugin-react";
|
||||||
|
import { defineConfig } from "vite";
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
server: {
|
||||||
|
port: 5173,
|
||||||
|
proxy: {
|
||||||
|
"/ws": {
|
||||||
|
target: "ws://localhost:8000",
|
||||||
|
ws: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user