Files
smart-support/backend/app/session_manager.py
Yaojia Wang af53111928 refactor: fix architectural issues across frontend and backend
Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
2026-04-06 15:59:14 +02:00

213 lines
7.2 KiB
Python

"""Session TTL management with sliding window and interrupt extension.
Provides both in-memory (SessionManager) and PostgreSQL-backed
(PgSessionManager) implementations behind a common Protocol.
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
@dataclass(frozen=True)
class SessionState:
thread_id: str
last_activity: float
has_pending_interrupt: bool
class SessionManagerProtocol(Protocol):
"""Protocol for session TTL management."""
def touch(self, thread_id: str) -> SessionState: ...
def is_expired(self, thread_id: str) -> bool: ...
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
def get_state(self, thread_id: str) -> SessionState | None: ...
def remove(self, thread_id: str) -> None: ...
class SessionManager:
"""In-memory session manager for single-worker development.
- Each message resets the TTL (sliding window).
- A pending interrupt suspends expiration until resolved.
"""
def __init__(self, session_ttl_seconds: int = 1800) -> None:
self._session_ttl = session_ttl_seconds
self._sessions: dict[str, SessionState] = {}
def touch(self, thread_id: str) -> SessionState:
"""Update last activity for a session (resets sliding window)."""
existing = self._sessions.get(thread_id)
new_state = SessionState(
thread_id=thread_id,
last_activity=time.time(),
has_pending_interrupt=existing.has_pending_interrupt if existing else False,
)
self._sessions = {**self._sessions, thread_id: new_state}
return new_state
def is_expired(self, thread_id: str) -> bool:
"""Check if a session has expired."""
state = self._sessions.get(thread_id)
if state is None:
return True
if state.has_pending_interrupt:
return False
elapsed = time.time() - state.last_activity
return elapsed > self._session_ttl
def extend_for_interrupt(self, thread_id: str) -> SessionState:
"""Mark session as having a pending interrupt (suspends TTL)."""
existing = self._sessions.get(thread_id)
if existing is None:
return self.touch(thread_id)
new_state = SessionState(
thread_id=thread_id,
last_activity=existing.last_activity,
has_pending_interrupt=True,
)
self._sessions = {**self._sessions, thread_id: new_state}
return new_state
def resolve_interrupt(self, thread_id: str) -> SessionState:
"""Remove interrupt extension and reset activity timer."""
new_state = SessionState(
thread_id=thread_id,
last_activity=time.time(),
has_pending_interrupt=False,
)
self._sessions = {**self._sessions, thread_id: new_state}
return new_state
def get_state(self, thread_id: str) -> SessionState | None:
return self._sessions.get(thread_id)
def remove(self, thread_id: str) -> None:
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
# Alias for explicit naming
InMemorySessionManager = SessionManager
class PgSessionManager:
"""PostgreSQL-backed session manager for multi-worker production."""
def __init__(
self,
pool: AsyncConnectionPool,
session_ttl_seconds: int = 1800,
) -> None:
self._pool = pool
self._session_ttl = session_ttl_seconds
def touch(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
async def _touch(self, thread_id: str) -> SessionState:
now = datetime.now(timezone.utc)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
VALUES (%(tid)s, %(now)s, FALSE)
ON CONFLICT (thread_id) DO UPDATE
SET last_activity = %(now)s
""",
{"tid": thread_id, "now": now},
)
return SessionState(
thread_id=thread_id,
last_activity=now.timestamp(),
has_pending_interrupt=False,
)
def is_expired(self, thread_id: str) -> bool:
state = self.get_state(thread_id)
if state is None:
return True
if state.has_pending_interrupt:
return False
elapsed = time.time() - state.last_activity
return elapsed > self._session_ttl
def extend_for_interrupt(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._set_interrupt(thread_id, True)
)
def resolve_interrupt(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._set_interrupt(thread_id, False)
)
async def _set_interrupt(
self, thread_id: str, has_interrupt: bool
) -> SessionState:
now = datetime.now(timezone.utc)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
VALUES (%(tid)s, %(now)s, %(interrupt)s)
ON CONFLICT (thread_id) DO UPDATE
SET last_activity = %(now)s,
has_pending_interrupt = %(interrupt)s
""",
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
)
return SessionState(
thread_id=thread_id,
last_activity=now.timestamp(),
has_pending_interrupt=has_interrupt,
)
def get_state(self, thread_id: str) -> SessionState | None:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._get_state(thread_id)
)
async def _get_state(self, thread_id: str) -> SessionState | None:
async with self._pool.connection() as conn:
cursor = await conn.execute(
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
{"tid": thread_id},
)
row = await cursor.fetchone()
if row is None:
return None
return SessionState(
thread_id=thread_id,
last_activity=row["last_activity"].timestamp(),
has_pending_interrupt=row["has_pending_interrupt"],
)
def remove(self, thread_id: str) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
async def _remove(self, thread_id: str) -> None:
async with self._pool.connection() as conn:
await conn.execute(
"DELETE FROM sessions WHERE thread_id = %(tid)s",
{"tid": thread_id},
)