refactor: fix architectural issues across frontend and backend
Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
@@ -4,16 +4,22 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
|
||||
from app.analytics.queries import get_analytics
|
||||
from app.api_utils import envelope
|
||||
from app.auth import require_admin_api_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
router = APIRouter(prefix="/api/analytics", tags=["analytics"])
|
||||
router = APIRouter(
|
||||
prefix="/api/analytics",
|
||||
tags=["analytics"],
|
||||
dependencies=[Depends(require_admin_api_key)],
|
||||
)
|
||||
|
||||
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
||||
_DEFAULT_RANGE = "7d"
|
||||
@@ -25,10 +31,6 @@ async def _get_pool(request: Request) -> AsyncConnectionPool:
|
||||
return request.app.state.pool
|
||||
|
||||
|
||||
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
||||
return {"success": success, "data": data, "error": error}
|
||||
|
||||
|
||||
def _parse_range(range_str: str) -> int:
|
||||
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
|
||||
match = _RANGE_PATTERN.match(range_str)
|
||||
@@ -55,4 +57,4 @@ async def analytics(
|
||||
range_days = _parse_range(range)
|
||||
pool = await _get_pool(request)
|
||||
result = await get_analytics(pool, range_days=range_days)
|
||||
return _envelope(asdict(result))
|
||||
return envelope(asdict(result))
|
||||
|
||||
10
backend/app/api_utils.py
Normal file
10
backend/app/api_utils.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Shared API response helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
||||
"""Wrap API response data in a standard envelope format."""
|
||||
return {"success": success, "data": data, "error": error}
|
||||
72
backend/app/auth.py
Normal file
72
backend/app/auth.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""API key authentication for admin endpoints and WebSocket connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
def _get_admin_api_key(request: Request) -> str:
|
||||
"""Retrieve the configured admin API key from app settings.
|
||||
|
||||
Returns empty string if settings are not configured (test/dev mode).
|
||||
"""
|
||||
settings = getattr(request.app.state, "settings", None)
|
||||
if settings is None:
|
||||
return ""
|
||||
key = getattr(settings, "admin_api_key", "")
|
||||
return key if isinstance(key, str) else ""
|
||||
|
||||
|
||||
async def require_admin_api_key(
|
||||
request: Request,
|
||||
api_key: Annotated[str | None, Depends(_API_KEY_HEADER)] = None,
|
||||
) -> None:
|
||||
"""Dependency that enforces API key authentication on admin endpoints.
|
||||
|
||||
Skips validation when no admin_api_key is configured (dev mode).
|
||||
"""
|
||||
expected = _get_admin_api_key(request)
|
||||
if not expected:
|
||||
return
|
||||
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing X-API-Key header",
|
||||
)
|
||||
if not secrets.compare_digest(api_key, expected):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
|
||||
async def verify_ws_token(
|
||||
ws: WebSocket,
|
||||
token: str | None = Query(default=None),
|
||||
) -> None:
|
||||
"""Verify WebSocket connection token from query parameter.
|
||||
|
||||
Skips validation when no admin_api_key is configured (dev mode).
|
||||
Usage: ws://host/ws?token=<api_key>
|
||||
"""
|
||||
settings = ws.app.state.settings
|
||||
expected = settings.admin_api_key
|
||||
if not expected:
|
||||
return
|
||||
|
||||
if token is None or not secrets.compare_digest(token, expected):
|
||||
await ws.close(code=4001, reason="Unauthorized")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or missing WebSocket token",
|
||||
)
|
||||
@@ -32,6 +32,8 @@ class Settings(BaseSettings):
|
||||
|
||||
template_name: str = ""
|
||||
|
||||
admin_api_key: str = ""
|
||||
|
||||
anthropic_api_key: str = ""
|
||||
openai_api_key: str = ""
|
||||
azure_openai_api_key: str = ""
|
||||
|
||||
@@ -51,6 +51,15 @@ CREATE TABLE IF NOT EXISTS analytics_events (
|
||||
);
|
||||
"""
|
||||
|
||||
_SESSIONS_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
thread_id TEXT PRIMARY KEY,
|
||||
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
_CONVERSATIONS_MIGRATION_DDL = """
|
||||
ALTER TABLE conversations
|
||||
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
||||
@@ -84,5 +93,6 @@ async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_CONVERSATIONS_DDL)
|
||||
await conn.execute(_INTERRUPTS_DDL)
|
||||
await conn.execute(_SESSIONS_DDL)
|
||||
await conn.execute(_ANALYTICS_EVENTS_DDL)
|
||||
await conn.execute(_CONVERSATIONS_MIGRATION_DDL)
|
||||
|
||||
@@ -9,13 +9,13 @@ from langchain.agents import create_agent
|
||||
from langgraph_supervisor import create_supervisor
|
||||
|
||||
from app.agents import get_tools_by_names
|
||||
from app.graph_context import GraphContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.intent import ClassificationResult, IntentClassifier
|
||||
from app.intent import IntentClassifier
|
||||
from app.registry import AgentRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -75,12 +75,11 @@ def build_graph(
|
||||
llm: BaseChatModel,
|
||||
checkpointer: AsyncPostgresSaver,
|
||||
intent_classifier: IntentClassifier | None = None,
|
||||
) -> CompiledStateGraph:
|
||||
) -> GraphContext:
|
||||
"""Build and compile the LangGraph supervisor graph.
|
||||
|
||||
If an intent_classifier is provided, the supervisor prompt is enhanced
|
||||
with agent descriptions for better routing. The classifier is stored
|
||||
for use by the routing layer (ws_handler).
|
||||
Returns a GraphContext that bundles the compiled graph with its
|
||||
associated registry and intent classifier.
|
||||
"""
|
||||
agent_nodes = build_agent_nodes(registry, llm)
|
||||
agent_descriptions = _format_agent_descriptions(registry)
|
||||
@@ -94,28 +93,10 @@ def build_graph(
|
||||
output_mode="full_history",
|
||||
)
|
||||
|
||||
graph = workflow.compile(checkpointer=checkpointer)
|
||||
compiled = workflow.compile(checkpointer=checkpointer)
|
||||
|
||||
# Attach classifier and registry to graph for use by ws_handler
|
||||
graph.intent_classifier = intent_classifier # type: ignore[attr-defined]
|
||||
graph.agent_registry = registry # type: ignore[attr-defined]
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def classify_intent(
|
||||
graph: CompiledStateGraph,
|
||||
message: str,
|
||||
) -> ClassificationResult | None:
|
||||
"""Classify user intent using the graph's attached classifier.
|
||||
|
||||
Returns None if no classifier is configured.
|
||||
"""
|
||||
classifier = getattr(graph, "intent_classifier", None)
|
||||
registry = getattr(graph, "agent_registry", None)
|
||||
|
||||
if classifier is None or registry is None:
|
||||
return None
|
||||
|
||||
agents = registry.list_agents()
|
||||
return await classifier.classify(message, agents)
|
||||
return GraphContext(
|
||||
graph=compiled,
|
||||
registry=registry,
|
||||
intent_classifier=intent_classifier,
|
||||
)
|
||||
|
||||
36
backend/app/graph_context.py
Normal file
36
backend/app/graph_context.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""GraphContext -- typed wrapper around the compiled graph and its dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.intent import ClassificationResult, IntentClassifier
|
||||
from app.registry import AgentRegistry
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphContext:
|
||||
"""Bundles the compiled LangGraph graph with its associated services.
|
||||
|
||||
Replaces the previous pattern of monkey-patching attributes onto the
|
||||
third-party CompiledStateGraph instance.
|
||||
"""
|
||||
|
||||
graph: CompiledStateGraph
|
||||
registry: AgentRegistry
|
||||
intent_classifier: IntentClassifier | None = None
|
||||
|
||||
async def classify_intent(self, message: str) -> ClassificationResult | None:
|
||||
"""Classify user intent using the attached classifier.
|
||||
|
||||
Returns None if no classifier is configured.
|
||||
"""
|
||||
if self.intent_classifier is None:
|
||||
return None
|
||||
|
||||
agents = self.registry.list_agents()
|
||||
return await self.intent_classifier.classify(message, agents)
|
||||
@@ -1,10 +1,18 @@
|
||||
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration."""
|
||||
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.
|
||||
|
||||
Provides both in-memory (InterruptManager) and PostgreSQL-backed
|
||||
(PgInterruptManager) implementations behind a common Protocol.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -28,8 +36,32 @@ class InterruptStatus:
|
||||
record: InterruptRecord
|
||||
|
||||
|
||||
class InterruptManagerProtocol(Protocol):
|
||||
"""Protocol for interrupt TTL management."""
|
||||
|
||||
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
|
||||
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
|
||||
def resolve(self, thread_id: str) -> None: ...
|
||||
def has_pending(self, thread_id: str) -> bool: ...
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
|
||||
|
||||
|
||||
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
|
||||
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||
return {
|
||||
"type": "interrupt_expired",
|
||||
"thread_id": expired_record.thread_id,
|
||||
"action": expired_record.action,
|
||||
"message": (
|
||||
f"The approval request for '{expired_record.action}' has expired "
|
||||
f"after {expired_record.ttl_seconds // 60} minutes. "
|
||||
f"Would you like to try again?"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class InterruptManager:
|
||||
"""Manages interrupt TTL with auto-expiration.
|
||||
"""In-memory interrupt manager for single-worker development.
|
||||
|
||||
Complements SessionManager -- this tracks interrupt-specific TTL
|
||||
while SessionManager handles session-level TTL.
|
||||
@@ -62,11 +94,9 @@ class InterruptManager:
|
||||
record = self._interrupts.get(thread_id)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
elapsed = time.time() - record.created_at
|
||||
remaining = max(0.0, record.ttl_seconds - elapsed)
|
||||
is_expired = elapsed > record.ttl_seconds
|
||||
|
||||
return InterruptStatus(
|
||||
is_expired=is_expired,
|
||||
remaining_seconds=remaining,
|
||||
@@ -84,28 +114,17 @@ class InterruptManager:
|
||||
now = time.time()
|
||||
expired: list[InterruptRecord] = []
|
||||
active: dict[str, InterruptRecord] = {}
|
||||
|
||||
for thread_id, record in self._interrupts.items():
|
||||
if now - record.created_at > record.ttl_seconds:
|
||||
expired.append(record)
|
||||
else:
|
||||
active[thread_id] = record
|
||||
|
||||
self._interrupts = active
|
||||
return tuple(expired)
|
||||
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||
return {
|
||||
"type": "interrupt_expired",
|
||||
"thread_id": expired_record.thread_id,
|
||||
"action": expired_record.action,
|
||||
"message": (
|
||||
f"The approval request for '{expired_record.action}' has expired "
|
||||
f"after {expired_record.ttl_seconds // 60} minutes. "
|
||||
f"Would you like to try again?"
|
||||
),
|
||||
}
|
||||
return _build_retry_prompt(expired_record)
|
||||
|
||||
def has_pending(self, thread_id: str) -> bool:
|
||||
"""Check if a thread has a pending (non-expired) interrupt."""
|
||||
@@ -113,3 +132,137 @@ class InterruptManager:
|
||||
if status is None:
|
||||
return False
|
||||
return not status.is_expired
|
||||
|
||||
|
||||
# Alias for explicit naming
|
||||
InMemoryInterruptManager = InterruptManager
|
||||
|
||||
|
||||
class PgInterruptManager:
|
||||
"""PostgreSQL-backed interrupt manager for multi-worker production.
|
||||
|
||||
Uses the existing active_interrupts table defined in db.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
ttl_seconds: int = 1800,
|
||||
) -> None:
|
||||
self._pool = pool
|
||||
self._ttl_seconds = ttl_seconds
|
||||
|
||||
def register(
|
||||
self,
|
||||
thread_id: str,
|
||||
action: str,
|
||||
params: dict,
|
||||
) -> InterruptRecord:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._register(thread_id, action, params)
|
||||
)
|
||||
|
||||
async def _register(
|
||||
self, thread_id: str, action: str, params: dict
|
||||
) -> InterruptRecord:
|
||||
import json
|
||||
|
||||
record = InterruptRecord(
|
||||
interrupt_id=uuid.uuid4().hex,
|
||||
thread_id=thread_id,
|
||||
action=action,
|
||||
params=dict(params),
|
||||
created_at=time.time(),
|
||||
ttl_seconds=self._ttl_seconds,
|
||||
)
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
|
||||
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
|
||||
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
|
||||
DO UPDATE SET
|
||||
interrupt_id = %(iid)s,
|
||||
action = %(action)s,
|
||||
params = %(params)s,
|
||||
created_at = NOW(),
|
||||
resolved_at = NULL
|
||||
""",
|
||||
{
|
||||
"iid": record.interrupt_id,
|
||||
"tid": thread_id,
|
||||
"action": action,
|
||||
"params": json.dumps(params),
|
||||
},
|
||||
)
|
||||
return record
|
||||
|
||||
def check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._check_status(thread_id)
|
||||
)
|
||||
|
||||
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||
async with self._pool.connection() as conn:
|
||||
cursor = await conn.execute(
|
||||
"""
|
||||
SELECT interrupt_id, action, params, created_at
|
||||
FROM active_interrupts
|
||||
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
""",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
created_at = row["created_at"].timestamp()
|
||||
elapsed = time.time() - created_at
|
||||
remaining = max(0.0, self._ttl_seconds - elapsed)
|
||||
is_expired = elapsed > self._ttl_seconds
|
||||
|
||||
record = InterruptRecord(
|
||||
interrupt_id=row["interrupt_id"],
|
||||
thread_id=thread_id,
|
||||
action=row["action"],
|
||||
params=row["params"] if isinstance(row["params"], dict) else {},
|
||||
created_at=created_at,
|
||||
ttl_seconds=self._ttl_seconds,
|
||||
)
|
||||
|
||||
return InterruptStatus(
|
||||
is_expired=is_expired,
|
||||
remaining_seconds=remaining,
|
||||
record=record,
|
||||
)
|
||||
|
||||
def resolve(self, thread_id: str) -> None:
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
|
||||
|
||||
async def _resolve(self, thread_id: str) -> None:
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE active_interrupts
|
||||
SET resolved_at = NOW(), resolution = 'resolved'
|
||||
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||
""",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||
return _build_retry_prompt(expired_record)
|
||||
|
||||
def has_pending(self, thread_id: str) -> bool:
|
||||
status = self.check_status(thread_id)
|
||||
if status is None:
|
||||
return False
|
||||
return not status.is_expired
|
||||
|
||||
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi import Depends, FastAPI, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.analytics.api import router as analytics_router
|
||||
@@ -25,6 +25,7 @@ from app.openapi.review_api import router as openapi_router
|
||||
from app.registry import AgentRegistry
|
||||
from app.replay.api import router as replay_router
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -52,7 +53,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
llm = create_llm(settings)
|
||||
intent_classifier = LLMIntentClassifier(llm)
|
||||
graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
||||
graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
||||
|
||||
session_manager = SessionManager(
|
||||
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
||||
@@ -71,7 +72,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
else:
|
||||
escalator = NoOpEscalator()
|
||||
|
||||
app.state.graph = graph
|
||||
app.state.graph_ctx = graph_ctx
|
||||
app.state.session_manager = session_manager
|
||||
app.state.interrupt_manager = interrupt_manager
|
||||
app.state.escalator = escalator
|
||||
@@ -93,7 +94,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await pool.close()
|
||||
|
||||
|
||||
_VERSION = "0.5.0"
|
||||
_VERSION = "0.6.0"
|
||||
|
||||
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
|
||||
|
||||
@@ -109,28 +110,37 @@ def health_check() -> dict:
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket) -> None:
|
||||
await ws.accept()
|
||||
graph = app.state.graph
|
||||
session_manager = app.state.session_manager
|
||||
interrupt_manager = app.state.interrupt_manager
|
||||
async def websocket_endpoint(
|
||||
ws: WebSocket,
|
||||
token: str | None = Query(default=None),
|
||||
) -> None:
|
||||
settings = app.state.settings
|
||||
|
||||
# Verify WebSocket token when admin_api_key is configured
|
||||
if settings.admin_api_key:
|
||||
import secrets as _secrets
|
||||
|
||||
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
|
||||
await ws.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await ws.accept()
|
||||
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
||||
|
||||
analytics_recorder = app.state.analytics_recorder
|
||||
conversation_tracker = app.state.conversation_tracker
|
||||
pool = app.state.pool
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=app.state.graph_ctx,
|
||||
session_manager=app.state.session_manager,
|
||||
callback_handler=callback_handler,
|
||||
interrupt_manager=app.state.interrupt_manager,
|
||||
analytics_recorder=app.state.analytics_recorder,
|
||||
conversation_tracker=app.state.conversation_tracker,
|
||||
pool=app.state.pool,
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_text()
|
||||
await dispatch_message(
|
||||
ws, graph, session_manager, callback_handler, raw_data,
|
||||
interrupt_manager=interrupt_manager,
|
||||
analytics_recorder=analytics_recorder,
|
||||
conversation_tracker=conversation_tracker,
|
||||
pool=pool,
|
||||
)
|
||||
await dispatch_message(ws, ws_ctx, raw_data)
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket client disconnected")
|
||||
|
||||
|
||||
@@ -15,15 +15,21 @@ import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from app.auth import require_admin_api_key
|
||||
from app.openapi.generator import generate_agent_yaml, generate_tool_code
|
||||
from app.openapi.importer import ImportOrchestrator
|
||||
from app.openapi.models import ClassificationResult, ImportJob
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/openapi", tags=["openapi"])
|
||||
router = APIRouter(
|
||||
prefix="/api/openapi",
|
||||
tags=["openapi"],
|
||||
dependencies=[Depends(require_admin_api_key)],
|
||||
)
|
||||
|
||||
# In-memory store: job_id -> job dict, guarded by async lock
|
||||
_job_store: dict[str, dict] = {}
|
||||
@@ -235,11 +241,42 @@ async def update_classification(
|
||||
|
||||
@router.post("/jobs/{job_id}/approve")
|
||||
async def approve_job(job_id: str) -> dict:
|
||||
"""Approve a job's classifications and trigger tool generation."""
|
||||
"""Approve a job's classifications and trigger tool generation.
|
||||
|
||||
Generates Python tool code for each classified endpoint and
|
||||
produces an agent YAML configuration snippet.
|
||||
"""
|
||||
job = _job_store.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
|
||||
updated_job = {**job, "status": "approved"}
|
||||
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||
if not classifications:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No classifications to approve. Import must complete first.",
|
||||
)
|
||||
|
||||
base_url = job["spec_url"].rsplit("/", 1)[0]
|
||||
generated_tools = []
|
||||
for clf in classifications:
|
||||
tool = generate_tool_code(clf, base_url)
|
||||
generated_tools.append({
|
||||
"function_name": tool.function_name,
|
||||
"agent_group": clf.agent_group,
|
||||
"code": tool.code,
|
||||
})
|
||||
|
||||
agent_yaml = generate_agent_yaml(tuple(classifications), base_url)
|
||||
|
||||
updated_job = {
|
||||
**job,
|
||||
"status": "approved",
|
||||
"generated_tools": generated_tools,
|
||||
"agent_yaml": agent_yaml,
|
||||
}
|
||||
_job_store[job_id] = updated_job
|
||||
return _job_to_response(updated_job)
|
||||
|
||||
response = _job_to_response(updated_job)
|
||||
response["generated_tools_count"] = len(generated_tools)
|
||||
return response
|
||||
|
||||
@@ -3,16 +3,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
|
||||
from app.api_utils import envelope
|
||||
from app.auth import require_admin_api_key
|
||||
|
||||
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["replay"])
|
||||
router = APIRouter(
|
||||
prefix="/api",
|
||||
tags=["replay"],
|
||||
dependencies=[Depends(require_admin_api_key)],
|
||||
)
|
||||
|
||||
_COUNT_CONVERSATIONS_SQL = """
|
||||
SELECT COUNT(*) FROM conversations
|
||||
@@ -38,10 +45,6 @@ async def get_pool(request: Request) -> AsyncConnectionPool:
|
||||
return request.app.state.pool
|
||||
|
||||
|
||||
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
||||
return {"success": success, "data": data, "error": error}
|
||||
|
||||
|
||||
@router.get("/conversations")
|
||||
async def list_conversations(
|
||||
request: Request,
|
||||
@@ -62,7 +65,7 @@ async def list_conversations(
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
return _envelope({
|
||||
return envelope({
|
||||
"conversations": [dict(row) for row in rows],
|
||||
"total": total,
|
||||
"page": page,
|
||||
@@ -119,4 +122,4 @@ async def get_replay(
|
||||
for s in page_steps
|
||||
],
|
||||
}
|
||||
return _envelope(data)
|
||||
return envelope(data)
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
"""Session TTL management with sliding window and interrupt extension."""
|
||||
"""Session TTL management with sliding window and interrupt extension.
|
||||
|
||||
Provides both in-memory (SessionManager) and PostgreSQL-backed
|
||||
(PgSessionManager) implementations behind a common Protocol.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -13,8 +22,19 @@ class SessionState:
|
||||
has_pending_interrupt: bool
|
||||
|
||||
|
||||
class SessionManagerProtocol(Protocol):
|
||||
"""Protocol for session TTL management."""
|
||||
|
||||
def touch(self, thread_id: str) -> SessionState: ...
|
||||
def is_expired(self, thread_id: str) -> bool: ...
|
||||
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
|
||||
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
|
||||
def get_state(self, thread_id: str) -> SessionState | None: ...
|
||||
def remove(self, thread_id: str) -> None: ...
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages session TTL with sliding window and interrupt extensions.
|
||||
"""In-memory session manager for single-worker development.
|
||||
|
||||
- Each message resets the TTL (sliding window).
|
||||
- A pending interrupt suspends expiration until resolved.
|
||||
@@ -40,10 +60,8 @@ class SessionManager:
|
||||
state = self._sessions.get(thread_id)
|
||||
if state is None:
|
||||
return True
|
||||
|
||||
if state.has_pending_interrupt:
|
||||
return False
|
||||
|
||||
elapsed = time.time() - state.last_activity
|
||||
return elapsed > self._session_ttl
|
||||
|
||||
@@ -52,7 +70,6 @@ class SessionManager:
|
||||
existing = self._sessions.get(thread_id)
|
||||
if existing is None:
|
||||
return self.touch(thread_id)
|
||||
|
||||
new_state = SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=existing.last_activity,
|
||||
@@ -76,3 +93,120 @@ class SessionManager:
|
||||
|
||||
def remove(self, thread_id: str) -> None:
|
||||
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
||||
|
||||
|
||||
# Alias for explicit naming
|
||||
InMemorySessionManager = SessionManager
|
||||
|
||||
|
||||
class PgSessionManager:
|
||||
"""PostgreSQL-backed session manager for multi-worker production."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
session_ttl_seconds: int = 1800,
|
||||
) -> None:
|
||||
self._pool = pool
|
||||
self._session_ttl = session_ttl_seconds
|
||||
|
||||
def touch(self, thread_id: str) -> SessionState:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
|
||||
|
||||
async def _touch(self, thread_id: str) -> SessionState:
|
||||
now = datetime.now(timezone.utc)
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||
VALUES (%(tid)s, %(now)s, FALSE)
|
||||
ON CONFLICT (thread_id) DO UPDATE
|
||||
SET last_activity = %(now)s
|
||||
""",
|
||||
{"tid": thread_id, "now": now},
|
||||
)
|
||||
return SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=now.timestamp(),
|
||||
has_pending_interrupt=False,
|
||||
)
|
||||
|
||||
def is_expired(self, thread_id: str) -> bool:
|
||||
state = self.get_state(thread_id)
|
||||
if state is None:
|
||||
return True
|
||||
if state.has_pending_interrupt:
|
||||
return False
|
||||
elapsed = time.time() - state.last_activity
|
||||
return elapsed > self._session_ttl
|
||||
|
||||
def extend_for_interrupt(self, thread_id: str) -> SessionState:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._set_interrupt(thread_id, True)
|
||||
)
|
||||
|
||||
def resolve_interrupt(self, thread_id: str) -> SessionState:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._set_interrupt(thread_id, False)
|
||||
)
|
||||
|
||||
async def _set_interrupt(
|
||||
self, thread_id: str, has_interrupt: bool
|
||||
) -> SessionState:
|
||||
now = datetime.now(timezone.utc)
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||
VALUES (%(tid)s, %(now)s, %(interrupt)s)
|
||||
ON CONFLICT (thread_id) DO UPDATE
|
||||
SET last_activity = %(now)s,
|
||||
has_pending_interrupt = %(interrupt)s
|
||||
""",
|
||||
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
|
||||
)
|
||||
return SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=now.timestamp(),
|
||||
has_pending_interrupt=has_interrupt,
|
||||
)
|
||||
|
||||
def get_state(self, thread_id: str) -> SessionState | None:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._get_state(thread_id)
|
||||
)
|
||||
|
||||
async def _get_state(self, thread_id: str) -> SessionState | None:
|
||||
async with self._pool.connection() as conn:
|
||||
cursor = await conn.execute(
|
||||
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=row["last_activity"].timestamp(),
|
||||
has_pending_interrupt=row["has_pending_interrupt"],
|
||||
)
|
||||
|
||||
def remove(self, thread_id: str) -> None:
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
|
||||
|
||||
async def _remove(self, thread_id: str) -> None:
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"DELETE FROM sessions WHERE thread_id = %(tid)s",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
|
||||
30
backend/app/ws_context.py
Normal file
30
backend/app/ws_context.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""WebSocketContext -- bundles all dependencies needed by dispatch_message."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.analytics.event_recorder import AnalyticsRecorder
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.conversation_tracker import ConversationTrackerProtocol
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WebSocketContext:
|
||||
"""All dependencies required for WebSocket message processing.
|
||||
|
||||
Replaces the previous 9-parameter function signature in dispatch_message.
|
||||
"""
|
||||
|
||||
graph_ctx: GraphContext
|
||||
session_manager: SessionManager
|
||||
callback_handler: TokenUsageCallbackHandler
|
||||
interrupt_manager: InterruptManager | None = None
|
||||
analytics_recorder: AnalyticsRecorder | None = None
|
||||
conversation_tracker: ConversationTrackerProtocol | None = None
|
||||
pool: Any = None
|
||||
@@ -7,22 +7,19 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.graph import classify_intent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.analytics.event_recorder import AnalyticsRecorder
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.conversation_tracker import ConversationTrackerProtocol
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,7 +43,7 @@ def _evict_stale_threads(cutoff: float) -> None:
|
||||
|
||||
async def handle_user_message(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
ctx: GraphContext,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
@@ -54,8 +51,6 @@ async def handle_user_message(
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Process a user message through the graph and stream results back."""
|
||||
# Touch first so new sessions are created before expiry check.
|
||||
# For existing sessions, touch resets the sliding window.
|
||||
existing = session_manager.get_state(thread_id)
|
||||
if existing is not None and session_manager.is_expired(thread_id):
|
||||
msg = "Session expired. Please start a new conversation."
|
||||
@@ -64,8 +59,7 @@ async def handle_user_message(
|
||||
|
||||
session_manager.touch(thread_id)
|
||||
|
||||
# Run intent classification if available (for logging/future multi-intent)
|
||||
classification = await classify_intent(graph, content)
|
||||
classification = await ctx.classify_intent(content)
|
||||
if classification is not None:
|
||||
logger.info(
|
||||
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
||||
@@ -74,7 +68,6 @@ async def handle_user_message(
|
||||
[i.agent_name for i in classification.intents],
|
||||
)
|
||||
|
||||
# If ambiguous, send clarification and return
|
||||
if classification.is_ambiguous and classification.clarification_question:
|
||||
await _send_json(
|
||||
ws,
|
||||
@@ -89,7 +82,6 @@ async def handle_user_message(
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
|
||||
# If multi-intent detected, add routing hint to the message
|
||||
if classification and len(classification.intents) > 1:
|
||||
agent_names = [i.agent_name for i in classification.intents]
|
||||
hint = (
|
||||
@@ -101,7 +93,7 @@ async def handle_user_message(
|
||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||
async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||
msg_chunk, metadata = chunk
|
||||
node = metadata.get("langgraph_node", "")
|
||||
|
||||
@@ -126,12 +118,11 @@ async def handle_user_message(
|
||||
},
|
||||
)
|
||||
|
||||
state = await graph.aget_state(config)
|
||||
state = await ctx.graph.aget_state(config)
|
||||
if _has_interrupt(state):
|
||||
interrupt_data = _extract_interrupt(state)
|
||||
session_manager.extend_for_interrupt(thread_id)
|
||||
|
||||
# Register interrupt with TTL tracking
|
||||
if interrupt_manager is not None:
|
||||
interrupt_manager.register(
|
||||
thread_id=thread_id,
|
||||
@@ -158,7 +149,7 @@ async def handle_user_message(
|
||||
|
||||
async def handle_interrupt_response(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
ctx: GraphContext,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
@@ -166,7 +157,6 @@ async def handle_interrupt_response(
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Resume graph execution after interrupt approval/rejection."""
|
||||
# Check interrupt TTL before resuming
|
||||
if interrupt_manager is not None:
|
||||
status = interrupt_manager.check_status(thread_id)
|
||||
if status is not None and status.is_expired:
|
||||
@@ -184,7 +174,7 @@ async def handle_interrupt_response(
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(
|
||||
async for chunk in ctx.graph.astream(
|
||||
Command(resume=approved),
|
||||
config=config,
|
||||
stream_mode="messages",
|
||||
@@ -212,14 +202,8 @@ async def handle_interrupt_response(
|
||||
|
||||
async def dispatch_message(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
ctx: WebSocketContext,
|
||||
raw_data: str,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
analytics_recorder: AnalyticsRecorder | None = None,
|
||||
conversation_tracker: ConversationTrackerProtocol | None = None,
|
||||
pool: Any = None,
|
||||
) -> None:
|
||||
"""Parse and route an incoming WebSocket message."""
|
||||
if len(raw_data) > MAX_MESSAGE_SIZE:
|
||||
@@ -268,14 +252,15 @@ async def dispatch_message(
|
||||
_thread_timestamps[thread_id] = [*recent, now]
|
||||
|
||||
await handle_user_message(
|
||||
ws, graph, session_manager, callback_handler, thread_id, content,
|
||||
interrupt_manager=interrupt_manager,
|
||||
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||
thread_id, content,
|
||||
interrupt_manager=ctx.interrupt_manager,
|
||||
)
|
||||
await _fire_and_forget_tracking(
|
||||
thread_id=thread_id,
|
||||
pool=pool,
|
||||
analytics_recorder=analytics_recorder,
|
||||
conversation_tracker=conversation_tracker,
|
||||
pool=ctx.pool,
|
||||
analytics_recorder=ctx.analytics_recorder,
|
||||
conversation_tracker=ctx.conversation_tracker,
|
||||
agent_name=None,
|
||||
tokens=0,
|
||||
cost=0.0,
|
||||
@@ -284,8 +269,9 @@ async def dispatch_message(
|
||||
elif msg_type == "interrupt_response":
|
||||
approved = data.get("approved", False)
|
||||
await handle_interrupt_response(
|
||||
ws, graph, session_manager, callback_handler, thread_id, approved,
|
||||
interrupt_manager=interrupt_manager,
|
||||
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||
thread_id, approved,
|
||||
interrupt_manager=ctx.interrupt_manager,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -294,9 +280,9 @@ async def dispatch_message(
|
||||
|
||||
async def _fire_and_forget_tracking(
|
||||
thread_id: str,
|
||||
pool: Any,
|
||||
analytics_recorder: Any | None,
|
||||
conversation_tracker: Any | None,
|
||||
pool: object,
|
||||
analytics_recorder: object | None,
|
||||
conversation_tracker: object | None,
|
||||
agent_name: str | None,
|
||||
tokens: int,
|
||||
cost: float,
|
||||
|
||||
Reference in New Issue
Block a user