refactor: fix architectural issues across frontend and backend
Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
@@ -1,10 +1,18 @@
|
||||
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration."""
|
||||
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.
|
||||
|
||||
Provides both in-memory (InterruptManager) and PostgreSQL-backed
|
||||
(PgInterruptManager) implementations behind a common Protocol.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -28,8 +36,32 @@ class InterruptStatus:
|
||||
record: InterruptRecord
|
||||
|
||||
|
||||
class InterruptManagerProtocol(Protocol):
|
||||
"""Protocol for interrupt TTL management."""
|
||||
|
||||
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
|
||||
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
|
||||
def resolve(self, thread_id: str) -> None: ...
|
||||
def has_pending(self, thread_id: str) -> bool: ...
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
|
||||
|
||||
|
||||
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
|
||||
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||
return {
|
||||
"type": "interrupt_expired",
|
||||
"thread_id": expired_record.thread_id,
|
||||
"action": expired_record.action,
|
||||
"message": (
|
||||
f"The approval request for '{expired_record.action}' has expired "
|
||||
f"after {expired_record.ttl_seconds // 60} minutes. "
|
||||
f"Would you like to try again?"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class InterruptManager:
|
||||
"""Manages interrupt TTL with auto-expiration.
|
||||
"""In-memory interrupt manager for single-worker development.
|
||||
|
||||
Complements SessionManager -- this tracks interrupt-specific TTL
|
||||
while SessionManager handles session-level TTL.
|
||||
@@ -62,11 +94,9 @@ class InterruptManager:
|
||||
record = self._interrupts.get(thread_id)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
elapsed = time.time() - record.created_at
|
||||
remaining = max(0.0, record.ttl_seconds - elapsed)
|
||||
is_expired = elapsed > record.ttl_seconds
|
||||
|
||||
return InterruptStatus(
|
||||
is_expired=is_expired,
|
||||
remaining_seconds=remaining,
|
||||
@@ -84,28 +114,17 @@ class InterruptManager:
|
||||
now = time.time()
|
||||
expired: list[InterruptRecord] = []
|
||||
active: dict[str, InterruptRecord] = {}
|
||||
|
||||
for thread_id, record in self._interrupts.items():
|
||||
if now - record.created_at > record.ttl_seconds:
|
||||
expired.append(record)
|
||||
else:
|
||||
active[thread_id] = record
|
||||
|
||||
self._interrupts = active
|
||||
return tuple(expired)
|
||||
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||
return {
|
||||
"type": "interrupt_expired",
|
||||
"thread_id": expired_record.thread_id,
|
||||
"action": expired_record.action,
|
||||
"message": (
|
||||
f"The approval request for '{expired_record.action}' has expired "
|
||||
f"after {expired_record.ttl_seconds // 60} minutes. "
|
||||
f"Would you like to try again?"
|
||||
),
|
||||
}
|
||||
return _build_retry_prompt(expired_record)
|
||||
|
||||
def has_pending(self, thread_id: str) -> bool:
|
||||
"""Check if a thread has a pending (non-expired) interrupt."""
|
||||
@@ -113,3 +132,137 @@ class InterruptManager:
|
||||
if status is None:
|
||||
return False
|
||||
return not status.is_expired
|
||||
|
||||
|
||||
# Alias for explicit naming
|
||||
InMemoryInterruptManager = InterruptManager
|
||||
|
||||
|
||||
class PgInterruptManager:
|
||||
"""PostgreSQL-backed interrupt manager for multi-worker production.
|
||||
|
||||
Uses the existing active_interrupts table defined in db.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
ttl_seconds: int = 1800,
|
||||
) -> None:
|
||||
self._pool = pool
|
||||
self._ttl_seconds = ttl_seconds
|
||||
|
||||
def register(
|
||||
self,
|
||||
thread_id: str,
|
||||
action: str,
|
||||
params: dict,
|
||||
) -> InterruptRecord:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._register(thread_id, action, params)
|
||||
)
|
||||
|
||||
async def _register(
|
||||
self, thread_id: str, action: str, params: dict
|
||||
) -> InterruptRecord:
|
||||
import json
|
||||
|
||||
record = InterruptRecord(
|
||||
interrupt_id=uuid.uuid4().hex,
|
||||
thread_id=thread_id,
|
||||
action=action,
|
||||
params=dict(params),
|
||||
created_at=time.time(),
|
||||
ttl_seconds=self._ttl_seconds,
|
||||
)
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
|
||||
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
|
||||
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
|
||||
DO UPDATE SET
|
||||
interrupt_id = %(iid)s,
|
||||
action = %(action)s,
|
||||
params = %(params)s,
|
||||
created_at = NOW(),
|
||||
resolved_at = NULL
|
||||
""",
|
||||
{
|
||||
"iid": record.interrupt_id,
|
||||
"tid": thread_id,
|
||||
"action": action,
|
||||
"params": json.dumps(params),
|
||||
},
|
||||
)
|
||||
return record
|
||||
|
||||
def check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._check_status(thread_id)
|
||||
)
|
||||
|
||||
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||
async with self._pool.connection() as conn:
|
||||
cursor = await conn.execute(
|
||||
"""
|
||||
SELECT interrupt_id, action, params, created_at
|
||||
FROM active_interrupts
|
||||
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
""",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
created_at = row["created_at"].timestamp()
|
||||
elapsed = time.time() - created_at
|
||||
remaining = max(0.0, self._ttl_seconds - elapsed)
|
||||
is_expired = elapsed > self._ttl_seconds
|
||||
|
||||
record = InterruptRecord(
|
||||
interrupt_id=row["interrupt_id"],
|
||||
thread_id=thread_id,
|
||||
action=row["action"],
|
||||
params=row["params"] if isinstance(row["params"], dict) else {},
|
||||
created_at=created_at,
|
||||
ttl_seconds=self._ttl_seconds,
|
||||
)
|
||||
|
||||
return InterruptStatus(
|
||||
is_expired=is_expired,
|
||||
remaining_seconds=remaining,
|
||||
record=record,
|
||||
)
|
||||
|
||||
def resolve(self, thread_id: str) -> None:
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
|
||||
|
||||
async def _resolve(self, thread_id: str) -> None:
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE active_interrupts
|
||||
SET resolved_at = NOW(), resolution = 'resolved'
|
||||
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||
""",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||
return _build_retry_prompt(expired_record)
|
||||
|
||||
def has_pending(self, thread_id: str) -> bool:
|
||||
status = self.check_status(thread_id)
|
||||
if status is None:
|
||||
return False
|
||||
return not status.is_expired
|
||||
|
||||
Reference in New Issue
Block a user