Files
smart-support/backend/app/interrupt_manager.py
Yaojia Wang af53111928 refactor: fix architectural issues across frontend and backend
Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
2026-04-06 15:59:14 +02:00

269 lines
8.5 KiB
Python

"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.
Provides both in-memory (InterruptManager) and PostgreSQL-backed
(PgInterruptManager) implementations behind a common Protocol.
"""
from __future__ import annotations
import time
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
@dataclass(frozen=True)
class InterruptRecord:
"""Immutable record of a pending interrupt."""
interrupt_id: str
thread_id: str
action: str
params: dict
created_at: float
ttl_seconds: int
@dataclass(frozen=True)
class InterruptStatus:
"""Current status of a tracked interrupt."""
is_expired: bool
remaining_seconds: float
record: InterruptRecord
class InterruptManagerProtocol(Protocol):
"""Protocol for interrupt TTL management."""
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
def resolve(self, thread_id: str) -> None: ...
def has_pending(self, thread_id: str) -> bool: ...
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action."""
return {
"type": "interrupt_expired",
"thread_id": expired_record.thread_id,
"action": expired_record.action,
"message": (
f"The approval request for '{expired_record.action}' has expired "
f"after {expired_record.ttl_seconds // 60} minutes. "
f"Would you like to try again?"
),
}
class InterruptManager:
"""In-memory interrupt manager for single-worker development.
Complements SessionManager -- this tracks interrupt-specific TTL
while SessionManager handles session-level TTL.
"""
def __init__(self, ttl_seconds: int = 1800) -> None:
self._ttl_seconds = ttl_seconds
self._interrupts: dict[str, InterruptRecord] = {}
def register(
self,
thread_id: str,
action: str,
params: dict,
) -> InterruptRecord:
"""Register a new pending interrupt with TTL tracking."""
record = InterruptRecord(
interrupt_id=uuid.uuid4().hex,
thread_id=thread_id,
action=action,
params=dict(params),
created_at=time.time(),
ttl_seconds=self._ttl_seconds,
)
self._interrupts = {**self._interrupts, thread_id: record}
return record
def check_status(self, thread_id: str) -> InterruptStatus | None:
"""Check the TTL status of a pending interrupt."""
record = self._interrupts.get(thread_id)
if record is None:
return None
elapsed = time.time() - record.created_at
remaining = max(0.0, record.ttl_seconds - elapsed)
is_expired = elapsed > record.ttl_seconds
return InterruptStatus(
is_expired=is_expired,
remaining_seconds=remaining,
record=record,
)
def resolve(self, thread_id: str) -> None:
"""Remove a resolved interrupt from tracking."""
self._interrupts = {
k: v for k, v in self._interrupts.items() if k != thread_id
}
def cleanup_expired(self) -> tuple[InterruptRecord, ...]:
"""Find and remove all expired interrupts. Returns the expired records."""
now = time.time()
expired: list[InterruptRecord] = []
active: dict[str, InterruptRecord] = {}
for thread_id, record in self._interrupts.items():
if now - record.created_at > record.ttl_seconds:
expired.append(record)
else:
active[thread_id] = record
self._interrupts = active
return tuple(expired)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action."""
return _build_retry_prompt(expired_record)
def has_pending(self, thread_id: str) -> bool:
"""Check if a thread has a pending (non-expired) interrupt."""
status = self.check_status(thread_id)
if status is None:
return False
return not status.is_expired
# Alias for explicit naming
InMemoryInterruptManager = InterruptManager
class PgInterruptManager:
"""PostgreSQL-backed interrupt manager for multi-worker production.
Uses the existing active_interrupts table defined in db.py.
"""
def __init__(
self,
pool: AsyncConnectionPool,
ttl_seconds: int = 1800,
) -> None:
self._pool = pool
self._ttl_seconds = ttl_seconds
def register(
self,
thread_id: str,
action: str,
params: dict,
) -> InterruptRecord:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._register(thread_id, action, params)
)
async def _register(
self, thread_id: str, action: str, params: dict
) -> InterruptRecord:
import json
record = InterruptRecord(
interrupt_id=uuid.uuid4().hex,
thread_id=thread_id,
action=action,
params=dict(params),
created_at=time.time(),
ttl_seconds=self._ttl_seconds,
)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
DO UPDATE SET
interrupt_id = %(iid)s,
action = %(action)s,
params = %(params)s,
created_at = NOW(),
resolved_at = NULL
""",
{
"iid": record.interrupt_id,
"tid": thread_id,
"action": action,
"params": json.dumps(params),
},
)
return record
def check_status(self, thread_id: str) -> InterruptStatus | None:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._check_status(thread_id)
)
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
async with self._pool.connection() as conn:
cursor = await conn.execute(
"""
SELECT interrupt_id, action, params, created_at
FROM active_interrupts
WHERE thread_id = %(tid)s AND resolved_at IS NULL
ORDER BY created_at DESC LIMIT 1
""",
{"tid": thread_id},
)
row = await cursor.fetchone()
if row is None:
return None
created_at = row["created_at"].timestamp()
elapsed = time.time() - created_at
remaining = max(0.0, self._ttl_seconds - elapsed)
is_expired = elapsed > self._ttl_seconds
record = InterruptRecord(
interrupt_id=row["interrupt_id"],
thread_id=thread_id,
action=row["action"],
params=row["params"] if isinstance(row["params"], dict) else {},
created_at=created_at,
ttl_seconds=self._ttl_seconds,
)
return InterruptStatus(
is_expired=is_expired,
remaining_seconds=remaining,
record=record,
)
def resolve(self, thread_id: str) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
async def _resolve(self, thread_id: str) -> None:
async with self._pool.connection() as conn:
await conn.execute(
"""
UPDATE active_interrupts
SET resolved_at = NOW(), resolution = 'resolved'
WHERE thread_id = %(tid)s AND resolved_at IS NULL
""",
{"tid": thread_id},
)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
return _build_retry_prompt(expired_record)
def has_pending(self, thread_id: str) -> bool:
status = self.check_status(thread_id)
if status is None:
return False
return not status.is_expired