"""FastAPI application entry point.""" from __future__ import annotations import logging from contextlib import asynccontextmanager from pathlib import Path from typing import TYPE_CHECKING from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.staticfiles import StaticFiles from app.callbacks import TokenUsageCallbackHandler from app.config import Settings from app.db import create_checkpointer, create_pool, setup_app_tables from app.escalation import NoOpEscalator, WebhookEscalator from app.graph import build_graph from app.intent import LLMIntentClassifier from app.interrupt_manager import InterruptManager from app.llm import create_llm from app.registry import AgentRegistry from app.session_manager import SessionManager from app.ws_handler import dispatch_message if TYPE_CHECKING: from collections.abc import AsyncGenerator logger = logging.getLogger(__name__) AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml" FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist" @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: settings = Settings() pool = await create_pool(settings) checkpointer = await create_checkpointer(pool) await setup_app_tables(pool) # Load agents from template or default YAML if settings.template_name: registry = AgentRegistry.load_template(settings.template_name) else: registry = AgentRegistry.load(AGENTS_YAML) llm = create_llm(settings) intent_classifier = LLMIntentClassifier(llm) graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier) session_manager = SessionManager( session_ttl_seconds=settings.session_ttl_minutes * 60, ) interrupt_manager = InterruptManager( ttl_seconds=settings.interrupt_ttl_minutes * 60, ) # Configure escalation if settings.webhook_url: escalator = WebhookEscalator( url=settings.webhook_url, timeout_seconds=settings.webhook_timeout_seconds, max_retries=settings.webhook_max_retries, ) else: escalator = NoOpEscalator() app.state.graph = graph app.state.session_manager = session_manager app.state.interrupt_manager = interrupt_manager app.state.escalator = escalator app.state.settings = settings app.state.pool = pool logger.info( "Smart Support started: %d agents loaded, LLM=%s/%s, template=%s", len(registry), settings.llm_provider, settings.llm_model, settings.template_name or "(default)", ) yield await pool.close() app = FastAPI(title="Smart Support", version="0.2.0", lifespan=lifespan) @app.websocket("/ws") async def websocket_endpoint(ws: WebSocket) -> None: await ws.accept() graph = app.state.graph session_manager = app.state.session_manager interrupt_manager = app.state.interrupt_manager settings = app.state.settings callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model) try: while True: raw_data = await ws.receive_text() await dispatch_message( ws, graph, session_manager, callback_handler, raw_data, interrupt_manager=interrupt_manager, ) except WebSocketDisconnect: logger.info("WebSocket client disconnected") if FRONTEND_DIST.is_dir(): app.mount("/", StaticFiles(directory=str(FRONTEND_DIST), html=True), name="frontend")