"""Interrupt TTL management -- tracks pending interrupts with auto-expiration. Provides both in-memory (InterruptManager) and PostgreSQL-backed (PgInterruptManager) implementations behind a common Protocol. """ from __future__ import annotations import time import uuid from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool @dataclass(frozen=True) class InterruptRecord: """Immutable record of a pending interrupt.""" interrupt_id: str thread_id: str action: str params: dict created_at: float ttl_seconds: int @dataclass(frozen=True) class InterruptStatus: """Current status of a tracked interrupt.""" is_expired: bool remaining_seconds: float record: InterruptRecord class InterruptManagerProtocol(Protocol): """Protocol for interrupt TTL management.""" def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ... def check_status(self, thread_id: str) -> InterruptStatus | None: ... def resolve(self, thread_id: str) -> None: ... def has_pending(self, thread_id: str) -> bool: ... def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ... def _build_retry_prompt(expired_record: InterruptRecord) -> dict: """Generate a WebSocket message prompting the user to retry an expired action.""" return { "type": "interrupt_expired", "thread_id": expired_record.thread_id, "action": expired_record.action, "message": ( f"The approval request for '{expired_record.action}' has expired " f"after {expired_record.ttl_seconds // 60} minutes. " f"Would you like to try again?" ), } class InterruptManager: """In-memory interrupt manager for single-worker development. Complements SessionManager -- this tracks interrupt-specific TTL while SessionManager handles session-level TTL. """ def __init__(self, ttl_seconds: int = 1800) -> None: self._ttl_seconds = ttl_seconds self._interrupts: dict[str, InterruptRecord] = {} def register( self, thread_id: str, action: str, params: dict, ) -> InterruptRecord: """Register a new pending interrupt with TTL tracking.""" record = InterruptRecord( interrupt_id=uuid.uuid4().hex, thread_id=thread_id, action=action, params=dict(params), created_at=time.time(), ttl_seconds=self._ttl_seconds, ) self._interrupts = {**self._interrupts, thread_id: record} return record def check_status(self, thread_id: str) -> InterruptStatus | None: """Check the TTL status of a pending interrupt.""" record = self._interrupts.get(thread_id) if record is None: return None elapsed = time.time() - record.created_at remaining = max(0.0, record.ttl_seconds - elapsed) is_expired = elapsed > record.ttl_seconds return InterruptStatus( is_expired=is_expired, remaining_seconds=remaining, record=record, ) def resolve(self, thread_id: str) -> None: """Remove a resolved interrupt from tracking.""" self._interrupts = { k: v for k, v in self._interrupts.items() if k != thread_id } def cleanup_expired(self) -> tuple[InterruptRecord, ...]: """Find and remove all expired interrupts. Returns the expired records.""" now = time.time() expired: list[InterruptRecord] = [] active: dict[str, InterruptRecord] = {} for thread_id, record in self._interrupts.items(): if now - record.created_at > record.ttl_seconds: expired.append(record) else: active[thread_id] = record self._interrupts = active return tuple(expired) def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: """Generate a WebSocket message prompting the user to retry an expired action.""" return _build_retry_prompt(expired_record) def has_pending(self, thread_id: str) -> bool: """Check if a thread has a pending (non-expired) interrupt.""" status = self.check_status(thread_id) if status is None: return False return not status.is_expired # Alias for explicit naming InMemoryInterruptManager = InterruptManager class PgInterruptManager: """PostgreSQL-backed interrupt manager for multi-worker production. Uses the existing active_interrupts table defined in db.py. """ def __init__( self, pool: AsyncConnectionPool, ttl_seconds: int = 1800, ) -> None: self._pool = pool self._ttl_seconds = ttl_seconds def register( self, thread_id: str, action: str, params: dict, ) -> InterruptRecord: import asyncio return asyncio.get_event_loop().run_until_complete( self._register(thread_id, action, params) ) async def _register( self, thread_id: str, action: str, params: dict ) -> InterruptRecord: import json record = InterruptRecord( interrupt_id=uuid.uuid4().hex, thread_id=thread_id, action=action, params=dict(params), created_at=time.time(), ttl_seconds=self._ttl_seconds, ) async with self._pool.connection() as conn: await conn.execute( """ INSERT INTO active_interrupts (interrupt_id, thread_id, action, params) VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s) ON CONFLICT (thread_id) WHERE resolved_at IS NULL DO UPDATE SET interrupt_id = %(iid)s, action = %(action)s, params = %(params)s, created_at = NOW(), resolved_at = NULL """, { "iid": record.interrupt_id, "tid": thread_id, "action": action, "params": json.dumps(params), }, ) return record def check_status(self, thread_id: str) -> InterruptStatus | None: import asyncio return asyncio.get_event_loop().run_until_complete( self._check_status(thread_id) ) async def _check_status(self, thread_id: str) -> InterruptStatus | None: async with self._pool.connection() as conn: cursor = await conn.execute( """ SELECT interrupt_id, action, params, created_at FROM active_interrupts WHERE thread_id = %(tid)s AND resolved_at IS NULL ORDER BY created_at DESC LIMIT 1 """, {"tid": thread_id}, ) row = await cursor.fetchone() if row is None: return None created_at = row["created_at"].timestamp() elapsed = time.time() - created_at remaining = max(0.0, self._ttl_seconds - elapsed) is_expired = elapsed > self._ttl_seconds record = InterruptRecord( interrupt_id=row["interrupt_id"], thread_id=thread_id, action=row["action"], params=row["params"] if isinstance(row["params"], dict) else {}, created_at=created_at, ttl_seconds=self._ttl_seconds, ) return InterruptStatus( is_expired=is_expired, remaining_seconds=remaining, record=record, ) def resolve(self, thread_id: str) -> None: import asyncio asyncio.get_event_loop().run_until_complete(self._resolve(thread_id)) async def _resolve(self, thread_id: str) -> None: async with self._pool.connection() as conn: await conn.execute( """ UPDATE active_interrupts SET resolved_at = NOW(), resolution = 'resolved' WHERE thread_id = %(tid)s AND resolved_at IS NULL """, {"tid": thread_id}, ) def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: return _build_retry_prompt(expired_record) def has_pending(self, thread_id: str) -> bool: status = self.check_status(thread_id) if status is None: return False return not status.is_expired