"""FastAPI application entry point.""" from __future__ import annotations import asyncio import contextlib from contextlib import asynccontextmanager from pathlib import Path from typing import TYPE_CHECKING from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from app.analytics.api import router as analytics_router from app.analytics.event_recorder import PostgresAnalyticsRecorder from app.api_utils import envelope from app.callbacks import TokenUsageCallbackHandler from app.config import Settings from app.conversation_tracker import PostgresConversationTracker from app.db import create_checkpointer, create_pool, run_alembic_migrations from app.escalation import NoOpEscalator, WebhookEscalator from app.graph import build_graph from app.intent import LLMIntentClassifier from app.interrupt_manager import InterruptManager from app.llm import create_llm from app.logging_config import configure_logging from app.openapi.review_api import router as openapi_router from app.registry import AgentRegistry from app.replay.api import router as replay_router from app.session_manager import SessionManager from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message if TYPE_CHECKING: from collections.abc import AsyncGenerator import structlog logger = structlog.get_logger() AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml" FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist" async def _interrupt_cleanup_loop( interrupt_manager: InterruptManager, interval: int = 60, ) -> None: """Periodically remove expired interrupts in the background. Runs until cancelled. Catches all exceptions to prevent the task from dying unexpectedly. """ while True: await asyncio.sleep(interval) try: expired = interrupt_manager.cleanup_expired() if expired: logger.info( "Cleaned up %d expired interrupt(s)", len(expired), ) except Exception: logger.exception("Error during interrupt cleanup") @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: settings = Settings() configure_logging(settings.log_format) pool = await create_pool(settings) checkpointer = await create_checkpointer(pool) run_alembic_migrations(settings.database_url) # Load agents from template or default YAML if settings.template_name: registry = AgentRegistry.load_template(settings.template_name) else: registry = AgentRegistry.load(AGENTS_YAML) llm = create_llm(settings) intent_classifier = LLMIntentClassifier(llm) graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier) session_manager = SessionManager( session_ttl_seconds=settings.session_ttl_minutes * 60, ) interrupt_manager = InterruptManager( ttl_seconds=settings.interrupt_ttl_minutes * 60, ) # Configure escalation if settings.webhook_url: escalator = WebhookEscalator( url=settings.webhook_url, timeout_seconds=settings.webhook_timeout_seconds, max_retries=settings.webhook_max_retries, ) else: escalator = NoOpEscalator() app.state.graph_ctx = graph_ctx app.state.session_manager = session_manager app.state.interrupt_manager = interrupt_manager app.state.escalator = escalator app.state.settings = settings app.state.pool = pool app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool) app.state.conversation_tracker = PostgresConversationTracker() logger.info( "Smart Support started: %d agents loaded, LLM=%s/%s, template=%s", len(registry), settings.llm_provider, settings.llm_model, settings.template_name or "(default)", ) cleanup_task = asyncio.create_task( _interrupt_cleanup_loop(interrupt_manager), ) yield cleanup_task.cancel() with contextlib.suppress(asyncio.CancelledError): await cleanup_task await pool.close() _VERSION = "0.6.0" app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan) app.include_router(openapi_router) app.include_router(replay_router) app.include_router(analytics_router) @app.exception_handler(HTTPException) async def http_exception_handler(request, exc): # type: ignore[no-untyped-def] """Wrap HTTPException in standard envelope format.""" return JSONResponse( status_code=exc.status_code, content=envelope(None, success=False, error=exc.detail), ) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def] """Wrap validation errors in standard envelope format.""" return JSONResponse( status_code=422, content=envelope(None, success=False, error=str(exc)), ) @app.exception_handler(Exception) async def general_exception_handler(request, exc): # type: ignore[no-untyped-def] """Catch-all handler -- never leak stack traces.""" logger.exception("Unhandled exception: %s", exc) return JSONResponse( status_code=500, content=envelope(None, success=False, error="Internal server error"), ) @app.get("/api/v1/health") def health_check() -> dict: """Health check endpoint for load balancers and monitoring.""" return {"status": "ok", "version": _VERSION} @app.websocket("/ws") async def websocket_endpoint( ws: WebSocket, token: str | None = Query(default=None), ) -> None: settings = app.state.settings # Verify WebSocket token when admin_api_key is configured if settings.admin_api_key: import secrets as _secrets if token is None or not _secrets.compare_digest(token, settings.admin_api_key): await ws.close(code=4001, reason="Unauthorized") return await ws.accept() callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model) ws_ctx = WebSocketContext( graph_ctx=app.state.graph_ctx, session_manager=app.state.session_manager, callback_handler=callback_handler, interrupt_manager=app.state.interrupt_manager, analytics_recorder=app.state.analytics_recorder, conversation_tracker=app.state.conversation_tracker, pool=app.state.pool, ) try: while True: raw_data = await ws.receive_text() await dispatch_message(ws, ws_ctx, raw_data) except WebSocketDisconnect: logger.info("WebSocket client disconnected") if FRONTEND_DIST.is_dir(): app.mount("/", StaticFiles(directory=str(FRONTEND_DIST), html=True), name="frontend")