From af531119286e307aa952474e1c29256e96dffa03 Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Mon, 6 Apr 2026 15:59:14 +0200 Subject: [PATCH] refactor: fix architectural issues across frontend and backend Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors. --- backend/app/analytics/api.py | 18 +- backend/app/api_utils.py | 10 + backend/app/auth.py | 72 +++++ backend/app/config.py | 2 + backend/app/db.py | 10 + backend/app/graph.py | 41 +-- backend/app/graph_context.py | 36 +++ backend/app/interrupt_manager.py | 185 ++++++++++-- backend/app/main.py | 48 ++-- backend/app/openapi/review_api.py | 47 +++- backend/app/replay/api.py | 21 +- backend/app/session_manager.py | 144 +++++++++- backend/app/ws_context.py | 30 ++ backend/app/ws_handler.py | 58 ++-- backend/tests/e2e/conftest.py | 26 +- .../integration/test_phase2_checkpoints.py | 49 +++- backend/tests/integration/test_routing.py | 83 +++--- backend/tests/integration/test_websocket.py | 102 ++++--- backend/tests/unit/test_db.py | 2 +- backend/tests/unit/test_db_phase4.py | 4 +- backend/tests/unit/test_edge_cases.py | 77 +++-- backend/tests/unit/test_graph.py | 36 +-- backend/tests/unit/test_main.py | 4 +- backend/tests/unit/test_ws_handler.py | 152 +++++----- frontend/src/api.ts | 71 +++++ frontend/src/pages/ChatPage.tsx | 46 ++- frontend/src/pages/ReplayPage.tsx | 2 +- frontend/src/pages/ReviewPage.tsx | 263 +++++++++++------- frontend/src/types.ts | 17 +- 29 files changed, 1183 insertions(+), 473 deletions(-) create mode 100644 backend/app/api_utils.py create mode 100644 backend/app/auth.py create mode 100644 backend/app/graph_context.py create mode 100644 backend/app/ws_context.py diff --git a/backend/app/analytics/api.py b/backend/app/analytics/api.py index d8ed694..7f7d890 100644 --- a/backend/app/analytics/api.py +++ b/backend/app/analytics/api.py @@ -4,16 +4,22 @@ from __future__ import annotations import re from dataclasses import asdict -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from app.analytics.queries import get_analytics +from app.api_utils import envelope +from app.auth import require_admin_api_key if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool -router = APIRouter(prefix="/api/analytics", tags=["analytics"]) +router = APIRouter( + prefix="/api/analytics", + tags=["analytics"], + dependencies=[Depends(require_admin_api_key)], +) _RANGE_PATTERN = re.compile(r"^(\d+)d$") _DEFAULT_RANGE = "7d" @@ -25,10 +31,6 @@ async def _get_pool(request: Request) -> AsyncConnectionPool: return request.app.state.pool -def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict: - return {"success": success, "data": data, "error": error} - - def _parse_range(range_str: str) -> int: """Parse 'Xd' range string to integer days. Raises 400 on invalid format.""" match = _RANGE_PATTERN.match(range_str) @@ -55,4 +57,4 @@ async def analytics( range_days = _parse_range(range) pool = await _get_pool(request) result = await get_analytics(pool, range_days=range_days) - return _envelope(asdict(result)) + return envelope(asdict(result)) diff --git a/backend/app/api_utils.py b/backend/app/api_utils.py new file mode 100644 index 0000000..9e66a75 --- /dev/null +++ b/backend/app/api_utils.py @@ -0,0 +1,10 @@ +"""Shared API response helpers.""" + +from __future__ import annotations + +from typing import Any + + +def envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict: + """Wrap API response data in a standard envelope format.""" + return {"success": success, "data": data, "error": error} diff --git a/backend/app/auth.py b/backend/app/auth.py new file mode 100644 index 0000000..3a186d4 --- /dev/null +++ b/backend/app/auth.py @@ -0,0 +1,72 @@ +"""API key authentication for admin endpoints and WebSocket connections.""" + +from __future__ import annotations + +import logging +import secrets +from typing import Annotated + +from fastapi import Depends, HTTPException, Query, Request, WebSocket, status +from fastapi.security import APIKeyHeader + +logger = logging.getLogger(__name__) + +_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False) + + +def _get_admin_api_key(request: Request) -> str: + """Retrieve the configured admin API key from app settings. + + Returns empty string if settings are not configured (test/dev mode). + """ + settings = getattr(request.app.state, "settings", None) + if settings is None: + return "" + key = getattr(settings, "admin_api_key", "") + return key if isinstance(key, str) else "" + + +async def require_admin_api_key( + request: Request, + api_key: Annotated[str | None, Depends(_API_KEY_HEADER)] = None, +) -> None: + """Dependency that enforces API key authentication on admin endpoints. + + Skips validation when no admin_api_key is configured (dev mode). + """ + expected = _get_admin_api_key(request) + if not expected: + return + + if api_key is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing X-API-Key header", + ) + if not secrets.compare_digest(api_key, expected): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid API key", + ) + + +async def verify_ws_token( + ws: WebSocket, + token: str | None = Query(default=None), +) -> None: + """Verify WebSocket connection token from query parameter. + + Skips validation when no admin_api_key is configured (dev mode). + Usage: ws://host/ws?token= + """ + settings = ws.app.state.settings + expected = settings.admin_api_key + if not expected: + return + + if token is None or not secrets.compare_digest(token, expected): + await ws.close(code=4001, reason="Unauthorized") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid or missing WebSocket token", + ) diff --git a/backend/app/config.py b/backend/app/config.py index 56c3f0a..6857506 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -32,6 +32,8 @@ class Settings(BaseSettings): template_name: str = "" + admin_api_key: str = "" + anthropic_api_key: str = "" openai_api_key: str = "" azure_openai_api_key: str = "" diff --git a/backend/app/db.py b/backend/app/db.py index 62b81e1..01aa95e 100644 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -51,6 +51,15 @@ CREATE TABLE IF NOT EXISTS analytics_events ( ); """ +_SESSIONS_DDL = """ +CREATE TABLE IF NOT EXISTS sessions ( + thread_id TEXT PRIMARY KEY, + last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(), + has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +""" + _CONVERSATIONS_MIGRATION_DDL = """ ALTER TABLE conversations ADD COLUMN IF NOT EXISTS resolution_type TEXT, @@ -84,5 +93,6 @@ async def setup_app_tables(pool: AsyncConnectionPool) -> None: async with pool.connection() as conn: await conn.execute(_CONVERSATIONS_DDL) await conn.execute(_INTERRUPTS_DDL) + await conn.execute(_SESSIONS_DDL) await conn.execute(_ANALYTICS_EVENTS_DDL) await conn.execute(_CONVERSATIONS_MIGRATION_DDL) diff --git a/backend/app/graph.py b/backend/app/graph.py index 70aa2ea..ee11b43 100644 --- a/backend/app/graph.py +++ b/backend/app/graph.py @@ -9,13 +9,13 @@ from langchain.agents import create_agent from langgraph_supervisor import create_supervisor from app.agents import get_tools_by_names +from app.graph_context import GraphContext if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - from langgraph.graph.state import CompiledStateGraph - from app.intent import ClassificationResult, IntentClassifier + from app.intent import IntentClassifier from app.registry import AgentRegistry logger = logging.getLogger(__name__) @@ -75,12 +75,11 @@ def build_graph( llm: BaseChatModel, checkpointer: AsyncPostgresSaver, intent_classifier: IntentClassifier | None = None, -) -> CompiledStateGraph: +) -> GraphContext: """Build and compile the LangGraph supervisor graph. - If an intent_classifier is provided, the supervisor prompt is enhanced - with agent descriptions for better routing. The classifier is stored - for use by the routing layer (ws_handler). + Returns a GraphContext that bundles the compiled graph with its + associated registry and intent classifier. """ agent_nodes = build_agent_nodes(registry, llm) agent_descriptions = _format_agent_descriptions(registry) @@ -94,28 +93,10 @@ def build_graph( output_mode="full_history", ) - graph = workflow.compile(checkpointer=checkpointer) + compiled = workflow.compile(checkpointer=checkpointer) - # Attach classifier and registry to graph for use by ws_handler - graph.intent_classifier = intent_classifier # type: ignore[attr-defined] - graph.agent_registry = registry # type: ignore[attr-defined] - - return graph - - -async def classify_intent( - graph: CompiledStateGraph, - message: str, -) -> ClassificationResult | None: - """Classify user intent using the graph's attached classifier. - - Returns None if no classifier is configured. - """ - classifier = getattr(graph, "intent_classifier", None) - registry = getattr(graph, "agent_registry", None) - - if classifier is None or registry is None: - return None - - agents = registry.list_agents() - return await classifier.classify(message, agents) + return GraphContext( + graph=compiled, + registry=registry, + intent_classifier=intent_classifier, + ) diff --git a/backend/app/graph_context.py b/backend/app/graph_context.py new file mode 100644 index 0000000..0b1f7f4 --- /dev/null +++ b/backend/app/graph_context.py @@ -0,0 +1,36 @@ +"""GraphContext -- typed wrapper around the compiled graph and its dependencies.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from langgraph.graph.state import CompiledStateGraph + + from app.intent import ClassificationResult, IntentClassifier + from app.registry import AgentRegistry + + +@dataclass(frozen=True) +class GraphContext: + """Bundles the compiled LangGraph graph with its associated services. + + Replaces the previous pattern of monkey-patching attributes onto the + third-party CompiledStateGraph instance. + """ + + graph: CompiledStateGraph + registry: AgentRegistry + intent_classifier: IntentClassifier | None = None + + async def classify_intent(self, message: str) -> ClassificationResult | None: + """Classify user intent using the attached classifier. + + Returns None if no classifier is configured. + """ + if self.intent_classifier is None: + return None + + agents = self.registry.list_agents() + return await self.intent_classifier.classify(message, agents) diff --git a/backend/app/interrupt_manager.py b/backend/app/interrupt_manager.py index 2b04871..3f7880b 100644 --- a/backend/app/interrupt_manager.py +++ b/backend/app/interrupt_manager.py @@ -1,10 +1,18 @@ -"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.""" +"""Interrupt TTL management -- tracks pending interrupts with auto-expiration. + +Provides both in-memory (InterruptManager) and PostgreSQL-backed +(PgInterruptManager) implementations behind a common Protocol. +""" from __future__ import annotations import time import uuid from dataclasses import dataclass +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from psycopg_pool import AsyncConnectionPool @dataclass(frozen=True) @@ -28,8 +36,32 @@ class InterruptStatus: record: InterruptRecord +class InterruptManagerProtocol(Protocol): + """Protocol for interrupt TTL management.""" + + def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ... + def check_status(self, thread_id: str) -> InterruptStatus | None: ... + def resolve(self, thread_id: str) -> None: ... + def has_pending(self, thread_id: str) -> bool: ... + def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ... + + +def _build_retry_prompt(expired_record: InterruptRecord) -> dict: + """Generate a WebSocket message prompting the user to retry an expired action.""" + return { + "type": "interrupt_expired", + "thread_id": expired_record.thread_id, + "action": expired_record.action, + "message": ( + f"The approval request for '{expired_record.action}' has expired " + f"after {expired_record.ttl_seconds // 60} minutes. " + f"Would you like to try again?" + ), + } + + class InterruptManager: - """Manages interrupt TTL with auto-expiration. + """In-memory interrupt manager for single-worker development. Complements SessionManager -- this tracks interrupt-specific TTL while SessionManager handles session-level TTL. @@ -62,11 +94,9 @@ class InterruptManager: record = self._interrupts.get(thread_id) if record is None: return None - elapsed = time.time() - record.created_at remaining = max(0.0, record.ttl_seconds - elapsed) is_expired = elapsed > record.ttl_seconds - return InterruptStatus( is_expired=is_expired, remaining_seconds=remaining, @@ -84,28 +114,17 @@ class InterruptManager: now = time.time() expired: list[InterruptRecord] = [] active: dict[str, InterruptRecord] = {} - for thread_id, record in self._interrupts.items(): if now - record.created_at > record.ttl_seconds: expired.append(record) else: active[thread_id] = record - self._interrupts = active return tuple(expired) def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: """Generate a WebSocket message prompting the user to retry an expired action.""" - return { - "type": "interrupt_expired", - "thread_id": expired_record.thread_id, - "action": expired_record.action, - "message": ( - f"The approval request for '{expired_record.action}' has expired " - f"after {expired_record.ttl_seconds // 60} minutes. " - f"Would you like to try again?" - ), - } + return _build_retry_prompt(expired_record) def has_pending(self, thread_id: str) -> bool: """Check if a thread has a pending (non-expired) interrupt.""" @@ -113,3 +132,137 @@ class InterruptManager: if status is None: return False return not status.is_expired + + +# Alias for explicit naming +InMemoryInterruptManager = InterruptManager + + +class PgInterruptManager: + """PostgreSQL-backed interrupt manager for multi-worker production. + + Uses the existing active_interrupts table defined in db.py. + """ + + def __init__( + self, + pool: AsyncConnectionPool, + ttl_seconds: int = 1800, + ) -> None: + self._pool = pool + self._ttl_seconds = ttl_seconds + + def register( + self, + thread_id: str, + action: str, + params: dict, + ) -> InterruptRecord: + import asyncio + + return asyncio.get_event_loop().run_until_complete( + self._register(thread_id, action, params) + ) + + async def _register( + self, thread_id: str, action: str, params: dict + ) -> InterruptRecord: + import json + + record = InterruptRecord( + interrupt_id=uuid.uuid4().hex, + thread_id=thread_id, + action=action, + params=dict(params), + created_at=time.time(), + ttl_seconds=self._ttl_seconds, + ) + async with self._pool.connection() as conn: + await conn.execute( + """ + INSERT INTO active_interrupts (interrupt_id, thread_id, action, params) + VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s) + ON CONFLICT (thread_id) WHERE resolved_at IS NULL + DO UPDATE SET + interrupt_id = %(iid)s, + action = %(action)s, + params = %(params)s, + created_at = NOW(), + resolved_at = NULL + """, + { + "iid": record.interrupt_id, + "tid": thread_id, + "action": action, + "params": json.dumps(params), + }, + ) + return record + + def check_status(self, thread_id: str) -> InterruptStatus | None: + import asyncio + + return asyncio.get_event_loop().run_until_complete( + self._check_status(thread_id) + ) + + async def _check_status(self, thread_id: str) -> InterruptStatus | None: + async with self._pool.connection() as conn: + cursor = await conn.execute( + """ + SELECT interrupt_id, action, params, created_at + FROM active_interrupts + WHERE thread_id = %(tid)s AND resolved_at IS NULL + ORDER BY created_at DESC LIMIT 1 + """, + {"tid": thread_id}, + ) + row = await cursor.fetchone() + + if row is None: + return None + + created_at = row["created_at"].timestamp() + elapsed = time.time() - created_at + remaining = max(0.0, self._ttl_seconds - elapsed) + is_expired = elapsed > self._ttl_seconds + + record = InterruptRecord( + interrupt_id=row["interrupt_id"], + thread_id=thread_id, + action=row["action"], + params=row["params"] if isinstance(row["params"], dict) else {}, + created_at=created_at, + ttl_seconds=self._ttl_seconds, + ) + + return InterruptStatus( + is_expired=is_expired, + remaining_seconds=remaining, + record=record, + ) + + def resolve(self, thread_id: str) -> None: + import asyncio + + asyncio.get_event_loop().run_until_complete(self._resolve(thread_id)) + + async def _resolve(self, thread_id: str) -> None: + async with self._pool.connection() as conn: + await conn.execute( + """ + UPDATE active_interrupts + SET resolved_at = NOW(), resolution = 'resolved' + WHERE thread_id = %(tid)s AND resolved_at IS NULL + """, + {"tid": thread_id}, + ) + + def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: + return _build_retry_prompt(expired_record) + + def has_pending(self, thread_id: str) -> bool: + status = self.check_status(thread_id) + if status is None: + return False + return not status.is_expired diff --git a/backend/app/main.py b/backend/app/main.py index 9a70380..a69f92e 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager from pathlib import Path from typing import TYPE_CHECKING -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import Depends, FastAPI, Query, WebSocket, WebSocketDisconnect from fastapi.staticfiles import StaticFiles from app.analytics.api import router as analytics_router @@ -25,6 +25,7 @@ from app.openapi.review_api import router as openapi_router from app.registry import AgentRegistry from app.replay.api import router as replay_router from app.session_manager import SessionManager +from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message if TYPE_CHECKING: @@ -52,7 +53,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: llm = create_llm(settings) intent_classifier = LLMIntentClassifier(llm) - graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier) + graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier) session_manager = SessionManager( session_ttl_seconds=settings.session_ttl_minutes * 60, @@ -71,7 +72,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: else: escalator = NoOpEscalator() - app.state.graph = graph + app.state.graph_ctx = graph_ctx app.state.session_manager = session_manager app.state.interrupt_manager = interrupt_manager app.state.escalator = escalator @@ -93,7 +94,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await pool.close() -_VERSION = "0.5.0" +_VERSION = "0.6.0" app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan) @@ -109,28 +110,37 @@ def health_check() -> dict: @app.websocket("/ws") -async def websocket_endpoint(ws: WebSocket) -> None: - await ws.accept() - graph = app.state.graph - session_manager = app.state.session_manager - interrupt_manager = app.state.interrupt_manager +async def websocket_endpoint( + ws: WebSocket, + token: str | None = Query(default=None), +) -> None: settings = app.state.settings + + # Verify WebSocket token when admin_api_key is configured + if settings.admin_api_key: + import secrets as _secrets + + if token is None or not _secrets.compare_digest(token, settings.admin_api_key): + await ws.close(code=4001, reason="Unauthorized") + return + + await ws.accept() callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model) - analytics_recorder = app.state.analytics_recorder - conversation_tracker = app.state.conversation_tracker - pool = app.state.pool + ws_ctx = WebSocketContext( + graph_ctx=app.state.graph_ctx, + session_manager=app.state.session_manager, + callback_handler=callback_handler, + interrupt_manager=app.state.interrupt_manager, + analytics_recorder=app.state.analytics_recorder, + conversation_tracker=app.state.conversation_tracker, + pool=app.state.pool, + ) try: while True: raw_data = await ws.receive_text() - await dispatch_message( - ws, graph, session_manager, callback_handler, raw_data, - interrupt_manager=interrupt_manager, - analytics_recorder=analytics_recorder, - conversation_tracker=conversation_tracker, - pool=pool, - ) + await dispatch_message(ws, ws_ctx, raw_data) except WebSocketDisconnect: logger.info("WebSocket client disconnected") diff --git a/backend/app/openapi/review_api.py b/backend/app/openapi/review_api.py index 49713db..6a8a452 100644 --- a/backend/app/openapi/review_api.py +++ b/backend/app/openapi/review_api.py @@ -15,15 +15,21 @@ import re import uuid from typing import Literal -from fastapi import APIRouter, BackgroundTasks, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from pydantic import BaseModel, field_validator +from app.auth import require_admin_api_key +from app.openapi.generator import generate_agent_yaml, generate_tool_code from app.openapi.importer import ImportOrchestrator from app.openapi.models import ClassificationResult, ImportJob logger = logging.getLogger(__name__) -router = APIRouter(prefix="/api/openapi", tags=["openapi"]) +router = APIRouter( + prefix="/api/openapi", + tags=["openapi"], + dependencies=[Depends(require_admin_api_key)], +) # In-memory store: job_id -> job dict, guarded by async lock _job_store: dict[str, dict] = {} @@ -235,11 +241,42 @@ async def update_classification( @router.post("/jobs/{job_id}/approve") async def approve_job(job_id: str) -> dict: - """Approve a job's classifications and trigger tool generation.""" + """Approve a job's classifications and trigger tool generation. + + Generates Python tool code for each classified endpoint and + produces an agent YAML configuration snippet. + """ job = _job_store.get(job_id) if job is None: raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") - updated_job = {**job, "status": "approved"} + classifications: list[ClassificationResult] = job.get("classifications", []) + if not classifications: + raise HTTPException( + status_code=400, + detail="No classifications to approve. Import must complete first.", + ) + + base_url = job["spec_url"].rsplit("/", 1)[0] + generated_tools = [] + for clf in classifications: + tool = generate_tool_code(clf, base_url) + generated_tools.append({ + "function_name": tool.function_name, + "agent_group": clf.agent_group, + "code": tool.code, + }) + + agent_yaml = generate_agent_yaml(tuple(classifications), base_url) + + updated_job = { + **job, + "status": "approved", + "generated_tools": generated_tools, + "agent_yaml": agent_yaml, + } _job_store[job_id] = updated_job - return _job_to_response(updated_job) + + response = _job_to_response(updated_job) + response["generated_tools_count"] = len(generated_tools) + return response diff --git a/backend/app/replay/api.py b/backend/app/replay/api.py index 4681b69..7f1d94e 100644 --- a/backend/app/replay/api.py +++ b/backend/app/replay/api.py @@ -3,16 +3,23 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Annotated, Any +from typing import TYPE_CHECKING, Annotated -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request + +from app.api_utils import envelope +from app.auth import require_admin_api_key _THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool -router = APIRouter(prefix="/api", tags=["replay"]) +router = APIRouter( + prefix="/api", + tags=["replay"], + dependencies=[Depends(require_admin_api_key)], +) _COUNT_CONVERSATIONS_SQL = """ SELECT COUNT(*) FROM conversations @@ -38,10 +45,6 @@ async def get_pool(request: Request) -> AsyncConnectionPool: return request.app.state.pool -def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict: - return {"success": success, "data": data, "error": error} - - @router.get("/conversations") async def list_conversations( request: Request, @@ -62,7 +65,7 @@ async def list_conversations( ) rows = await cursor.fetchall() - return _envelope({ + return envelope({ "conversations": [dict(row) for row in rows], "total": total, "page": page, @@ -119,4 +122,4 @@ async def get_replay( for s in page_steps ], } - return _envelope(data) + return envelope(data) diff --git a/backend/app/session_manager.py b/backend/app/session_manager.py index 6b02dfb..0ad88ba 100644 --- a/backend/app/session_manager.py +++ b/backend/app/session_manager.py @@ -1,9 +1,18 @@ -"""Session TTL management with sliding window and interrupt extension.""" +"""Session TTL management with sliding window and interrupt extension. + +Provides both in-memory (SessionManager) and PostgreSQL-backed +(PgSessionManager) implementations behind a common Protocol. +""" from __future__ import annotations import time from dataclasses import dataclass +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from psycopg_pool import AsyncConnectionPool @dataclass(frozen=True) @@ -13,8 +22,19 @@ class SessionState: has_pending_interrupt: bool +class SessionManagerProtocol(Protocol): + """Protocol for session TTL management.""" + + def touch(self, thread_id: str) -> SessionState: ... + def is_expired(self, thread_id: str) -> bool: ... + def extend_for_interrupt(self, thread_id: str) -> SessionState: ... + def resolve_interrupt(self, thread_id: str) -> SessionState: ... + def get_state(self, thread_id: str) -> SessionState | None: ... + def remove(self, thread_id: str) -> None: ... + + class SessionManager: - """Manages session TTL with sliding window and interrupt extensions. + """In-memory session manager for single-worker development. - Each message resets the TTL (sliding window). - A pending interrupt suspends expiration until resolved. @@ -40,10 +60,8 @@ class SessionManager: state = self._sessions.get(thread_id) if state is None: return True - if state.has_pending_interrupt: return False - elapsed = time.time() - state.last_activity return elapsed > self._session_ttl @@ -52,7 +70,6 @@ class SessionManager: existing = self._sessions.get(thread_id) if existing is None: return self.touch(thread_id) - new_state = SessionState( thread_id=thread_id, last_activity=existing.last_activity, @@ -76,3 +93,120 @@ class SessionManager: def remove(self, thread_id: str) -> None: self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id} + + +# Alias for explicit naming +InMemorySessionManager = SessionManager + + +class PgSessionManager: + """PostgreSQL-backed session manager for multi-worker production.""" + + def __init__( + self, + pool: AsyncConnectionPool, + session_ttl_seconds: int = 1800, + ) -> None: + self._pool = pool + self._session_ttl = session_ttl_seconds + + def touch(self, thread_id: str) -> SessionState: + import asyncio + + return asyncio.get_event_loop().run_until_complete(self._touch(thread_id)) + + async def _touch(self, thread_id: str) -> SessionState: + now = datetime.now(timezone.utc) + async with self._pool.connection() as conn: + await conn.execute( + """ + INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt) + VALUES (%(tid)s, %(now)s, FALSE) + ON CONFLICT (thread_id) DO UPDATE + SET last_activity = %(now)s + """, + {"tid": thread_id, "now": now}, + ) + return SessionState( + thread_id=thread_id, + last_activity=now.timestamp(), + has_pending_interrupt=False, + ) + + def is_expired(self, thread_id: str) -> bool: + state = self.get_state(thread_id) + if state is None: + return True + if state.has_pending_interrupt: + return False + elapsed = time.time() - state.last_activity + return elapsed > self._session_ttl + + def extend_for_interrupt(self, thread_id: str) -> SessionState: + import asyncio + + return asyncio.get_event_loop().run_until_complete( + self._set_interrupt(thread_id, True) + ) + + def resolve_interrupt(self, thread_id: str) -> SessionState: + import asyncio + + return asyncio.get_event_loop().run_until_complete( + self._set_interrupt(thread_id, False) + ) + + async def _set_interrupt( + self, thread_id: str, has_interrupt: bool + ) -> SessionState: + now = datetime.now(timezone.utc) + async with self._pool.connection() as conn: + await conn.execute( + """ + INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt) + VALUES (%(tid)s, %(now)s, %(interrupt)s) + ON CONFLICT (thread_id) DO UPDATE + SET last_activity = %(now)s, + has_pending_interrupt = %(interrupt)s + """, + {"tid": thread_id, "now": now, "interrupt": has_interrupt}, + ) + return SessionState( + thread_id=thread_id, + last_activity=now.timestamp(), + has_pending_interrupt=has_interrupt, + ) + + def get_state(self, thread_id: str) -> SessionState | None: + import asyncio + + return asyncio.get_event_loop().run_until_complete( + self._get_state(thread_id) + ) + + async def _get_state(self, thread_id: str) -> SessionState | None: + async with self._pool.connection() as conn: + cursor = await conn.execute( + "SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s", + {"tid": thread_id}, + ) + row = await cursor.fetchone() + if row is None: + return None + return SessionState( + thread_id=thread_id, + last_activity=row["last_activity"].timestamp(), + has_pending_interrupt=row["has_pending_interrupt"], + ) + + def remove(self, thread_id: str) -> None: + import asyncio + + asyncio.get_event_loop().run_until_complete(self._remove(thread_id)) + + async def _remove(self, thread_id: str) -> None: + async with self._pool.connection() as conn: + await conn.execute( + "DELETE FROM sessions WHERE thread_id = %(tid)s", + {"tid": thread_id}, + ) diff --git a/backend/app/ws_context.py b/backend/app/ws_context.py new file mode 100644 index 0000000..91f198d --- /dev/null +++ b/backend/app/ws_context.py @@ -0,0 +1,30 @@ +"""WebSocketContext -- bundles all dependencies needed by dispatch_message.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from app.analytics.event_recorder import AnalyticsRecorder + from app.callbacks import TokenUsageCallbackHandler + from app.conversation_tracker import ConversationTrackerProtocol + from app.graph_context import GraphContext + from app.interrupt_manager import InterruptManager + from app.session_manager import SessionManager + + +@dataclass(frozen=True) +class WebSocketContext: + """All dependencies required for WebSocket message processing. + + Replaces the previous 9-parameter function signature in dispatch_message. + """ + + graph_ctx: GraphContext + session_manager: SessionManager + callback_handler: TokenUsageCallbackHandler + interrupt_manager: InterruptManager | None = None + analytics_recorder: AnalyticsRecorder | None = None + conversation_tracker: ConversationTrackerProtocol | None = None + pool: Any = None diff --git a/backend/app/ws_handler.py b/backend/app/ws_handler.py index 81ddfde..9eaa29d 100644 --- a/backend/app/ws_handler.py +++ b/backend/app/ws_handler.py @@ -7,22 +7,19 @@ import logging import re import time from collections import defaultdict -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from langchain_core.messages import HumanMessage from langgraph.types import Command -from app.graph import classify_intent - if TYPE_CHECKING: from fastapi import WebSocket - from langgraph.graph.state import CompiledStateGraph - from app.analytics.event_recorder import AnalyticsRecorder from app.callbacks import TokenUsageCallbackHandler - from app.conversation_tracker import ConversationTrackerProtocol + from app.graph_context import GraphContext from app.interrupt_manager import InterruptManager from app.session_manager import SessionManager + from app.ws_context import WebSocketContext logger = logging.getLogger(__name__) @@ -46,7 +43,7 @@ def _evict_stale_threads(cutoff: float) -> None: async def handle_user_message( ws: WebSocket, - graph: CompiledStateGraph, + ctx: GraphContext, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, thread_id: str, @@ -54,8 +51,6 @@ async def handle_user_message( interrupt_manager: InterruptManager | None = None, ) -> None: """Process a user message through the graph and stream results back.""" - # Touch first so new sessions are created before expiry check. - # For existing sessions, touch resets the sliding window. existing = session_manager.get_state(thread_id) if existing is not None and session_manager.is_expired(thread_id): msg = "Session expired. Please start a new conversation." @@ -64,8 +59,7 @@ async def handle_user_message( session_manager.touch(thread_id) - # Run intent classification if available (for logging/future multi-intent) - classification = await classify_intent(graph, content) + classification = await ctx.classify_intent(content) if classification is not None: logger.info( "Intent classification for thread %s: ambiguous=%s, intents=%s", @@ -74,7 +68,6 @@ async def handle_user_message( [i.agent_name for i in classification.intents], ) - # If ambiguous, send clarification and return if classification.is_ambiguous and classification.clarification_question: await _send_json( ws, @@ -89,7 +82,6 @@ async def handle_user_message( config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} - # If multi-intent detected, add routing hint to the message if classification and len(classification.intents) > 1: agent_names = [i.agent_name for i in classification.intents] hint = ( @@ -101,7 +93,7 @@ async def handle_user_message( input_msg = {"messages": [HumanMessage(content=content)]} try: - async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"): + async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"): msg_chunk, metadata = chunk node = metadata.get("langgraph_node", "") @@ -126,12 +118,11 @@ async def handle_user_message( }, ) - state = await graph.aget_state(config) + state = await ctx.graph.aget_state(config) if _has_interrupt(state): interrupt_data = _extract_interrupt(state) session_manager.extend_for_interrupt(thread_id) - # Register interrupt with TTL tracking if interrupt_manager is not None: interrupt_manager.register( thread_id=thread_id, @@ -158,7 +149,7 @@ async def handle_user_message( async def handle_interrupt_response( ws: WebSocket, - graph: CompiledStateGraph, + ctx: GraphContext, session_manager: SessionManager, callback_handler: TokenUsageCallbackHandler, thread_id: str, @@ -166,7 +157,6 @@ async def handle_interrupt_response( interrupt_manager: InterruptManager | None = None, ) -> None: """Resume graph execution after interrupt approval/rejection.""" - # Check interrupt TTL before resuming if interrupt_manager is not None: status = interrupt_manager.check_status(thread_id) if status is not None and status.is_expired: @@ -184,7 +174,7 @@ async def handle_interrupt_response( config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} try: - async for chunk in graph.astream( + async for chunk in ctx.graph.astream( Command(resume=approved), config=config, stream_mode="messages", @@ -212,14 +202,8 @@ async def handle_interrupt_response( async def dispatch_message( ws: WebSocket, - graph: CompiledStateGraph, - session_manager: SessionManager, - callback_handler: TokenUsageCallbackHandler, + ctx: WebSocketContext, raw_data: str, - interrupt_manager: InterruptManager | None = None, - analytics_recorder: AnalyticsRecorder | None = None, - conversation_tracker: ConversationTrackerProtocol | None = None, - pool: Any = None, ) -> None: """Parse and route an incoming WebSocket message.""" if len(raw_data) > MAX_MESSAGE_SIZE: @@ -268,14 +252,15 @@ async def dispatch_message( _thread_timestamps[thread_id] = [*recent, now] await handle_user_message( - ws, graph, session_manager, callback_handler, thread_id, content, - interrupt_manager=interrupt_manager, + ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler, + thread_id, content, + interrupt_manager=ctx.interrupt_manager, ) await _fire_and_forget_tracking( thread_id=thread_id, - pool=pool, - analytics_recorder=analytics_recorder, - conversation_tracker=conversation_tracker, + pool=ctx.pool, + analytics_recorder=ctx.analytics_recorder, + conversation_tracker=ctx.conversation_tracker, agent_name=None, tokens=0, cost=0.0, @@ -284,8 +269,9 @@ async def dispatch_message( elif msg_type == "interrupt_response": approved = data.get("approved", False) await handle_interrupt_response( - ws, graph, session_manager, callback_handler, thread_id, approved, - interrupt_manager=interrupt_manager, + ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler, + thread_id, approved, + interrupt_manager=ctx.interrupt_manager, ) else: @@ -294,9 +280,9 @@ async def dispatch_message( async def _fire_and_forget_tracking( thread_id: str, - pool: Any, - analytics_recorder: Any | None, - conversation_tracker: Any | None, + pool: object, + analytics_recorder: object | None, + conversation_tracker: object | None, agent_name: str | None, tokens: int, cost: float, diff --git a/backend/tests/e2e/conftest.py b/backend/tests/e2e/conftest.py index 98ebdb6..5bc87c1 100644 --- a/backend/tests/e2e/conftest.py +++ b/backend/tests/e2e/conftest.py @@ -13,10 +13,12 @@ from httpx import ASGITransport, AsyncClient from app.analytics.api import router as analytics_router from app.callbacks import TokenUsageCallbackHandler +from app.graph_context import GraphContext from app.interrupt_manager import InterruptManager from app.openapi.review_api import _job_store, router as openapi_router from app.replay.api import router as replay_router from app.session_manager import SessionManager +from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message @@ -74,8 +76,6 @@ def make_graph( ) -> MagicMock: """Build a mock LangGraph CompiledStateGraph.""" g = MagicMock() - g.intent_classifier = None - g.agent_registry = None if state is None: state = make_state() @@ -93,6 +93,14 @@ def make_graph( return g +def make_graph_ctx(graph: MagicMock | None = None) -> GraphContext: + """Build a GraphContext wrapping a mock graph.""" + g = graph or make_graph() + registry = MagicMock() + registry.list_agents = MagicMock(return_value=()) + return GraphContext(graph=g, registry=registry, intent_classifier=None) + + # --------------------------------------------------------------------------- # Fake database pool # --------------------------------------------------------------------------- @@ -148,6 +156,7 @@ def create_e2e_app( ) -> FastAPI: """Create a FastAPI app wired with mocked dependencies for E2E testing.""" g = graph or make_graph() + graph_ctx = make_graph_ctx(g) p = pool or FakePool() sm = SessionManager(session_ttl_seconds=session_ttl) im = InterruptManager(ttl_seconds=interrupt_ttl) @@ -157,7 +166,7 @@ def create_e2e_app( app.include_router(replay_router) app.include_router(analytics_router) - app.state.graph = g + app.state.graph_ctx = graph_ctx app.state.session_manager = sm app.state.interrupt_manager = im app.state.pool = p @@ -175,17 +184,16 @@ def create_e2e_app( try: while True: raw_data = await ws.receive_text() - await dispatch_message( - ws, - app.state.graph, - app.state.session_manager, - TokenUsageCallbackHandler(model_name="test-model"), - raw_data, + ws_ctx = WebSocketContext( + graph_ctx=app.state.graph_ctx, + session_manager=app.state.session_manager, + callback_handler=TokenUsageCallbackHandler(model_name="test-model"), interrupt_manager=app.state.interrupt_manager, analytics_recorder=app.state.analytics_recorder, conversation_tracker=app.state.conversation_tracker, pool=app.state.pool, ) + await dispatch_message(ws, ws_ctx, raw_data) except WebSocketDisconnect: pass diff --git a/backend/tests/integration/test_phase2_checkpoints.py b/backend/tests/integration/test_phase2_checkpoints.py index 3ed517c..0271b8c 100644 --- a/backend/tests/integration/test_phase2_checkpoints.py +++ b/backend/tests/integration/test_phase2_checkpoints.py @@ -20,10 +20,12 @@ import pytest from app.callbacks import TokenUsageCallbackHandler from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator +from app.graph_context import GraphContext from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier from app.interrupt_manager import InterruptManager from app.registry import AgentConfig, AgentRegistry from app.session_manager import SessionManager +from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates" @@ -128,10 +130,8 @@ class TestCheckpoint1OrderQueryRouting: mock_classifier.classify = AsyncMock(return_value=ClassificationResult( intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), )) - graph.intent_classifier = mock_classifier mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) - graph.agent_registry = mock_registry # Graph streams order_lookup response graph.astream = MagicMock(return_value=AsyncIterHelper([ @@ -140,14 +140,21 @@ class TestCheckpoint1OrderQueryRouting: ])) graph.aget_state = AsyncMock(return_value=_state()) + graph_ctx = GraphContext( + graph=graph, registry=mock_registry, intent_classifier=mock_classifier, + ) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() + ws_ctx = WebSocketContext( + graph_ctx=graph_ctx, session_manager=sm, + callback_handler=cb, interrupt_manager=im, + ) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) - await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"] assert any(m["tool"] == "get_order_status" for m in tool_msgs) @@ -201,25 +208,30 @@ class TestCheckpoint2MultiIntentSequential: IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), ), )) - graph.intent_classifier = mock_classifier mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) - graph.agent_registry = mock_registry graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.aget_state = AsyncMock(return_value=_state()) + graph_ctx = GraphContext( + graph=graph, registry=mock_registry, intent_classifier=mock_classifier, + ) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() + ws_ctx = WebSocketContext( + graph_ctx=graph_ctx, session_manager=sm, + callback_handler=cb, interrupt_manager=im, + ) raw = json.dumps({ "type": "message", "thread_id": "t1", "content": "取消订单 1042 并给我一个 10% 折扣", }) - await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) # Verify the graph was called with the routing hint in the message call_args = graph.astream.call_args @@ -267,21 +279,26 @@ class TestCheckpoint3AmbiguousClarification: "Could you please provide more details about what you need help with?" ), )) - graph.intent_classifier = mock_classifier mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) - graph.agent_registry = mock_registry graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.aget_state = AsyncMock(return_value=_state()) + graph_ctx = GraphContext( + graph=graph, registry=mock_registry, intent_classifier=mock_classifier, + ) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() + ws_ctx = WebSocketContext( + graph_ctx=graph_ctx, session_manager=sm, + callback_handler=cb, interrupt_manager=im, + ) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."}) - await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) clarifications = [m for m in ws.sent if m["type"] == "clarification"] assert len(clarifications) == 1 @@ -303,20 +320,26 @@ class TestCheckpoint4InterruptTTLAutoCancel: async def test_30min_expired_interrupt_auto_cancels(self) -> None: st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) graph = MagicMock() - graph.intent_classifier = None - graph.agent_registry = None graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.aget_state = AsyncMock(return_value=st) + mock_registry = MagicMock() + mock_registry.list_agents = MagicMock(return_value=()) + graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None) + sm = SessionManager() sm.touch("t1") im = InterruptManager(ttl_seconds=1800) # 30 minutes cb = TokenUsageCallbackHandler() ws = FakeWS() + ws_ctx = WebSocketContext( + graph_ctx=graph_ctx, session_manager=sm, + callback_handler=cb, interrupt_manager=im, + ) # Trigger interrupt raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"}) - await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) interrupts = [m for m in ws.sent if m["type"] == "interrupt"] assert len(interrupts) == 1 @@ -333,7 +356,7 @@ class TestCheckpoint4InterruptTTLAutoCancel: "thread_id": "t1", "approved": True, }) - await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) # Should get retry prompt, NOT resume the graph expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"] diff --git a/backend/tests/integration/test_routing.py b/backend/tests/integration/test_routing.py index 2ec8634..82bae24 100644 --- a/backend/tests/integration/test_routing.py +++ b/backend/tests/integration/test_routing.py @@ -18,10 +18,12 @@ from unittest.mock import AsyncMock, MagicMock import pytest from app.callbacks import TokenUsageCallbackHandler +from app.graph_context import GraphContext from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier from app.interrupt_manager import InterruptManager from app.registry import AgentConfig from app.session_manager import SessionManager +from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message # --------------------------------------------------------------------------- @@ -103,36 +105,45 @@ def _make_classifier(result: ClassificationResult) -> AsyncMock: return classifier -def _make_graph( +def _make_graph_and_ctx( classifier_result: ClassificationResult | None, chunks: list, state=None, -) -> MagicMock: - """Build a graph mock with optional intent classifier.""" +) -> tuple[MagicMock, GraphContext]: + """Build a graph mock and GraphContext with optional intent classifier.""" graph = MagicMock() - - if classifier_result is not None: - graph.intent_classifier = _make_classifier(classifier_result) - mock_registry = MagicMock() - mock_registry.list_agents = MagicMock(return_value=AGENTS) - graph.agent_registry = mock_registry - else: - graph.intent_classifier = None - graph.agent_registry = None - graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks))) graph.aget_state = AsyncMock(return_value=state or _state()) - return graph + + if classifier_result is not None: + classifier = _make_classifier(classifier_result) + mock_registry = MagicMock() + mock_registry.list_agents = MagicMock(return_value=AGENTS) + graph_ctx = GraphContext( + graph=graph, registry=mock_registry, intent_classifier=classifier, + ) + else: + mock_registry = MagicMock() + mock_registry.list_agents = MagicMock(return_value=()) + graph_ctx = GraphContext( + graph=graph, registry=mock_registry, intent_classifier=None, + ) + + return graph, graph_ctx -async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]: +async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]: sm = SessionManager() sm.touch(thread_id) im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() + ws_ctx = WebSocketContext( + graph_ctx=graph_ctx, session_manager=sm, + callback_handler=cb, interrupt_manager=im, + ) raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content}) - await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) return ws.sent @@ -151,12 +162,12 @@ class TestSingleIntentRouting: agent_name="order_lookup", confidence=0.95, reasoning="status query", ),), ) - graph = _make_graph(result, [ + graph, graph_ctx = _make_graph_and_ctx(result, [ _tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"), _chunk("Order 1042 is shipped.", "order_lookup"), ]) - msgs = await _dispatch(graph, "What is the status of order 1042?") + msgs = await _dispatch(graph_ctx, "What is the status of order 1042?") tools = [m for m in msgs if m["type"] == "tool_call"] assert len(tools) == 1 @@ -171,13 +182,13 @@ class TestSingleIntentRouting: result = ClassificationResult( intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),), ) - graph = _make_graph( + graph, graph_ctx = _make_graph_and_ctx( result, [_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")], state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}), ) - msgs = await _dispatch(graph, "Cancel order 1042") + msgs = await _dispatch(graph_ctx, "Cancel order 1042") tools = [m for m in msgs if m["type"] == "tool_call"] assert tools[0]["tool"] == "cancel_order" @@ -191,12 +202,12 @@ class TestSingleIntentRouting: result = ClassificationResult( intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),), ) - graph = _make_graph(result, [ + graph, graph_ctx = _make_graph_and_ctx(result, [ _tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"), _chunk("Here is your coupon: SAVE15-ABC12345", "discount"), ]) - msgs = await _dispatch(graph, "Give me a 15% coupon") + msgs = await _dispatch(graph_ctx, "Give me a 15% coupon") tools = [m for m in msgs if m["type"] == "tool_call"] assert tools[0]["tool"] == "generate_coupon" @@ -207,11 +218,11 @@ class TestSingleIntentRouting: result = ClassificationResult( intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),), ) - graph = _make_graph(result, [ + graph, graph_ctx = _make_graph_and_ctx(result, [ _chunk("I can help with order inquiries.", "fallback"), ]) - msgs = await _dispatch(graph, "What can you do?") + msgs = await _dispatch(graph_ctx, "What can you do?") tokens = [m for m in msgs if m["type"] == "token"] assert tokens[0]["agent"] == "fallback" @@ -233,7 +244,7 @@ class TestMultiIntentRouting: IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), ), ) - graph = _make_graph(result, [ + graph, graph_ctx = _make_graph_and_ctx(result, [ _tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"), _tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"), ]) @@ -243,13 +254,17 @@ class TestMultiIntentRouting: im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() + ws_ctx = WebSocketContext( + graph_ctx=graph_ctx, session_manager=sm, + callback_handler=cb, interrupt_manager=im, + ) raw = json.dumps({ "type": "message", "thread_id": "t1", "content": "取消订单 1042 并给我一个 10% 折扣", }) - await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) # Verify routing hint was injected call_args = graph.astream.call_args[0][0] @@ -269,16 +284,20 @@ class TestMultiIntentRouting: result = ClassificationResult( intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), ) - graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")]) + graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")]) sm = SessionManager() sm.touch("t1") im = InterruptManager() cb = TokenUsageCallbackHandler() ws = FakeWS() + ws_ctx = WebSocketContext( + graph_ctx=graph_ctx, session_manager=sm, + callback_handler=cb, interrupt_manager=im, + ) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) - await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) msg_content = graph.astream.call_args[0][0]["messages"][0].content assert "[System:" not in msg_content @@ -299,9 +318,9 @@ class TestAmbiguityRouting: is_ambiguous=True, clarification_question="Could you please clarify what you need?", ) - graph = _make_graph(result, []) + graph, graph_ctx = _make_graph_and_ctx(result, []) - msgs = await _dispatch(graph, "嗯...") + msgs = await _dispatch(graph_ctx, "嗯...") clarifications = [m for m in msgs if m["type"] == "clarification"] assert len(clarifications) == 1 @@ -339,12 +358,12 @@ class TestNoClassifierFallback: @pytest.mark.asyncio async def test_no_classifier_routes_via_supervisor(self) -> None: - graph = _make_graph( + graph, graph_ctx = _make_graph_and_ctx( classifier_result=None, chunks=[_chunk("Order 1042 is shipped.", "order_lookup")], ) - msgs = await _dispatch(graph, "What is order 1042 status?") + msgs = await _dispatch(graph_ctx, "What is order 1042 status?") tokens = [m for m in msgs if m["type"] == "token"] assert len(tokens) == 1 diff --git a/backend/tests/integration/test_websocket.py b/backend/tests/integration/test_websocket.py index ebff166..ffa32a6 100644 --- a/backend/tests/integration/test_websocket.py +++ b/backend/tests/integration/test_websocket.py @@ -15,8 +15,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.callbacks import TokenUsageCallbackHandler +from app.graph_context import GraphContext from app.interrupt_manager import InterruptManager from app.session_manager import SessionManager +from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message # --------------------------------------------------------------------------- @@ -81,8 +83,6 @@ def _graph( resume_chunks: list | None = None, ) -> MagicMock: g = MagicMock() - g.intent_classifier = None - g.agent_registry = None if st is None: st = _state() @@ -100,6 +100,13 @@ def _graph( return g +def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext: + g = graph or _graph() + registry = MagicMock() + registry.list_agents = MagicMock(return_value=()) + return GraphContext(graph=g, registry=registry, intent_classifier=None) + + def _setup( graph=None, session_ttl: int = 1800, @@ -109,23 +116,28 @@ def _setup( ): """Create test dependencies. Pre-touches session by default.""" g = graph or _graph() + graph_ctx = _make_graph_ctx(g) sm = SessionManager(session_ttl_seconds=session_ttl) im = InterruptManager(ttl_seconds=interrupt_ttl) cb = TokenUsageCallbackHandler() ws = FakeWS() + ws_ctx = WebSocketContext( + graph_ctx=graph_ctx, session_manager=sm, + callback_handler=cb, interrupt_manager=im, + ) if touch: sm.touch(thread_id) - return g, sm, im, cb, ws + return g, sm, im, cb, ws, ws_ctx -async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"): +async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"): raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content}) - await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) -async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True): +async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True): raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved}) - await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) # --------------------------------------------------------------------------- @@ -136,10 +148,10 @@ async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True): class TestWebSocketHappyPath: @pytest.mark.asyncio async def test_send_message_receives_tokens_and_complete(self) -> None: - g, sm, im, cb, ws = _setup( + g, sm, im, cb, ws, ws_ctx = _setup( graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")]) ) - await _send(ws, g, sm, im, cb, content="What is the status of order 1042?") + await _send(ws, ws_ctx, content="What is the status of order 1042?") tokens = [m for m in ws.sent if m["type"] == "token"] assert len(tokens) == 2 @@ -153,13 +165,13 @@ class TestWebSocketHappyPath: @pytest.mark.asyncio async def test_tool_call_streamed(self) -> None: - g, sm, im, cb, ws = _setup( + g, sm, im, cb, ws, ws_ctx = _setup( graph=_graph(chunks=[ _tool_chunk("get_order_status", {"order_id": "1042"}), _chunk("Order shipped."), ]) ) - await _send(ws, g, sm, im, cb, content="Check order 1042") + await _send(ws, ws_ctx, content="Check order 1042") tools = [m for m in ws.sent if m["type"] == "tool_call"] assert len(tools) == 1 @@ -168,9 +180,9 @@ class TestWebSocketHappyPath: @pytest.mark.asyncio async def test_multiple_messages_same_session(self) -> None: - g, sm, im, cb, ws = _setup() + g, sm, im, cb, ws, ws_ctx = _setup() for i in range(3): - await _send(ws, g, sm, im, cb, content=f"msg {i}") + await _send(ws, ws_ctx, content=f"msg {i}") completes = [m for m in ws.sent if m["type"] == "message_complete"] assert len(completes) == 3 @@ -183,10 +195,10 @@ class TestWebSocketInterruptApproval: st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) resume = [_chunk("Order 1042 cancelled.", "order_actions")] g = _graph(chunks=[], st=st_int, resume_chunks=resume) - g_, sm, im, cb, ws = _setup(graph=g) + g_, sm, im, cb, ws, ws_ctx = _setup(graph=g) # Send message -> triggers interrupt - await _send(ws, g_, sm, im, cb, content="Cancel order 1042") + await _send(ws, ws_ctx, content="Cancel order 1042") interrupts = [m for m in ws.sent if m["type"] == "interrupt"] assert len(interrupts) == 1 @@ -196,7 +208,7 @@ class TestWebSocketInterruptApproval: # Approve ws.sent.clear() - await _respond(ws, g_, sm, im, cb, approved=True) + await _respond(ws, ws_ctx, approved=True) tokens = [m for m in ws.sent if m["type"] == "token"] assert len(tokens) == 1 @@ -211,12 +223,12 @@ class TestWebSocketInterruptApproval: st_int = _state(interrupt=True) resume = [_chunk("Order remains active.", "order_actions")] g = _graph(chunks=[], st=st_int, resume_chunks=resume) - g_, sm, im, cb, ws = _setup(graph=g) + g_, sm, im, cb, ws, ws_ctx = _setup(graph=g) - await _send(ws, g_, sm, im, cb, content="Cancel order 1042") + await _send(ws, ws_ctx, content="Cancel order 1042") ws.sent.clear() - await _respond(ws, g_, sm, im, cb, approved=False) + await _respond(ws, ws_ctx, approved=False) tokens = [m for m in ws.sent if m["type"] == "token"] assert "remains active" in tokens[0]["content"] @@ -226,28 +238,28 @@ class TestWebSocketInterruptApproval: class TestWebSocketSessionTTL: @pytest.mark.asyncio async def test_expired_session_returns_error(self) -> None: - g, sm, im, cb, ws = _setup(session_ttl=0) + g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0) # Session was touched in _setup, but TTL is 0 so it's already expired - await _send(ws, g, sm, im, cb, content="hello") + await _send(ws, ws_ctx, content="hello") assert ws.sent[0]["type"] == "error" assert "expired" in ws.sent[0]["message"].lower() @pytest.mark.asyncio async def test_new_session_not_expired(self) -> None: - g, sm, im, cb, ws = _setup(session_ttl=3600) - await _send(ws, g, sm, im, cb, content="hello") + g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600) + await _send(ws, ws_ctx, content="hello") completes = [m for m in ws.sent if m["type"] == "message_complete"] assert len(completes) == 1 @pytest.mark.asyncio async def test_sliding_window_resets_on_message(self) -> None: - g, sm, im, cb, ws = _setup(session_ttl=3600) + g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600) - await _send(ws, g, sm, im, cb, content="hello") + await _send(ws, ws_ctx, content="hello") first_activity = sm.get_state("t1").last_activity time.sleep(0.01) - await _send(ws, g, sm, im, cb, content="hello again") + await _send(ws, ws_ctx, content="hello again") second_activity = sm.get_state("t1").last_activity assert second_activity > first_activity @@ -256,9 +268,9 @@ class TestWebSocketSessionTTL: async def test_interrupt_extends_session_ttl(self) -> None: st_int = _state(interrupt=True) g = _graph(chunks=[], st=st_int) - g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600) + g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600) - await _send(ws, g_, sm, im, cb, content="cancel order") + await _send(ws, ws_ctx, content="cancel order") state = sm.get_state("t1") assert state is not None @@ -270,53 +282,53 @@ class TestWebSocketSessionTTL: class TestWebSocketValidation: @pytest.mark.asyncio async def test_invalid_json(self) -> None: - g, sm, im, cb, ws = _setup() - await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im) + g, sm, im, cb, ws, ws_ctx = _setup() + await dispatch_message(ws, ws_ctx, "not json") assert ws.sent[0]["type"] == "error" assert "Invalid JSON" in ws.sent[0]["message"] @pytest.mark.asyncio async def test_missing_thread_id(self) -> None: - g, sm, im, cb, ws = _setup() + g, sm, im, cb, ws, ws_ctx = _setup() raw = json.dumps({"type": "message", "content": "hi"}) - await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) assert ws.sent[0]["type"] == "error" assert "thread_id" in ws.sent[0]["message"] @pytest.mark.asyncio async def test_invalid_thread_id_format(self) -> None: - g, sm, im, cb, ws = _setup() + g, sm, im, cb, ws, ws_ctx = _setup() raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"}) - await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) assert ws.sent[0]["type"] == "error" @pytest.mark.asyncio async def test_missing_content(self) -> None: - g, sm, im, cb, ws = _setup() + g, sm, im, cb, ws, ws_ctx = _setup() raw = json.dumps({"type": "message", "thread_id": "t1"}) - await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) assert ws.sent[0]["type"] == "error" @pytest.mark.asyncio async def test_unknown_message_type(self) -> None: - g, sm, im, cb, ws = _setup() + g, sm, im, cb, ws, ws_ctx = _setup() raw = json.dumps({"type": "foobar", "thread_id": "t1"}) - await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) assert ws.sent[0]["type"] == "error" assert "Unknown" in ws.sent[0]["message"] @pytest.mark.asyncio async def test_message_too_large(self) -> None: - g, sm, im, cb, ws = _setup() - await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im) + g, sm, im, cb, ws, ws_ctx = _setup() + await dispatch_message(ws, ws_ctx, "x" * 40_000) assert ws.sent[0]["type"] == "error" assert "too large" in ws.sent[0]["message"].lower() @pytest.mark.asyncio async def test_content_too_long(self) -> None: - g, sm, im, cb, ws = _setup() + g, sm, im, cb, ws, ws_ctx = _setup() raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001}) - await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, raw) assert ws.sent[0]["type"] == "error" assert "too long" in ws.sent[0]["message"].lower() @@ -327,10 +339,10 @@ class TestWebSocketInterruptTTL: async def test_expired_interrupt_sends_retry_prompt(self) -> None: st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) g = _graph(chunks=[], st=st_int) - g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5) + g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5) # Trigger interrupt - await _send(ws, g_, sm, im, cb, content="Cancel order 1042") + await _send(ws, ws_ctx, content="Cancel order 1042") interrupts = [m for m in ws.sent if m["type"] == "interrupt"] assert len(interrupts) == 1 @@ -341,7 +353,7 @@ class TestWebSocketInterruptTTL: with patch("app.interrupt_manager.time") as mock_time: mock_time.time.return_value = record.created_at + 10 - await _respond(ws, g_, sm, im, cb, approved=True) + await _respond(ws, ws_ctx, approved=True) assert ws.sent[0]["type"] == "interrupt_expired" assert "cancel_order" in ws.sent[0]["message"] diff --git a/backend/tests/unit/test_db.py b/backend/tests/unit/test_db.py index caf34ce..edd1653 100644 --- a/backend/tests/unit/test_db.py +++ b/backend/tests/unit/test_db.py @@ -55,7 +55,7 @@ class TestDbModule: from app.db import setup_app_tables await setup_app_tables(mock_pool) - assert mock_conn.execute.await_count == 4 + assert mock_conn.execute.await_count == 5 def test_ddl_statements_valid(self) -> None: assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL diff --git a/backend/tests/unit/test_db_phase4.py b/backend/tests/unit/test_db_phase4.py index f49f97c..7d42627 100644 --- a/backend/tests/unit/test_db_phase4.py +++ b/backend/tests/unit/test_db_phase4.py @@ -51,5 +51,5 @@ class TestAnalyticsEventsDDL: from app.db import setup_app_tables await setup_app_tables(mock_pool) - # Now expects 4 statements: conversations, interrupts, analytics_events, migrations - assert mock_conn.execute.await_count == 4 + # Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations + assert mock_conn.execute.await_count == 5 diff --git a/backend/tests/unit/test_edge_cases.py b/backend/tests/unit/test_edge_cases.py index 29aa103..2a3b68e 100644 --- a/backend/tests/unit/test_edge_cases.py +++ b/backend/tests/unit/test_edge_cases.py @@ -8,7 +8,9 @@ from unittest.mock import AsyncMock, MagicMock import pytest from app.callbacks import TokenUsageCallbackHandler +from app.graph_context import GraphContext from app.session_manager import SessionManager +from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message pytestmark = pytest.mark.unit @@ -20,7 +22,7 @@ def _make_ws() -> AsyncMock: return ws -def _make_graph() -> AsyncMock: +def _make_graph() -> MagicMock: graph = AsyncMock() class AsyncIterHelper: @@ -34,23 +36,32 @@ def _make_graph() -> AsyncMock: state = MagicMock() state.tasks = () graph.aget_state = AsyncMock(return_value=state) - graph.intent_classifier = None - graph.agent_registry = None return graph +def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext: + graph = _make_graph() + registry = MagicMock() + registry.list_agents = MagicMock(return_value=()) + graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None) + return WebSocketContext( + graph_ctx=graph_ctx, + session_manager=sm or SessionManager(), + callback_handler=TokenUsageCallbackHandler(), + ) + + @pytest.mark.unit class TestEmptyMessageHandling: @pytest.mark.asyncio async def test_empty_message_content_returns_error(self) -> None: ws = _make_ws() - graph = _make_graph() sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx(sm=sm) sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""}) - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @@ -60,13 +71,12 @@ class TestEmptyMessageHandling: @pytest.mark.asyncio async def test_whitespace_only_message_treated_as_empty(self) -> None: ws = _make_ws() - graph = _make_graph() sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx(sm=sm) sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "}) - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @@ -77,14 +87,13 @@ class TestOversizedMessageHandling: @pytest.mark.asyncio async def test_content_over_10000_chars_returns_error(self) -> None: ws = _make_ws() - graph = _make_graph() sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx(sm=sm) sm.touch("t1") content = "x" * 10001 msg = json.dumps({"type": "message", "thread_id": "t1", "content": content}) - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @@ -93,14 +102,13 @@ class TestOversizedMessageHandling: @pytest.mark.asyncio async def test_content_exactly_10000_chars_is_accepted(self) -> None: ws = _make_ws() - graph = _make_graph() sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx(sm=sm) sm.touch("t1") content = "x" * 10000 msg = json.dumps({"type": "message", "thread_id": "t1", "content": content}) - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) last_call = ws.send_json.call_args[0][0] # Should be processed, not an error about length @@ -110,12 +118,10 @@ class TestOversizedMessageHandling: @pytest.mark.asyncio async def test_raw_message_over_32kb_returns_error(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() large_msg = "x" * 40_000 - await dispatch_message(ws, graph, sm, cb, large_msg) + await dispatch_message(ws, ws_ctx, large_msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @@ -127,11 +133,9 @@ class TestInvalidJsonHandling: @pytest.mark.asyncio async def test_invalid_json_returns_error(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() - await dispatch_message(ws, graph, sm, cb, "not valid json {{") + await dispatch_message(ws, ws_ctx, "not valid json {{") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @@ -140,11 +144,9 @@ class TestInvalidJsonHandling: @pytest.mark.asyncio async def test_empty_string_returns_json_error(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() - await dispatch_message(ws, graph, sm, cb, "") + await dispatch_message(ws, ws_ctx, "") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @@ -152,11 +154,9 @@ class TestInvalidJsonHandling: @pytest.mark.asyncio async def test_json_array_not_object_returns_error(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() - await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]') + await dispatch_message(ws, ws_ctx, '["not", "an", "object"]') call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @@ -167,17 +167,15 @@ class TestRateLimiting: @pytest.mark.asyncio async def test_rapid_fire_messages_rate_limited(self) -> None: ws = _make_ws() - _make_graph() # ensure graph factory works, not needed directly sm = SessionManager() - cb = TokenUsageCallbackHandler() sm.touch("t1") # Simulate 11 rapid messages (exceeds 10 per 10 seconds limit) rate_limit_triggered = False for i in range(11): - graph2 = _make_graph() # fresh graph each time - await dispatch_message(ws, graph2, sm, cb, json.dumps({ + ws_ctx = _make_ws_ctx(sm=sm) + await dispatch_message(ws, ws_ctx, json.dumps({ "type": "message", "thread_id": "t1", "content": f"message {i}", @@ -193,19 +191,18 @@ class TestRateLimiting: async def test_different_threads_have_separate_rate_limits(self) -> None: ws = _make_ws() sm = SessionManager() - cb = TokenUsageCallbackHandler() sm.touch("t1") sm.touch("t2") # Send 5 messages on t1 and 5 on t2 -- neither should be rate limited for i in range(5): - graph1 = _make_graph() - graph2 = _make_graph() - await dispatch_message(ws, graph1, sm, cb, json.dumps({ + ws_ctx1 = _make_ws_ctx(sm=sm) + ws_ctx2 = _make_ws_ctx(sm=sm) + await dispatch_message(ws, ws_ctx1, json.dumps({ "type": "message", "thread_id": "t1", "content": f"msg {i}", })) - await dispatch_message(ws, graph2, sm, cb, json.dumps({ + await dispatch_message(ws, ws_ctx2, json.dumps({ "type": "message", "thread_id": "t2", "content": f"msg {i}", })) diff --git a/backend/tests/unit/test_graph.py b/backend/tests/unit/test_graph.py index 4556dfd..c52a65a 100644 --- a/backend/tests/unit/test_graph.py +++ b/backend/tests/unit/test_graph.py @@ -8,7 +8,8 @@ from unittest.mock import AsyncMock, MagicMock import pytest from langgraph.checkpoint.memory import InMemorySaver -from app.graph import build_agent_nodes, build_graph, classify_intent +from app.graph import build_agent_nodes, build_graph +from app.graph_context import GraphContext from app.intent import ClassificationResult, IntentTarget if TYPE_CHECKING: @@ -37,8 +38,9 @@ class TestBuildGraph: mock_llm.with_structured_output = MagicMock(return_value=mock_llm) checkpointer = InMemorySaver() - graph = build_graph(sample_registry, mock_llm, checkpointer) - assert graph is not None + graph_ctx = build_graph(sample_registry, mock_llm, checkpointer) + assert graph_ctx is not None + assert graph_ctx.graph is not None def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None: mock_llm = MagicMock() @@ -47,11 +49,11 @@ class TestBuildGraph: checkpointer = InMemorySaver() mock_classifier = MagicMock() - graph = build_graph( + graph_ctx = build_graph( sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier ) - assert graph.intent_classifier is mock_classifier - assert graph.agent_registry is sample_registry + assert graph_ctx.intent_classifier is mock_classifier + assert graph_ctx.registry is sample_registry def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None: mock_llm = MagicMock() @@ -59,17 +61,18 @@ class TestBuildGraph: mock_llm.with_structured_output = MagicMock(return_value=mock_llm) checkpointer = InMemorySaver() - graph = build_graph(sample_registry, mock_llm, checkpointer) - assert graph.intent_classifier is None + graph_ctx = build_graph(sample_registry, mock_llm, checkpointer) + assert graph_ctx.intent_classifier is None @pytest.mark.unit class TestClassifyIntent: @pytest.mark.asyncio async def test_returns_none_without_classifier(self) -> None: - graph = MagicMock() - graph.intent_classifier = None - result = await classify_intent(graph, "hello") + mock_registry = MagicMock() + mock_registry.list_agents = MagicMock(return_value=()) + graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None) + result = await graph_ctx.classify_intent("hello") assert result is None @pytest.mark.asyncio @@ -80,11 +83,12 @@ class TestClassifyIntent: mock_classifier = AsyncMock() mock_classifier.classify = AsyncMock(return_value=expected) - graph = MagicMock() - graph.intent_classifier = mock_classifier - graph.agent_registry = MagicMock() - graph.agent_registry.list_agents = MagicMock(return_value=()) + mock_registry = MagicMock() + mock_registry.list_agents = MagicMock(return_value=()) + graph_ctx = GraphContext( + graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier, + ) - result = await classify_intent(graph, "check order") + result = await graph_ctx.classify_intent("check order") assert result is not None assert result.intents[0].agent_name == "order_lookup" diff --git a/backend/tests/unit/test_main.py b/backend/tests/unit/test_main.py index f448f22..8a02f53 100644 --- a/backend/tests/unit/test_main.py +++ b/backend/tests/unit/test_main.py @@ -13,7 +13,7 @@ class TestMainModule: assert app.title == "Smart Support" def test_app_version(self) -> None: - assert app.version == "0.5.0" + assert app.version == "0.6.0" def test_agents_yaml_path_exists(self) -> None: assert AGENTS_YAML.name == "agents.yaml" @@ -39,4 +39,4 @@ class TestMainModule: assert "/api/health" in routes def test_app_version_is_0_5_0(self) -> None: - assert app.version == "0.5.0" + assert app.version == "0.6.0" diff --git a/backend/tests/unit/test_ws_handler.py b/backend/tests/unit/test_ws_handler.py index 446589d..e8460d3 100644 --- a/backend/tests/unit/test_ws_handler.py +++ b/backend/tests/unit/test_ws_handler.py @@ -8,8 +8,10 @@ from unittest.mock import AsyncMock, MagicMock import pytest from app.callbacks import TokenUsageCallbackHandler +from app.graph_context import GraphContext from app.interrupt_manager import InterruptManager from app.session_manager import SessionManager +from app.ws_context import WebSocketContext from app.ws_handler import ( _extract_interrupt, _has_interrupt, @@ -25,18 +27,42 @@ def _make_ws() -> AsyncMock: return ws -def _make_graph() -> AsyncMock: +def _make_graph() -> MagicMock: graph = AsyncMock() graph.astream = MagicMock(return_value=AsyncIterHelper([])) state = MagicMock() state.tasks = () graph.aget_state = AsyncMock(return_value=state) - # Phase 2: graph needs intent_classifier and agent_registry attrs - graph.intent_classifier = None - graph.agent_registry = None return graph +def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext: + g = graph or _make_graph() + registry = MagicMock() + registry.list_agents = MagicMock(return_value=()) + return GraphContext(graph=g, registry=registry, intent_classifier=None) + + +def _make_ws_ctx( + graph_ctx: GraphContext | None = None, + sm: SessionManager | None = None, + cb: TokenUsageCallbackHandler | None = None, + interrupt_manager: InterruptManager | None = None, + analytics_recorder=None, + conversation_tracker=None, + pool=None, +) -> WebSocketContext: + return WebSocketContext( + graph_ctx=graph_ctx or _make_graph_ctx(), + session_manager=sm or SessionManager(), + callback_handler=cb or TokenUsageCallbackHandler(), + interrupt_manager=interrupt_manager, + analytics_recorder=analytics_recorder, + conversation_tracker=conversation_tracker, + pool=pool, + ) + + class AsyncIterHelper: """Helper to make a list behave as an async iterator.""" @@ -57,11 +83,9 @@ class TestDispatchMessage: @pytest.mark.asyncio async def test_invalid_json(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() - await dispatch_message(ws, graph, sm, cb, "not json") + await dispatch_message(ws, ws_ctx, "not json") ws.send_json.assert_awaited_once() call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @@ -70,12 +94,10 @@ class TestDispatchMessage: @pytest.mark.asyncio async def test_missing_thread_id(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() msg = json.dumps({"type": "message", "content": "hello"}) - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "thread_id" in call_data["message"] @@ -83,24 +105,20 @@ class TestDispatchMessage: @pytest.mark.asyncio async def test_missing_content(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() msg = json.dumps({"type": "message", "thread_id": "t1"}) - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @pytest.mark.asyncio async def test_unknown_message_type(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() msg = json.dumps({"type": "unknown", "thread_id": "t1"}) - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "Unknown" in call_data["message"] @@ -108,12 +126,10 @@ class TestDispatchMessage: @pytest.mark.asyncio async def test_message_too_large(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() large_msg = "x" * 40_000 - await dispatch_message(ws, graph, sm, cb, large_msg) + await dispatch_message(ws, ws_ctx, large_msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "too large" in call_data["message"].lower() @@ -121,12 +137,10 @@ class TestDispatchMessage: @pytest.mark.asyncio async def test_invalid_thread_id_format(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"}) - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "thread_id" in call_data["message"].lower() @@ -134,12 +148,10 @@ class TestDispatchMessage: @pytest.mark.asyncio async def test_content_too_long(self) -> None: ws = _make_ws() - graph = _make_graph() - sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx() msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001}) - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "too long" in call_data["message"].lower() @@ -147,14 +159,13 @@ class TestDispatchMessage: @pytest.mark.asyncio async def test_dispatch_with_interrupt_manager(self) -> None: ws = _make_ws() - graph = _make_graph() sm = SessionManager() - cb = TokenUsageCallbackHandler() im = InterruptManager() + ws_ctx = _make_ws_ctx(sm=sm, interrupt_manager=im) sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) - await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im) + await dispatch_message(ws, ws_ctx, msg) last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @@ -164,14 +175,14 @@ class TestHandleUserMessage: @pytest.mark.asyncio async def test_expired_session(self) -> None: ws = _make_ws() - graph = _make_graph() + graph_ctx = _make_graph_ctx() sm = SessionManager(session_ttl_seconds=0) cb = TokenUsageCallbackHandler() # First call creates the session (TTL=0) - await handle_user_message(ws, graph, sm, cb, "t1", "hello") + await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello") # Second call finds it expired - await handle_user_message(ws, graph, sm, cb, "t1", "hello again") + await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello again") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "expired" in call_data["message"].lower() @@ -179,12 +190,12 @@ class TestHandleUserMessage: @pytest.mark.asyncio async def test_successful_message(self) -> None: ws = _make_ws() - graph = _make_graph() + graph_ctx = _make_graph_ctx() sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") - await handle_user_message(ws, graph, sm, cb, "t1", "hello") + await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello") last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @@ -193,13 +204,12 @@ class TestHandleUserMessage: ws = _make_ws() graph = AsyncMock() graph.astream = MagicMock(side_effect=RuntimeError("boom")) - graph.intent_classifier = None - graph.agent_registry = None + graph_ctx = _make_graph_ctx(graph=graph) sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") - await handle_user_message(ws, graph, sm, cb, "t1", "hello") + await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @@ -207,8 +217,6 @@ class TestHandleUserMessage: async def test_interrupt_registered_with_manager(self) -> None: ws = _make_ws() graph = AsyncMock() - graph.intent_classifier = None - graph.agent_registry = None graph.astream = MagicMock(return_value=AsyncIterHelper([])) # Simulate interrupt in state @@ -220,13 +228,14 @@ class TestHandleUserMessage: state.tasks = (task,) graph.aget_state = AsyncMock(return_value=state) + graph_ctx = _make_graph_ctx(graph=graph) sm = SessionManager() cb = TokenUsageCallbackHandler() im = InterruptManager() sm.touch("t1") await handle_user_message( - ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im, + ws, graph_ctx, sm, cb, "t1", "cancel order 1042", interrupt_manager=im, ) # Interrupt should be registered @@ -257,16 +266,17 @@ class TestHandleUserMessage: clarification_question="What do you mean?", ) ) - graph.intent_classifier = mock_classifier mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) - graph.agent_registry = mock_registry + graph_ctx = GraphContext( + graph=graph, registry=mock_registry, intent_classifier=mock_classifier, + ) sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") - await handle_user_message(ws, graph, sm, cb, "t1", "hmm") + await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hmm") calls = [c[0][0] for c in ws.send_json.call_args_list] clarification_msgs = [c for c in calls if c.get("type") == "clarification"] @@ -279,13 +289,13 @@ class TestHandleInterruptResponse: @pytest.mark.asyncio async def test_approved_interrupt(self) -> None: ws = _make_ws() - graph = _make_graph() + graph_ctx = _make_graph_ctx() sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") sm.extend_for_interrupt("t1") - await handle_interrupt_response(ws, graph, sm, cb, "t1", True) + await handle_interrupt_response(ws, graph_ctx, sm, cb, "t1", True) last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @@ -294,7 +304,7 @@ class TestHandleInterruptResponse: from unittest.mock import patch ws = _make_ws() - graph = _make_graph() + graph_ctx = _make_graph_ctx() sm = SessionManager() cb = TokenUsageCallbackHandler() im = InterruptManager(ttl_seconds=5) @@ -307,7 +317,7 @@ class TestHandleInterruptResponse: with patch("app.interrupt_manager.time") as mock_time: mock_time.time.return_value = im._interrupts["t1"].created_at + 10 await handle_interrupt_response( - ws, graph, sm, cb, "t1", True, interrupt_manager=im + ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im ) call_data = ws.send_json.call_args[0][0] @@ -317,7 +327,7 @@ class TestHandleInterruptResponse: @pytest.mark.asyncio async def test_valid_interrupt_resolves(self) -> None: ws = _make_ws() - graph = _make_graph() + graph_ctx = _make_graph_ctx() sm = SessionManager() cb = TokenUsageCallbackHandler() im = InterruptManager(ttl_seconds=1800) @@ -327,7 +337,7 @@ class TestHandleInterruptResponse: im.register("t1", "cancel_order", {}) await handle_interrupt_response( - ws, graph, sm, cb, "t1", True, interrupt_manager=im + ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im ) # Interrupt should be resolved @@ -374,19 +384,14 @@ class TestDispatchMessageWithTracking: @pytest.mark.asyncio async def test_conversation_tracker_called_on_message(self) -> None: ws = _make_ws() - graph = _make_graph() sm = SessionManager() - cb = TokenUsageCallbackHandler() tracker = AsyncMock() pool = MagicMock() + ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool) sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) - await dispatch_message( - ws, graph, sm, cb, msg, - conversation_tracker=tracker, - pool=pool, - ) + await dispatch_message(ws, ws_ctx, msg) tracker.ensure_conversation.assert_awaited_once_with(pool, "t1") tracker.record_turn.assert_awaited_once() @@ -394,53 +399,42 @@ class TestDispatchMessageWithTracking: @pytest.mark.asyncio async def test_analytics_recorder_called_on_message(self) -> None: ws = _make_ws() - graph = _make_graph() sm = SessionManager() - cb = TokenUsageCallbackHandler() recorder = AsyncMock() pool = MagicMock() + ws_ctx = _make_ws_ctx(sm=sm, analytics_recorder=recorder, pool=pool) sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) - await dispatch_message( - ws, graph, sm, cb, msg, - analytics_recorder=recorder, - pool=pool, - ) + await dispatch_message(ws, ws_ctx, msg) recorder.record.assert_awaited_once() @pytest.mark.asyncio async def test_tracker_failure_does_not_break_chat(self) -> None: ws = _make_ws() - graph = _make_graph() sm = SessionManager() - cb = TokenUsageCallbackHandler() tracker = AsyncMock() tracker.ensure_conversation.side_effect = RuntimeError("DB down") pool = MagicMock() + ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool) sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) # Should not raise despite tracker failure - await dispatch_message( - ws, graph, sm, cb, msg, - conversation_tracker=tracker, - pool=pool, - ) + await dispatch_message(ws, ws_ctx, msg) last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @pytest.mark.asyncio async def test_no_tracker_no_error(self) -> None: ws = _make_ws() - graph = _make_graph() sm = SessionManager() - cb = TokenUsageCallbackHandler() + ws_ctx = _make_ws_ctx(sm=sm) sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) # No tracker or recorder passed -- should work fine - await dispatch_message(ws, graph, sm, cb, msg) + await dispatch_message(ws, ws_ctx, msg) last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" diff --git a/frontend/src/api.ts b/frontend/src/api.ts index 4a146fa..b49263d 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -101,3 +101,74 @@ export async function fetchReplay( export async function fetchAnalytics(range = "7d"): Promise { return apiFetch(`/api/analytics?range=${range}`); } + +// -- OpenAPI import -- + +export interface ImportJobResponse { + job_id: string; + status: string; + spec_url: string; + total_endpoints: number; + classified_count: number; + error_message: string | null; + generated_tools_count?: number; +} + +export interface EndpointClassification { + index: number; + access_type: string; + needs_interrupt: boolean; + agent_group: string; + confidence: number; + customer_params: string[]; + endpoint: { + path: string; + method: string; + operation_id: string; + summary: string; + description: string; + }; +} + +async function apiPost(path: string, body: unknown): Promise { + const res = await fetch(`${API_BASE}${path}`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); + if (!res.ok) { + throw new Error(`API error ${res.status}: ${res.statusText}`); + } + return res.json(); +} + +export async function startImport(url: string): Promise { + return apiPost("/api/openapi/import", { url }); +} + +export async function fetchImportJob(jobId: string): Promise { + const res = await fetch(`${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}`); + if (!res.ok) { + throw new Error(`API error ${res.status}: ${res.statusText}`); + } + return res.json(); +} + +export async function fetchClassifications( + jobId: string +): Promise { + const res = await fetch( + `${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}/classifications` + ); + if (!res.ok) { + throw new Error(`API error ${res.status}: ${res.statusText}`); + } + return res.json(); +} + +export async function approveJob(jobId: string): Promise { + return apiPost( + `/api/openapi/jobs/${encodeURIComponent(jobId)}/approve`, + {} + ); +} diff --git a/frontend/src/pages/ChatPage.tsx b/frontend/src/pages/ChatPage.tsx index 72617e8..e2cbacf 100644 --- a/frontend/src/pages/ChatPage.tsx +++ b/frontend/src/pages/ChatPage.tsx @@ -13,10 +13,8 @@ import type { ToolAction, } from "../types"; -let msgCounter = 0; function nextId(): string { - msgCounter += 1; - return `msg-${msgCounter}`; + return crypto.randomUUID(); } export function ChatPage() { @@ -68,6 +66,48 @@ export function ChatPage() { setIsWaiting(false); break; } + case "clarification": { + setMessages((prev) => [ + ...prev, + { + id: nextId(), + sender: "agent", + agent: "System", + content: msg.message, + timestamp: Date.now(), + }, + ]); + setIsWaiting(false); + break; + } + case "interrupt_expired": { + setCurrentInterrupt(null); + setMessages((prev) => [ + ...prev, + { + id: nextId(), + sender: "agent", + agent: "System", + content: msg.message, + timestamp: Date.now(), + }, + ]); + setIsWaiting(false); + break; + } + case "tool_result": { + setToolActions((prev) => { + const last = prev[prev.length - 1]; + if (last && last.tool === msg.tool && last.agent === msg.agent) { + return [ + ...prev.slice(0, -1), + { ...last, result: msg.result }, + ]; + } + return prev; + }); + break; + } case "message_complete": { setMessages((prev) => { const last = prev[prev.length - 1]; diff --git a/frontend/src/pages/ReplayPage.tsx b/frontend/src/pages/ReplayPage.tsx index b3615bd..0555354 100644 --- a/frontend/src/pages/ReplayPage.tsx +++ b/frontend/src/pages/ReplayPage.tsx @@ -81,7 +81,7 @@ export function ReplayPage() { {/* Timeline */}
- +
)} diff --git a/frontend/src/pages/ReviewPage.tsx b/frontend/src/pages/ReviewPage.tsx index a04c404..894a963 100644 --- a/frontend/src/pages/ReviewPage.tsx +++ b/frontend/src/pages/ReviewPage.tsx @@ -1,12 +1,14 @@ import { useEffect, useRef, useState } from "react"; +import { + approveJob, + fetchClassifications, + fetchImportJob, + startImport, + type EndpointClassification, + type ImportJobResponse, +} from "../api"; -interface ImportJob { - job_id: string; - status: "pending" | "processing" | "done" | "failed"; - error_message?: string; -} - -interface EndpointClassification { +interface FlatClassification { path: string; method: string; summary: string; @@ -14,48 +16,23 @@ interface EndpointClassification { agent_group: string; } +function flattenClassification(c: EndpointClassification): FlatClassification { + return { + path: c.endpoint?.path ?? "", + method: c.endpoint?.method ?? "", + summary: c.endpoint?.summary ?? "", + access_type: c.access_type ?? "read", + agent_group: c.agent_group ?? "Unassigned", + }; +} + export function ReviewPage() { const [url, setUrl] = useState(""); - const [job, setJob] = useState(null); + const [job, setJob] = useState(null); const [submitting, setSubmitting] = useState(false); const [submitError, setSubmitError] = useState(null); - const [classifications, setClassifications] = useState([ - { - path: "/api/v1/orders/{order_id}/cancel", - method: "post", - summary: "Cancel an active Shopify order", - access_type: "write", - agent_group: "Order Specialist", - }, - { - path: "/api/v1/orders/{order_id}", - method: "get", - summary: "Retrieve detailed information about an order", - access_type: "read", - agent_group: "Order Specialist", - }, - { - path: "/api/v1/payments/{charge_id}/refund", - method: "post", - summary: "Issue a full or partial refund for a charge", - access_type: "write", - agent_group: "Billing Assistant", - }, - { - path: "/api/v1/customers/{email}/discounts", - method: "post", - summary: "Apply a loyalty discount to a customer account", - access_type: "write", - agent_group: "Billing Assistant", - }, - { - path: "/api/v1/inventory/check", - method: "get", - summary: "Query realtime stock levels across warehouses", - access_type: "read", - agent_group: "Unassigned", - } - ]); + const [approveStatus, setApproveStatus] = useState(null); + const [classifications, setClassifications] = useState([]); const pollRef = useRef | null>(null); useEffect(() => { @@ -65,25 +42,13 @@ export function ReviewPage() { }, []); function pollJob(jobId: string) { - fetch(`/api/openapi/jobs/${encodeURIComponent(jobId)}`) - .then((r) => r.json()) - .then((data) => { - const j: ImportJob = data.data ?? data; + fetchImportJob(jobId) + .then((j) => { setJob(j); if (j.status === "done") { - return fetch(`/api/openapi/jobs/${encodeURIComponent(jobId)}/classifications`) - .then((r) => r.json()) - .then((clfs: EndpointClassification[]) => { - setClassifications( - clfs.map((c: any) => ({ - path: c.endpoint?.path ?? c.path ?? "", - method: c.endpoint?.method ?? c.method ?? "", - summary: c.endpoint?.summary ?? c.summary ?? "", - access_type: c.access_type ?? "read", - agent_group: c.agent_group ?? "Unassigned", - })) - ); - }); + return fetchClassifications(jobId).then((clfs) => { + setClassifications(clfs.map(flattenClassification)); + }); } else if (j.status === "failed") { return; } else { @@ -100,17 +65,12 @@ export function ReviewPage() { if (!url.trim()) return; setSubmitting(true); setSubmitError(null); + setApproveStatus(null); setJob(null); setClassifications([]); - fetch("/api/openapi/import", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ url }), - }) - .then((r) => r.json()) - .then((data) => { - const j: ImportJob = data.data ?? data; + startImport(url) + .then((j) => { setJob(j); if (j.job_id) pollJob(j.job_id); }) @@ -120,7 +80,7 @@ export function ReviewPage() { function handleFieldChange( idx: number, - field: keyof EndpointClassification, + field: keyof FlatClassification, value: string ) { setClassifications((prev) => @@ -130,21 +90,26 @@ export function ReviewPage() { function handleApprove() { if (!job?.job_id) return; - fetch(`/api/openapi/jobs/${encodeURIComponent(job.job_id)}/approve`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ endpoints: classifications }), - }).then(() => { - alert("Approved and saved."); - }); + setApproveStatus(null); + approveJob(job.job_id) + .then((result) => { + setJob(result); + setApproveStatus( + `Configuration saved. ${result.generated_tools_count ?? 0} tools generated.` + ); + }) + .catch((err: Error) => setApproveStatus(`Error: ${err.message}`)); } - const groupedByAgent = classifications.reduce((acc, c, idx) => { - const group = c.agent_group || "Unassigned"; - if (!acc[group]) acc[group] = []; - acc[group].push({ ...c, originalIdx: idx }); - return acc; - }, {} as Record); + const groupedByAgent = classifications.reduce( + (acc, c, idx) => { + const group = c.agent_group || "Unassigned"; + if (!acc[group]) acc[group] = []; + acc[group].push({ ...c, originalIdx: idx }); + return acc; + }, + {} as Record + ); return (
@@ -167,35 +132,105 @@ export function ReviewPage() { - {submitError &&
Error: {submitError}
} + {submitError && ( +
+ Error: {submitError} +
+ )} {job && ( -
+
Job: {job.job_id} — Status:{" "} - + {job.status} - {job.error_message &&
{job.error_message}
} + {job.error_message && ( +
+ {job.error_message} +
+ )} +
+ )} + + {approveStatus && ( +
+ {approveStatus}
)} {classifications.length > 0 && ( <> -
-
-

Assigned Capabilities ({classifications.length})

-

Grouped by target Agent.

-
- +
+
+

+ Assigned Capabilities ({classifications.length}) +

+

+ Grouped by target Agent. +

+
+
{Object.entries(groupedByAgent).map(([groupName, tools]) => (
-
{groupName === "Unassigned" ? "?" : groupName.charAt(0).toUpperCase()}
+
+ {groupName === "Unassigned" + ? "?" + : groupName.charAt(0).toUpperCase()} +

{groupName}

{tools.length} Attached Tools @@ -205,16 +240,36 @@ export function ReviewPage() { {tools.map((t) => (
- + {t.method} - {t.path} + + {t.path} +
{t.summary}
handleFieldChange(t.originalIdx, "agent_group", e.target.value)} + onChange={(e) => + handleFieldChange( + t.originalIdx, + "agent_group", + e.target.value + ) + } className="tool-input" placeholder="Agent Name" /> diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 7394b75..2b2cd3e 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -39,13 +39,28 @@ export interface ErrorMessage { message: string; } +export interface ClarificationMessage { + type: "clarification"; + thread_id: string; + message: string; +} + +export interface InterruptExpiredMessage { + type: "interrupt_expired"; + thread_id: string; + action: string; + message: string; +} + export type ServerMessage = | TokenMessage | InterruptMessage | ToolCallMessage | ToolResultMessage | MessageCompleteMessage - | ErrorMessage; + | ErrorMessage + | ClarificationMessage + | InterruptExpiredMessage; // -- Client -> Server messages --