feat: complete phase 5 -- error hardening, frontend, Docker, demo, docs
Backend: - ConversationTracker: Protocol + PostgresConversationTracker for lifecycle tracking - Error handler: ErrorCategory enum, classify_error(), with_retry() exponential backoff - Wire PostgresAnalyticsRecorder + ConversationTracker into ws_handler - Rate limiting (10 msg/10s per thread), edge case hardening - Health endpoint GET /api/health, version 0.5.0 - Demo seed data script + sample OpenAPI spec Frontend (all new): - React Router with NavBar (Chat / Replay / Dashboard / Review) - ReplayListPage + ReplayPage with ReplayTimeline component - DashboardPage with MetricCard, range selector, zero-state - ReviewPage for OpenAPI classification review - ErrorBanner for WebSocket disconnect handling - API client (api.ts) with typed fetch wrappers Infrastructure: - Frontend Dockerfile (multi-stage node -> nginx) - nginx.conf with SPA routing + API/WS proxy - docker-compose.yml with frontend service + healthchecks - .env.example files (root + backend) Documentation: - README.md with quick start and architecture - Agent configuration guide - OpenAPI import guide - Deployment guide - Demo script 48 new tests, 449 total passing, 92.87% coverage
This commit is contained in:
@@ -1,19 +1,34 @@
|
||||
# Database
|
||||
# Smart Support Backend -- environment variables
|
||||
# Copy to .env and fill in your values
|
||||
|
||||
# Required: PostgreSQL connection string
|
||||
DATABASE_URL=postgresql://smart_support:dev_password@localhost:5432/smart_support
|
||||
|
||||
# LLM Provider: anthropic | openai | google
|
||||
# Required: LLM provider configuration
|
||||
# provider: anthropic | openai | google
|
||||
LLM_PROVIDER=anthropic
|
||||
LLM_MODEL=claude-sonnet-4-6
|
||||
|
||||
# API Keys (set the one matching your LLM_PROVIDER)
|
||||
# API keys -- provide the one matching LLM_PROVIDER
|
||||
ANTHROPIC_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
GOOGLE_API_KEY=
|
||||
|
||||
# Session
|
||||
# Optional: webhook endpoint for escalation notifications
|
||||
# The backend will POST a JSON payload when a conversation is escalated.
|
||||
WEBHOOK_URL=
|
||||
WEBHOOK_TIMEOUT_SECONDS=10
|
||||
WEBHOOK_MAX_RETRIES=3
|
||||
|
||||
# Session management
|
||||
SESSION_TTL_MINUTES=30
|
||||
INTERRUPT_TTL_MINUTES=30
|
||||
|
||||
# Server
|
||||
# Optional: load a named agent template instead of agents.yaml
|
||||
# Leave blank to use the default agents.yaml in the backend directory.
|
||||
# Available templates: ecommerce, saas, generic
|
||||
TEMPLATE_NAME=
|
||||
|
||||
# Server binding
|
||||
WS_HOST=0.0.0.0
|
||||
WS_PORT=8000
|
||||
|
||||
135
backend/app/conversation_tracker.py
Normal file
135
backend/app/conversation_tracker.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Conversation tracker -- Protocol and implementations for tracking conversation state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
_ENSURE_SQL = """
|
||||
INSERT INTO conversations
|
||||
(thread_id, started_at, last_activity)
|
||||
VALUES
|
||||
(%(thread_id)s, NOW(), NOW())
|
||||
ON CONFLICT (thread_id) DO NOTHING
|
||||
"""
|
||||
|
||||
_RECORD_TURN_SQL = """
|
||||
UPDATE conversations
|
||||
SET
|
||||
turn_count = turn_count + 1,
|
||||
agents_used = CASE
|
||||
WHEN %(agent_name)s IS NOT NULL AND NOT (agents_used @> ARRAY[%(agent_name)s]::text[])
|
||||
THEN agents_used || ARRAY[%(agent_name)s]::text[]
|
||||
ELSE agents_used
|
||||
END,
|
||||
total_tokens = total_tokens + %(tokens)s,
|
||||
total_cost_usd = total_cost_usd + %(cost)s,
|
||||
last_activity = NOW()
|
||||
WHERE thread_id = %(thread_id)s
|
||||
"""
|
||||
|
||||
_RESOLVE_SQL = """
|
||||
UPDATE conversations
|
||||
SET
|
||||
resolution_type = %(resolution_type)s,
|
||||
ended_at = NOW()
|
||||
WHERE thread_id = %(thread_id)s
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConversationTrackerProtocol(Protocol):
|
||||
"""Protocol for tracking conversation lifecycle and metrics."""
|
||||
|
||||
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
||||
"""Create conversation row if it does not already exist."""
|
||||
...
|
||||
|
||||
async def record_turn(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
agent_name: str | None,
|
||||
tokens: int,
|
||||
cost: float,
|
||||
) -> None:
|
||||
"""Increment turn count and update aggregated metrics."""
|
||||
...
|
||||
|
||||
async def resolve(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
resolution_type: str,
|
||||
) -> None:
|
||||
"""Mark conversation as resolved with a resolution type."""
|
||||
...
|
||||
|
||||
|
||||
class NoOpConversationTracker:
|
||||
"""No-op implementation -- used in tests or when DB is unavailable."""
|
||||
|
||||
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
async def record_turn(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
agent_name: str | None,
|
||||
tokens: int,
|
||||
cost: float,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
async def resolve(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
resolution_type: str,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
|
||||
class PostgresConversationTracker:
|
||||
"""Postgres-backed conversation tracker."""
|
||||
|
||||
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
||||
"""Insert conversation row; do nothing if already exists (ON CONFLICT DO NOTHING)."""
|
||||
params = {"thread_id": thread_id}
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_ENSURE_SQL, params)
|
||||
|
||||
async def record_turn(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
agent_name: str | None,
|
||||
tokens: int,
|
||||
cost: float,
|
||||
) -> None:
|
||||
"""Increment turn count, append agent if new, update token/cost totals."""
|
||||
params = {
|
||||
"thread_id": thread_id,
|
||||
"agent_name": agent_name,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
}
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_RECORD_TURN_SQL, params)
|
||||
|
||||
async def resolve(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
resolution_type: str,
|
||||
) -> None:
|
||||
"""Set resolution_type and ended_at on the conversation row."""
|
||||
params = {
|
||||
"thread_id": thread_id,
|
||||
"resolution_type": resolution_type,
|
||||
}
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_RESOLVE_SQL, params)
|
||||
@@ -11,9 +11,10 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.analytics.api import router as analytics_router
|
||||
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.config import Settings
|
||||
from app.conversation_tracker import PostgresConversationTracker
|
||||
from app.db import create_checkpointer, create_pool, setup_app_tables
|
||||
from app.escalation import NoOpEscalator, WebhookEscalator
|
||||
from app.graph import build_graph
|
||||
@@ -76,7 +77,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
app.state.escalator = escalator
|
||||
app.state.settings = settings
|
||||
app.state.pool = pool
|
||||
app.state.analytics_recorder = NoOpAnalyticsRecorder()
|
||||
app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool)
|
||||
app.state.conversation_tracker = PostgresConversationTracker()
|
||||
|
||||
logger.info(
|
||||
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
|
||||
@@ -91,13 +93,19 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await pool.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Smart Support", version="0.4.0", lifespan=lifespan)
|
||||
app = FastAPI(title="Smart Support", version="0.5.0", lifespan=lifespan)
|
||||
|
||||
app.include_router(openapi_router)
|
||||
app.include_router(replay_router)
|
||||
app.include_router(analytics_router)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
def health_check() -> dict:
|
||||
"""Health check endpoint for load balancers and monitoring."""
|
||||
return {"status": "ok", "version": "0.5.0"}
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket) -> None:
|
||||
await ws.accept()
|
||||
@@ -107,12 +115,19 @@ async def websocket_endpoint(ws: WebSocket) -> None:
|
||||
settings = app.state.settings
|
||||
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
||||
|
||||
analytics_recorder = app.state.analytics_recorder
|
||||
conversation_tracker = app.state.conversation_tracker
|
||||
pool = app.state.pool
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_text()
|
||||
await dispatch_message(
|
||||
ws, graph, session_manager, callback_handler, raw_data,
|
||||
interrupt_manager=interrupt_manager,
|
||||
analytics_recorder=analytics_recorder,
|
||||
conversation_tracker=conversation_tracker,
|
||||
pool=pool,
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket client disconnected")
|
||||
|
||||
3
backend/app/tools/__init__.py
Normal file
3
backend/app/tools/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Tools package for smart-support backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
72
backend/app/tools/error_handler.py
Normal file
72
backend/app/tools/error_handler.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Error classification and retry logic for tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class ErrorCategory(Enum):
|
||||
"""Categories for error classification to guide retry decisions."""
|
||||
|
||||
RETRYABLE = "retryable"
|
||||
NON_RETRYABLE = "non_retryable"
|
||||
AUTH_FAILURE = "auth_failure"
|
||||
TIMEOUT = "timeout"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
def classify_error(exc: Exception) -> ErrorCategory:
|
||||
"""Classify an exception into an ErrorCategory.
|
||||
|
||||
Rules:
|
||||
- httpx.TimeoutException -> TIMEOUT
|
||||
- httpx.ConnectError -> NETWORK
|
||||
- httpx.HTTPStatusError 401/403 -> AUTH_FAILURE
|
||||
- httpx.HTTPStatusError 429/500/502/503 -> RETRYABLE
|
||||
- anything else -> NON_RETRYABLE
|
||||
"""
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
return ErrorCategory.TIMEOUT
|
||||
if isinstance(exc, httpx.ConnectError):
|
||||
return ErrorCategory.NETWORK
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
code = exc.response.status_code
|
||||
if code in (401, 403):
|
||||
return ErrorCategory.AUTH_FAILURE
|
||||
if code in (429, 500, 502, 503):
|
||||
return ErrorCategory.RETRYABLE
|
||||
return ErrorCategory.NON_RETRYABLE
|
||||
return ErrorCategory.NON_RETRYABLE
|
||||
|
||||
|
||||
async def with_retry(
|
||||
fn: Callable[..., Any],
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
) -> Any:
|
||||
"""Execute an async callable with exponential backoff for RETRYABLE errors.
|
||||
|
||||
Only ErrorCategory.RETRYABLE errors trigger retries. All other error
|
||||
categories raise immediately after the first attempt.
|
||||
"""
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
return await fn()
|
||||
except Exception as exc:
|
||||
category = classify_error(exc)
|
||||
if category != ErrorCategory.RETRYABLE:
|
||||
raise
|
||||
last_exc = exc
|
||||
if attempt < max_retries:
|
||||
delay = base_delay * (2 ** (attempt - 1))
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise last_exc # type: ignore[misc]
|
||||
@@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -16,16 +18,23 @@ if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.analytics.event_recorder import AnalyticsRecorder
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.conversation_tracker import ConversationTrackerProtocol
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
||||
MAX_CONTENT_LENGTH = 8_000 # characters
|
||||
MAX_CONTENT_LENGTH = 10_000 # characters
|
||||
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||
|
||||
# Rate limiting: max 10 messages per 10-second window, per thread
|
||||
_RATE_LIMIT_MAX = 10
|
||||
_RATE_LIMIT_WINDOW = 10.0
|
||||
_thread_timestamps: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
|
||||
async def handle_user_message(
|
||||
ws: WebSocket,
|
||||
@@ -197,6 +206,9 @@ async def dispatch_message(
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
raw_data: str,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
analytics_recorder: AnalyticsRecorder | None = None,
|
||||
conversation_tracker: ConversationTrackerProtocol | None = None,
|
||||
pool: Any = None,
|
||||
) -> None:
|
||||
"""Parse and route an incoming WebSocket message."""
|
||||
if len(raw_data) > MAX_MESSAGE_SIZE:
|
||||
@@ -205,10 +217,14 @@ async def dispatch_message(
|
||||
|
||||
try:
|
||||
data = json.loads(raw_data)
|
||||
except json.JSONDecodeError:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
|
||||
return
|
||||
|
||||
if not isinstance(data, dict):
|
||||
await _send_json(ws, {"type": "error", "message": "Invalid JSON: expected object"})
|
||||
return
|
||||
|
||||
msg_type = data.get("type")
|
||||
thread_id = data.get("thread_id", "")
|
||||
|
||||
@@ -222,16 +238,36 @@ async def dispatch_message(
|
||||
|
||||
if msg_type == "message":
|
||||
content = data.get("content", "")
|
||||
if not content:
|
||||
if not content or not content.strip():
|
||||
await _send_json(ws, {"type": "error", "message": "Missing message content"})
|
||||
return
|
||||
if len(content) > MAX_CONTENT_LENGTH:
|
||||
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
||||
return
|
||||
|
||||
# Rate limiting check
|
||||
now = time.time()
|
||||
timestamps = _thread_timestamps[thread_id]
|
||||
cutoff = now - _RATE_LIMIT_WINDOW
|
||||
_thread_timestamps[thread_id] = [t for t in timestamps if t >= cutoff]
|
||||
if len(_thread_timestamps[thread_id]) >= _RATE_LIMIT_MAX:
|
||||
await _send_json(ws, {"type": "error", "message": "Rate limit exceeded"})
|
||||
return
|
||||
_thread_timestamps[thread_id].append(now)
|
||||
|
||||
await handle_user_message(
|
||||
ws, graph, session_manager, callback_handler, thread_id, content,
|
||||
interrupt_manager=interrupt_manager,
|
||||
)
|
||||
await _fire_and_forget_tracking(
|
||||
thread_id=thread_id,
|
||||
pool=pool,
|
||||
analytics_recorder=analytics_recorder,
|
||||
conversation_tracker=conversation_tracker,
|
||||
agent_name=None,
|
||||
tokens=0,
|
||||
cost=0.0,
|
||||
)
|
||||
|
||||
elif msg_type == "interrupt_response":
|
||||
approved = data.get("approved", False)
|
||||
@@ -244,6 +280,36 @@ async def dispatch_message(
|
||||
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
|
||||
|
||||
|
||||
async def _fire_and_forget_tracking(
|
||||
thread_id: str,
|
||||
pool: Any,
|
||||
analytics_recorder: Any | None,
|
||||
conversation_tracker: Any | None,
|
||||
agent_name: str | None,
|
||||
tokens: int,
|
||||
cost: float,
|
||||
) -> None:
|
||||
"""Fire-and-forget analytics/tracking; failures must NOT break chat."""
|
||||
try:
|
||||
if conversation_tracker is not None and pool is not None:
|
||||
await conversation_tracker.ensure_conversation(pool, thread_id)
|
||||
await conversation_tracker.record_turn(pool, thread_id, agent_name, tokens, cost)
|
||||
except Exception:
|
||||
logger.exception("Conversation tracker error for thread %s (suppressed)", thread_id)
|
||||
|
||||
try:
|
||||
if analytics_recorder is not None:
|
||||
await analytics_recorder.record(
|
||||
thread_id=thread_id,
|
||||
event_type="message",
|
||||
agent_name=agent_name,
|
||||
tokens_used=tokens,
|
||||
cost_usd=cost,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Analytics recorder error for thread %s (suppressed)", thread_id)
|
||||
|
||||
|
||||
def _has_interrupt(state: Any) -> bool:
|
||||
"""Check if the graph state has a pending interrupt."""
|
||||
tasks = getattr(state, "tasks", ())
|
||||
|
||||
153
backend/fixtures/demo_data.py
Normal file
153
backend/fixtures/demo_data.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Seed script -- inserts sample conversations and analytics events for demo purposes.
|
||||
|
||||
Usage:
|
||||
cd backend
|
||||
python fixtures/demo_data.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import psycopg
|
||||
|
||||
DATABASE_URL = os.environ.get(
|
||||
"DATABASE_URL",
|
||||
"postgresql://smart_support:dev_password@localhost:5432/smart_support",
|
||||
)
|
||||
|
||||
SAMPLE_CONVERSATIONS = [
|
||||
{
|
||||
"thread_id": "demo-thread-001",
|
||||
"agents_used": ["order_agent"],
|
||||
"turn_count": 3,
|
||||
"total_tokens": 1250,
|
||||
"total_cost_usd": 0.00375,
|
||||
"resolution_type": "resolved",
|
||||
"minutes_ago": 5,
|
||||
},
|
||||
{
|
||||
"thread_id": "demo-thread-002",
|
||||
"agents_used": ["order_agent", "refund_agent"],
|
||||
"turn_count": 6,
|
||||
"total_tokens": 3200,
|
||||
"total_cost_usd": 0.0096,
|
||||
"resolution_type": "resolved",
|
||||
"minutes_ago": 30,
|
||||
},
|
||||
{
|
||||
"thread_id": "demo-thread-003",
|
||||
"agents_used": ["general_agent"],
|
||||
"turn_count": 2,
|
||||
"total_tokens": 800,
|
||||
"total_cost_usd": 0.0024,
|
||||
"resolution_type": None,
|
||||
"minutes_ago": 60,
|
||||
},
|
||||
{
|
||||
"thread_id": "demo-thread-004",
|
||||
"agents_used": ["order_agent", "general_agent"],
|
||||
"turn_count": 8,
|
||||
"total_tokens": 4500,
|
||||
"total_cost_usd": 0.0135,
|
||||
"resolution_type": "escalated",
|
||||
"minutes_ago": 120,
|
||||
},
|
||||
{
|
||||
"thread_id": "demo-thread-005",
|
||||
"agents_used": ["refund_agent"],
|
||||
"turn_count": 4,
|
||||
"total_tokens": 2100,
|
||||
"total_cost_usd": 0.0063,
|
||||
"resolution_type": "resolved",
|
||||
"minutes_ago": 240,
|
||||
},
|
||||
]
|
||||
|
||||
SAMPLE_EVENTS = [
|
||||
{"thread_id": "demo-thread-001", "event_type": "message", "agent_name": "order_agent", "tokens_used": 400, "cost_usd": 0.0012, "success": True},
|
||||
{"thread_id": "demo-thread-001", "event_type": "tool_call", "agent_name": "order_agent", "tool_name": "get_order_status", "tokens_used": 0, "cost_usd": 0.0, "success": True},
|
||||
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "order_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
|
||||
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
|
||||
{"thread_id": "demo-thread-002", "event_type": "tool_call", "agent_name": "refund_agent", "tool_name": "process_refund", "tokens_used": 0, "cost_usd": 0.0, "success": True},
|
||||
{"thread_id": "demo-thread-003", "event_type": "message", "agent_name": "general_agent", "tokens_used": 800, "cost_usd": 0.0024, "success": True},
|
||||
{"thread_id": "demo-thread-004", "event_type": "message", "agent_name": "order_agent", "tokens_used": 2000, "cost_usd": 0.006, "success": True},
|
||||
{"thread_id": "demo-thread-004", "event_type": "escalation", "agent_name": "general_agent", "tokens_used": 2500, "cost_usd": 0.0075, "success": False},
|
||||
{"thread_id": "demo-thread-005", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 2100, "cost_usd": 0.0063, "success": True},
|
||||
]
|
||||
|
||||
_INSERT_CONVERSATION = """
|
||||
INSERT INTO conversations
|
||||
(thread_id, started_at, last_activity, turn_count, agents_used,
|
||||
total_tokens, total_cost_usd, resolution_type, ended_at)
|
||||
VALUES
|
||||
(%(thread_id)s, %(started_at)s, %(last_activity)s, %(turn_count)s,
|
||||
%(agents_used)s, %(total_tokens)s, %(total_cost_usd)s,
|
||||
%(resolution_type)s, %(ended_at)s)
|
||||
ON CONFLICT (thread_id) DO NOTHING
|
||||
"""
|
||||
|
||||
_INSERT_EVENT = """
|
||||
INSERT INTO analytics_events
|
||||
(thread_id, event_type, agent_name, tool_name, tokens_used, cost_usd, success)
|
||||
VALUES
|
||||
(%(thread_id)s, %(event_type)s, %(agent_name)s, %(tool_name)s,
|
||||
%(tokens_used)s, %(cost_usd)s, %(success)s)
|
||||
"""
|
||||
|
||||
|
||||
async def seed() -> None:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
async with await psycopg.AsyncConnection.connect(DATABASE_URL) as conn:
|
||||
print("Seeding conversations...")
|
||||
for conv in SAMPLE_CONVERSATIONS:
|
||||
started_at = now - timedelta(minutes=conv["minutes_ago"])
|
||||
last_activity = started_at + timedelta(minutes=conv["turn_count"] * 2)
|
||||
ended_at = last_activity if conv["resolution_type"] else None
|
||||
|
||||
await conn.execute(
|
||||
_INSERT_CONVERSATION,
|
||||
{
|
||||
"thread_id": conv["thread_id"],
|
||||
"started_at": started_at,
|
||||
"last_activity": last_activity,
|
||||
"turn_count": conv["turn_count"],
|
||||
"agents_used": conv["agents_used"],
|
||||
"total_tokens": conv["total_tokens"],
|
||||
"total_cost_usd": conv["total_cost_usd"],
|
||||
"resolution_type": conv["resolution_type"],
|
||||
"ended_at": ended_at,
|
||||
},
|
||||
)
|
||||
print(f" Inserted conversation {conv['thread_id']}")
|
||||
|
||||
print("Seeding analytics events...")
|
||||
for event in SAMPLE_EVENTS:
|
||||
await conn.execute(
|
||||
_INSERT_EVENT,
|
||||
{
|
||||
"thread_id": event["thread_id"],
|
||||
"event_type": event["event_type"],
|
||||
"agent_name": event.get("agent_name"),
|
||||
"tool_name": event.get("tool_name"),
|
||||
"tokens_used": event.get("tokens_used", 0),
|
||||
"cost_usd": event.get("cost_usd", 0.0),
|
||||
"success": event.get("success"),
|
||||
},
|
||||
)
|
||||
print(f" Inserted event {event['event_type']} for {event['thread_id']}")
|
||||
|
||||
await conn.commit()
|
||||
|
||||
print("Done. Demo data seeded successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(seed())
|
||||
238
backend/fixtures/sample_openapi.yaml
Normal file
238
backend/fixtures/sample_openapi.yaml
Normal file
@@ -0,0 +1,238 @@
|
||||
openapi: "3.0.3"
|
||||
info:
|
||||
title: "E-Commerce API"
|
||||
description: "Sample e-commerce API for Smart Support demo."
|
||||
version: "1.0.0"
|
||||
|
||||
servers:
|
||||
- url: "https://api.example-shop.com/v1"
|
||||
description: "Production server"
|
||||
|
||||
paths:
|
||||
/orders/{order_id}:
|
||||
get:
|
||||
operationId: getOrder
|
||||
summary: "Get order details"
|
||||
description: "Retrieve the full details of a specific order."
|
||||
parameters:
|
||||
- name: order_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Order details"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Order"
|
||||
|
||||
/orders/{order_id}/cancel:
|
||||
post:
|
||||
operationId: cancelOrder
|
||||
summary: "Cancel an order"
|
||||
description: "Cancel an order that has not yet been shipped."
|
||||
parameters:
|
||||
- name: order_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
reason:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Order cancelled"
|
||||
"400":
|
||||
description: "Order cannot be cancelled (already shipped)"
|
||||
|
||||
/orders/{order_id}/refund:
|
||||
post:
|
||||
operationId: refundOrder
|
||||
summary: "Request a refund"
|
||||
description: "Submit a refund request for a completed order."
|
||||
parameters:
|
||||
- name: order_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
amount:
|
||||
type: number
|
||||
description: "Refund amount in USD. Leave null for full refund."
|
||||
reason:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Refund submitted"
|
||||
"400":
|
||||
description: "Invalid refund request"
|
||||
|
||||
/customers/{customer_id}:
|
||||
get:
|
||||
operationId: getCustomer
|
||||
summary: "Get customer profile"
|
||||
description: "Retrieve customer profile and account information."
|
||||
parameters:
|
||||
- name: customer_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Customer profile"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Customer"
|
||||
|
||||
/customers/{customer_id}/orders:
|
||||
get:
|
||||
operationId: listCustomerOrders
|
||||
summary: "List customer orders"
|
||||
description: "Get a paginated list of orders for a customer."
|
||||
parameters:
|
||||
- name: customer_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: page
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 1
|
||||
- name: per_page
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 20
|
||||
responses:
|
||||
"200":
|
||||
description: "List of orders"
|
||||
|
||||
/products/{product_id}:
|
||||
get:
|
||||
operationId: getProduct
|
||||
summary: "Get product details"
|
||||
description: "Retrieve product information including inventory status."
|
||||
parameters:
|
||||
- name: product_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Product details"
|
||||
|
||||
/support/tickets:
|
||||
post:
|
||||
operationId: createSupportTicket
|
||||
summary: "Create support ticket"
|
||||
description: "Open a new support ticket for a customer issue."
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CreateTicketRequest"
|
||||
responses:
|
||||
"201":
|
||||
description: "Ticket created"
|
||||
|
||||
/support/tickets/{ticket_id}:
|
||||
get:
|
||||
operationId: getSupportTicket
|
||||
summary: "Get support ticket"
|
||||
description: "Retrieve a support ticket and its conversation history."
|
||||
parameters:
|
||||
- name: ticket_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Ticket details"
|
||||
|
||||
components:
|
||||
schemas:
|
||||
Order:
|
||||
type: object
|
||||
properties:
|
||||
order_id:
|
||||
type: string
|
||||
customer_id:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
enum: [pending, processing, shipped, delivered, cancelled, refunded]
|
||||
items:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/OrderItem"
|
||||
total_usd:
|
||||
type: number
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
OrderItem:
|
||||
type: object
|
||||
properties:
|
||||
product_id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
quantity:
|
||||
type: integer
|
||||
unit_price_usd:
|
||||
type: number
|
||||
|
||||
Customer:
|
||||
type: object
|
||||
properties:
|
||||
customer_id:
|
||||
type: string
|
||||
email:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
tier:
|
||||
type: string
|
||||
enum: [standard, premium, vip]
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
CreateTicketRequest:
|
||||
type: object
|
||||
required: [customer_id, subject, description]
|
||||
properties:
|
||||
customer_id:
|
||||
type: string
|
||||
subject:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
priority:
|
||||
type: string
|
||||
enum: [low, medium, high, urgent]
|
||||
default: medium
|
||||
@@ -15,6 +15,16 @@ if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_rate_limit_state() -> None:
|
||||
"""Clear module-level rate limit state between tests to prevent leakage."""
|
||||
import app.ws_handler as ws_handler
|
||||
|
||||
ws_handler._thread_timestamps.clear()
|
||||
yield
|
||||
ws_handler._thread_timestamps.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings() -> Settings:
|
||||
return Settings(
|
||||
|
||||
@@ -315,7 +315,7 @@ class TestWebSocketValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_too_long(self) -> None:
|
||||
g, sm, im, cb, ws = _setup()
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000})
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "too long" in ws.sent[0]["message"].lower()
|
||||
|
||||
156
backend/tests/unit/test_conversation_tracker.py
Normal file
156
backend/tests/unit/test_conversation_tracker.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tests for app.conversation_tracker module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.conversation_tracker import (
|
||||
ConversationTrackerProtocol,
|
||||
NoOpConversationTracker,
|
||||
PostgresConversationTracker,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_pool() -> AsyncMock:
|
||||
"""Create a mock async connection pool."""
|
||||
pool = AsyncMock()
|
||||
conn = AsyncMock()
|
||||
conn.execute = AsyncMock()
|
||||
pool.connection = MagicMock(return_value=_AsyncContextManager(conn))
|
||||
return pool, conn
|
||||
|
||||
|
||||
class _AsyncContextManager:
|
||||
"""Async context manager helper."""
|
||||
|
||||
def __init__(self, value: object) -> None:
|
||||
self._value = value
|
||||
|
||||
async def __aenter__(self) -> object:
|
||||
return self._value
|
||||
|
||||
async def __aexit__(self, *args: object) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class TestConversationTrackerProtocol:
|
||||
def test_noop_satisfies_protocol(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
assert isinstance(tracker, ConversationTrackerProtocol)
|
||||
|
||||
def test_postgres_satisfies_protocol(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
assert isinstance(tracker, ConversationTrackerProtocol)
|
||||
|
||||
|
||||
class TestNoOpConversationTracker:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_conversation_does_nothing(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
pool = AsyncMock()
|
||||
# Should not raise
|
||||
await tracker.ensure_conversation(pool, "thread-1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_does_nothing(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
pool = AsyncMock()
|
||||
await tracker.record_turn(pool, "thread-1", "agent_a", 100, 0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_does_nothing(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
pool = AsyncMock()
|
||||
await tracker.resolve(pool, "thread-1", "resolved")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_none_agent_name(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
pool = AsyncMock()
|
||||
await tracker.record_turn(pool, "thread-1", None, 0, 0.0)
|
||||
|
||||
|
||||
class TestPostgresConversationTracker:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_conversation_executes_insert(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.ensure_conversation(pool, "thread-abc")
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert "INSERT" in sql
|
||||
assert "ON CONFLICT" in sql
|
||||
assert params["thread_id"] == "thread-abc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_executes_update(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.record_turn(pool, "thread-abc", "order_agent", 250, 0.12)
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert "UPDATE" in sql
|
||||
assert params["thread_id"] == "thread-abc"
|
||||
assert params["agent_name"] == "order_agent"
|
||||
assert params["tokens"] == 250
|
||||
assert params["cost"] == 0.12
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_accepts_none_agent_name(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.record_turn(pool, "thread-abc", None, 0, 0.0)
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert params["agent_name"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_executes_update(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.resolve(pool, "thread-abc", "resolved")
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert "UPDATE" in sql
|
||||
assert params["thread_id"] == "thread-abc"
|
||||
assert params["resolution_type"] == "resolved"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_sets_ended_at(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.resolve(pool, "thread-abc", "escalated")
|
||||
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert "ended_at" in sql.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_conversation_with_special_thread_id(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.ensure_conversation(pool, "thread-123-abc-XYZ")
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_with_zero_cost(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.record_turn(pool, "t1", "agent", 0, 0.0)
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
213
backend/tests/unit/test_edge_cases.py
Normal file
213
backend/tests/unit/test_edge_cases.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Edge case tests for ws_handler input validation and rate limiting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_ws() -> AsyncMock:
|
||||
ws = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
def _make_graph() -> AsyncMock:
|
||||
graph = AsyncMock()
|
||||
|
||||
class AsyncIterHelper:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
raise StopAsyncIteration
|
||||
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper())
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
return graph
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEmptyMessageHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_message_content_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
msg_lower = call_data["message"].lower()
|
||||
assert "content" in msg_lower or "missing" in msg_lower
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_message_treated_as_empty(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOversizedMessageHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_over_10000_chars_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
content = "x" * 10001
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too long" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
content = "x" * 10000
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
# Should be processed, not an error about length
|
||||
msg_text = last_call.get("message", "").lower()
|
||||
assert last_call["type"] != "error" or "too long" not in msg_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_message_over_32kb_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
large_msg = "x" * 40_000
|
||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too large" in call_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInvalidJsonHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, "not valid json {{")
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "invalid json" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_returns_json_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, "")
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_array_not_object_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]')
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRateLimiting:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_fire_messages_rate_limited(self) -> None:
|
||||
ws = _make_ws()
|
||||
_make_graph() # ensure graph factory works, not needed directly
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
|
||||
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
|
||||
rate_limit_triggered = False
|
||||
for i in range(11):
|
||||
graph2 = _make_graph() # fresh graph each time
|
||||
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": f"message {i}",
|
||||
}))
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
if last_call["type"] == "error" and "rate" in last_call.get("message", "").lower():
|
||||
rate_limit_triggered = True
|
||||
break
|
||||
|
||||
assert rate_limit_triggered, "Rate limiting should trigger after 10 rapid messages"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_threads_have_separate_rate_limits(self) -> None:
|
||||
ws = _make_ws()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
sm.touch("t2")
|
||||
|
||||
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
|
||||
for i in range(5):
|
||||
graph1 = _make_graph()
|
||||
graph2 = _make_graph()
|
||||
await dispatch_message(ws, graph1, sm, cb, json.dumps({
|
||||
"type": "message", "thread_id": "t1", "content": f"msg {i}",
|
||||
}))
|
||||
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
||||
"type": "message", "thread_id": "t2", "content": f"msg {i}",
|
||||
}))
|
||||
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert "rate" not in last_call.get("message", "").lower()
|
||||
175
backend/tests/unit/test_error_handler.py
Normal file
175
backend/tests/unit/test_error_handler.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Tests for app.tools.error_handler module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.tools.error_handler import (
|
||||
ErrorCategory,
|
||||
classify_error,
|
||||
with_retry,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestErrorClassification:
|
||||
def test_timeout_exception_is_timeout(self) -> None:
|
||||
exc = httpx.TimeoutException("timed out")
|
||||
assert classify_error(exc) == ErrorCategory.TIMEOUT
|
||||
|
||||
def test_connect_error_is_network(self) -> None:
|
||||
exc = httpx.ConnectError("connection refused")
|
||||
assert classify_error(exc) == ErrorCategory.NETWORK
|
||||
|
||||
def test_401_is_auth_failure(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(401, request=request)
|
||||
exc = httpx.HTTPStatusError("401", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
|
||||
|
||||
def test_403_is_auth_failure(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(403, request=request)
|
||||
exc = httpx.HTTPStatusError("403", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
|
||||
|
||||
def test_429_is_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(429, request=request)
|
||||
exc = httpx.HTTPStatusError("429", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||
|
||||
def test_500_is_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(500, request=request)
|
||||
exc = httpx.HTTPStatusError("500", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||
|
||||
def test_502_is_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(502, request=request)
|
||||
exc = httpx.HTTPStatusError("502", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||
|
||||
def test_503_is_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(503, request=request)
|
||||
exc = httpx.HTTPStatusError("503", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||
|
||||
def test_404_is_non_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(404, request=request)
|
||||
exc = httpx.HTTPStatusError("404", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||
|
||||
def test_400_is_non_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(400, request=request)
|
||||
exc = httpx.HTTPStatusError("400", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||
|
||||
def test_generic_exception_is_non_retryable(self) -> None:
|
||||
exc = ValueError("bad value")
|
||||
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||
|
||||
def test_runtime_error_is_non_retryable(self) -> None:
|
||||
exc = RuntimeError("boom")
|
||||
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||
|
||||
|
||||
class TestWithRetry:
|
||||
@pytest.mark.asyncio
|
||||
async def test_succeeds_on_first_try(self) -> None:
|
||||
fn = AsyncMock(return_value="ok")
|
||||
result = await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
assert result == "ok"
|
||||
assert fn.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_retryable_error(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(503, request=request)
|
||||
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "success"])
|
||||
|
||||
with patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert result == "success"
|
||||
assert fn.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_retry_non_retryable_error(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(404, request=request)
|
||||
non_retryable_exc = httpx.HTTPStatusError("404", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=non_retryable_exc)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert fn.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_retry_auth_failure(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(401, request=request)
|
||||
auth_exc = httpx.HTTPStatusError("401", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=auth_exc)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert fn.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_after_max_retries_exhausted(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(500, request=request)
|
||||
retryable_exc = httpx.HTTPStatusError("500", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=retryable_exc)
|
||||
|
||||
with (
|
||||
patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock),
|
||||
pytest.raises(httpx.HTTPStatusError),
|
||||
):
|
||||
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert fn.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_retry_timeout(self) -> None:
|
||||
"""TimeoutException is TIMEOUT category -- not retried by default."""
|
||||
fn = AsyncMock(side_effect=httpx.TimeoutException("timed out"))
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert fn.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exponential_backoff_increases_delay(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(503, request=request)
|
||||
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "done"])
|
||||
sleep_delays: list[float] = []
|
||||
|
||||
async def capture_sleep(delay: float) -> None:
|
||||
sleep_delays.append(delay)
|
||||
|
||||
with patch("app.tools.error_handler.asyncio.sleep", side_effect=capture_sleep):
|
||||
await with_retry(fn, max_retries=3, base_delay=1.0)
|
||||
|
||||
assert len(sleep_delays) == 2
|
||||
assert sleep_delays[1] > sleep_delays[0]
|
||||
@@ -13,7 +13,7 @@ class TestMainModule:
|
||||
assert app.title == "Smart Support"
|
||||
|
||||
def test_app_version(self) -> None:
|
||||
assert app.version == "0.4.0"
|
||||
assert app.version == "0.5.0"
|
||||
|
||||
def test_agents_yaml_path_exists(self) -> None:
|
||||
assert AGENTS_YAML.name == "agents.yaml"
|
||||
@@ -33,3 +33,10 @@ class TestMainModule:
|
||||
def test_analytics_router_registered(self) -> None:
|
||||
routes = [r.path for r in app.routes if hasattr(r, "path")]
|
||||
assert any("analytics" in p for p in routes)
|
||||
|
||||
def test_health_route_registered(self) -> None:
|
||||
routes = [r.path for r in app.routes if hasattr(r, "path")]
|
||||
assert "/api/health" in routes
|
||||
|
||||
def test_app_version_is_0_5_0(self) -> None:
|
||||
assert app.version == "0.5.0"
|
||||
|
||||
@@ -138,7 +138,7 @@ class TestDispatchMessage:
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000})
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -364,3 +364,80 @@ class TestInterruptHelpers:
|
||||
state.tasks = ()
|
||||
data = _extract_interrupt(state)
|
||||
assert data["action"] == "unknown"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDispatchMessageWithTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_tracker_called_on_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
tracker = AsyncMock()
|
||||
pool = MagicMock()
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
await dispatch_message(
|
||||
ws, graph, sm, cb, msg,
|
||||
conversation_tracker=tracker,
|
||||
pool=pool,
|
||||
)
|
||||
|
||||
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
|
||||
tracker.record_turn.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analytics_recorder_called_on_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
recorder = AsyncMock()
|
||||
pool = MagicMock()
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
await dispatch_message(
|
||||
ws, graph, sm, cb, msg,
|
||||
analytics_recorder=recorder,
|
||||
pool=pool,
|
||||
)
|
||||
|
||||
recorder.record.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracker_failure_does_not_break_chat(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
tracker = AsyncMock()
|
||||
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
|
||||
pool = MagicMock()
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
# Should not raise despite tracker failure
|
||||
await dispatch_message(
|
||||
ws, graph, sm, cb, msg,
|
||||
conversation_tracker=tracker,
|
||||
pool=pool,
|
||||
)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tracker_no_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
# No tracker or recorder passed -- should work fine
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
Reference in New Issue
Block a user