"""Session TTL management with sliding window and interrupt extension. Provides both in-memory (SessionManager) and PostgreSQL-backed (PgSessionManager) implementations behind a common Protocol. """ from __future__ import annotations import time from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Protocol if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool @dataclass(frozen=True) class SessionState: thread_id: str last_activity: float has_pending_interrupt: bool class SessionManagerProtocol(Protocol): """Protocol for session TTL management.""" def touch(self, thread_id: str) -> SessionState: ... def is_expired(self, thread_id: str) -> bool: ... def extend_for_interrupt(self, thread_id: str) -> SessionState: ... def resolve_interrupt(self, thread_id: str) -> SessionState: ... def get_state(self, thread_id: str) -> SessionState | None: ... def remove(self, thread_id: str) -> None: ... class SessionManager: """In-memory session manager for single-worker development. - Each message resets the TTL (sliding window). - A pending interrupt suspends expiration until resolved. """ def __init__(self, session_ttl_seconds: int = 1800) -> None: self._session_ttl = session_ttl_seconds self._sessions: dict[str, SessionState] = {} def touch(self, thread_id: str) -> SessionState: """Update last activity for a session (resets sliding window).""" existing = self._sessions.get(thread_id) new_state = SessionState( thread_id=thread_id, last_activity=time.time(), has_pending_interrupt=existing.has_pending_interrupt if existing else False, ) self._sessions = {**self._sessions, thread_id: new_state} return new_state def is_expired(self, thread_id: str) -> bool: """Check if a session has expired.""" state = self._sessions.get(thread_id) if state is None: return True if state.has_pending_interrupt: return False elapsed = time.time() - state.last_activity return elapsed > self._session_ttl def extend_for_interrupt(self, thread_id: str) -> SessionState: """Mark session as having a pending interrupt (suspends TTL).""" existing = self._sessions.get(thread_id) if existing is None: return self.touch(thread_id) new_state = SessionState( thread_id=thread_id, last_activity=existing.last_activity, has_pending_interrupt=True, ) self._sessions = {**self._sessions, thread_id: new_state} return new_state def resolve_interrupt(self, thread_id: str) -> SessionState: """Remove interrupt extension and reset activity timer.""" new_state = SessionState( thread_id=thread_id, last_activity=time.time(), has_pending_interrupt=False, ) self._sessions = {**self._sessions, thread_id: new_state} return new_state def get_state(self, thread_id: str) -> SessionState | None: return self._sessions.get(thread_id) def remove(self, thread_id: str) -> None: self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id} # Alias for explicit naming InMemorySessionManager = SessionManager class PgSessionManager: """PostgreSQL-backed session manager for multi-worker production.""" def __init__( self, pool: AsyncConnectionPool, session_ttl_seconds: int = 1800, ) -> None: self._pool = pool self._session_ttl = session_ttl_seconds def touch(self, thread_id: str) -> SessionState: import asyncio return asyncio.get_event_loop().run_until_complete(self._touch(thread_id)) async def _touch(self, thread_id: str) -> SessionState: now = datetime.now(timezone.utc) async with self._pool.connection() as conn: await conn.execute( """ INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt) VALUES (%(tid)s, %(now)s, FALSE) ON CONFLICT (thread_id) DO UPDATE SET last_activity = %(now)s """, {"tid": thread_id, "now": now}, ) return SessionState( thread_id=thread_id, last_activity=now.timestamp(), has_pending_interrupt=False, ) def is_expired(self, thread_id: str) -> bool: state = self.get_state(thread_id) if state is None: return True if state.has_pending_interrupt: return False elapsed = time.time() - state.last_activity return elapsed > self._session_ttl def extend_for_interrupt(self, thread_id: str) -> SessionState: import asyncio return asyncio.get_event_loop().run_until_complete( self._set_interrupt(thread_id, True) ) def resolve_interrupt(self, thread_id: str) -> SessionState: import asyncio return asyncio.get_event_loop().run_until_complete( self._set_interrupt(thread_id, False) ) async def _set_interrupt( self, thread_id: str, has_interrupt: bool ) -> SessionState: now = datetime.now(timezone.utc) async with self._pool.connection() as conn: await conn.execute( """ INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt) VALUES (%(tid)s, %(now)s, %(interrupt)s) ON CONFLICT (thread_id) DO UPDATE SET last_activity = %(now)s, has_pending_interrupt = %(interrupt)s """, {"tid": thread_id, "now": now, "interrupt": has_interrupt}, ) return SessionState( thread_id=thread_id, last_activity=now.timestamp(), has_pending_interrupt=has_interrupt, ) def get_state(self, thread_id: str) -> SessionState | None: import asyncio return asyncio.get_event_loop().run_until_complete( self._get_state(thread_id) ) async def _get_state(self, thread_id: str) -> SessionState | None: async with self._pool.connection() as conn: cursor = await conn.execute( "SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s", {"tid": thread_id}, ) row = await cursor.fetchone() if row is None: return None return SessionState( thread_id=thread_id, last_activity=row["last_activity"].timestamp(), has_pending_interrupt=row["has_pending_interrupt"], ) def remove(self, thread_id: str) -> None: import asyncio asyncio.get_event_loop().run_until_complete(self._remove(thread_id)) async def _remove(self, thread_id: str) -> None: async with self._pool.connection() as conn: await conn.execute( "DELETE FROM sessions WHERE thread_id = %(tid)s", {"tid": thread_id}, )