diff --git a/.gitignore b/.gitignore index ffa9be3..0a2882d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,10 +7,15 @@ build/ .venv/ venv/ .env +.pytest_cache/ +.coverage +htmlcov/ +.ruff_cache/ # Node node_modules/ .next/ +frontend/dist/ # IDE .vscode/ diff --git a/CLAUDE.md b/CLAUDE.md index 64c1138..6298847 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -234,7 +234,7 @@ A checkpoint includes: | Phase | Branch | Focus | Status | |-------|--------|-------|--------| -| 1 | `phase-1/core-framework` | FastAPI + LangGraph + React chat loop + PostgresSaver | NOT STARTED | +| 1 | `phase-1/core-framework` | FastAPI + LangGraph + React chat loop + PostgresSaver | IN PROGRESS | | 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | NOT STARTED | | 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | NOT STARTED | | 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | NOT STARTED | diff --git a/backend/.env.example b/backend/.env.example new file mode 100644 index 0000000..79b360b --- /dev/null +++ b/backend/.env.example @@ -0,0 +1,19 @@ +# Database +DATABASE_URL=postgresql://smart_support:dev_password@localhost:5432/smart_support + +# LLM Provider: anthropic | openai | google +LLM_PROVIDER=anthropic +LLM_MODEL=claude-sonnet-4-6 + +# API Keys (set the one matching your LLM_PROVIDER) +ANTHROPIC_API_KEY= +OPENAI_API_KEY= +GOOGLE_API_KEY= + +# Session +SESSION_TTL_MINUTES=30 +INTERRUPT_TTL_MINUTES=30 + +# Server +WS_HOST=0.0.0.0 +WS_PORT=8000 diff --git a/backend/Dockerfile b/backend/Dockerfile new file mode 100644 index 0000000..61e5da6 --- /dev/null +++ b/backend/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3.11-slim + +WORKDIR /app + +COPY pyproject.toml . +RUN pip install --no-cache-dir -e . + +COPY . . + +EXPOSE 8000 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/backend/agents.yaml b/backend/agents.yaml new file mode 100644 index 0000000..10e4721 --- /dev/null +++ b/backend/agents.yaml @@ -0,0 +1,31 @@ +agents: + - name: order_lookup + description: "Looks up order status and tracking information. Use for queries about order status, shipping, and delivery." + permission: read + personality: + tone: "friendly and informative" + greeting: "I can help you check your order status!" + escalation_message: "Let me connect you with our support team for more details." + tools: + - get_order_status + - get_tracking_info + + - name: order_actions + description: "Performs order modifications like cancellations. Use when the customer wants to cancel, modify, or change an order." + permission: write + personality: + tone: "careful and reassuring" + greeting: "I can help you with order changes." + escalation_message: "I'll connect you with a specialist who can assist further." + tools: + - cancel_order + + - name: fallback + description: "Handles general questions, unclear requests, and conversations that don't match other agents." + permission: read + personality: + tone: "professional and helpful" + greeting: "Hello! How can I help you today?" + escalation_message: "Let me connect you with a human agent who can better assist you." + tools: + - fallback_respond diff --git a/backend/app/__init__.py b/backend/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/agents/__init__.py b/backend/app/agents/__init__.py new file mode 100644 index 0000000..9153db2 --- /dev/null +++ b/backend/app/agents/__init__.py @@ -0,0 +1,30 @@ +"""Agent tools registry -- maps tool name strings to actual tool functions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from langchain_core.tools import BaseTool + +from app.agents.fallback import fallback_respond +from app.agents.order_actions import cancel_order +from app.agents.order_lookup import get_order_status, get_tracking_info + +_TOOL_MAP: dict[str, BaseTool] = { + "get_order_status": get_order_status, + "get_tracking_info": get_tracking_info, + "cancel_order": cancel_order, + "fallback_respond": fallback_respond, +} + + +def get_tools_by_names(tool_names: list[str]) -> list[BaseTool]: + """Resolve tool name strings from YAML config to actual tool objects.""" + tools = [] + for name in tool_names: + if name not in _TOOL_MAP: + available = ", ".join(sorted(_TOOL_MAP.keys())) + raise ValueError(f"Unknown tool '{name}'. Available tools: {available}") + tools.append(_TOOL_MAP[name]) + return tools diff --git a/backend/app/agents/fallback.py b/backend/app/agents/fallback.py new file mode 100644 index 0000000..bd10271 --- /dev/null +++ b/backend/app/agents/fallback.py @@ -0,0 +1,18 @@ +"""Fallback agent tools -- handles unmatched intents.""" + +from __future__ import annotations + +from langchain_core.tools import tool + + +@tool +def fallback_respond(query: str) -> str: + """Provide a helpful response when the user's intent doesn't match a specific agent.""" + return ( + "I'm here to help with order inquiries and actions. " + "Here's what I can do:\n" + "- Check order status (e.g., 'What is the status of order 1042?')\n" + "- Get tracking information (e.g., 'Track order 1042')\n" + "- Cancel an order (e.g., 'Cancel order 1042')\n\n" + "Could you please rephrase your request?" + ) diff --git a/backend/app/agents/order_actions.py b/backend/app/agents/order_actions.py new file mode 100644 index 0000000..7df45ea --- /dev/null +++ b/backend/app/agents/order_actions.py @@ -0,0 +1,37 @@ +"""Order action tools -- write operations requiring human approval.""" + +from __future__ import annotations + +from langchain_core.tools import tool +from langgraph.types import interrupt + + +@tool +def cancel_order(order_id: str) -> dict: + """Cancel an order. Requires human approval before execution.""" + response = interrupt( + { + "action": "cancel_order", + "order_id": order_id, + "message": f"Please confirm: cancel order {order_id}?", + } + ) + + if isinstance(response, bool): + approved = response + elif isinstance(response, dict): + approved = response.get("approved", False) + else: + approved = bool(response) + + if approved: + return { + "status": "cancelled", + "order_id": order_id, + "message": f"Order {order_id} has been successfully cancelled.", + } + return { + "status": "kept", + "order_id": order_id, + "message": f"Order {order_id} cancellation was declined. The order remains active.", + } diff --git a/backend/app/agents/order_lookup.py b/backend/app/agents/order_lookup.py new file mode 100644 index 0000000..888139e --- /dev/null +++ b/backend/app/agents/order_lookup.py @@ -0,0 +1,68 @@ +"""Order lookup tools -- read-only operations.""" + +from __future__ import annotations + +from types import MappingProxyType + +from langchain_core.tools import tool + +MOCK_ORDERS: MappingProxyType[str, dict] = MappingProxyType( + { + "1042": { + "order_id": "1042", + "status": "shipped", + "items": ["Wireless Headphones", "USB-C Cable"], + "total": 89.99, + "placed_at": "2026-03-25", + }, + "1043": { + "order_id": "1043", + "status": "processing", + "items": ["Laptop Stand"], + "total": 49.99, + "placed_at": "2026-03-28", + }, + "1044": { + "order_id": "1044", + "status": "delivered", + "items": ["Mechanical Keyboard", "Mouse Pad"], + "total": 159.99, + "placed_at": "2026-03-20", + }, + } +) + +MOCK_TRACKING: MappingProxyType[str, dict] = MappingProxyType( + { + "1042": { + "order_id": "1042", + "carrier": "FedEx", + "tracking_number": "FX-9876543210", + "estimated_delivery": "2026-04-01", + "current_location": "Distribution Center, Chicago IL", + }, + "1044": { + "order_id": "1044", + "carrier": "UPS", + "tracking_number": "1Z-5678901234", + "estimated_delivery": "2026-03-22", + "current_location": "Delivered", + }, + } +) + + +@tool +def get_order_status(order_id: str) -> dict: + """Look up the current status of an order by order ID.""" + if order_id in MOCK_ORDERS: + return dict(MOCK_ORDERS[order_id]) + return {"error": f"Order {order_id} not found", "order_id": order_id} + + +@tool +def get_tracking_info(order_id: str) -> dict: + """Get shipping and tracking information for an order.""" + if order_id in MOCK_TRACKING: + return dict(MOCK_TRACKING[order_id]) + return {"error": f"No tracking information for order {order_id}", "order_id": order_id} diff --git a/backend/app/callbacks.py b/backend/app/callbacks.py new file mode 100644 index 0000000..3c3b2d7 --- /dev/null +++ b/backend/app/callbacks.py @@ -0,0 +1,60 @@ +"""Token usage tracking callback handler.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from langchain_core.callbacks import BaseCallbackHandler + +if TYPE_CHECKING: + from langchain_core.outputs import LLMResult + +COST_PER_1K_TOKENS: dict[str, dict[str, float]] = { + "claude-sonnet-4-6": {"prompt": 0.003, "completion": 0.015}, + "claude-haiku-4-5-20251001": {"prompt": 0.0008, "completion": 0.004}, + "gpt-4o": {"prompt": 0.0025, "completion": 0.01}, + "gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006}, +} + +DEFAULT_COST = {"prompt": 0.003, "completion": 0.015} + + +@dataclass(frozen=True) +class TokenUsage: + prompt_tokens: int + completion_tokens: int + total_tokens: int + total_cost_usd: float + + +class TokenUsageCallbackHandler(BaseCallbackHandler): + """Accumulates token usage and cost across LLM invocations.""" + + def __init__(self, model_name: str = "") -> None: + self._model_name = model_name + self._prompt_tokens = 0 + self._completion_tokens = 0 + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + if response.llm_output and "token_usage" in response.llm_output: + usage = response.llm_output["token_usage"] + self._prompt_tokens += usage.get("prompt_tokens", 0) + self._completion_tokens += usage.get("completion_tokens", 0) + + def get_usage(self) -> TokenUsage: + costs = COST_PER_1K_TOKENS.get(self._model_name, DEFAULT_COST) + cost = ( + self._prompt_tokens * costs["prompt"] / 1000 + + self._completion_tokens * costs["completion"] / 1000 + ) + return TokenUsage( + prompt_tokens=self._prompt_tokens, + completion_tokens=self._completion_tokens, + total_tokens=self._prompt_tokens + self._completion_tokens, + total_cost_usd=round(cost, 6), + ) + + def reset(self) -> None: + self._prompt_tokens = 0 + self._completion_tokens = 0 diff --git a/backend/app/config.py b/backend/app/config.py new file mode 100644 index 0000000..9198ab9 --- /dev/null +++ b/backend/app/config.py @@ -0,0 +1,46 @@ +"""Centralized application configuration via pydantic-settings.""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + ) + + database_url: str + + llm_provider: Literal["anthropic", "openai", "google"] = "anthropic" + llm_model: str = "claude-sonnet-4-6" + + session_ttl_minutes: int = 30 + interrupt_ttl_minutes: int = 30 + + ws_host: str = "0.0.0.0" + ws_port: int = 8000 + + anthropic_api_key: str = "" + openai_api_key: str = "" + google_api_key: str = "" + + @model_validator(mode="after") + def validate_provider_key(self) -> Settings: + key_map = { + "anthropic": self.anthropic_api_key, + "openai": self.openai_api_key, + "google": self.google_api_key, + } + key = key_map.get(self.llm_provider, "") + if not key: + raise ValueError( + f"API key for provider '{self.llm_provider}' is required. " + f"Set the corresponding environment variable." + ) + return self diff --git a/backend/app/db.py b/backend/app/db.py new file mode 100644 index 0000000..082e6c6 --- /dev/null +++ b/backend/app/db.py @@ -0,0 +1,61 @@ +"""Database connection pool and PostgresSaver checkpoint setup.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from psycopg.rows import dict_row +from psycopg_pool import AsyncConnectionPool + +if TYPE_CHECKING: + from app.config import Settings + +_CONVERSATIONS_DDL = """ +CREATE TABLE IF NOT EXISTS conversations ( + thread_id TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(), + total_tokens INTEGER NOT NULL DEFAULT 0, + total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0, + status TEXT NOT NULL DEFAULT 'active' +); +""" + +_INTERRUPTS_DDL = """ +CREATE TABLE IF NOT EXISTS active_interrupts ( + interrupt_id TEXT PRIMARY KEY, + thread_id TEXT NOT NULL REFERENCES conversations(thread_id), + action TEXT NOT NULL, + params JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + resolved_at TIMESTAMPTZ, + resolution TEXT +); +""" + + +async def create_pool(settings: Settings) -> AsyncConnectionPool: + """Create an async connection pool with the required psycopg settings.""" + pool = AsyncConnectionPool( + conninfo=settings.database_url, + kwargs={"autocommit": True, "row_factory": dict_row}, + min_size=2, + max_size=10, + ) + await pool.open() + return pool + + +async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver: + """Create and initialize the LangGraph checkpointer.""" + checkpointer = AsyncPostgresSaver(conn=pool) + await checkpointer.setup() + return checkpointer + + +async def setup_app_tables(pool: AsyncConnectionPool) -> None: + """Create application-specific tables (conversations, active_interrupts).""" + async with pool.connection() as conn: + await conn.execute(_CONVERSATIONS_DDL) + await conn.execute(_INTERRUPTS_DDL) diff --git a/backend/app/graph.py b/backend/app/graph.py new file mode 100644 index 0000000..699b4d8 --- /dev/null +++ b/backend/app/graph.py @@ -0,0 +1,70 @@ +"""LangGraph Supervisor construction -- connects registry, agents, LLM, and persistence.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from langgraph.prebuilt import create_react_agent +from langgraph_supervisor import create_supervisor + +from app.agents import get_tools_by_names + +if TYPE_CHECKING: + from langchain_core.language_models import BaseChatModel + from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver + from langgraph.graph.state import CompiledStateGraph + + from app.registry import AgentRegistry + +SUPERVISOR_PROMPT = ( + "You are a customer support supervisor. " + "Route customer requests to the appropriate agent based on their description. " + "For order status and tracking queries, use the order_lookup agent. " + "For order modifications like cancellations, use the order_actions agent. " + "For anything else, use the fallback agent." +) + + +def build_agent_nodes( + registry: AgentRegistry, + llm: BaseChatModel, +) -> list: + """Create LangGraph react agent nodes from registry configurations.""" + agent_nodes = [] + for agent_config in registry.list_agents(): + tools = get_tools_by_names(agent_config.tools) + + system_prompt = ( + f"You are the {agent_config.name} agent. " + f"Personality: {agent_config.personality.tone}. " + f"{agent_config.description} " + f"Permission level: {agent_config.permission}." + ) + + agent_node = create_react_agent( + model=llm, + tools=tools, + name=agent_config.name, + prompt=system_prompt, + ) + agent_nodes.append(agent_node) + + return agent_nodes + + +def build_graph( + registry: AgentRegistry, + llm: BaseChatModel, + checkpointer: AsyncPostgresSaver, +) -> CompiledStateGraph: + """Build and compile the LangGraph supervisor graph.""" + agent_nodes = build_agent_nodes(registry, llm) + + workflow = create_supervisor( + agent_nodes, + model=llm, + prompt=SUPERVISOR_PROMPT, + output_mode="full_history", + ) + + return workflow.compile(checkpointer=checkpointer) diff --git a/backend/app/llm.py b/backend/app/llm.py new file mode 100644 index 0000000..22a26d1 --- /dev/null +++ b/backend/app/llm.py @@ -0,0 +1,42 @@ +"""LLM provider factory with prompt caching support.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from langchain_core.language_models import BaseChatModel + + from app.config import Settings + + +def create_llm(settings: Settings) -> BaseChatModel: + """Create an LLM instance based on the configured provider.""" + provider = settings.llm_provider + model = settings.llm_model + + if provider == "anthropic": + from langchain_anthropic import ChatAnthropic + + return ChatAnthropic( + model=model, + api_key=settings.anthropic_api_key, + ) + + if provider == "openai": + from langchain_openai import ChatOpenAI + + return ChatOpenAI( + model=model, + api_key=settings.openai_api_key, + ) + + if provider == "google": + from langchain_google_genai import ChatGoogleGenerativeAI + + return ChatGoogleGenerativeAI( + model=model, + google_api_key=settings.google_api_key, + ) + + raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.") diff --git a/backend/app/main.py b/backend/app/main.py new file mode 100644 index 0000000..2f64e9e --- /dev/null +++ b/backend/app/main.py @@ -0,0 +1,83 @@ +"""FastAPI application entry point.""" + +from __future__ import annotations + +import logging +from contextlib import asynccontextmanager +from pathlib import Path +from typing import TYPE_CHECKING + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.staticfiles import StaticFiles + +from app.callbacks import TokenUsageCallbackHandler +from app.config import Settings +from app.db import create_checkpointer, create_pool, setup_app_tables +from app.graph import build_graph +from app.llm import create_llm +from app.registry import AgentRegistry +from app.session_manager import SessionManager +from app.ws_handler import dispatch_message + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +logger = logging.getLogger(__name__) + +AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml" +FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist" + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + settings = Settings() + + pool = await create_pool(settings) + checkpointer = await create_checkpointer(pool) + await setup_app_tables(pool) + + registry = AgentRegistry.load(AGENTS_YAML) + llm = create_llm(settings) + graph = build_graph(registry, llm, checkpointer) + session_manager = SessionManager( + session_ttl_seconds=settings.session_ttl_minutes * 60, + ) + + app.state.graph = graph + app.state.session_manager = session_manager + app.state.settings = settings + app.state.pool = pool + + logger.info( + "Smart Support started: %d agents loaded, LLM=%s/%s", + len(registry), + settings.llm_provider, + settings.llm_model, + ) + + yield + + await pool.close() + + +app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan) + + +@app.websocket("/ws") +async def websocket_endpoint(ws: WebSocket) -> None: + await ws.accept() + graph = app.state.graph + session_manager = app.state.session_manager + settings = app.state.settings + callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model) + + try: + while True: + raw_data = await ws.receive_text() + await dispatch_message(ws, graph, session_manager, callback_handler, raw_data) + except WebSocketDisconnect: + logger.info("WebSocket client disconnected") + + +if FRONTEND_DIST.is_dir(): + app.mount("/", StaticFiles(directory=str(FRONTEND_DIST), html=True), name="frontend") diff --git a/backend/app/registry.py b/backend/app/registry.py new file mode 100644 index 0000000..0b1de0d --- /dev/null +++ b/backend/app/registry.py @@ -0,0 +1,104 @@ +"""YAML Agent registry loader with validation.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Literal + +import yaml +from pydantic import BaseModel, field_validator + + +class PersonalityConfig(BaseModel, frozen=True): + tone: str = "professional and helpful" + greeting: str = "Hello! How can I help you today?" + escalation_message: str = "Let me connect you with a human agent." + + +class AgentConfig(BaseModel, frozen=True): + name: str + description: str + permission: Literal["read", "write"] + personality: PersonalityConfig = PersonalityConfig() + tools: list[str] + + @field_validator("name") + @classmethod + def name_not_empty(cls, v: str) -> str: + if not v.strip(): + raise ValueError("Agent name must not be empty") + return v.strip() + + @field_validator("tools") + @classmethod + def tools_not_empty(cls, v: list[str]) -> list[str]: + if not v: + raise ValueError("Agent must have at least one tool") + return v + + +class AgentRegistry: + """Immutable registry of agent configurations loaded from YAML.""" + + def __init__(self, agents: tuple[AgentConfig, ...]) -> None: + self._agents = {agent.name: agent for agent in agents} + + @classmethod + def load(cls, yaml_path: str | Path) -> AgentRegistry: + """Load and validate agent configurations from a YAML file.""" + path = Path(yaml_path) + if not path.exists(): + raise FileNotFoundError(f"Agent config file not found: {path}") + + raw_text = path.read_text(encoding="utf-8") + if not raw_text.strip(): + raise ValueError(f"Agent config file is empty: {path}") + + try: + data = yaml.safe_load(raw_text) + except yaml.YAMLError as exc: + msg = f"Invalid YAML in {path}" + if hasattr(exc, "problem_mark") and exc.problem_mark is not None: + mark = exc.problem_mark + msg += f" at line {mark.line + 1}, column {mark.column + 1}" + raise ValueError(msg) from exc + + if not isinstance(data, dict) or "agents" not in data: + raise ValueError(f"Agent config must have a top-level 'agents' key in {path}") + + raw_agents = data["agents"] + if not isinstance(raw_agents, list) or not raw_agents: + raise ValueError(f"'agents' must be a non-empty list in {path}") + + agents: list[AgentConfig] = [] + seen_names: set[str] = set() + + for i, raw in enumerate(raw_agents): + if not isinstance(raw, dict): + raise ValueError(f"Agent at index {i} must be a mapping in {path}") + try: + agent = AgentConfig(**raw) + except Exception as exc: + raise ValueError(f"Invalid agent config at index {i} in {path}: {exc}") from exc + + if agent.name in seen_names: + raise ValueError(f"Duplicate agent name '{agent.name}' in {path}") + seen_names.add(agent.name) + agents.append(agent) + + return cls(agents=tuple(agents)) + + def get_agent(self, name: str) -> AgentConfig: + if name not in self._agents: + available = ", ".join(sorted(self._agents.keys())) + raise KeyError(f"Agent '{name}' not found. Available: {available}") + return self._agents[name] + + def list_agents(self) -> tuple[AgentConfig, ...]: + return tuple(self._agents.values()) + + def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]: + return tuple(a for a in self._agents.values() if a.permission == permission) + + def __len__(self) -> int: + return len(self._agents) diff --git a/backend/app/session_manager.py b/backend/app/session_manager.py new file mode 100644 index 0000000..6b02dfb --- /dev/null +++ b/backend/app/session_manager.py @@ -0,0 +1,78 @@ +"""Session TTL management with sliding window and interrupt extension.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass + + +@dataclass(frozen=True) +class SessionState: + thread_id: str + last_activity: float + has_pending_interrupt: bool + + +class SessionManager: + """Manages session TTL with sliding window and interrupt extensions. + + - Each message resets the TTL (sliding window). + - A pending interrupt suspends expiration until resolved. + """ + + def __init__(self, session_ttl_seconds: int = 1800) -> None: + self._session_ttl = session_ttl_seconds + self._sessions: dict[str, SessionState] = {} + + def touch(self, thread_id: str) -> SessionState: + """Update last activity for a session (resets sliding window).""" + existing = self._sessions.get(thread_id) + new_state = SessionState( + thread_id=thread_id, + last_activity=time.time(), + has_pending_interrupt=existing.has_pending_interrupt if existing else False, + ) + self._sessions = {**self._sessions, thread_id: new_state} + return new_state + + def is_expired(self, thread_id: str) -> bool: + """Check if a session has expired.""" + state = self._sessions.get(thread_id) + if state is None: + return True + + if state.has_pending_interrupt: + return False + + elapsed = time.time() - state.last_activity + return elapsed > self._session_ttl + + def extend_for_interrupt(self, thread_id: str) -> SessionState: + """Mark session as having a pending interrupt (suspends TTL).""" + existing = self._sessions.get(thread_id) + if existing is None: + return self.touch(thread_id) + + new_state = SessionState( + thread_id=thread_id, + last_activity=existing.last_activity, + has_pending_interrupt=True, + ) + self._sessions = {**self._sessions, thread_id: new_state} + return new_state + + def resolve_interrupt(self, thread_id: str) -> SessionState: + """Remove interrupt extension and reset activity timer.""" + new_state = SessionState( + thread_id=thread_id, + last_activity=time.time(), + has_pending_interrupt=False, + ) + self._sessions = {**self._sessions, thread_id: new_state} + return new_state + + def get_state(self, thread_id: str) -> SessionState | None: + return self._sessions.get(thread_id) + + def remove(self, thread_id: str) -> None: + self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id} diff --git a/backend/app/ws_handler.py b/backend/app/ws_handler.py new file mode 100644 index 0000000..c501921 --- /dev/null +++ b/backend/app/ws_handler.py @@ -0,0 +1,204 @@ +"""WebSocket message handling logic -- extracted from main for testability.""" + +from __future__ import annotations + +import json +import logging +import re +from typing import TYPE_CHECKING, Any + +from langchain_core.messages import HumanMessage +from langgraph.types import Command + +if TYPE_CHECKING: + from fastapi import WebSocket + from langgraph.graph.state import CompiledStateGraph + + from app.callbacks import TokenUsageCallbackHandler + from app.session_manager import SessionManager + +logger = logging.getLogger(__name__) + +MAX_MESSAGE_SIZE = 32_768 # 32 KB +MAX_CONTENT_LENGTH = 8_000 # characters +THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") + + +async def handle_user_message( + ws: WebSocket, + graph: CompiledStateGraph, + session_manager: SessionManager, + callback_handler: TokenUsageCallbackHandler, + thread_id: str, + content: str, +) -> None: + """Process a user message through the graph and stream results back.""" + if session_manager.is_expired(thread_id): + msg = "Session expired. Please start a new conversation." + await _send_json(ws, {"type": "error", "message": msg}) + return + + session_manager.touch(thread_id) + config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} + input_msg = {"messages": [HumanMessage(content=content)]} + + try: + async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"): + msg_chunk, metadata = chunk + node = metadata.get("langgraph_node", "") + + if hasattr(msg_chunk, "tool_calls") and msg_chunk.tool_calls: + for tc in msg_chunk.tool_calls: + await _send_json( + ws, + { + "type": "tool_call", + "agent": node, + "tool": tc.get("name", ""), + "args": tc.get("args", {}), + }, + ) + elif hasattr(msg_chunk, "content") and msg_chunk.content: + await _send_json( + ws, + { + "type": "token", + "agent": node, + "content": msg_chunk.content, + }, + ) + + state = await graph.aget_state(config) + if _has_interrupt(state): + interrupt_data = _extract_interrupt(state) + session_manager.extend_for_interrupt(thread_id) + await _send_json( + ws, + { + "type": "interrupt", + "thread_id": thread_id, + **interrupt_data, + }, + ) + else: + await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) + + except Exception: + logger.exception("Error processing message for thread %s", thread_id) + err = "An error occurred processing your message." + await _send_json(ws, {"type": "error", "message": err}) + + +async def handle_interrupt_response( + ws: WebSocket, + graph: CompiledStateGraph, + session_manager: SessionManager, + callback_handler: TokenUsageCallbackHandler, + thread_id: str, + approved: bool, +) -> None: + """Resume graph execution after interrupt approval/rejection.""" + session_manager.resolve_interrupt(thread_id) + session_manager.touch(thread_id) + + config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} + + try: + async for chunk in graph.astream( + Command(resume=approved), + config=config, + stream_mode="messages", + ): + msg_chunk, metadata = chunk + node = metadata.get("langgraph_node", "") + + if hasattr(msg_chunk, "content") and msg_chunk.content: + await _send_json( + ws, + { + "type": "token", + "agent": node, + "content": msg_chunk.content, + }, + ) + + await _send_json(ws, {"type": "message_complete", "thread_id": thread_id}) + + except Exception: + logger.exception("Error resuming interrupt for thread %s", thread_id) + err = "An error occurred processing your response." + await _send_json(ws, {"type": "error", "message": err}) + + +async def dispatch_message( + ws: WebSocket, + graph: CompiledStateGraph, + session_manager: SessionManager, + callback_handler: TokenUsageCallbackHandler, + raw_data: str, +) -> None: + """Parse and route an incoming WebSocket message.""" + if len(raw_data) > MAX_MESSAGE_SIZE: + await _send_json(ws, {"type": "error", "message": "Message too large"}) + return + + try: + data = json.loads(raw_data) + except json.JSONDecodeError: + await _send_json(ws, {"type": "error", "message": "Invalid JSON"}) + return + + msg_type = data.get("type") + thread_id = data.get("thread_id", "") + + if not thread_id: + await _send_json(ws, {"type": "error", "message": "Missing thread_id"}) + return + + if not THREAD_ID_PATTERN.match(thread_id): + await _send_json(ws, {"type": "error", "message": "Invalid thread_id format"}) + return + + if msg_type == "message": + content = data.get("content", "") + if not content: + await _send_json(ws, {"type": "error", "message": "Missing message content"}) + return + if len(content) > MAX_CONTENT_LENGTH: + await _send_json(ws, {"type": "error", "message": "Message content too long"}) + return + await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content) + + elif msg_type == "interrupt_response": + approved = data.get("approved", False) + await handle_interrupt_response( + ws, graph, session_manager, callback_handler, thread_id, approved + ) + + else: + await _send_json(ws, {"type": "error", "message": "Unknown message type"}) + + +def _has_interrupt(state: Any) -> bool: + """Check if the graph state has a pending interrupt.""" + tasks = getattr(state, "tasks", ()) + return any(getattr(t, "interrupts", ()) for t in tasks) + + +def _extract_interrupt(state: Any) -> dict: + """Extract interrupt data from graph state.""" + for task in getattr(state, "tasks", ()): + for intr in getattr(task, "interrupts", ()): + value = intr.value if hasattr(intr, "value") else {} + if not isinstance(value, dict): + value = {} + return { + "action": value.get("action", "unknown"), + "params": value, + } + return {"action": "unknown", "params": {}} + + +async def _send_json(ws: WebSocket, data: dict) -> None: + """Send a JSON message through the WebSocket.""" + await ws.send_json(data) diff --git a/backend/pyproject.toml b/backend/pyproject.toml new file mode 100644 index 0000000..bc89701 --- /dev/null +++ b/backend/pyproject.toml @@ -0,0 +1,64 @@ +[project] +name = "smart-support" +version = "0.1.0" +description = "AI customer support action-layer framework" +requires-python = ">=3.11" +dependencies = [ + "fastapi>=0.115,<1.0", + "uvicorn[standard]>=0.34,<1.0", + "langgraph>=0.4,<1.0", + "langgraph-supervisor>=0.0.12,<1.0", + "langgraph-checkpoint-postgres>=3.0,<4.0", + "langchain-core>=0.3,<1.0", + "langchain-anthropic>=0.3,<2.0", + "langchain-openai>=0.3,<1.0", + "langchain-google-genai>=2.1,<3.0", + "psycopg[binary,pool]>=3.2,<4.0", + "pydantic>=2.10,<3.0", + "pydantic-settings>=2.7,<3.0", + "pyyaml>=6.0,<7.0", + "python-dotenv>=1.0,<2.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3,<9.0", + "pytest-asyncio>=0.25,<1.0", + "pytest-cov>=6.0,<7.0", + "httpx>=0.28,<1.0", + "ruff>=0.9,<1.0", +] + +[build-system] +requires = ["setuptools>=75.0"] +build-backend = "setuptools.build_meta" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +markers = [ + "unit: per-module isolated tests", + "integration: cross-module with real PostgreSQL", + "e2e: full-stack user flow tests", +] +addopts = "--strict-markers" + +[tool.coverage.run] +source = ["app"] + +[tool.coverage.report] +fail_under = 80 +show_missing = true + +[tool.ruff] +target-version = "py311" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP", "B", "A", "SIM", "TCH"] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["N806", "B017"] + +[tool.ruff.format] +quote-style = "double" diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..a588adf --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,69 @@ +"""Shared test fixtures and marker registration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +import yaml + +from app.config import Settings +from app.registry import AgentRegistry +from app.session_manager import SessionManager + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.fixture +def test_settings() -> Settings: + return Settings( + database_url="postgresql://test:test@localhost:5432/test_db", + llm_provider="anthropic", + llm_model="claude-sonnet-4-6", + anthropic_api_key="test-key", + ) + + +@pytest.fixture +def sample_yaml_path(tmp_path: Path) -> Path: + data = { + "agents": [ + { + "name": "test_reader", + "description": "A test read agent", + "permission": "read", + "tools": ["get_order_status"], + }, + { + "name": "test_writer", + "description": "A test write agent", + "permission": "write", + "personality": { + "tone": "formal", + "greeting": "Greetings.", + "escalation_message": "Escalating now.", + }, + "tools": ["cancel_order"], + }, + { + "name": "test_fallback", + "description": "A fallback agent", + "permission": "read", + "tools": ["fallback_respond"], + }, + ] + } + path = tmp_path / "test_agents.yaml" + path.write_text(yaml.dump(data), encoding="utf-8") + return path + + +@pytest.fixture +def sample_registry(sample_yaml_path: Path) -> AgentRegistry: + return AgentRegistry.load(sample_yaml_path) + + +@pytest.fixture +def session_manager() -> SessionManager: + return SessionManager(session_ttl_seconds=60) diff --git a/backend/tests/e2e/__init__.py b/backend/tests/e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/integration/__init__.py b/backend/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/unit/__init__.py b/backend/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/unit/test_agents.py b/backend/tests/unit/test_agents.py new file mode 100644 index 0000000..17ace97 --- /dev/null +++ b/backend/tests/unit/test_agents.py @@ -0,0 +1,82 @@ +"""Tests for agent tools.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from app.agents.fallback import fallback_respond +from app.agents.order_lookup import get_order_status, get_tracking_info + + +@pytest.mark.unit +class TestOrderLookup: + def test_get_order_status_existing(self) -> None: + result = get_order_status.invoke({"order_id": "1042"}) + assert result["order_id"] == "1042" + assert result["status"] == "shipped" + + def test_get_order_status_not_found(self) -> None: + result = get_order_status.invoke({"order_id": "9999"}) + assert "error" in result + assert "9999" in result["error"] + + def test_get_tracking_info_existing(self) -> None: + result = get_tracking_info.invoke({"order_id": "1042"}) + assert result["carrier"] == "FedEx" + assert result["tracking_number"] == "FX-9876543210" + + def test_get_tracking_info_not_found(self) -> None: + result = get_tracking_info.invoke({"order_id": "1043"}) + assert "error" in result + + def test_all_mock_orders_have_required_fields(self) -> None: + from app.agents.order_lookup import MOCK_ORDERS + + for oid, order in MOCK_ORDERS.items(): + assert "order_id" in order + assert "status" in order + assert order["order_id"] == oid + + +@pytest.mark.unit +class TestFallback: + def test_fallback_respond_returns_help(self) -> None: + result = fallback_respond.invoke({"query": "random question"}) + assert "order" in result.lower() + assert "help" in result.lower() or "can do" in result.lower() + + +@pytest.mark.unit +class TestOrderActions: + def test_cancel_order_approved(self) -> None: + with patch("app.agents.order_actions.interrupt", return_value=True): + from app.agents.order_actions import cancel_order + + result = cancel_order.invoke({"order_id": "1042"}) + assert result["status"] == "cancelled" + assert "1042" in result["message"] + + def test_cancel_order_rejected(self) -> None: + with patch("app.agents.order_actions.interrupt", return_value=False): + from app.agents.order_actions import cancel_order + + result = cancel_order.invoke({"order_id": "1042"}) + assert result["status"] == "kept" + assert "declined" in result["message"] + + +@pytest.mark.unit +class TestToolBridge: + def test_get_tools_by_names(self) -> None: + from app.agents import get_tools_by_names + + tools = get_tools_by_names(["get_order_status", "cancel_order"]) + assert len(tools) == 2 + + def test_unknown_tool_raises(self) -> None: + from app.agents import get_tools_by_names + + with pytest.raises(ValueError, match="Unknown tool"): + get_tools_by_names(["nonexistent_tool"]) diff --git a/backend/tests/unit/test_callbacks.py b/backend/tests/unit/test_callbacks.py new file mode 100644 index 0000000..ca7d4a3 --- /dev/null +++ b/backend/tests/unit/test_callbacks.py @@ -0,0 +1,102 @@ +"""Tests for app.callbacks module.""" + +from __future__ import annotations + +import pytest + +from app.callbacks import TokenUsageCallbackHandler + + +@pytest.mark.unit +class TestTokenUsageCallbackHandler: + def test_initial_state(self) -> None: + handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6") + usage = handler.get_usage() + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 0 + assert usage.total_cost_usd == 0.0 + + def test_accumulates_tokens(self) -> None: + handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6") + + class FakeResult: + llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}} + + handler.on_llm_end(FakeResult()) + usage = handler.get_usage() + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + + def test_accumulates_across_calls(self) -> None: + handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6") + + class FakeResult: + llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}} + + handler.on_llm_end(FakeResult()) + handler.on_llm_end(FakeResult()) + usage = handler.get_usage() + assert usage.prompt_tokens == 200 + assert usage.completion_tokens == 100 + assert usage.total_tokens == 300 + + def test_cost_calculation(self) -> None: + handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6") + + class FakeResult: + llm_output = {"token_usage": {"prompt_tokens": 1000, "completion_tokens": 1000}} + + handler.on_llm_end(FakeResult()) + usage = handler.get_usage() + # claude-sonnet-4-6: prompt $0.003/1K, completion $0.015/1K + expected_cost = 1000 * 0.003 / 1000 + 1000 * 0.015 / 1000 + assert usage.total_cost_usd == pytest.approx(expected_cost) + + def test_reset(self) -> None: + handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6") + + class FakeResult: + llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}} + + handler.on_llm_end(FakeResult()) + handler.reset() + usage = handler.get_usage() + assert usage.total_tokens == 0 + + def test_usage_is_immutable(self) -> None: + handler = TokenUsageCallbackHandler() + usage = handler.get_usage() + with pytest.raises(Exception): + usage.prompt_tokens = 999 + + def test_unknown_model_uses_default_cost(self) -> None: + handler = TokenUsageCallbackHandler(model_name="unknown-model") + + class FakeResult: + llm_output = {"token_usage": {"prompt_tokens": 1000, "completion_tokens": 1000}} + + handler.on_llm_end(FakeResult()) + usage = handler.get_usage() + assert usage.total_cost_usd > 0 + + def test_handles_missing_token_usage(self) -> None: + handler = TokenUsageCallbackHandler() + + class FakeResult: + llm_output = {} + + handler.on_llm_end(FakeResult()) + usage = handler.get_usage() + assert usage.total_tokens == 0 + + def test_handles_none_llm_output(self) -> None: + handler = TokenUsageCallbackHandler() + + class FakeResult: + llm_output = None + + handler.on_llm_end(FakeResult()) + usage = handler.get_usage() + assert usage.total_tokens == 0 diff --git a/backend/tests/unit/test_config.py b/backend/tests/unit/test_config.py new file mode 100644 index 0000000..824464e --- /dev/null +++ b/backend/tests/unit/test_config.py @@ -0,0 +1,60 @@ +"""Tests for app.config module.""" + +from __future__ import annotations + +import pytest + +from app.config import Settings + + +@pytest.mark.unit +class TestSettings: + def test_default_values(self) -> None: + settings = Settings( + database_url="postgresql://x:x@localhost/db", + anthropic_api_key="key", + ) + assert settings.llm_provider == "anthropic" + assert settings.llm_model == "claude-sonnet-4-6" + assert settings.session_ttl_minutes == 30 + assert settings.interrupt_ttl_minutes == 30 + + def test_custom_values(self) -> None: + settings = Settings( + database_url="postgresql://x:x@localhost/db", + llm_provider="openai", + llm_model="gpt-4o", + session_ttl_minutes=15, + openai_api_key="sk-test", + ) + assert settings.llm_provider == "openai" + assert settings.llm_model == "gpt-4o" + assert settings.session_ttl_minutes == 15 + + def test_invalid_provider_rejected(self) -> None: + with pytest.raises(Exception): + Settings( + database_url="postgresql://x:x@localhost/db", + llm_provider="invalid", + ) + + def test_missing_database_url_rejected(self) -> None: + with pytest.raises(Exception): + Settings(anthropic_api_key="key") + + def test_empty_api_key_for_provider_rejected(self) -> None: + with pytest.raises(ValueError, match="API key"): + Settings( + database_url="postgresql://x:x@localhost/db", + llm_provider="anthropic", + anthropic_api_key="", + ) + + def test_wrong_provider_key_rejected(self) -> None: + with pytest.raises(ValueError, match="API key"): + Settings( + database_url="postgresql://x:x@localhost/db", + llm_provider="openai", + anthropic_api_key="key", + openai_api_key="", + ) diff --git a/backend/tests/unit/test_db.py b/backend/tests/unit/test_db.py new file mode 100644 index 0000000..6eaff60 --- /dev/null +++ b/backend/tests/unit/test_db.py @@ -0,0 +1,64 @@ +"""Tests for app.db module.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.config import Settings +from app.db import _CONVERSATIONS_DDL, _INTERRUPTS_DDL + + +@pytest.mark.unit +class TestDbModule: + @pytest.mark.asyncio + async def test_create_pool_sets_correct_params(self) -> None: + settings = Settings( + database_url="postgresql://user:pass@localhost:5432/testdb", + anthropic_api_key="key", + ) + with patch("app.db.AsyncConnectionPool") as MockPool: + mock_pool = AsyncMock() + MockPool.return_value = mock_pool + + from app.db import create_pool + + await create_pool(settings) + MockPool.assert_called_once() + call_kwargs = MockPool.call_args + assert "postgresql://user:pass@localhost:5432/testdb" in str(call_kwargs) + mock_pool.open.assert_awaited_once() + + @pytest.mark.asyncio + async def test_create_checkpointer_calls_setup(self) -> None: + mock_pool = AsyncMock() + with patch("app.db.AsyncPostgresSaver") as MockSaver: + mock_saver = AsyncMock() + MockSaver.return_value = mock_saver + + from app.db import create_checkpointer + + await create_checkpointer(mock_pool) + MockSaver.assert_called_once_with(conn=mock_pool) + mock_saver.setup.assert_awaited_once() + + @pytest.mark.asyncio + async def test_setup_app_tables_executes_ddl(self) -> None: + mock_conn = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn) + mock_ctx.__aexit__ = AsyncMock(return_value=None) + mock_pool = MagicMock() + mock_pool.connection.return_value = mock_ctx + + from app.db import setup_app_tables + + await setup_app_tables(mock_pool) + assert mock_conn.execute.await_count == 2 + + def test_ddl_statements_valid(self) -> None: + assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL + assert "CREATE TABLE IF NOT EXISTS active_interrupts" in _INTERRUPTS_DDL + assert "thread_id" in _CONVERSATIONS_DDL + assert "interrupt_id" in _INTERRUPTS_DDL diff --git a/backend/tests/unit/test_graph.py b/backend/tests/unit/test_graph.py new file mode 100644 index 0000000..7e0a890 --- /dev/null +++ b/backend/tests/unit/test_graph.py @@ -0,0 +1,44 @@ +"""Tests for app.graph module.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.graph import SUPERVISOR_PROMPT, build_agent_nodes, build_graph + +if TYPE_CHECKING: + from app.registry import AgentRegistry + + +@pytest.mark.unit +class TestBuildAgentNodes: + def test_creates_correct_number_of_nodes(self, sample_registry: AgentRegistry) -> None: + mock_llm = MagicMock() + nodes = build_agent_nodes(sample_registry, mock_llm) + assert len(nodes) == 3 + + def test_nodes_are_runnable(self, sample_registry: AgentRegistry) -> None: + mock_llm = MagicMock() + nodes = build_agent_nodes(sample_registry, mock_llm) + for node in nodes: + assert hasattr(node, "invoke") or hasattr(node, "ainvoke") + + +@pytest.mark.unit +class TestBuildGraph: + def test_graph_compiles_with_mock_checkpointer(self, sample_registry: AgentRegistry) -> None: + mock_llm = MagicMock() + mock_llm.bind_tools = MagicMock(return_value=mock_llm) + mock_llm.with_structured_output = MagicMock(return_value=mock_llm) + mock_checkpointer = AsyncMock() + + graph = build_graph(sample_registry, mock_llm, mock_checkpointer) + assert graph is not None + + def test_supervisor_prompt_contains_routing_info(self) -> None: + assert "order_lookup" in SUPERVISOR_PROMPT + assert "order_actions" in SUPERVISOR_PROMPT + assert "fallback" in SUPERVISOR_PROMPT diff --git a/backend/tests/unit/test_llm.py b/backend/tests/unit/test_llm.py new file mode 100644 index 0000000..c442869 --- /dev/null +++ b/backend/tests/unit/test_llm.py @@ -0,0 +1,41 @@ +"""Tests for app.llm module.""" + +from __future__ import annotations + +import pytest + +from app.config import Settings +from app.llm import create_llm + + +@pytest.mark.unit +class TestCreateLlm: + def test_anthropic_provider(self) -> None: + settings = Settings( + database_url="postgresql://x:x@localhost/db", + llm_provider="anthropic", + llm_model="claude-sonnet-4-6", + anthropic_api_key="test-key", + ) + llm = create_llm(settings) + assert type(llm).__name__ == "ChatAnthropic" + + def test_openai_provider(self) -> None: + settings = Settings( + database_url="postgresql://x:x@localhost/db", + llm_provider="openai", + llm_model="gpt-4o", + openai_api_key="sk-test", + ) + llm = create_llm(settings) + assert type(llm).__name__ == "ChatOpenAI" + + def test_google_provider(self) -> None: + settings = Settings( + database_url="postgresql://x:x@localhost/db", + llm_provider="google", + llm_model="gemini-pro", + google_api_key="test-key", + ) + llm = create_llm(settings) + assert type(llm).__name__ == "ChatGoogleGenerativeAI" diff --git a/backend/tests/unit/test_main.py b/backend/tests/unit/test_main.py new file mode 100644 index 0000000..cc73634 --- /dev/null +++ b/backend/tests/unit/test_main.py @@ -0,0 +1,27 @@ +"""Tests for app.main module.""" + +from __future__ import annotations + +import pytest + +from app.main import AGENTS_YAML, FRONTEND_DIST, app + + +@pytest.mark.unit +class TestMainModule: + def test_app_title(self) -> None: + assert app.title == "Smart Support" + + def test_app_version(self) -> None: + assert app.version == "0.1.0" + + def test_agents_yaml_path_exists(self) -> None: + assert AGENTS_YAML.name == "agents.yaml" + + def test_frontend_dist_path(self) -> None: + assert "frontend" in str(FRONTEND_DIST) + assert "dist" in str(FRONTEND_DIST) + + def test_websocket_route_registered(self) -> None: + routes = [r.path for r in app.routes if hasattr(r, "path")] + assert "/ws" in routes diff --git a/backend/tests/unit/test_registry.py b/backend/tests/unit/test_registry.py new file mode 100644 index 0000000..477f420 --- /dev/null +++ b/backend/tests/unit/test_registry.py @@ -0,0 +1,147 @@ +"""Tests for app.registry module.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +import yaml + +from app.registry import AgentConfig, AgentRegistry, PersonalityConfig + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.mark.unit +class TestPersonalityConfig: + def test_defaults(self) -> None: + p = PersonalityConfig() + assert p.tone == "professional and helpful" + assert "Hello" in p.greeting + assert "human agent" in p.escalation_message + + def test_custom_values(self) -> None: + p = PersonalityConfig(tone="casual", greeting="Hey!", escalation_message="Hold on.") + assert p.tone == "casual" + + def test_immutable(self) -> None: + p = PersonalityConfig() + with pytest.raises(Exception): + p.tone = "new tone" + + +@pytest.mark.unit +class TestAgentConfig: + def test_valid_config(self) -> None: + ac = AgentConfig( + name="test", + description="A test agent", + permission="read", + tools=["tool1"], + ) + assert ac.name == "test" + assert ac.permission == "read" + + def test_empty_name_rejected(self) -> None: + with pytest.raises(ValueError, match="must not be empty"): + AgentConfig(name=" ", description="d", permission="read", tools=["t"]) + + def test_empty_tools_rejected(self) -> None: + with pytest.raises(ValueError, match="at least one tool"): + AgentConfig(name="x", description="d", permission="read", tools=[]) + + def test_invalid_permission_rejected(self) -> None: + with pytest.raises(Exception): + AgentConfig(name="x", description="d", permission="admin", tools=["t"]) + + def test_name_stripped(self) -> None: + ac = AgentConfig(name=" test ", description="d", permission="read", tools=["t"]) + assert ac.name == "test" + + def test_immutable(self) -> None: + ac = AgentConfig(name="test", description="d", permission="read", tools=["t"]) + with pytest.raises(Exception): + ac.name = "new" + + +@pytest.mark.unit +class TestAgentRegistry: + def test_load_valid_yaml(self, sample_yaml_path: Path) -> None: + registry = AgentRegistry.load(sample_yaml_path) + assert len(registry) == 3 + assert registry.get_agent("test_reader").permission == "read" + assert registry.get_agent("test_writer").permission == "write" + + def test_list_agents(self, sample_registry: AgentRegistry) -> None: + agents = sample_registry.list_agents() + assert len(agents) == 3 + names = {a.name for a in agents} + assert names == {"test_reader", "test_writer", "test_fallback"} + + def test_get_agents_by_permission(self, sample_registry: AgentRegistry) -> None: + readers = sample_registry.get_agents_by_permission("read") + assert len(readers) == 2 + writers = sample_registry.get_agents_by_permission("write") + assert len(writers) == 1 + + def test_get_nonexistent_agent(self, sample_registry: AgentRegistry) -> None: + with pytest.raises(KeyError, match="not found"): + sample_registry.get_agent("nonexistent") + + def test_personality_defaults_applied(self, sample_registry: AgentRegistry) -> None: + agent = sample_registry.get_agent("test_reader") + assert agent.personality.tone == "professional and helpful" + + def test_personality_custom_applied(self, sample_registry: AgentRegistry) -> None: + agent = sample_registry.get_agent("test_writer") + assert agent.personality.tone == "formal" + assert agent.personality.greeting == "Greetings." + + def test_file_not_found(self) -> None: + with pytest.raises(FileNotFoundError): + AgentRegistry.load("/nonexistent/path.yaml") + + def test_empty_file(self, tmp_path: Path) -> None: + path = tmp_path / "empty.yaml" + path.write_text("", encoding="utf-8") + with pytest.raises(ValueError, match="empty"): + AgentRegistry.load(path) + + def test_invalid_yaml_syntax(self, tmp_path: Path) -> None: + path = tmp_path / "bad.yaml" + path.write_text("agents:\n - name: [invalid\n", encoding="utf-8") + with pytest.raises(ValueError, match="Invalid YAML"): + AgentRegistry.load(path) + + def test_missing_agents_key(self, tmp_path: Path) -> None: + path = tmp_path / "no_agents.yaml" + path.write_text(yaml.dump({"items": []}), encoding="utf-8") + with pytest.raises(ValueError, match="agents"): + AgentRegistry.load(path) + + def test_duplicate_agent_names(self, tmp_path: Path) -> None: + data = { + "agents": [ + {"name": "dup", "description": "a", "permission": "read", "tools": ["t1"]}, + {"name": "dup", "description": "b", "permission": "read", "tools": ["t2"]}, + ] + } + path = tmp_path / "dups.yaml" + path.write_text(yaml.dump(data), encoding="utf-8") + with pytest.raises(ValueError, match="Duplicate"): + AgentRegistry.load(path) + + def test_missing_required_fields(self, tmp_path: Path) -> None: + data = {"agents": [{"name": "x"}]} + path = tmp_path / "missing.yaml" + path.write_text(yaml.dump(data), encoding="utf-8") + with pytest.raises(ValueError, match="Invalid agent config"): + AgentRegistry.load(path) + + def test_empty_agents_list(self, tmp_path: Path) -> None: + data = {"agents": []} + path = tmp_path / "empty_list.yaml" + path.write_text(yaml.dump(data), encoding="utf-8") + with pytest.raises(ValueError, match="non-empty"): + AgentRegistry.load(path) diff --git a/backend/tests/unit/test_session_manager.py b/backend/tests/unit/test_session_manager.py new file mode 100644 index 0000000..2e126b9 --- /dev/null +++ b/backend/tests/unit/test_session_manager.py @@ -0,0 +1,70 @@ +"""Tests for app.session_manager module.""" + +from __future__ import annotations + +import time +from unittest.mock import patch + +import pytest + +from app.session_manager import SessionManager + + +@pytest.mark.unit +class TestSessionManager: + def test_new_session_not_expired(self, session_manager: SessionManager) -> None: + session_manager.touch("thread-1") + assert not session_manager.is_expired("thread-1") + + def test_unknown_session_is_expired(self, session_manager: SessionManager) -> None: + assert session_manager.is_expired("unknown") + + def test_session_expires_after_ttl(self) -> None: + mgr = SessionManager(session_ttl_seconds=1) + mgr.touch("t1") + with patch("app.session_manager.time") as mock_time: + mock_time.time.return_value = time.time() + 2 + assert mgr.is_expired("t1") + + def test_touch_resets_ttl(self) -> None: + mgr = SessionManager(session_ttl_seconds=5) + mgr.touch("t1") + initial_state = mgr.get_state("t1") + # Touch again after some time + with patch("app.session_manager.time") as mock_time: + mock_time.time.return_value = time.time() + 3 + mgr.touch("t1") + new_state = mgr.get_state("t1") + assert new_state.last_activity > initial_state.last_activity + + def test_interrupt_suspends_expiration(self) -> None: + mgr = SessionManager(session_ttl_seconds=1) + mgr.touch("t1") + mgr.extend_for_interrupt("t1") + with patch("app.session_manager.time") as mock_time: + mock_time.time.return_value = time.time() + 100 + assert not mgr.is_expired("t1") + + def test_resolve_interrupt_resumes_ttl(self) -> None: + mgr = SessionManager(session_ttl_seconds=1) + mgr.touch("t1") + mgr.extend_for_interrupt("t1") + mgr.resolve_interrupt("t1") + state = mgr.get_state("t1") + assert not state.has_pending_interrupt + + def test_extend_for_nonexistent_creates_session(self) -> None: + mgr = SessionManager() + mgr.extend_for_interrupt("new-thread") + state = mgr.get_state("new-thread") + assert state is not None + + def test_remove_session(self, session_manager: SessionManager) -> None: + session_manager.touch("t1") + session_manager.remove("t1") + assert session_manager.get_state("t1") is None + + def test_session_state_is_immutable(self, session_manager: SessionManager) -> None: + state = session_manager.touch("t1") + with pytest.raises(Exception): + state.thread_id = "new" diff --git a/backend/tests/unit/test_ws_handler.py b/backend/tests/unit/test_ws_handler.py new file mode 100644 index 0000000..ff8eb8a --- /dev/null +++ b/backend/tests/unit/test_ws_handler.py @@ -0,0 +1,233 @@ +"""Tests for app.ws_handler module.""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.callbacks import TokenUsageCallbackHandler +from app.session_manager import SessionManager +from app.ws_handler import ( + _extract_interrupt, + _has_interrupt, + dispatch_message, + handle_interrupt_response, + handle_user_message, +) + + +def _make_ws() -> AsyncMock: + ws = AsyncMock() + ws.send_json = AsyncMock() + return ws + + +def _make_graph() -> AsyncMock: + graph = AsyncMock() + graph.astream = MagicMock(return_value=AsyncIterHelper([])) + state = MagicMock() + state.tasks = () + graph.aget_state = AsyncMock(return_value=state) + return graph + + +class AsyncIterHelper: + """Helper to make a list behave as an async iterator.""" + + def __init__(self, items): + self._items = items + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +@pytest.mark.unit +class TestDispatchMessage: + @pytest.mark.asyncio + async def test_invalid_json(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + await dispatch_message(ws, graph, sm, cb, "not json") + ws.send_json.assert_awaited_once() + call_data = ws.send_json.call_args[0][0] + assert call_data["type"] == "error" + assert "Invalid JSON" in call_data["message"] + + @pytest.mark.asyncio + async def test_missing_thread_id(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + msg = json.dumps({"type": "message", "content": "hello"}) + await dispatch_message(ws, graph, sm, cb, msg) + call_data = ws.send_json.call_args[0][0] + assert call_data["type"] == "error" + assert "thread_id" in call_data["message"] + + @pytest.mark.asyncio + async def test_missing_content(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + msg = json.dumps({"type": "message", "thread_id": "t1"}) + await dispatch_message(ws, graph, sm, cb, msg) + call_data = ws.send_json.call_args[0][0] + assert call_data["type"] == "error" + + @pytest.mark.asyncio + async def test_unknown_message_type(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + msg = json.dumps({"type": "unknown", "thread_id": "t1"}) + await dispatch_message(ws, graph, sm, cb, msg) + call_data = ws.send_json.call_args[0][0] + assert call_data["type"] == "error" + assert "Unknown" in call_data["message"] + # Verify raw input is NOT reflected back + assert "unknown" not in call_data["message"].lower().replace("unknown message type", "") + + @pytest.mark.asyncio + async def test_message_too_large(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + large_msg = "x" * 40_000 + await dispatch_message(ws, graph, sm, cb, large_msg) + call_data = ws.send_json.call_args[0][0] + assert call_data["type"] == "error" + assert "too large" in call_data["message"].lower() + + @pytest.mark.asyncio + async def test_invalid_thread_id_format(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"}) + await dispatch_message(ws, graph, sm, cb, msg) + call_data = ws.send_json.call_args[0][0] + assert call_data["type"] == "error" + assert "thread_id" in call_data["message"].lower() + + @pytest.mark.asyncio + async def test_content_too_long(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000}) + await dispatch_message(ws, graph, sm, cb, msg) + call_data = ws.send_json.call_args[0][0] + assert call_data["type"] == "error" + assert "too long" in call_data["message"].lower() + + +@pytest.mark.unit +class TestHandleUserMessage: + @pytest.mark.asyncio + async def test_expired_session(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager(session_ttl_seconds=0) + cb = TokenUsageCallbackHandler() + + await handle_user_message(ws, graph, sm, cb, "t1", "hello") + call_data = ws.send_json.call_args[0][0] + assert call_data["type"] == "error" + assert "expired" in call_data["message"].lower() + + @pytest.mark.asyncio + async def test_successful_message(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + sm.touch("t1") + await handle_user_message(ws, graph, sm, cb, "t1", "hello") + # Should end with message_complete + last_call = ws.send_json.call_args[0][0] + assert last_call["type"] == "message_complete" + + @pytest.mark.asyncio + async def test_graph_error_sends_error_message(self) -> None: + ws = _make_ws() + graph = AsyncMock() + graph.astream = MagicMock(side_effect=RuntimeError("boom")) + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + sm.touch("t1") + await handle_user_message(ws, graph, sm, cb, "t1", "hello") + call_data = ws.send_json.call_args[0][0] + assert call_data["type"] == "error" + + +@pytest.mark.unit +class TestHandleInterruptResponse: + @pytest.mark.asyncio + async def test_approved_interrupt(self) -> None: + ws = _make_ws() + graph = _make_graph() + sm = SessionManager() + cb = TokenUsageCallbackHandler() + + sm.touch("t1") + sm.extend_for_interrupt("t1") + await handle_interrupt_response(ws, graph, sm, cb, "t1", True) + last_call = ws.send_json.call_args[0][0] + assert last_call["type"] == "message_complete" + + +@pytest.mark.unit +class TestInterruptHelpers: + def test_has_interrupt_false_for_empty_tasks(self) -> None: + state = MagicMock() + state.tasks = () + assert not _has_interrupt(state) + + def test_has_interrupt_true(self) -> None: + interrupt_obj = MagicMock() + interrupt_obj.value = {"action": "cancel"} + task = MagicMock() + task.interrupts = (interrupt_obj,) + state = MagicMock() + state.tasks = (task,) + assert _has_interrupt(state) + + def test_extract_interrupt_data(self) -> None: + interrupt_obj = MagicMock() + interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"} + task = MagicMock() + task.interrupts = (interrupt_obj,) + state = MagicMock() + state.tasks = (task,) + data = _extract_interrupt(state) + assert data["action"] == "cancel_order" + + def test_extract_interrupt_empty(self) -> None: + state = MagicMock() + state.tasks = () + data = _extract_interrupt(state) + assert data["action"] == "unknown" diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..6232d46 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,39 @@ +services: + postgres: + image: postgres:16 + environment: + POSTGRES_DB: smart_support + POSTGRES_USER: smart_support + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-dev_password} + ports: + - "5432:5432" + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U smart_support -d smart_support"] + interval: 5s + timeout: 3s + retries: 5 + + backend: + build: + context: ./backend + dockerfile: Dockerfile + ports: + - "8000:8000" + environment: + DATABASE_URL: postgresql://smart_support:${POSTGRES_PASSWORD:-dev_password}@postgres:5432/smart_support + LLM_PROVIDER: ${LLM_PROVIDER:-anthropic} + LLM_MODEL: ${LLM_MODEL:-claude-sonnet-4-6} + ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-} + OPENAI_API_KEY: ${OPENAI_API_KEY:-} + GOOGLE_API_KEY: ${GOOGLE_API_KEY:-} + depends_on: + postgres: + condition: service_healthy + volumes: + - ./backend:/app + command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload + +volumes: + pgdata: diff --git a/docs/phases/phase-1-dev-log.md b/docs/phases/phase-1-dev-log.md new file mode 100644 index 0000000..32b1838 --- /dev/null +++ b/docs/phases/phase-1-dev-log.md @@ -0,0 +1,88 @@ +# Phase 1: Core Framework -- Development Log + +> Status: IN PROGRESS +> Phase branch: `phase-1/core-framework` +> Date started: 2026-03-30 +> Date completed: -- +> Related plan section: [Phase 1 in DEVELOPMENT-PLAN](../DEVELOPMENT-PLAN.md#phase-1-æ ¸å¿ƒæ¡†æž¶-第-1-3-周) + +## What Was Built + +- FastAPI WebSocket backend with `/ws` endpoint for real-time chat +- LangGraph Supervisor (via `langgraph-supervisor`) connecting 3 agents +- YAML-based Agent Registry with Pydantic validation +- 3 Mock Agents: order_lookup (read), order_actions (write + interrupt), fallback +- PostgresSaver checkpoint persistence via `langgraph-checkpoint-postgres` +- Session TTL management with 30-minute sliding window and interrupt extension +- LLM provider abstraction (Anthropic/OpenAI/Google) with prompt caching support +- Token usage tracking callback handler +- React Chat UI with streaming display, interrupt confirmation, and agent action viewer +- Docker Compose configuration (PostgreSQL 16 + backend) + +## Code Structure + +### New files + +Backend (`backend/app/`): +- `config.py` -- pydantic-settings centralized configuration +- `db.py` -- Async PostgreSQL pool + AsyncPostgresSaver setup +- `llm.py` -- LLM provider factory (ChatAnthropic/ChatOpenAI/ChatGoogleGenerativeAI) +- `callbacks.py` -- Token usage + cost tracking callback handler +- `registry.py` -- YAML agent registry with validation + immutable config models +- `session_manager.py` -- Session TTL with sliding window + interrupt extension +- `graph.py` -- LangGraph Supervisor construction from registry +- `ws_handler.py` -- WebSocket message dispatch + streaming logic +- `main.py` -- FastAPI app entry with lifespan + WebSocket endpoint +- `agents/__init__.py` -- Tool name-to-function bridge +- `agents/order_lookup.py` -- Mock order status/tracking tools +- `agents/order_actions.py` -- Mock cancel_order with interrupt() +- `agents/fallback.py` -- Fallback response tool + +Frontend (`frontend/src/`): +- `types.ts` -- WebSocket message protocol TypeScript types +- `hooks/useWebSocket.ts` -- WebSocket connection + reconnect + message dispatch +- `components/ChatMessages.tsx` -- Streaming message display +- `components/ChatInput.tsx` -- Message input +- `components/InterruptPrompt.tsx` -- Approve/reject interrupt UI +- `components/AgentAction.tsx` -- Tool call inline display +- `pages/ChatPage.tsx` -- Main chat page composing all components + +Infrastructure: +- `backend/pyproject.toml` -- Dependencies + pytest + ruff config +- `backend/agents.yaml` -- Agent registry YAML config +- `backend/Dockerfile` -- Backend container +- `docker-compose.yml` -- PostgreSQL 16 + backend services +- `.gitignore` -- Updated for Python + Node artifacts + +Tests (`backend/tests/unit/`): +- `test_config.py` -- Settings validation tests +- `test_registry.py` -- 17 tests for registry loading/validation +- `test_agents.py` -- 10 tests for tool functions + tool bridge +- `test_llm.py` -- 3 tests for LLM provider factory +- `test_callbacks.py` -- 9 tests for token usage tracking +- `test_session_manager.py` -- 9 tests for session TTL logic +- `test_graph.py` -- 4 tests for supervisor construction +- `test_db.py` -- 5 tests for database setup +- `test_ws_handler.py` -- 12 tests for WebSocket message handling +- `test_main.py` -- 5 tests for app configuration + +## Test Coverage + +- Unit test count: 82 +- Integration test count: 0 (requires running PostgreSQL) +- E2E test count: 0 (manual verification in plan) +- Overall coverage: 88% + +## Deviations from Plan + +- Used `astream(stream_mode="messages")` instead of `astream_events()` per langgraph best practices +- Separated WebSocket handler logic into `ws_handler.py` for testability (not in original plan) +- Session manager uses in-memory storage instead of DB-backed (sufficient for Phase 1 single-instance) + +## Known Issues / Tech Debt + +- Session manager not DB-backed (loses state on restart) -- acceptable for Phase 1 single-instance +- WebSocket reconnect does not re-send pending interrupt state from server +- No rate limiting on WebSocket endpoint (Phase 2) +- No authentication (Phase 2) +- `main.py` coverage at 47% -- lifespan function not unit-testable without full DB diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 0000000..9f5a3b7 --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,12 @@ + + +
+ + +{JSON.stringify(action.args, null, 2)}
+ {JSON.stringify(action.result, null, 2)}
+