"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.""" from __future__ import annotations import time import uuid from dataclasses import dataclass @dataclass(frozen=True) class InterruptRecord: """Immutable record of a pending interrupt.""" interrupt_id: str thread_id: str action: str params: dict created_at: float ttl_seconds: int @dataclass(frozen=True) class InterruptStatus: """Current status of a tracked interrupt.""" is_expired: bool remaining_seconds: float record: InterruptRecord class InterruptManager: """Manages interrupt TTL with auto-expiration. Complements SessionManager -- this tracks interrupt-specific TTL while SessionManager handles session-level TTL. """ def __init__(self, ttl_seconds: int = 1800) -> None: self._ttl_seconds = ttl_seconds self._interrupts: dict[str, InterruptRecord] = {} def register( self, thread_id: str, action: str, params: dict, ) -> InterruptRecord: """Register a new pending interrupt with TTL tracking.""" record = InterruptRecord( interrupt_id=uuid.uuid4().hex, thread_id=thread_id, action=action, params=dict(params), created_at=time.time(), ttl_seconds=self._ttl_seconds, ) self._interrupts = {**self._interrupts, thread_id: record} return record def check_status(self, thread_id: str) -> InterruptStatus | None: """Check the TTL status of a pending interrupt.""" record = self._interrupts.get(thread_id) if record is None: return None elapsed = time.time() - record.created_at remaining = max(0.0, record.ttl_seconds - elapsed) is_expired = elapsed > record.ttl_seconds return InterruptStatus( is_expired=is_expired, remaining_seconds=remaining, record=record, ) def resolve(self, thread_id: str) -> None: """Remove a resolved interrupt from tracking.""" self._interrupts = { k: v for k, v in self._interrupts.items() if k != thread_id } def cleanup_expired(self) -> tuple[InterruptRecord, ...]: """Find and remove all expired interrupts. Returns the expired records.""" now = time.time() expired: list[InterruptRecord] = [] active: dict[str, InterruptRecord] = {} for thread_id, record in self._interrupts.items(): if now - record.created_at > record.ttl_seconds: expired.append(record) else: active[thread_id] = record self._interrupts = active return tuple(expired) def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: """Generate a WebSocket message prompting the user to retry an expired action.""" return { "type": "interrupt_expired", "thread_id": expired_record.thread_id, "action": expired_record.action, "message": ( f"The approval request for '{expired_record.action}' has expired " f"after {expired_record.ttl_seconds // 60} minutes. " f"Would you like to try again?" ), } def has_pending(self, thread_id: str) -> bool: """Check if a thread has a pending (non-expired) interrupt.""" status = self.check_status(thread_id) if status is None: return False return not status.is_expired