- API versioning: all REST endpoints prefixed with /api/v1/ - Structured logging: replaced stdlib logging with structlog (console/JSON modes) - Alembic migrations: versioned DB schema with initial migration - Error standardization: global exception handlers for consistent envelope format - Interrupt cleanup: asyncio background task for expired interrupt removal - Integration tests: +30 tests (analytics, replay, openapi, error, session APIs) - Frontend tests: +57 tests (all components, pages, useWebSocket hook) - Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
216 lines
6.8 KiB
Python
216 lines
6.8 KiB
Python
"""FastAPI application entry point."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from app.analytics.api import router as analytics_router
|
|
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
|
from app.api_utils import envelope
|
|
from app.callbacks import TokenUsageCallbackHandler
|
|
from app.config import Settings
|
|
from app.conversation_tracker import PostgresConversationTracker
|
|
from app.db import create_checkpointer, create_pool, run_alembic_migrations
|
|
from app.escalation import NoOpEscalator, WebhookEscalator
|
|
from app.graph import build_graph
|
|
from app.intent import LLMIntentClassifier
|
|
from app.interrupt_manager import InterruptManager
|
|
from app.llm import create_llm
|
|
from app.logging_config import configure_logging
|
|
from app.openapi.review_api import router as openapi_router
|
|
from app.registry import AgentRegistry
|
|
from app.replay.api import router as replay_router
|
|
from app.session_manager import SessionManager
|
|
from app.ws_context import WebSocketContext
|
|
from app.ws_handler import dispatch_message
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import AsyncGenerator
|
|
|
|
import structlog
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
|
|
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
|
|
|
|
|
|
async def _interrupt_cleanup_loop(
|
|
interrupt_manager: InterruptManager,
|
|
interval: int = 60,
|
|
) -> None:
|
|
"""Periodically remove expired interrupts in the background.
|
|
|
|
Runs until cancelled. Catches all exceptions to prevent the task
|
|
from dying unexpectedly.
|
|
"""
|
|
while True:
|
|
await asyncio.sleep(interval)
|
|
try:
|
|
expired = interrupt_manager.cleanup_expired()
|
|
if expired:
|
|
logger.info(
|
|
"Cleaned up %d expired interrupt(s)",
|
|
len(expired),
|
|
)
|
|
except Exception:
|
|
logger.exception("Error during interrupt cleanup")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
settings = Settings()
|
|
configure_logging(settings.log_format)
|
|
|
|
pool = await create_pool(settings)
|
|
checkpointer = await create_checkpointer(pool)
|
|
run_alembic_migrations(settings.database_url)
|
|
|
|
# Load agents from template or default YAML
|
|
if settings.template_name:
|
|
registry = AgentRegistry.load_template(settings.template_name)
|
|
else:
|
|
registry = AgentRegistry.load(AGENTS_YAML)
|
|
|
|
llm = create_llm(settings)
|
|
intent_classifier = LLMIntentClassifier(llm)
|
|
graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
|
|
|
session_manager = SessionManager(
|
|
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
|
)
|
|
interrupt_manager = InterruptManager(
|
|
ttl_seconds=settings.interrupt_ttl_minutes * 60,
|
|
)
|
|
|
|
# Configure escalation
|
|
if settings.webhook_url:
|
|
escalator = WebhookEscalator(
|
|
url=settings.webhook_url,
|
|
timeout_seconds=settings.webhook_timeout_seconds,
|
|
max_retries=settings.webhook_max_retries,
|
|
)
|
|
else:
|
|
escalator = NoOpEscalator()
|
|
|
|
app.state.graph_ctx = graph_ctx
|
|
app.state.session_manager = session_manager
|
|
app.state.interrupt_manager = interrupt_manager
|
|
app.state.escalator = escalator
|
|
app.state.settings = settings
|
|
app.state.pool = pool
|
|
app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool)
|
|
app.state.conversation_tracker = PostgresConversationTracker()
|
|
|
|
logger.info(
|
|
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
|
|
len(registry),
|
|
settings.llm_provider,
|
|
settings.llm_model,
|
|
settings.template_name or "(default)",
|
|
)
|
|
|
|
cleanup_task = asyncio.create_task(
|
|
_interrupt_cleanup_loop(interrupt_manager),
|
|
)
|
|
|
|
yield
|
|
|
|
cleanup_task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await cleanup_task
|
|
|
|
await pool.close()
|
|
|
|
|
|
_VERSION = "0.6.0"
|
|
|
|
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
|
|
|
|
app.include_router(openapi_router)
|
|
app.include_router(replay_router)
|
|
app.include_router(analytics_router)
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
|
"""Wrap HTTPException in standard envelope format."""
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content=envelope(None, success=False, error=exc.detail),
|
|
)
|
|
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
|
"""Wrap validation errors in standard envelope format."""
|
|
return JSONResponse(
|
|
status_code=422,
|
|
content=envelope(None, success=False, error=str(exc)),
|
|
)
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
|
"""Catch-all handler -- never leak stack traces."""
|
|
logger.exception("Unhandled exception: %s", exc)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content=envelope(None, success=False, error="Internal server error"),
|
|
)
|
|
|
|
|
|
@app.get("/api/v1/health")
|
|
def health_check() -> dict:
|
|
"""Health check endpoint for load balancers and monitoring."""
|
|
return {"status": "ok", "version": _VERSION}
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(
|
|
ws: WebSocket,
|
|
token: str | None = Query(default=None),
|
|
) -> None:
|
|
settings = app.state.settings
|
|
|
|
# Verify WebSocket token when admin_api_key is configured
|
|
if settings.admin_api_key:
|
|
import secrets as _secrets
|
|
|
|
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
|
|
await ws.close(code=4001, reason="Unauthorized")
|
|
return
|
|
|
|
await ws.accept()
|
|
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
|
|
|
ws_ctx = WebSocketContext(
|
|
graph_ctx=app.state.graph_ctx,
|
|
session_manager=app.state.session_manager,
|
|
callback_handler=callback_handler,
|
|
interrupt_manager=app.state.interrupt_manager,
|
|
analytics_recorder=app.state.analytics_recorder,
|
|
conversation_tracker=app.state.conversation_tracker,
|
|
pool=app.state.pool,
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
raw_data = await ws.receive_text()
|
|
await dispatch_message(ws, ws_ctx, raw_data)
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket client disconnected")
|
|
|
|
|
|
if FRONTEND_DIST.is_dir():
|
|
app.mount("/", StaticFiles(directory=str(FRONTEND_DIST), html=True), name="frontend")
|