Files
smart-support/backend/app/main.py
Yaojia Wang f0699436c5 refactor: engineering improvements -- API versioning, structured logging, Alembic, error standardization, test coverage
- API versioning: all REST endpoints prefixed with /api/v1/
- Structured logging: replaced stdlib logging with structlog (console/JSON modes)
- Alembic migrations: versioned DB schema with initial migration
- Error standardization: global exception handlers for consistent envelope format
- Interrupt cleanup: asyncio background task for expired interrupt removal
- Integration tests: +30 tests (analytics, replay, openapi, error, session APIs)
- Frontend tests: +57 tests (all components, pages, useWebSocket hook)
- Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
2026-04-06 23:19:29 +02:00

216 lines
6.8 KiB
Python

"""FastAPI application entry point."""
from __future__ import annotations
import asyncio
import contextlib
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from app.analytics.api import router as analytics_router
from app.analytics.event_recorder import PostgresAnalyticsRecorder
from app.api_utils import envelope
from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings
from app.conversation_tracker import PostgresConversationTracker
from app.db import create_checkpointer, create_pool, run_alembic_migrations
from app.escalation import NoOpEscalator, WebhookEscalator
from app.graph import build_graph
from app.intent import LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.llm import create_llm
from app.logging_config import configure_logging
from app.openapi.review_api import router as openapi_router
from app.registry import AgentRegistry
from app.replay.api import router as replay_router
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
import structlog
logger = structlog.get_logger()
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
async def _interrupt_cleanup_loop(
interrupt_manager: InterruptManager,
interval: int = 60,
) -> None:
"""Periodically remove expired interrupts in the background.
Runs until cancelled. Catches all exceptions to prevent the task
from dying unexpectedly.
"""
while True:
await asyncio.sleep(interval)
try:
expired = interrupt_manager.cleanup_expired()
if expired:
logger.info(
"Cleaned up %d expired interrupt(s)",
len(expired),
)
except Exception:
logger.exception("Error during interrupt cleanup")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
settings = Settings()
configure_logging(settings.log_format)
pool = await create_pool(settings)
checkpointer = await create_checkpointer(pool)
run_alembic_migrations(settings.database_url)
# Load agents from template or default YAML
if settings.template_name:
registry = AgentRegistry.load_template(settings.template_name)
else:
registry = AgentRegistry.load(AGENTS_YAML)
llm = create_llm(settings)
intent_classifier = LLMIntentClassifier(llm)
graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
session_manager = SessionManager(
session_ttl_seconds=settings.session_ttl_minutes * 60,
)
interrupt_manager = InterruptManager(
ttl_seconds=settings.interrupt_ttl_minutes * 60,
)
# Configure escalation
if settings.webhook_url:
escalator = WebhookEscalator(
url=settings.webhook_url,
timeout_seconds=settings.webhook_timeout_seconds,
max_retries=settings.webhook_max_retries,
)
else:
escalator = NoOpEscalator()
app.state.graph_ctx = graph_ctx
app.state.session_manager = session_manager
app.state.interrupt_manager = interrupt_manager
app.state.escalator = escalator
app.state.settings = settings
app.state.pool = pool
app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool)
app.state.conversation_tracker = PostgresConversationTracker()
logger.info(
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
len(registry),
settings.llm_provider,
settings.llm_model,
settings.template_name or "(default)",
)
cleanup_task = asyncio.create_task(
_interrupt_cleanup_loop(interrupt_manager),
)
yield
cleanup_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await cleanup_task
await pool.close()
_VERSION = "0.6.0"
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
app.include_router(openapi_router)
app.include_router(replay_router)
app.include_router(analytics_router)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap HTTPException in standard envelope format."""
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap validation errors in standard envelope format."""
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Catch-all handler -- never leak stack traces."""
logger.exception("Unhandled exception: %s", exc)
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
@app.get("/api/v1/health")
def health_check() -> dict:
"""Health check endpoint for load balancers and monitoring."""
return {"status": "ok", "version": _VERSION}
@app.websocket("/ws")
async def websocket_endpoint(
ws: WebSocket,
token: str | None = Query(default=None),
) -> None:
settings = app.state.settings
# Verify WebSocket token when admin_api_key is configured
if settings.admin_api_key:
import secrets as _secrets
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
await ws.close(code=4001, reason="Unauthorized")
return
await ws.accept()
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
ws_ctx = WebSocketContext(
graph_ctx=app.state.graph_ctx,
session_manager=app.state.session_manager,
callback_handler=callback_handler,
interrupt_manager=app.state.interrupt_manager,
analytics_recorder=app.state.analytics_recorder,
conversation_tracker=app.state.conversation_tracker,
pool=app.state.pool,
)
try:
while True:
raw_data = await ws.receive_text()
await dispatch_message(ws, ws_ctx, raw_data)
except WebSocketDisconnect:
logger.info("WebSocket client disconnected")
if FRONTEND_DIST.is_dir():
app.mount("/", StaticFiles(directory=str(FRONTEND_DIST), html=True), name="frontend")