"""FastAPI application entry point.""" from __future__ import annotations import logging from contextlib import asynccontextmanager from pathlib import Path from typing import TYPE_CHECKING from fastapi import Depends, FastAPI, Query, WebSocket, WebSocketDisconnect from fastapi.staticfiles import StaticFiles from app.analytics.api import router as analytics_router from app.analytics.event_recorder import PostgresAnalyticsRecorder from app.callbacks import TokenUsageCallbackHandler from app.config import Settings from app.conversation_tracker import PostgresConversationTracker from app.db import create_checkpointer, create_pool, setup_app_tables from app.escalation import NoOpEscalator, WebhookEscalator from app.graph import build_graph from app.intent import LLMIntentClassifier from app.interrupt_manager import InterruptManager from app.llm import create_llm from app.openapi.review_api import router as openapi_router from app.registry import AgentRegistry from app.replay.api import router as replay_router from app.session_manager import SessionManager from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message if TYPE_CHECKING: from collections.abc import AsyncGenerator logger = logging.getLogger(__name__) AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml" FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist" @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: settings = Settings() pool = await create_pool(settings) checkpointer = await create_checkpointer(pool) await setup_app_tables(pool) # Load agents from template or default YAML if settings.template_name: registry = AgentRegistry.load_template(settings.template_name) else: registry = AgentRegistry.load(AGENTS_YAML) llm = create_llm(settings) intent_classifier = LLMIntentClassifier(llm) graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier) session_manager = SessionManager( session_ttl_seconds=settings.session_ttl_minutes * 60, ) interrupt_manager = InterruptManager( ttl_seconds=settings.interrupt_ttl_minutes * 60, ) # Configure escalation if settings.webhook_url: escalator = WebhookEscalator( url=settings.webhook_url, timeout_seconds=settings.webhook_timeout_seconds, max_retries=settings.webhook_max_retries, ) else: escalator = NoOpEscalator() app.state.graph_ctx = graph_ctx app.state.session_manager = session_manager app.state.interrupt_manager = interrupt_manager app.state.escalator = escalator app.state.settings = settings app.state.pool = pool app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool) app.state.conversation_tracker = PostgresConversationTracker() logger.info( "Smart Support started: %d agents loaded, LLM=%s/%s, template=%s", len(registry), settings.llm_provider, settings.llm_model, settings.template_name or "(default)", ) yield await pool.close() _VERSION = "0.6.0" app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan) app.include_router(openapi_router) app.include_router(replay_router) app.include_router(analytics_router) @app.get("/api/health") def health_check() -> dict: """Health check endpoint for load balancers and monitoring.""" return {"status": "ok", "version": _VERSION} @app.websocket("/ws") async def websocket_endpoint( ws: WebSocket, token: str | None = Query(default=None), ) -> None: settings = app.state.settings # Verify WebSocket token when admin_api_key is configured if settings.admin_api_key: import secrets as _secrets if token is None or not _secrets.compare_digest(token, settings.admin_api_key): await ws.close(code=4001, reason="Unauthorized") return await ws.accept() callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model) ws_ctx = WebSocketContext( graph_ctx=app.state.graph_ctx, session_manager=app.state.session_manager, callback_handler=callback_handler, interrupt_manager=app.state.interrupt_manager, analytics_recorder=app.state.analytics_recorder, conversation_tracker=app.state.conversation_tracker, pool=app.state.pool, ) try: while True: raw_data = await ws.receive_text() await dispatch_message(ws, ws_ctx, raw_data) except WebSocketDisconnect: logger.info("WebSocket client disconnected") if FRONTEND_DIST.is_dir(): app.mount("/", StaticFiles(directory=str(FRONTEND_DIST), html=True), name="frontend")