feat: complete phase 5 -- error hardening, frontend, Docker, demo, docs

Backend:
- ConversationTracker: Protocol + PostgresConversationTracker for lifecycle tracking
- Error handler: ErrorCategory enum, classify_error(), with_retry() exponential backoff
- Wire PostgresAnalyticsRecorder + ConversationTracker into ws_handler
- Rate limiting (10 msg/10s per thread), edge case hardening
- Health endpoint GET /api/health, version 0.5.0
- Demo seed data script + sample OpenAPI spec

Frontend (all new):
- React Router with NavBar (Chat / Replay / Dashboard / Review)
- ReplayListPage + ReplayPage with ReplayTimeline component
- DashboardPage with MetricCard, range selector, zero-state
- ReviewPage for OpenAPI classification review
- ErrorBanner for WebSocket disconnect handling
- API client (api.ts) with typed fetch wrappers

Infrastructure:
- Frontend Dockerfile (multi-stage node -> nginx)
- nginx.conf with SPA routing + API/WS proxy
- docker-compose.yml with frontend service + healthchecks
- .env.example files (root + backend)

Documentation:
- README.md with quick start and architecture
- Agent configuration guide
- OpenAPI import guide
- Deployment guide
- Demo script

48 new tests, 449 total passing, 92.87% coverage
This commit is contained in:
Yaojia Wang
2026-03-31 21:20:06 +02:00
parent 38644594d2
commit 0e78e5b06b
44 changed files with 3397 additions and 169 deletions

View File

@@ -0,0 +1,135 @@
"""Conversation tracker -- Protocol and implementations for tracking conversation state."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol, runtime_checkable
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
_ENSURE_SQL = """
INSERT INTO conversations
(thread_id, started_at, last_activity)
VALUES
(%(thread_id)s, NOW(), NOW())
ON CONFLICT (thread_id) DO NOTHING
"""
_RECORD_TURN_SQL = """
UPDATE conversations
SET
turn_count = turn_count + 1,
agents_used = CASE
WHEN %(agent_name)s IS NOT NULL AND NOT (agents_used @> ARRAY[%(agent_name)s]::text[])
THEN agents_used || ARRAY[%(agent_name)s]::text[]
ELSE agents_used
END,
total_tokens = total_tokens + %(tokens)s,
total_cost_usd = total_cost_usd + %(cost)s,
last_activity = NOW()
WHERE thread_id = %(thread_id)s
"""
_RESOLVE_SQL = """
UPDATE conversations
SET
resolution_type = %(resolution_type)s,
ended_at = NOW()
WHERE thread_id = %(thread_id)s
"""
@runtime_checkable
class ConversationTrackerProtocol(Protocol):
"""Protocol for tracking conversation lifecycle and metrics."""
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
"""Create conversation row if it does not already exist."""
...
async def record_turn(
self,
pool: AsyncConnectionPool,
thread_id: str,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Increment turn count and update aggregated metrics."""
...
async def resolve(
self,
pool: AsyncConnectionPool,
thread_id: str,
resolution_type: str,
) -> None:
"""Mark conversation as resolved with a resolution type."""
...
class NoOpConversationTracker:
"""No-op implementation -- used in tests or when DB is unavailable."""
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
"""Do nothing."""
async def record_turn(
self,
pool: AsyncConnectionPool,
thread_id: str,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Do nothing."""
async def resolve(
self,
pool: AsyncConnectionPool,
thread_id: str,
resolution_type: str,
) -> None:
"""Do nothing."""
class PostgresConversationTracker:
"""Postgres-backed conversation tracker."""
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
"""Insert conversation row; do nothing if already exists (ON CONFLICT DO NOTHING)."""
params = {"thread_id": thread_id}
async with pool.connection() as conn:
await conn.execute(_ENSURE_SQL, params)
async def record_turn(
self,
pool: AsyncConnectionPool,
thread_id: str,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Increment turn count, append agent if new, update token/cost totals."""
params = {
"thread_id": thread_id,
"agent_name": agent_name,
"tokens": tokens,
"cost": cost,
}
async with pool.connection() as conn:
await conn.execute(_RECORD_TURN_SQL, params)
async def resolve(
self,
pool: AsyncConnectionPool,
thread_id: str,
resolution_type: str,
) -> None:
"""Set resolution_type and ended_at on the conversation row."""
params = {
"thread_id": thread_id,
"resolution_type": resolution_type,
}
async with pool.connection() as conn:
await conn.execute(_RESOLVE_SQL, params)

View File

@@ -11,9 +11,10 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from app.analytics.api import router as analytics_router
from app.analytics.event_recorder import NoOpAnalyticsRecorder
from app.analytics.event_recorder import PostgresAnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings
from app.conversation_tracker import PostgresConversationTracker
from app.db import create_checkpointer, create_pool, setup_app_tables
from app.escalation import NoOpEscalator, WebhookEscalator
from app.graph import build_graph
@@ -76,7 +77,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
app.state.escalator = escalator
app.state.settings = settings
app.state.pool = pool
app.state.analytics_recorder = NoOpAnalyticsRecorder()
app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool)
app.state.conversation_tracker = PostgresConversationTracker()
logger.info(
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
@@ -91,13 +93,19 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await pool.close()
app = FastAPI(title="Smart Support", version="0.4.0", lifespan=lifespan)
app = FastAPI(title="Smart Support", version="0.5.0", lifespan=lifespan)
app.include_router(openapi_router)
app.include_router(replay_router)
app.include_router(analytics_router)
@app.get("/api/health")
def health_check() -> dict:
"""Health check endpoint for load balancers and monitoring."""
return {"status": "ok", "version": "0.5.0"}
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket) -> None:
await ws.accept()
@@ -107,12 +115,19 @@ async def websocket_endpoint(ws: WebSocket) -> None:
settings = app.state.settings
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
analytics_recorder = app.state.analytics_recorder
conversation_tracker = app.state.conversation_tracker
pool = app.state.pool
try:
while True:
raw_data = await ws.receive_text()
await dispatch_message(
ws, graph, session_manager, callback_handler, raw_data,
interrupt_manager=interrupt_manager,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
pool=pool,
)
except WebSocketDisconnect:
logger.info("WebSocket client disconnected")

View File

@@ -0,0 +1,3 @@
"""Tools package for smart-support backend."""
from __future__ import annotations

View File

@@ -0,0 +1,72 @@
"""Error classification and retry logic for tool calls."""
from __future__ import annotations
import asyncio
from enum import Enum
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from collections.abc import Callable
import httpx
class ErrorCategory(Enum):
"""Categories for error classification to guide retry decisions."""
RETRYABLE = "retryable"
NON_RETRYABLE = "non_retryable"
AUTH_FAILURE = "auth_failure"
TIMEOUT = "timeout"
NETWORK = "network"
def classify_error(exc: Exception) -> ErrorCategory:
"""Classify an exception into an ErrorCategory.
Rules:
- httpx.TimeoutException -> TIMEOUT
- httpx.ConnectError -> NETWORK
- httpx.HTTPStatusError 401/403 -> AUTH_FAILURE
- httpx.HTTPStatusError 429/500/502/503 -> RETRYABLE
- anything else -> NON_RETRYABLE
"""
if isinstance(exc, httpx.TimeoutException):
return ErrorCategory.TIMEOUT
if isinstance(exc, httpx.ConnectError):
return ErrorCategory.NETWORK
if isinstance(exc, httpx.HTTPStatusError):
code = exc.response.status_code
if code in (401, 403):
return ErrorCategory.AUTH_FAILURE
if code in (429, 500, 502, 503):
return ErrorCategory.RETRYABLE
return ErrorCategory.NON_RETRYABLE
return ErrorCategory.NON_RETRYABLE
async def with_retry(
fn: Callable[..., Any],
max_retries: int = 3,
base_delay: float = 1.0,
) -> Any:
"""Execute an async callable with exponential backoff for RETRYABLE errors.
Only ErrorCategory.RETRYABLE errors trigger retries. All other error
categories raise immediately after the first attempt.
"""
last_exc: Exception | None = None
for attempt in range(1, max_retries + 1):
try:
return await fn()
except Exception as exc:
category = classify_error(exc)
if category != ErrorCategory.RETRYABLE:
raise
last_exc = exc
if attempt < max_retries:
delay = base_delay * (2 ** (attempt - 1))
await asyncio.sleep(delay)
raise last_exc # type: ignore[misc]

View File

@@ -5,6 +5,8 @@ from __future__ import annotations
import json
import logging
import re
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from langchain_core.messages import HumanMessage
@@ -16,16 +18,23 @@ if TYPE_CHECKING:
from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.analytics.event_recorder import AnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler
from app.conversation_tracker import ConversationTrackerProtocol
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
logger = logging.getLogger(__name__)
MAX_MESSAGE_SIZE = 32_768 # 32 KB
MAX_CONTENT_LENGTH = 8_000 # characters
MAX_CONTENT_LENGTH = 10_000 # characters
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
# Rate limiting: max 10 messages per 10-second window, per thread
_RATE_LIMIT_MAX = 10
_RATE_LIMIT_WINDOW = 10.0
_thread_timestamps: dict[str, list[float]] = defaultdict(list)
async def handle_user_message(
ws: WebSocket,
@@ -197,6 +206,9 @@ async def dispatch_message(
callback_handler: TokenUsageCallbackHandler,
raw_data: str,
interrupt_manager: InterruptManager | None = None,
analytics_recorder: AnalyticsRecorder | None = None,
conversation_tracker: ConversationTrackerProtocol | None = None,
pool: Any = None,
) -> None:
"""Parse and route an incoming WebSocket message."""
if len(raw_data) > MAX_MESSAGE_SIZE:
@@ -205,10 +217,14 @@ async def dispatch_message(
try:
data = json.loads(raw_data)
except json.JSONDecodeError:
except (json.JSONDecodeError, ValueError):
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
return
if not isinstance(data, dict):
await _send_json(ws, {"type": "error", "message": "Invalid JSON: expected object"})
return
msg_type = data.get("type")
thread_id = data.get("thread_id", "")
@@ -222,16 +238,36 @@ async def dispatch_message(
if msg_type == "message":
content = data.get("content", "")
if not content:
if not content or not content.strip():
await _send_json(ws, {"type": "error", "message": "Missing message content"})
return
if len(content) > MAX_CONTENT_LENGTH:
await _send_json(ws, {"type": "error", "message": "Message content too long"})
return
# Rate limiting check
now = time.time()
timestamps = _thread_timestamps[thread_id]
cutoff = now - _RATE_LIMIT_WINDOW
_thread_timestamps[thread_id] = [t for t in timestamps if t >= cutoff]
if len(_thread_timestamps[thread_id]) >= _RATE_LIMIT_MAX:
await _send_json(ws, {"type": "error", "message": "Rate limit exceeded"})
return
_thread_timestamps[thread_id].append(now)
await handle_user_message(
ws, graph, session_manager, callback_handler, thread_id, content,
interrupt_manager=interrupt_manager,
)
await _fire_and_forget_tracking(
thread_id=thread_id,
pool=pool,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
agent_name=None,
tokens=0,
cost=0.0,
)
elif msg_type == "interrupt_response":
approved = data.get("approved", False)
@@ -244,6 +280,36 @@ async def dispatch_message(
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
async def _fire_and_forget_tracking(
thread_id: str,
pool: Any,
analytics_recorder: Any | None,
conversation_tracker: Any | None,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Fire-and-forget analytics/tracking; failures must NOT break chat."""
try:
if conversation_tracker is not None and pool is not None:
await conversation_tracker.ensure_conversation(pool, thread_id)
await conversation_tracker.record_turn(pool, thread_id, agent_name, tokens, cost)
except Exception:
logger.exception("Conversation tracker error for thread %s (suppressed)", thread_id)
try:
if analytics_recorder is not None:
await analytics_recorder.record(
thread_id=thread_id,
event_type="message",
agent_name=agent_name,
tokens_used=tokens,
cost_usd=cost,
)
except Exception:
logger.exception("Analytics recorder error for thread %s (suppressed)", thread_id)
def _has_interrupt(state: Any) -> bool:
"""Check if the graph state has a pending interrupt."""
tasks = getattr(state, "tasks", ())