refactor: fix architectural issues across frontend and backend

Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
Yaojia Wang
2026-04-06 15:59:14 +02:00
parent b8654aa31f
commit af53111928
29 changed files with 1183 additions and 473 deletions

View File

@@ -7,22 +7,19 @@ import logging
import re
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from app.graph import classify_intent
if TYPE_CHECKING:
from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.analytics.event_recorder import AnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler
from app.conversation_tracker import ConversationTrackerProtocol
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
logger = logging.getLogger(__name__)
@@ -46,7 +43,7 @@ def _evict_stale_threads(cutoff: float) -> None:
async def handle_user_message(
ws: WebSocket,
graph: CompiledStateGraph,
ctx: GraphContext,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
@@ -54,8 +51,6 @@ async def handle_user_message(
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Process a user message through the graph and stream results back."""
# Touch first so new sessions are created before expiry check.
# For existing sessions, touch resets the sliding window.
existing = session_manager.get_state(thread_id)
if existing is not None and session_manager.is_expired(thread_id):
msg = "Session expired. Please start a new conversation."
@@ -64,8 +59,7 @@ async def handle_user_message(
session_manager.touch(thread_id)
# Run intent classification if available (for logging/future multi-intent)
classification = await classify_intent(graph, content)
classification = await ctx.classify_intent(content)
if classification is not None:
logger.info(
"Intent classification for thread %s: ambiguous=%s, intents=%s",
@@ -74,7 +68,6 @@ async def handle_user_message(
[i.agent_name for i in classification.intents],
)
# If ambiguous, send clarification and return
if classification.is_ambiguous and classification.clarification_question:
await _send_json(
ws,
@@ -89,7 +82,6 @@ async def handle_user_message(
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
# If multi-intent detected, add routing hint to the message
if classification and len(classification.intents) > 1:
agent_names = [i.agent_name for i in classification.intents]
hint = (
@@ -101,7 +93,7 @@ async def handle_user_message(
input_msg = {"messages": [HumanMessage(content=content)]}
try:
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"):
msg_chunk, metadata = chunk
node = metadata.get("langgraph_node", "")
@@ -126,12 +118,11 @@ async def handle_user_message(
},
)
state = await graph.aget_state(config)
state = await ctx.graph.aget_state(config)
if _has_interrupt(state):
interrupt_data = _extract_interrupt(state)
session_manager.extend_for_interrupt(thread_id)
# Register interrupt with TTL tracking
if interrupt_manager is not None:
interrupt_manager.register(
thread_id=thread_id,
@@ -158,7 +149,7 @@ async def handle_user_message(
async def handle_interrupt_response(
ws: WebSocket,
graph: CompiledStateGraph,
ctx: GraphContext,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
@@ -166,7 +157,6 @@ async def handle_interrupt_response(
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Resume graph execution after interrupt approval/rejection."""
# Check interrupt TTL before resuming
if interrupt_manager is not None:
status = interrupt_manager.check_status(thread_id)
if status is not None and status.is_expired:
@@ -184,7 +174,7 @@ async def handle_interrupt_response(
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
try:
async for chunk in graph.astream(
async for chunk in ctx.graph.astream(
Command(resume=approved),
config=config,
stream_mode="messages",
@@ -212,14 +202,8 @@ async def handle_interrupt_response(
async def dispatch_message(
ws: WebSocket,
graph: CompiledStateGraph,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
ctx: WebSocketContext,
raw_data: str,
interrupt_manager: InterruptManager | None = None,
analytics_recorder: AnalyticsRecorder | None = None,
conversation_tracker: ConversationTrackerProtocol | None = None,
pool: Any = None,
) -> None:
"""Parse and route an incoming WebSocket message."""
if len(raw_data) > MAX_MESSAGE_SIZE:
@@ -268,14 +252,15 @@ async def dispatch_message(
_thread_timestamps[thread_id] = [*recent, now]
await handle_user_message(
ws, graph, session_manager, callback_handler, thread_id, content,
interrupt_manager=interrupt_manager,
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
thread_id, content,
interrupt_manager=ctx.interrupt_manager,
)
await _fire_and_forget_tracking(
thread_id=thread_id,
pool=pool,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
pool=ctx.pool,
analytics_recorder=ctx.analytics_recorder,
conversation_tracker=ctx.conversation_tracker,
agent_name=None,
tokens=0,
cost=0.0,
@@ -284,8 +269,9 @@ async def dispatch_message(
elif msg_type == "interrupt_response":
approved = data.get("approved", False)
await handle_interrupt_response(
ws, graph, session_manager, callback_handler, thread_id, approved,
interrupt_manager=interrupt_manager,
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
thread_id, approved,
interrupt_manager=ctx.interrupt_manager,
)
else:
@@ -294,9 +280,9 @@ async def dispatch_message(
async def _fire_and_forget_tracking(
thread_id: str,
pool: Any,
analytics_recorder: Any | None,
conversation_tracker: Any | None,
pool: object,
analytics_recorder: object | None,
conversation_tracker: object | None,
agent_name: str | None,
tokens: int,
cost: float,