Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
213 lines
7.2 KiB
Python
213 lines
7.2 KiB
Python
"""Session TTL management with sliding window and interrupt extension.
|
|
|
|
Provides both in-memory (SessionManager) and PostgreSQL-backed
|
|
(PgSessionManager) implementations behind a common Protocol.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import TYPE_CHECKING, Protocol
|
|
|
|
if TYPE_CHECKING:
|
|
from psycopg_pool import AsyncConnectionPool
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SessionState:
|
|
thread_id: str
|
|
last_activity: float
|
|
has_pending_interrupt: bool
|
|
|
|
|
|
class SessionManagerProtocol(Protocol):
|
|
"""Protocol for session TTL management."""
|
|
|
|
def touch(self, thread_id: str) -> SessionState: ...
|
|
def is_expired(self, thread_id: str) -> bool: ...
|
|
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
|
|
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
|
|
def get_state(self, thread_id: str) -> SessionState | None: ...
|
|
def remove(self, thread_id: str) -> None: ...
|
|
|
|
|
|
class SessionManager:
|
|
"""In-memory session manager for single-worker development.
|
|
|
|
- Each message resets the TTL (sliding window).
|
|
- A pending interrupt suspends expiration until resolved.
|
|
"""
|
|
|
|
def __init__(self, session_ttl_seconds: int = 1800) -> None:
|
|
self._session_ttl = session_ttl_seconds
|
|
self._sessions: dict[str, SessionState] = {}
|
|
|
|
def touch(self, thread_id: str) -> SessionState:
|
|
"""Update last activity for a session (resets sliding window)."""
|
|
existing = self._sessions.get(thread_id)
|
|
new_state = SessionState(
|
|
thread_id=thread_id,
|
|
last_activity=time.time(),
|
|
has_pending_interrupt=existing.has_pending_interrupt if existing else False,
|
|
)
|
|
self._sessions = {**self._sessions, thread_id: new_state}
|
|
return new_state
|
|
|
|
def is_expired(self, thread_id: str) -> bool:
|
|
"""Check if a session has expired."""
|
|
state = self._sessions.get(thread_id)
|
|
if state is None:
|
|
return True
|
|
if state.has_pending_interrupt:
|
|
return False
|
|
elapsed = time.time() - state.last_activity
|
|
return elapsed > self._session_ttl
|
|
|
|
def extend_for_interrupt(self, thread_id: str) -> SessionState:
|
|
"""Mark session as having a pending interrupt (suspends TTL)."""
|
|
existing = self._sessions.get(thread_id)
|
|
if existing is None:
|
|
return self.touch(thread_id)
|
|
new_state = SessionState(
|
|
thread_id=thread_id,
|
|
last_activity=existing.last_activity,
|
|
has_pending_interrupt=True,
|
|
)
|
|
self._sessions = {**self._sessions, thread_id: new_state}
|
|
return new_state
|
|
|
|
def resolve_interrupt(self, thread_id: str) -> SessionState:
|
|
"""Remove interrupt extension and reset activity timer."""
|
|
new_state = SessionState(
|
|
thread_id=thread_id,
|
|
last_activity=time.time(),
|
|
has_pending_interrupt=False,
|
|
)
|
|
self._sessions = {**self._sessions, thread_id: new_state}
|
|
return new_state
|
|
|
|
def get_state(self, thread_id: str) -> SessionState | None:
|
|
return self._sessions.get(thread_id)
|
|
|
|
def remove(self, thread_id: str) -> None:
|
|
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
|
|
|
|
|
# Alias for explicit naming
|
|
InMemorySessionManager = SessionManager
|
|
|
|
|
|
class PgSessionManager:
|
|
"""PostgreSQL-backed session manager for multi-worker production."""
|
|
|
|
def __init__(
|
|
self,
|
|
pool: AsyncConnectionPool,
|
|
session_ttl_seconds: int = 1800,
|
|
) -> None:
|
|
self._pool = pool
|
|
self._session_ttl = session_ttl_seconds
|
|
|
|
def touch(self, thread_id: str) -> SessionState:
|
|
import asyncio
|
|
|
|
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
|
|
|
|
async def _touch(self, thread_id: str) -> SessionState:
|
|
now = datetime.now(timezone.utc)
|
|
async with self._pool.connection() as conn:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
|
VALUES (%(tid)s, %(now)s, FALSE)
|
|
ON CONFLICT (thread_id) DO UPDATE
|
|
SET last_activity = %(now)s
|
|
""",
|
|
{"tid": thread_id, "now": now},
|
|
)
|
|
return SessionState(
|
|
thread_id=thread_id,
|
|
last_activity=now.timestamp(),
|
|
has_pending_interrupt=False,
|
|
)
|
|
|
|
def is_expired(self, thread_id: str) -> bool:
|
|
state = self.get_state(thread_id)
|
|
if state is None:
|
|
return True
|
|
if state.has_pending_interrupt:
|
|
return False
|
|
elapsed = time.time() - state.last_activity
|
|
return elapsed > self._session_ttl
|
|
|
|
def extend_for_interrupt(self, thread_id: str) -> SessionState:
|
|
import asyncio
|
|
|
|
return asyncio.get_event_loop().run_until_complete(
|
|
self._set_interrupt(thread_id, True)
|
|
)
|
|
|
|
def resolve_interrupt(self, thread_id: str) -> SessionState:
|
|
import asyncio
|
|
|
|
return asyncio.get_event_loop().run_until_complete(
|
|
self._set_interrupt(thread_id, False)
|
|
)
|
|
|
|
async def _set_interrupt(
|
|
self, thread_id: str, has_interrupt: bool
|
|
) -> SessionState:
|
|
now = datetime.now(timezone.utc)
|
|
async with self._pool.connection() as conn:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
|
VALUES (%(tid)s, %(now)s, %(interrupt)s)
|
|
ON CONFLICT (thread_id) DO UPDATE
|
|
SET last_activity = %(now)s,
|
|
has_pending_interrupt = %(interrupt)s
|
|
""",
|
|
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
|
|
)
|
|
return SessionState(
|
|
thread_id=thread_id,
|
|
last_activity=now.timestamp(),
|
|
has_pending_interrupt=has_interrupt,
|
|
)
|
|
|
|
def get_state(self, thread_id: str) -> SessionState | None:
|
|
import asyncio
|
|
|
|
return asyncio.get_event_loop().run_until_complete(
|
|
self._get_state(thread_id)
|
|
)
|
|
|
|
async def _get_state(self, thread_id: str) -> SessionState | None:
|
|
async with self._pool.connection() as conn:
|
|
cursor = await conn.execute(
|
|
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
|
|
{"tid": thread_id},
|
|
)
|
|
row = await cursor.fetchone()
|
|
if row is None:
|
|
return None
|
|
return SessionState(
|
|
thread_id=thread_id,
|
|
last_activity=row["last_activity"].timestamp(),
|
|
has_pending_interrupt=row["has_pending_interrupt"],
|
|
)
|
|
|
|
def remove(self, thread_id: str) -> None:
|
|
import asyncio
|
|
|
|
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
|
|
|
|
async def _remove(self, thread_id: str) -> None:
|
|
async with self._pool.connection() as conn:
|
|
await conn.execute(
|
|
"DELETE FROM sessions WHERE thread_id = %(tid)s",
|
|
{"tid": thread_id},
|
|
)
|