feat: complete phase 5 -- error hardening, frontend, Docker, demo, docs

Backend:
- ConversationTracker: Protocol + PostgresConversationTracker for lifecycle tracking
- Error handler: ErrorCategory enum, classify_error(), with_retry() exponential backoff
- Wire PostgresAnalyticsRecorder + ConversationTracker into ws_handler
- Rate limiting (10 msg/10s per thread), edge case hardening
- Health endpoint GET /api/health, version 0.5.0
- Demo seed data script + sample OpenAPI spec

Frontend (all new):
- React Router with NavBar (Chat / Replay / Dashboard / Review)
- ReplayListPage + ReplayPage with ReplayTimeline component
- DashboardPage with MetricCard, range selector, zero-state
- ReviewPage for OpenAPI classification review
- ErrorBanner for WebSocket disconnect handling
- API client (api.ts) with typed fetch wrappers

Infrastructure:
- Frontend Dockerfile (multi-stage node -> nginx)
- nginx.conf with SPA routing + API/WS proxy
- docker-compose.yml with frontend service + healthchecks
- .env.example files (root + backend)

Documentation:
- README.md with quick start and architecture
- Agent configuration guide
- OpenAPI import guide
- Deployment guide
- Demo script

48 new tests, 449 total passing, 92.87% coverage
This commit is contained in:
Yaojia Wang
2026-03-31 21:20:06 +02:00
parent 38644594d2
commit 0e78e5b06b
44 changed files with 3397 additions and 169 deletions

View File

@@ -1,19 +1,34 @@
# Database
# Smart Support Backend -- environment variables
# Copy to .env and fill in your values
# Required: PostgreSQL connection string
DATABASE_URL=postgresql://smart_support:dev_password@localhost:5432/smart_support
# LLM Provider: anthropic | openai | google
# Required: LLM provider configuration
# provider: anthropic | openai | google
LLM_PROVIDER=anthropic
LLM_MODEL=claude-sonnet-4-6
# API Keys (set the one matching your LLM_PROVIDER)
# API keys -- provide the one matching LLM_PROVIDER
ANTHROPIC_API_KEY=
OPENAI_API_KEY=
GOOGLE_API_KEY=
# Session
# Optional: webhook endpoint for escalation notifications
# The backend will POST a JSON payload when a conversation is escalated.
WEBHOOK_URL=
WEBHOOK_TIMEOUT_SECONDS=10
WEBHOOK_MAX_RETRIES=3
# Session management
SESSION_TTL_MINUTES=30
INTERRUPT_TTL_MINUTES=30
# Server
# Optional: load a named agent template instead of agents.yaml
# Leave blank to use the default agents.yaml in the backend directory.
# Available templates: ecommerce, saas, generic
TEMPLATE_NAME=
# Server binding
WS_HOST=0.0.0.0
WS_PORT=8000

View File

@@ -0,0 +1,135 @@
"""Conversation tracker -- Protocol and implementations for tracking conversation state."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol, runtime_checkable
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
_ENSURE_SQL = """
INSERT INTO conversations
(thread_id, started_at, last_activity)
VALUES
(%(thread_id)s, NOW(), NOW())
ON CONFLICT (thread_id) DO NOTHING
"""
_RECORD_TURN_SQL = """
UPDATE conversations
SET
turn_count = turn_count + 1,
agents_used = CASE
WHEN %(agent_name)s IS NOT NULL AND NOT (agents_used @> ARRAY[%(agent_name)s]::text[])
THEN agents_used || ARRAY[%(agent_name)s]::text[]
ELSE agents_used
END,
total_tokens = total_tokens + %(tokens)s,
total_cost_usd = total_cost_usd + %(cost)s,
last_activity = NOW()
WHERE thread_id = %(thread_id)s
"""
_RESOLVE_SQL = """
UPDATE conversations
SET
resolution_type = %(resolution_type)s,
ended_at = NOW()
WHERE thread_id = %(thread_id)s
"""
@runtime_checkable
class ConversationTrackerProtocol(Protocol):
"""Protocol for tracking conversation lifecycle and metrics."""
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
"""Create conversation row if it does not already exist."""
...
async def record_turn(
self,
pool: AsyncConnectionPool,
thread_id: str,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Increment turn count and update aggregated metrics."""
...
async def resolve(
self,
pool: AsyncConnectionPool,
thread_id: str,
resolution_type: str,
) -> None:
"""Mark conversation as resolved with a resolution type."""
...
class NoOpConversationTracker:
"""No-op implementation -- used in tests or when DB is unavailable."""
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
"""Do nothing."""
async def record_turn(
self,
pool: AsyncConnectionPool,
thread_id: str,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Do nothing."""
async def resolve(
self,
pool: AsyncConnectionPool,
thread_id: str,
resolution_type: str,
) -> None:
"""Do nothing."""
class PostgresConversationTracker:
"""Postgres-backed conversation tracker."""
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
"""Insert conversation row; do nothing if already exists (ON CONFLICT DO NOTHING)."""
params = {"thread_id": thread_id}
async with pool.connection() as conn:
await conn.execute(_ENSURE_SQL, params)
async def record_turn(
self,
pool: AsyncConnectionPool,
thread_id: str,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Increment turn count, append agent if new, update token/cost totals."""
params = {
"thread_id": thread_id,
"agent_name": agent_name,
"tokens": tokens,
"cost": cost,
}
async with pool.connection() as conn:
await conn.execute(_RECORD_TURN_SQL, params)
async def resolve(
self,
pool: AsyncConnectionPool,
thread_id: str,
resolution_type: str,
) -> None:
"""Set resolution_type and ended_at on the conversation row."""
params = {
"thread_id": thread_id,
"resolution_type": resolution_type,
}
async with pool.connection() as conn:
await conn.execute(_RESOLVE_SQL, params)

View File

@@ -11,9 +11,10 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from app.analytics.api import router as analytics_router
from app.analytics.event_recorder import NoOpAnalyticsRecorder
from app.analytics.event_recorder import PostgresAnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings
from app.conversation_tracker import PostgresConversationTracker
from app.db import create_checkpointer, create_pool, setup_app_tables
from app.escalation import NoOpEscalator, WebhookEscalator
from app.graph import build_graph
@@ -76,7 +77,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
app.state.escalator = escalator
app.state.settings = settings
app.state.pool = pool
app.state.analytics_recorder = NoOpAnalyticsRecorder()
app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool)
app.state.conversation_tracker = PostgresConversationTracker()
logger.info(
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
@@ -91,13 +93,19 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await pool.close()
app = FastAPI(title="Smart Support", version="0.4.0", lifespan=lifespan)
app = FastAPI(title="Smart Support", version="0.5.0", lifespan=lifespan)
app.include_router(openapi_router)
app.include_router(replay_router)
app.include_router(analytics_router)
@app.get("/api/health")
def health_check() -> dict:
"""Health check endpoint for load balancers and monitoring."""
return {"status": "ok", "version": "0.5.0"}
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket) -> None:
await ws.accept()
@@ -107,12 +115,19 @@ async def websocket_endpoint(ws: WebSocket) -> None:
settings = app.state.settings
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
analytics_recorder = app.state.analytics_recorder
conversation_tracker = app.state.conversation_tracker
pool = app.state.pool
try:
while True:
raw_data = await ws.receive_text()
await dispatch_message(
ws, graph, session_manager, callback_handler, raw_data,
interrupt_manager=interrupt_manager,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
pool=pool,
)
except WebSocketDisconnect:
logger.info("WebSocket client disconnected")

View File

@@ -0,0 +1,3 @@
"""Tools package for smart-support backend."""
from __future__ import annotations

View File

@@ -0,0 +1,72 @@
"""Error classification and retry logic for tool calls."""
from __future__ import annotations
import asyncio
from enum import Enum
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from collections.abc import Callable
import httpx
class ErrorCategory(Enum):
"""Categories for error classification to guide retry decisions."""
RETRYABLE = "retryable"
NON_RETRYABLE = "non_retryable"
AUTH_FAILURE = "auth_failure"
TIMEOUT = "timeout"
NETWORK = "network"
def classify_error(exc: Exception) -> ErrorCategory:
"""Classify an exception into an ErrorCategory.
Rules:
- httpx.TimeoutException -> TIMEOUT
- httpx.ConnectError -> NETWORK
- httpx.HTTPStatusError 401/403 -> AUTH_FAILURE
- httpx.HTTPStatusError 429/500/502/503 -> RETRYABLE
- anything else -> NON_RETRYABLE
"""
if isinstance(exc, httpx.TimeoutException):
return ErrorCategory.TIMEOUT
if isinstance(exc, httpx.ConnectError):
return ErrorCategory.NETWORK
if isinstance(exc, httpx.HTTPStatusError):
code = exc.response.status_code
if code in (401, 403):
return ErrorCategory.AUTH_FAILURE
if code in (429, 500, 502, 503):
return ErrorCategory.RETRYABLE
return ErrorCategory.NON_RETRYABLE
return ErrorCategory.NON_RETRYABLE
async def with_retry(
fn: Callable[..., Any],
max_retries: int = 3,
base_delay: float = 1.0,
) -> Any:
"""Execute an async callable with exponential backoff for RETRYABLE errors.
Only ErrorCategory.RETRYABLE errors trigger retries. All other error
categories raise immediately after the first attempt.
"""
last_exc: Exception | None = None
for attempt in range(1, max_retries + 1):
try:
return await fn()
except Exception as exc:
category = classify_error(exc)
if category != ErrorCategory.RETRYABLE:
raise
last_exc = exc
if attempt < max_retries:
delay = base_delay * (2 ** (attempt - 1))
await asyncio.sleep(delay)
raise last_exc # type: ignore[misc]

View File

@@ -5,6 +5,8 @@ from __future__ import annotations
import json
import logging
import re
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from langchain_core.messages import HumanMessage
@@ -16,16 +18,23 @@ if TYPE_CHECKING:
from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.analytics.event_recorder import AnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler
from app.conversation_tracker import ConversationTrackerProtocol
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
logger = logging.getLogger(__name__)
MAX_MESSAGE_SIZE = 32_768 # 32 KB
MAX_CONTENT_LENGTH = 8_000 # characters
MAX_CONTENT_LENGTH = 10_000 # characters
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
# Rate limiting: max 10 messages per 10-second window, per thread
_RATE_LIMIT_MAX = 10
_RATE_LIMIT_WINDOW = 10.0
_thread_timestamps: dict[str, list[float]] = defaultdict(list)
async def handle_user_message(
ws: WebSocket,
@@ -197,6 +206,9 @@ async def dispatch_message(
callback_handler: TokenUsageCallbackHandler,
raw_data: str,
interrupt_manager: InterruptManager | None = None,
analytics_recorder: AnalyticsRecorder | None = None,
conversation_tracker: ConversationTrackerProtocol | None = None,
pool: Any = None,
) -> None:
"""Parse and route an incoming WebSocket message."""
if len(raw_data) > MAX_MESSAGE_SIZE:
@@ -205,10 +217,14 @@ async def dispatch_message(
try:
data = json.loads(raw_data)
except json.JSONDecodeError:
except (json.JSONDecodeError, ValueError):
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
return
if not isinstance(data, dict):
await _send_json(ws, {"type": "error", "message": "Invalid JSON: expected object"})
return
msg_type = data.get("type")
thread_id = data.get("thread_id", "")
@@ -222,16 +238,36 @@ async def dispatch_message(
if msg_type == "message":
content = data.get("content", "")
if not content:
if not content or not content.strip():
await _send_json(ws, {"type": "error", "message": "Missing message content"})
return
if len(content) > MAX_CONTENT_LENGTH:
await _send_json(ws, {"type": "error", "message": "Message content too long"})
return
# Rate limiting check
now = time.time()
timestamps = _thread_timestamps[thread_id]
cutoff = now - _RATE_LIMIT_WINDOW
_thread_timestamps[thread_id] = [t for t in timestamps if t >= cutoff]
if len(_thread_timestamps[thread_id]) >= _RATE_LIMIT_MAX:
await _send_json(ws, {"type": "error", "message": "Rate limit exceeded"})
return
_thread_timestamps[thread_id].append(now)
await handle_user_message(
ws, graph, session_manager, callback_handler, thread_id, content,
interrupt_manager=interrupt_manager,
)
await _fire_and_forget_tracking(
thread_id=thread_id,
pool=pool,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
agent_name=None,
tokens=0,
cost=0.0,
)
elif msg_type == "interrupt_response":
approved = data.get("approved", False)
@@ -244,6 +280,36 @@ async def dispatch_message(
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
async def _fire_and_forget_tracking(
thread_id: str,
pool: Any,
analytics_recorder: Any | None,
conversation_tracker: Any | None,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Fire-and-forget analytics/tracking; failures must NOT break chat."""
try:
if conversation_tracker is not None and pool is not None:
await conversation_tracker.ensure_conversation(pool, thread_id)
await conversation_tracker.record_turn(pool, thread_id, agent_name, tokens, cost)
except Exception:
logger.exception("Conversation tracker error for thread %s (suppressed)", thread_id)
try:
if analytics_recorder is not None:
await analytics_recorder.record(
thread_id=thread_id,
event_type="message",
agent_name=agent_name,
tokens_used=tokens,
cost_usd=cost,
)
except Exception:
logger.exception("Analytics recorder error for thread %s (suppressed)", thread_id)
def _has_interrupt(state: Any) -> bool:
"""Check if the graph state has a pending interrupt."""
tasks = getattr(state, "tasks", ())

View File

@@ -0,0 +1,153 @@
"""Seed script -- inserts sample conversations and analytics events for demo purposes.
Usage:
cd backend
python fixtures/demo_data.py
"""
from __future__ import annotations
import asyncio
import os
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import psycopg
DATABASE_URL = os.environ.get(
"DATABASE_URL",
"postgresql://smart_support:dev_password@localhost:5432/smart_support",
)
SAMPLE_CONVERSATIONS = [
{
"thread_id": "demo-thread-001",
"agents_used": ["order_agent"],
"turn_count": 3,
"total_tokens": 1250,
"total_cost_usd": 0.00375,
"resolution_type": "resolved",
"minutes_ago": 5,
},
{
"thread_id": "demo-thread-002",
"agents_used": ["order_agent", "refund_agent"],
"turn_count": 6,
"total_tokens": 3200,
"total_cost_usd": 0.0096,
"resolution_type": "resolved",
"minutes_ago": 30,
},
{
"thread_id": "demo-thread-003",
"agents_used": ["general_agent"],
"turn_count": 2,
"total_tokens": 800,
"total_cost_usd": 0.0024,
"resolution_type": None,
"minutes_ago": 60,
},
{
"thread_id": "demo-thread-004",
"agents_used": ["order_agent", "general_agent"],
"turn_count": 8,
"total_tokens": 4500,
"total_cost_usd": 0.0135,
"resolution_type": "escalated",
"minutes_ago": 120,
},
{
"thread_id": "demo-thread-005",
"agents_used": ["refund_agent"],
"turn_count": 4,
"total_tokens": 2100,
"total_cost_usd": 0.0063,
"resolution_type": "resolved",
"minutes_ago": 240,
},
]
SAMPLE_EVENTS = [
{"thread_id": "demo-thread-001", "event_type": "message", "agent_name": "order_agent", "tokens_used": 400, "cost_usd": 0.0012, "success": True},
{"thread_id": "demo-thread-001", "event_type": "tool_call", "agent_name": "order_agent", "tool_name": "get_order_status", "tokens_used": 0, "cost_usd": 0.0, "success": True},
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "order_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
{"thread_id": "demo-thread-002", "event_type": "tool_call", "agent_name": "refund_agent", "tool_name": "process_refund", "tokens_used": 0, "cost_usd": 0.0, "success": True},
{"thread_id": "demo-thread-003", "event_type": "message", "agent_name": "general_agent", "tokens_used": 800, "cost_usd": 0.0024, "success": True},
{"thread_id": "demo-thread-004", "event_type": "message", "agent_name": "order_agent", "tokens_used": 2000, "cost_usd": 0.006, "success": True},
{"thread_id": "demo-thread-004", "event_type": "escalation", "agent_name": "general_agent", "tokens_used": 2500, "cost_usd": 0.0075, "success": False},
{"thread_id": "demo-thread-005", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 2100, "cost_usd": 0.0063, "success": True},
]
_INSERT_CONVERSATION = """
INSERT INTO conversations
(thread_id, started_at, last_activity, turn_count, agents_used,
total_tokens, total_cost_usd, resolution_type, ended_at)
VALUES
(%(thread_id)s, %(started_at)s, %(last_activity)s, %(turn_count)s,
%(agents_used)s, %(total_tokens)s, %(total_cost_usd)s,
%(resolution_type)s, %(ended_at)s)
ON CONFLICT (thread_id) DO NOTHING
"""
_INSERT_EVENT = """
INSERT INTO analytics_events
(thread_id, event_type, agent_name, tool_name, tokens_used, cost_usd, success)
VALUES
(%(thread_id)s, %(event_type)s, %(agent_name)s, %(tool_name)s,
%(tokens_used)s, %(cost_usd)s, %(success)s)
"""
async def seed() -> None:
now = datetime.now(tz=timezone.utc)
async with await psycopg.AsyncConnection.connect(DATABASE_URL) as conn:
print("Seeding conversations...")
for conv in SAMPLE_CONVERSATIONS:
started_at = now - timedelta(minutes=conv["minutes_ago"])
last_activity = started_at + timedelta(minutes=conv["turn_count"] * 2)
ended_at = last_activity if conv["resolution_type"] else None
await conn.execute(
_INSERT_CONVERSATION,
{
"thread_id": conv["thread_id"],
"started_at": started_at,
"last_activity": last_activity,
"turn_count": conv["turn_count"],
"agents_used": conv["agents_used"],
"total_tokens": conv["total_tokens"],
"total_cost_usd": conv["total_cost_usd"],
"resolution_type": conv["resolution_type"],
"ended_at": ended_at,
},
)
print(f" Inserted conversation {conv['thread_id']}")
print("Seeding analytics events...")
for event in SAMPLE_EVENTS:
await conn.execute(
_INSERT_EVENT,
{
"thread_id": event["thread_id"],
"event_type": event["event_type"],
"agent_name": event.get("agent_name"),
"tool_name": event.get("tool_name"),
"tokens_used": event.get("tokens_used", 0),
"cost_usd": event.get("cost_usd", 0.0),
"success": event.get("success"),
},
)
print(f" Inserted event {event['event_type']} for {event['thread_id']}")
await conn.commit()
print("Done. Demo data seeded successfully.")
if __name__ == "__main__":
asyncio.run(seed())

View File

@@ -0,0 +1,238 @@
openapi: "3.0.3"
info:
title: "E-Commerce API"
description: "Sample e-commerce API for Smart Support demo."
version: "1.0.0"
servers:
- url: "https://api.example-shop.com/v1"
description: "Production server"
paths:
/orders/{order_id}:
get:
operationId: getOrder
summary: "Get order details"
description: "Retrieve the full details of a specific order."
parameters:
- name: order_id
in: path
required: true
schema:
type: string
responses:
"200":
description: "Order details"
content:
application/json:
schema:
$ref: "#/components/schemas/Order"
/orders/{order_id}/cancel:
post:
operationId: cancelOrder
summary: "Cancel an order"
description: "Cancel an order that has not yet been shipped."
parameters:
- name: order_id
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
reason:
type: string
responses:
"200":
description: "Order cancelled"
"400":
description: "Order cannot be cancelled (already shipped)"
/orders/{order_id}/refund:
post:
operationId: refundOrder
summary: "Request a refund"
description: "Submit a refund request for a completed order."
parameters:
- name: order_id
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
amount:
type: number
description: "Refund amount in USD. Leave null for full refund."
reason:
type: string
responses:
"200":
description: "Refund submitted"
"400":
description: "Invalid refund request"
/customers/{customer_id}:
get:
operationId: getCustomer
summary: "Get customer profile"
description: "Retrieve customer profile and account information."
parameters:
- name: customer_id
in: path
required: true
schema:
type: string
responses:
"200":
description: "Customer profile"
content:
application/json:
schema:
$ref: "#/components/schemas/Customer"
/customers/{customer_id}/orders:
get:
operationId: listCustomerOrders
summary: "List customer orders"
description: "Get a paginated list of orders for a customer."
parameters:
- name: customer_id
in: path
required: true
schema:
type: string
- name: page
in: query
schema:
type: integer
default: 1
- name: per_page
in: query
schema:
type: integer
default: 20
responses:
"200":
description: "List of orders"
/products/{product_id}:
get:
operationId: getProduct
summary: "Get product details"
description: "Retrieve product information including inventory status."
parameters:
- name: product_id
in: path
required: true
schema:
type: string
responses:
"200":
description: "Product details"
/support/tickets:
post:
operationId: createSupportTicket
summary: "Create support ticket"
description: "Open a new support ticket for a customer issue."
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CreateTicketRequest"
responses:
"201":
description: "Ticket created"
/support/tickets/{ticket_id}:
get:
operationId: getSupportTicket
summary: "Get support ticket"
description: "Retrieve a support ticket and its conversation history."
parameters:
- name: ticket_id
in: path
required: true
schema:
type: string
responses:
"200":
description: "Ticket details"
components:
schemas:
Order:
type: object
properties:
order_id:
type: string
customer_id:
type: string
status:
type: string
enum: [pending, processing, shipped, delivered, cancelled, refunded]
items:
type: array
items:
$ref: "#/components/schemas/OrderItem"
total_usd:
type: number
created_at:
type: string
format: date-time
OrderItem:
type: object
properties:
product_id:
type: string
name:
type: string
quantity:
type: integer
unit_price_usd:
type: number
Customer:
type: object
properties:
customer_id:
type: string
email:
type: string
name:
type: string
tier:
type: string
enum: [standard, premium, vip]
created_at:
type: string
format: date-time
CreateTicketRequest:
type: object
required: [customer_id, subject, description]
properties:
customer_id:
type: string
subject:
type: string
description:
type: string
priority:
type: string
enum: [low, medium, high, urgent]
default: medium

View File

@@ -15,6 +15,16 @@ if TYPE_CHECKING:
from pathlib import Path
@pytest.fixture(autouse=True)
def clear_rate_limit_state() -> None:
"""Clear module-level rate limit state between tests to prevent leakage."""
import app.ws_handler as ws_handler
ws_handler._thread_timestamps.clear()
yield
ws_handler._thread_timestamps.clear()
@pytest.fixture
def test_settings() -> Settings:
return Settings(

View File

@@ -315,7 +315,7 @@ class TestWebSocketValidation:
@pytest.mark.asyncio
async def test_content_too_long(self) -> None:
g, sm, im, cb, ws = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000})
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
assert ws.sent[0]["type"] == "error"
assert "too long" in ws.sent[0]["message"].lower()

View File

@@ -0,0 +1,156 @@
"""Tests for app.conversation_tracker module."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.conversation_tracker import (
ConversationTrackerProtocol,
NoOpConversationTracker,
PostgresConversationTracker,
)
pytestmark = pytest.mark.unit
def _make_pool() -> AsyncMock:
"""Create a mock async connection pool."""
pool = AsyncMock()
conn = AsyncMock()
conn.execute = AsyncMock()
pool.connection = MagicMock(return_value=_AsyncContextManager(conn))
return pool, conn
class _AsyncContextManager:
"""Async context manager helper."""
def __init__(self, value: object) -> None:
self._value = value
async def __aenter__(self) -> object:
return self._value
async def __aexit__(self, *args: object) -> None:
pass
class TestConversationTrackerProtocol:
def test_noop_satisfies_protocol(self) -> None:
tracker = NoOpConversationTracker()
assert isinstance(tracker, ConversationTrackerProtocol)
def test_postgres_satisfies_protocol(self) -> None:
tracker = PostgresConversationTracker()
assert isinstance(tracker, ConversationTrackerProtocol)
class TestNoOpConversationTracker:
@pytest.mark.asyncio
async def test_ensure_conversation_does_nothing(self) -> None:
tracker = NoOpConversationTracker()
pool = AsyncMock()
# Should not raise
await tracker.ensure_conversation(pool, "thread-1")
@pytest.mark.asyncio
async def test_record_turn_does_nothing(self) -> None:
tracker = NoOpConversationTracker()
pool = AsyncMock()
await tracker.record_turn(pool, "thread-1", "agent_a", 100, 0.05)
@pytest.mark.asyncio
async def test_resolve_does_nothing(self) -> None:
tracker = NoOpConversationTracker()
pool = AsyncMock()
await tracker.resolve(pool, "thread-1", "resolved")
@pytest.mark.asyncio
async def test_accepts_none_agent_name(self) -> None:
tracker = NoOpConversationTracker()
pool = AsyncMock()
await tracker.record_turn(pool, "thread-1", None, 0, 0.0)
class TestPostgresConversationTracker:
@pytest.mark.asyncio
async def test_ensure_conversation_executes_insert(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.ensure_conversation(pool, "thread-abc")
conn.execute.assert_awaited_once()
sql, params = conn.execute.call_args[0]
assert "INSERT" in sql
assert "ON CONFLICT" in sql
assert params["thread_id"] == "thread-abc"
@pytest.mark.asyncio
async def test_record_turn_executes_update(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.record_turn(pool, "thread-abc", "order_agent", 250, 0.12)
conn.execute.assert_awaited_once()
sql, params = conn.execute.call_args[0]
assert "UPDATE" in sql
assert params["thread_id"] == "thread-abc"
assert params["agent_name"] == "order_agent"
assert params["tokens"] == 250
assert params["cost"] == 0.12
@pytest.mark.asyncio
async def test_record_turn_accepts_none_agent_name(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.record_turn(pool, "thread-abc", None, 0, 0.0)
conn.execute.assert_awaited_once()
sql, params = conn.execute.call_args[0]
assert params["agent_name"] is None
@pytest.mark.asyncio
async def test_resolve_executes_update(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.resolve(pool, "thread-abc", "resolved")
conn.execute.assert_awaited_once()
sql, params = conn.execute.call_args[0]
assert "UPDATE" in sql
assert params["thread_id"] == "thread-abc"
assert params["resolution_type"] == "resolved"
@pytest.mark.asyncio
async def test_resolve_sets_ended_at(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.resolve(pool, "thread-abc", "escalated")
sql, params = conn.execute.call_args[0]
assert "ended_at" in sql.lower()
@pytest.mark.asyncio
async def test_ensure_conversation_with_special_thread_id(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.ensure_conversation(pool, "thread-123-abc-XYZ")
conn.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_record_turn_with_zero_cost(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.record_turn(pool, "t1", "agent", 0, 0.0)
conn.execute.assert_awaited_once()

View File

@@ -0,0 +1,213 @@
"""Edge case tests for ws_handler input validation and rate limiting."""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.session_manager import SessionManager
from app.ws_handler import dispatch_message
pytestmark = pytest.mark.unit
def _make_ws() -> AsyncMock:
ws = AsyncMock()
ws.send_json = AsyncMock()
return ws
def _make_graph() -> AsyncMock:
graph = AsyncMock()
class AsyncIterHelper:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
graph.astream = MagicMock(return_value=AsyncIterHelper())
state = MagicMock()
state.tasks = ()
graph.aget_state = AsyncMock(return_value=state)
graph.intent_classifier = None
graph.agent_registry = None
return graph
@pytest.mark.unit
class TestEmptyMessageHandling:
@pytest.mark.asyncio
async def test_empty_message_content_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
await dispatch_message(ws, graph, sm, cb, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
msg_lower = call_data["message"].lower()
assert "content" in msg_lower or "missing" in msg_lower
@pytest.mark.asyncio
async def test_whitespace_only_message_treated_as_empty(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
await dispatch_message(ws, graph, sm, cb, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.unit
class TestOversizedMessageHandling:
@pytest.mark.asyncio
async def test_content_over_10000_chars_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
content = "x" * 10001
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "too long" in call_data["message"].lower()
@pytest.mark.asyncio
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
content = "x" * 10000
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg)
last_call = ws.send_json.call_args[0][0]
# Should be processed, not an error about length
msg_text = last_call.get("message", "").lower()
assert last_call["type"] != "error" or "too long" not in msg_text
@pytest.mark.asyncio
async def test_raw_message_over_32kb_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
large_msg = "x" * 40_000
await dispatch_message(ws, graph, sm, cb, large_msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "too large" in call_data["message"].lower()
@pytest.mark.unit
class TestInvalidJsonHandling:
@pytest.mark.asyncio
async def test_invalid_json_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, "not valid json {{")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "invalid json" in call_data["message"].lower()
@pytest.mark.asyncio
async def test_empty_string_returns_json_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, "")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.asyncio
async def test_json_array_not_object_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]')
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.unit
class TestRateLimiting:
@pytest.mark.asyncio
async def test_rapid_fire_messages_rate_limited(self) -> None:
ws = _make_ws()
_make_graph() # ensure graph factory works, not needed directly
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
rate_limit_triggered = False
for i in range(11):
graph2 = _make_graph() # fresh graph each time
await dispatch_message(ws, graph2, sm, cb, json.dumps({
"type": "message",
"thread_id": "t1",
"content": f"message {i}",
}))
last_call = ws.send_json.call_args[0][0]
if last_call["type"] == "error" and "rate" in last_call.get("message", "").lower():
rate_limit_triggered = True
break
assert rate_limit_triggered, "Rate limiting should trigger after 10 rapid messages"
@pytest.mark.asyncio
async def test_different_threads_have_separate_rate_limits(self) -> None:
ws = _make_ws()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
sm.touch("t2")
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
for i in range(5):
graph1 = _make_graph()
graph2 = _make_graph()
await dispatch_message(ws, graph1, sm, cb, json.dumps({
"type": "message", "thread_id": "t1", "content": f"msg {i}",
}))
await dispatch_message(ws, graph2, sm, cb, json.dumps({
"type": "message", "thread_id": "t2", "content": f"msg {i}",
}))
last_call = ws.send_json.call_args[0][0]
assert "rate" not in last_call.get("message", "").lower()

View File

@@ -0,0 +1,175 @@
"""Tests for app.tools.error_handler module."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from app.tools.error_handler import (
ErrorCategory,
classify_error,
with_retry,
)
pytestmark = pytest.mark.unit
class TestErrorClassification:
def test_timeout_exception_is_timeout(self) -> None:
exc = httpx.TimeoutException("timed out")
assert classify_error(exc) == ErrorCategory.TIMEOUT
def test_connect_error_is_network(self) -> None:
exc = httpx.ConnectError("connection refused")
assert classify_error(exc) == ErrorCategory.NETWORK
def test_401_is_auth_failure(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(401, request=request)
exc = httpx.HTTPStatusError("401", request=request, response=response)
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
def test_403_is_auth_failure(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(403, request=request)
exc = httpx.HTTPStatusError("403", request=request, response=response)
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
def test_429_is_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(429, request=request)
exc = httpx.HTTPStatusError("429", request=request, response=response)
assert classify_error(exc) == ErrorCategory.RETRYABLE
def test_500_is_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(500, request=request)
exc = httpx.HTTPStatusError("500", request=request, response=response)
assert classify_error(exc) == ErrorCategory.RETRYABLE
def test_502_is_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(502, request=request)
exc = httpx.HTTPStatusError("502", request=request, response=response)
assert classify_error(exc) == ErrorCategory.RETRYABLE
def test_503_is_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(503, request=request)
exc = httpx.HTTPStatusError("503", request=request, response=response)
assert classify_error(exc) == ErrorCategory.RETRYABLE
def test_404_is_non_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(404, request=request)
exc = httpx.HTTPStatusError("404", request=request, response=response)
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
def test_400_is_non_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(400, request=request)
exc = httpx.HTTPStatusError("400", request=request, response=response)
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
def test_generic_exception_is_non_retryable(self) -> None:
exc = ValueError("bad value")
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
def test_runtime_error_is_non_retryable(self) -> None:
exc = RuntimeError("boom")
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
class TestWithRetry:
@pytest.mark.asyncio
async def test_succeeds_on_first_try(self) -> None:
fn = AsyncMock(return_value="ok")
result = await with_retry(fn, max_retries=3, base_delay=0.0)
assert result == "ok"
assert fn.call_count == 1
@pytest.mark.asyncio
async def test_retries_on_retryable_error(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(503, request=request)
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "success"])
with patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock):
result = await with_retry(fn, max_retries=3, base_delay=0.0)
assert result == "success"
assert fn.call_count == 3
@pytest.mark.asyncio
async def test_does_not_retry_non_retryable_error(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(404, request=request)
non_retryable_exc = httpx.HTTPStatusError("404", request=request, response=response)
fn = AsyncMock(side_effect=non_retryable_exc)
with pytest.raises(httpx.HTTPStatusError):
await with_retry(fn, max_retries=3, base_delay=0.0)
assert fn.call_count == 1
@pytest.mark.asyncio
async def test_does_not_retry_auth_failure(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(401, request=request)
auth_exc = httpx.HTTPStatusError("401", request=request, response=response)
fn = AsyncMock(side_effect=auth_exc)
with pytest.raises(httpx.HTTPStatusError):
await with_retry(fn, max_retries=3, base_delay=0.0)
assert fn.call_count == 1
@pytest.mark.asyncio
async def test_raises_after_max_retries_exhausted(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(500, request=request)
retryable_exc = httpx.HTTPStatusError("500", request=request, response=response)
fn = AsyncMock(side_effect=retryable_exc)
with (
patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock),
pytest.raises(httpx.HTTPStatusError),
):
await with_retry(fn, max_retries=3, base_delay=0.0)
assert fn.call_count == 3
@pytest.mark.asyncio
async def test_does_not_retry_timeout(self) -> None:
"""TimeoutException is TIMEOUT category -- not retried by default."""
fn = AsyncMock(side_effect=httpx.TimeoutException("timed out"))
with pytest.raises(httpx.TimeoutException):
await with_retry(fn, max_retries=3, base_delay=0.0)
assert fn.call_count == 1
@pytest.mark.asyncio
async def test_exponential_backoff_increases_delay(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(503, request=request)
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "done"])
sleep_delays: list[float] = []
async def capture_sleep(delay: float) -> None:
sleep_delays.append(delay)
with patch("app.tools.error_handler.asyncio.sleep", side_effect=capture_sleep):
await with_retry(fn, max_retries=3, base_delay=1.0)
assert len(sleep_delays) == 2
assert sleep_delays[1] > sleep_delays[0]

View File

@@ -13,7 +13,7 @@ class TestMainModule:
assert app.title == "Smart Support"
def test_app_version(self) -> None:
assert app.version == "0.4.0"
assert app.version == "0.5.0"
def test_agents_yaml_path_exists(self) -> None:
assert AGENTS_YAML.name == "agents.yaml"
@@ -33,3 +33,10 @@ class TestMainModule:
def test_analytics_router_registered(self) -> None:
routes = [r.path for r in app.routes if hasattr(r, "path")]
assert any("analytics" in p for p in routes)
def test_health_route_registered(self) -> None:
routes = [r.path for r in app.routes if hasattr(r, "path")]
assert "/api/health" in routes
def test_app_version_is_0_5_0(self) -> None:
assert app.version == "0.5.0"

View File

@@ -138,7 +138,7 @@ class TestDispatchMessage:
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000})
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, graph, sm, cb, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -364,3 +364,80 @@ class TestInterruptHelpers:
state.tasks = ()
data = _extract_interrupt(state)
assert data["action"] == "unknown"
@pytest.mark.unit
class TestDispatchMessageWithTracking:
@pytest.mark.asyncio
async def test_conversation_tracker_called_on_message(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
tracker = AsyncMock()
pool = MagicMock()
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(
ws, graph, sm, cb, msg,
conversation_tracker=tracker,
pool=pool,
)
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
tracker.record_turn.assert_awaited_once()
@pytest.mark.asyncio
async def test_analytics_recorder_called_on_message(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
recorder = AsyncMock()
pool = MagicMock()
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(
ws, graph, sm, cb, msg,
analytics_recorder=recorder,
pool=pool,
)
recorder.record.assert_awaited_once()
@pytest.mark.asyncio
async def test_tracker_failure_does_not_break_chat(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
tracker = AsyncMock()
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
pool = MagicMock()
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# Should not raise despite tracker failure
await dispatch_message(
ws, graph, sm, cb, msg,
conversation_tracker=tracker,
pool=pool,
)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.asyncio
async def test_no_tracker_no_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# No tracker or recorder passed -- should work fine
await dispatch_message(ws, graph, sm, cb, msg)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"