"""Session TTL management with sliding window and interrupt extension.""" from __future__ import annotations import time from dataclasses import dataclass @dataclass(frozen=True) class SessionState: thread_id: str last_activity: float has_pending_interrupt: bool class SessionManager: """Manages session TTL with sliding window and interrupt extensions. - Each message resets the TTL (sliding window). - A pending interrupt suspends expiration until resolved. """ def __init__(self, session_ttl_seconds: int = 1800) -> None: self._session_ttl = session_ttl_seconds self._sessions: dict[str, SessionState] = {} def touch(self, thread_id: str) -> SessionState: """Update last activity for a session (resets sliding window).""" existing = self._sessions.get(thread_id) new_state = SessionState( thread_id=thread_id, last_activity=time.time(), has_pending_interrupt=existing.has_pending_interrupt if existing else False, ) self._sessions = {**self._sessions, thread_id: new_state} return new_state def is_expired(self, thread_id: str) -> bool: """Check if a session has expired.""" state = self._sessions.get(thread_id) if state is None: return True if state.has_pending_interrupt: return False elapsed = time.time() - state.last_activity return elapsed > self._session_ttl def extend_for_interrupt(self, thread_id: str) -> SessionState: """Mark session as having a pending interrupt (suspends TTL).""" existing = self._sessions.get(thread_id) if existing is None: return self.touch(thread_id) new_state = SessionState( thread_id=thread_id, last_activity=existing.last_activity, has_pending_interrupt=True, ) self._sessions = {**self._sessions, thread_id: new_state} return new_state def resolve_interrupt(self, thread_id: str) -> SessionState: """Remove interrupt extension and reset activity timer.""" new_state = SessionState( thread_id=thread_id, last_activity=time.time(), has_pending_interrupt=False, ) self._sessions = {**self._sessions, thread_id: new_state} return new_state def get_state(self, thread_id: str) -> SessionState | None: return self._sessions.get(thread_id) def remove(self, thread_id: str) -> None: self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}