refactor: fix architectural issues across frontend and backend

Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
Yaojia Wang
2026-04-06 15:59:14 +02:00
parent b8654aa31f
commit af53111928
29 changed files with 1183 additions and 473 deletions

View File

@@ -4,16 +4,22 @@ from __future__ import annotations
import re import re
from dataclasses import asdict from dataclasses import asdict
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING
from fastapi import APIRouter, HTTPException, Query, Request from fastapi import APIRouter, Depends, HTTPException, Query, Request
from app.analytics.queries import get_analytics from app.analytics.queries import get_analytics
from app.api_utils import envelope
from app.auth import require_admin_api_key
if TYPE_CHECKING: if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
router = APIRouter(prefix="/api/analytics", tags=["analytics"]) router = APIRouter(
prefix="/api/analytics",
tags=["analytics"],
dependencies=[Depends(require_admin_api_key)],
)
_RANGE_PATTERN = re.compile(r"^(\d+)d$") _RANGE_PATTERN = re.compile(r"^(\d+)d$")
_DEFAULT_RANGE = "7d" _DEFAULT_RANGE = "7d"
@@ -25,10 +31,6 @@ async def _get_pool(request: Request) -> AsyncConnectionPool:
return request.app.state.pool return request.app.state.pool
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
return {"success": success, "data": data, "error": error}
def _parse_range(range_str: str) -> int: def _parse_range(range_str: str) -> int:
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format.""" """Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
match = _RANGE_PATTERN.match(range_str) match = _RANGE_PATTERN.match(range_str)
@@ -55,4 +57,4 @@ async def analytics(
range_days = _parse_range(range) range_days = _parse_range(range)
pool = await _get_pool(request) pool = await _get_pool(request)
result = await get_analytics(pool, range_days=range_days) result = await get_analytics(pool, range_days=range_days)
return _envelope(asdict(result)) return envelope(asdict(result))

10
backend/app/api_utils.py Normal file
View File

@@ -0,0 +1,10 @@
"""Shared API response helpers."""
from __future__ import annotations
from typing import Any
def envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
"""Wrap API response data in a standard envelope format."""
return {"success": success, "data": data, "error": error}

72
backend/app/auth.py Normal file
View File

@@ -0,0 +1,72 @@
"""API key authentication for admin endpoints and WebSocket connections."""
from __future__ import annotations
import logging
import secrets
from typing import Annotated
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
from fastapi.security import APIKeyHeader
logger = logging.getLogger(__name__)
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
def _get_admin_api_key(request: Request) -> str:
"""Retrieve the configured admin API key from app settings.
Returns empty string if settings are not configured (test/dev mode).
"""
settings = getattr(request.app.state, "settings", None)
if settings is None:
return ""
key = getattr(settings, "admin_api_key", "")
return key if isinstance(key, str) else ""
async def require_admin_api_key(
request: Request,
api_key: Annotated[str | None, Depends(_API_KEY_HEADER)] = None,
) -> None:
"""Dependency that enforces API key authentication on admin endpoints.
Skips validation when no admin_api_key is configured (dev mode).
"""
expected = _get_admin_api_key(request)
if not expected:
return
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing X-API-Key header",
)
if not secrets.compare_digest(api_key, expected):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API key",
)
async def verify_ws_token(
ws: WebSocket,
token: str | None = Query(default=None),
) -> None:
"""Verify WebSocket connection token from query parameter.
Skips validation when no admin_api_key is configured (dev mode).
Usage: ws://host/ws?token=<api_key>
"""
settings = ws.app.state.settings
expected = settings.admin_api_key
if not expected:
return
if token is None or not secrets.compare_digest(token, expected):
await ws.close(code=4001, reason="Unauthorized")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing WebSocket token",
)

View File

@@ -32,6 +32,8 @@ class Settings(BaseSettings):
template_name: str = "" template_name: str = ""
admin_api_key: str = ""
anthropic_api_key: str = "" anthropic_api_key: str = ""
openai_api_key: str = "" openai_api_key: str = ""
azure_openai_api_key: str = "" azure_openai_api_key: str = ""

View File

@@ -51,6 +51,15 @@ CREATE TABLE IF NOT EXISTS analytics_events (
); );
""" """
_SESSIONS_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
thread_id TEXT PRIMARY KEY,
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_CONVERSATIONS_MIGRATION_DDL = """ _CONVERSATIONS_MIGRATION_DDL = """
ALTER TABLE conversations ALTER TABLE conversations
ADD COLUMN IF NOT EXISTS resolution_type TEXT, ADD COLUMN IF NOT EXISTS resolution_type TEXT,
@@ -84,5 +93,6 @@ async def setup_app_tables(pool: AsyncConnectionPool) -> None:
async with pool.connection() as conn: async with pool.connection() as conn:
await conn.execute(_CONVERSATIONS_DDL) await conn.execute(_CONVERSATIONS_DDL)
await conn.execute(_INTERRUPTS_DDL) await conn.execute(_INTERRUPTS_DDL)
await conn.execute(_SESSIONS_DDL)
await conn.execute(_ANALYTICS_EVENTS_DDL) await conn.execute(_ANALYTICS_EVENTS_DDL)
await conn.execute(_CONVERSATIONS_MIGRATION_DDL) await conn.execute(_CONVERSATIONS_MIGRATION_DDL)

View File

@@ -9,13 +9,13 @@ from langchain.agents import create_agent
from langgraph_supervisor import create_supervisor from langgraph_supervisor import create_supervisor
from app.agents import get_tools_by_names from app.agents import get_tools_by_names
from app.graph_context import GraphContext
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph.state import CompiledStateGraph
from app.intent import ClassificationResult, IntentClassifier from app.intent import IntentClassifier
from app.registry import AgentRegistry from app.registry import AgentRegistry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -75,12 +75,11 @@ def build_graph(
llm: BaseChatModel, llm: BaseChatModel,
checkpointer: AsyncPostgresSaver, checkpointer: AsyncPostgresSaver,
intent_classifier: IntentClassifier | None = None, intent_classifier: IntentClassifier | None = None,
) -> CompiledStateGraph: ) -> GraphContext:
"""Build and compile the LangGraph supervisor graph. """Build and compile the LangGraph supervisor graph.
If an intent_classifier is provided, the supervisor prompt is enhanced Returns a GraphContext that bundles the compiled graph with its
with agent descriptions for better routing. The classifier is stored associated registry and intent classifier.
for use by the routing layer (ws_handler).
""" """
agent_nodes = build_agent_nodes(registry, llm) agent_nodes = build_agent_nodes(registry, llm)
agent_descriptions = _format_agent_descriptions(registry) agent_descriptions = _format_agent_descriptions(registry)
@@ -94,28 +93,10 @@ def build_graph(
output_mode="full_history", output_mode="full_history",
) )
graph = workflow.compile(checkpointer=checkpointer) compiled = workflow.compile(checkpointer=checkpointer)
# Attach classifier and registry to graph for use by ws_handler return GraphContext(
graph.intent_classifier = intent_classifier # type: ignore[attr-defined] graph=compiled,
graph.agent_registry = registry # type: ignore[attr-defined] registry=registry,
intent_classifier=intent_classifier,
return graph )
async def classify_intent(
graph: CompiledStateGraph,
message: str,
) -> ClassificationResult | None:
"""Classify user intent using the graph's attached classifier.
Returns None if no classifier is configured.
"""
classifier = getattr(graph, "intent_classifier", None)
registry = getattr(graph, "agent_registry", None)
if classifier is None or registry is None:
return None
agents = registry.list_agents()
return await classifier.classify(message, agents)

View File

@@ -0,0 +1,36 @@
"""GraphContext -- typed wrapper around the compiled graph and its dependencies."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from app.intent import ClassificationResult, IntentClassifier
from app.registry import AgentRegistry
@dataclass(frozen=True)
class GraphContext:
"""Bundles the compiled LangGraph graph with its associated services.
Replaces the previous pattern of monkey-patching attributes onto the
third-party CompiledStateGraph instance.
"""
graph: CompiledStateGraph
registry: AgentRegistry
intent_classifier: IntentClassifier | None = None
async def classify_intent(self, message: str) -> ClassificationResult | None:
"""Classify user intent using the attached classifier.
Returns None if no classifier is configured.
"""
if self.intent_classifier is None:
return None
agents = self.registry.list_agents()
return await self.intent_classifier.classify(message, agents)

View File

@@ -1,10 +1,18 @@
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.""" """Interrupt TTL management -- tracks pending interrupts with auto-expiration.
Provides both in-memory (InterruptManager) and PostgreSQL-backed
(PgInterruptManager) implementations behind a common Protocol.
"""
from __future__ import annotations from __future__ import annotations
import time import time
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -28,8 +36,32 @@ class InterruptStatus:
record: InterruptRecord record: InterruptRecord
class InterruptManagerProtocol(Protocol):
"""Protocol for interrupt TTL management."""
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
def resolve(self, thread_id: str) -> None: ...
def has_pending(self, thread_id: str) -> bool: ...
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action."""
return {
"type": "interrupt_expired",
"thread_id": expired_record.thread_id,
"action": expired_record.action,
"message": (
f"The approval request for '{expired_record.action}' has expired "
f"after {expired_record.ttl_seconds // 60} minutes. "
f"Would you like to try again?"
),
}
class InterruptManager: class InterruptManager:
"""Manages interrupt TTL with auto-expiration. """In-memory interrupt manager for single-worker development.
Complements SessionManager -- this tracks interrupt-specific TTL Complements SessionManager -- this tracks interrupt-specific TTL
while SessionManager handles session-level TTL. while SessionManager handles session-level TTL.
@@ -62,11 +94,9 @@ class InterruptManager:
record = self._interrupts.get(thread_id) record = self._interrupts.get(thread_id)
if record is None: if record is None:
return None return None
elapsed = time.time() - record.created_at elapsed = time.time() - record.created_at
remaining = max(0.0, record.ttl_seconds - elapsed) remaining = max(0.0, record.ttl_seconds - elapsed)
is_expired = elapsed > record.ttl_seconds is_expired = elapsed > record.ttl_seconds
return InterruptStatus( return InterruptStatus(
is_expired=is_expired, is_expired=is_expired,
remaining_seconds=remaining, remaining_seconds=remaining,
@@ -84,28 +114,17 @@ class InterruptManager:
now = time.time() now = time.time()
expired: list[InterruptRecord] = [] expired: list[InterruptRecord] = []
active: dict[str, InterruptRecord] = {} active: dict[str, InterruptRecord] = {}
for thread_id, record in self._interrupts.items(): for thread_id, record in self._interrupts.items():
if now - record.created_at > record.ttl_seconds: if now - record.created_at > record.ttl_seconds:
expired.append(record) expired.append(record)
else: else:
active[thread_id] = record active[thread_id] = record
self._interrupts = active self._interrupts = active
return tuple(expired) return tuple(expired)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action.""" """Generate a WebSocket message prompting the user to retry an expired action."""
return { return _build_retry_prompt(expired_record)
"type": "interrupt_expired",
"thread_id": expired_record.thread_id,
"action": expired_record.action,
"message": (
f"The approval request for '{expired_record.action}' has expired "
f"after {expired_record.ttl_seconds // 60} minutes. "
f"Would you like to try again?"
),
}
def has_pending(self, thread_id: str) -> bool: def has_pending(self, thread_id: str) -> bool:
"""Check if a thread has a pending (non-expired) interrupt.""" """Check if a thread has a pending (non-expired) interrupt."""
@@ -113,3 +132,137 @@ class InterruptManager:
if status is None: if status is None:
return False return False
return not status.is_expired return not status.is_expired
# Alias for explicit naming
InMemoryInterruptManager = InterruptManager
class PgInterruptManager:
"""PostgreSQL-backed interrupt manager for multi-worker production.
Uses the existing active_interrupts table defined in db.py.
"""
def __init__(
self,
pool: AsyncConnectionPool,
ttl_seconds: int = 1800,
) -> None:
self._pool = pool
self._ttl_seconds = ttl_seconds
def register(
self,
thread_id: str,
action: str,
params: dict,
) -> InterruptRecord:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._register(thread_id, action, params)
)
async def _register(
self, thread_id: str, action: str, params: dict
) -> InterruptRecord:
import json
record = InterruptRecord(
interrupt_id=uuid.uuid4().hex,
thread_id=thread_id,
action=action,
params=dict(params),
created_at=time.time(),
ttl_seconds=self._ttl_seconds,
)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
DO UPDATE SET
interrupt_id = %(iid)s,
action = %(action)s,
params = %(params)s,
created_at = NOW(),
resolved_at = NULL
""",
{
"iid": record.interrupt_id,
"tid": thread_id,
"action": action,
"params": json.dumps(params),
},
)
return record
def check_status(self, thread_id: str) -> InterruptStatus | None:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._check_status(thread_id)
)
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
async with self._pool.connection() as conn:
cursor = await conn.execute(
"""
SELECT interrupt_id, action, params, created_at
FROM active_interrupts
WHERE thread_id = %(tid)s AND resolved_at IS NULL
ORDER BY created_at DESC LIMIT 1
""",
{"tid": thread_id},
)
row = await cursor.fetchone()
if row is None:
return None
created_at = row["created_at"].timestamp()
elapsed = time.time() - created_at
remaining = max(0.0, self._ttl_seconds - elapsed)
is_expired = elapsed > self._ttl_seconds
record = InterruptRecord(
interrupt_id=row["interrupt_id"],
thread_id=thread_id,
action=row["action"],
params=row["params"] if isinstance(row["params"], dict) else {},
created_at=created_at,
ttl_seconds=self._ttl_seconds,
)
return InterruptStatus(
is_expired=is_expired,
remaining_seconds=remaining,
record=record,
)
def resolve(self, thread_id: str) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
async def _resolve(self, thread_id: str) -> None:
async with self._pool.connection() as conn:
await conn.execute(
"""
UPDATE active_interrupts
SET resolved_at = NOW(), resolution = 'resolved'
WHERE thread_id = %(tid)s AND resolved_at IS NULL
""",
{"tid": thread_id},
)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
return _build_retry_prompt(expired_record)
def has_pending(self, thread_id: str) -> bool:
status = self.check_status(thread_id)
if status is None:
return False
return not status.is_expired

View File

@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import Depends, FastAPI, Query, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from app.analytics.api import router as analytics_router from app.analytics.api import router as analytics_router
@@ -25,6 +25,7 @@ from app.openapi.review_api import router as openapi_router
from app.registry import AgentRegistry from app.registry import AgentRegistry
from app.replay.api import router as replay_router from app.replay.api import router as replay_router
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -52,7 +53,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
llm = create_llm(settings) llm = create_llm(settings)
intent_classifier = LLMIntentClassifier(llm) intent_classifier = LLMIntentClassifier(llm)
graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier) graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
session_manager = SessionManager( session_manager = SessionManager(
session_ttl_seconds=settings.session_ttl_minutes * 60, session_ttl_seconds=settings.session_ttl_minutes * 60,
@@ -71,7 +72,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
else: else:
escalator = NoOpEscalator() escalator = NoOpEscalator()
app.state.graph = graph app.state.graph_ctx = graph_ctx
app.state.session_manager = session_manager app.state.session_manager = session_manager
app.state.interrupt_manager = interrupt_manager app.state.interrupt_manager = interrupt_manager
app.state.escalator = escalator app.state.escalator = escalator
@@ -93,7 +94,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await pool.close() await pool.close()
_VERSION = "0.5.0" _VERSION = "0.6.0"
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan) app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
@@ -109,28 +110,37 @@ def health_check() -> dict:
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket) -> None: async def websocket_endpoint(
await ws.accept() ws: WebSocket,
graph = app.state.graph token: str | None = Query(default=None),
session_manager = app.state.session_manager ) -> None:
interrupt_manager = app.state.interrupt_manager
settings = app.state.settings settings = app.state.settings
# Verify WebSocket token when admin_api_key is configured
if settings.admin_api_key:
import secrets as _secrets
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
await ws.close(code=4001, reason="Unauthorized")
return
await ws.accept()
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model) callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
analytics_recorder = app.state.analytics_recorder ws_ctx = WebSocketContext(
conversation_tracker = app.state.conversation_tracker graph_ctx=app.state.graph_ctx,
pool = app.state.pool session_manager=app.state.session_manager,
callback_handler=callback_handler,
interrupt_manager=app.state.interrupt_manager,
analytics_recorder=app.state.analytics_recorder,
conversation_tracker=app.state.conversation_tracker,
pool=app.state.pool,
)
try: try:
while True: while True:
raw_data = await ws.receive_text() raw_data = await ws.receive_text()
await dispatch_message( await dispatch_message(ws, ws_ctx, raw_data)
ws, graph, session_manager, callback_handler, raw_data,
interrupt_manager=interrupt_manager,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
pool=pool,
)
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info("WebSocket client disconnected") logger.info("WebSocket client disconnected")

View File

@@ -15,15 +15,21 @@ import re
import uuid import uuid
from typing import Literal from typing import Literal
from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from app.auth import require_admin_api_key
from app.openapi.generator import generate_agent_yaml, generate_tool_code
from app.openapi.importer import ImportOrchestrator from app.openapi.importer import ImportOrchestrator
from app.openapi.models import ClassificationResult, ImportJob from app.openapi.models import ClassificationResult, ImportJob
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/openapi", tags=["openapi"]) router = APIRouter(
prefix="/api/openapi",
tags=["openapi"],
dependencies=[Depends(require_admin_api_key)],
)
# In-memory store: job_id -> job dict, guarded by async lock # In-memory store: job_id -> job dict, guarded by async lock
_job_store: dict[str, dict] = {} _job_store: dict[str, dict] = {}
@@ -235,11 +241,42 @@ async def update_classification(
@router.post("/jobs/{job_id}/approve") @router.post("/jobs/{job_id}/approve")
async def approve_job(job_id: str) -> dict: async def approve_job(job_id: str) -> dict:
"""Approve a job's classifications and trigger tool generation.""" """Approve a job's classifications and trigger tool generation.
Generates Python tool code for each classified endpoint and
produces an agent YAML configuration snippet.
"""
job = _job_store.get(job_id) job = _job_store.get(job_id)
if job is None: if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
updated_job = {**job, "status": "approved"} classifications: list[ClassificationResult] = job.get("classifications", [])
if not classifications:
raise HTTPException(
status_code=400,
detail="No classifications to approve. Import must complete first.",
)
base_url = job["spec_url"].rsplit("/", 1)[0]
generated_tools = []
for clf in classifications:
tool = generate_tool_code(clf, base_url)
generated_tools.append({
"function_name": tool.function_name,
"agent_group": clf.agent_group,
"code": tool.code,
})
agent_yaml = generate_agent_yaml(tuple(classifications), base_url)
updated_job = {
**job,
"status": "approved",
"generated_tools": generated_tools,
"agent_yaml": agent_yaml,
}
_job_store[job_id] = updated_job _job_store[job_id] = updated_job
return _job_to_response(updated_job)
response = _job_to_response(updated_job)
response["generated_tools_count"] = len(generated_tools)
return response

View File

@@ -3,16 +3,23 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import TYPE_CHECKING, Annotated, Any from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, HTTPException, Query, Request from fastapi import APIRouter, Depends, HTTPException, Query, Request
from app.api_utils import envelope
from app.auth import require_admin_api_key
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") _THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
if TYPE_CHECKING: if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
router = APIRouter(prefix="/api", tags=["replay"]) router = APIRouter(
prefix="/api",
tags=["replay"],
dependencies=[Depends(require_admin_api_key)],
)
_COUNT_CONVERSATIONS_SQL = """ _COUNT_CONVERSATIONS_SQL = """
SELECT COUNT(*) FROM conversations SELECT COUNT(*) FROM conversations
@@ -38,10 +45,6 @@ async def get_pool(request: Request) -> AsyncConnectionPool:
return request.app.state.pool return request.app.state.pool
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
return {"success": success, "data": data, "error": error}
@router.get("/conversations") @router.get("/conversations")
async def list_conversations( async def list_conversations(
request: Request, request: Request,
@@ -62,7 +65,7 @@ async def list_conversations(
) )
rows = await cursor.fetchall() rows = await cursor.fetchall()
return _envelope({ return envelope({
"conversations": [dict(row) for row in rows], "conversations": [dict(row) for row in rows],
"total": total, "total": total,
"page": page, "page": page,
@@ -119,4 +122,4 @@ async def get_replay(
for s in page_steps for s in page_steps
], ],
} }
return _envelope(data) return envelope(data)

View File

@@ -1,9 +1,18 @@
"""Session TTL management with sliding window and interrupt extension.""" """Session TTL management with sliding window and interrupt extension.
Provides both in-memory (SessionManager) and PostgreSQL-backed
(PgSessionManager) implementations behind a common Protocol.
"""
from __future__ import annotations from __future__ import annotations
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -13,8 +22,19 @@ class SessionState:
has_pending_interrupt: bool has_pending_interrupt: bool
class SessionManagerProtocol(Protocol):
"""Protocol for session TTL management."""
def touch(self, thread_id: str) -> SessionState: ...
def is_expired(self, thread_id: str) -> bool: ...
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
def get_state(self, thread_id: str) -> SessionState | None: ...
def remove(self, thread_id: str) -> None: ...
class SessionManager: class SessionManager:
"""Manages session TTL with sliding window and interrupt extensions. """In-memory session manager for single-worker development.
- Each message resets the TTL (sliding window). - Each message resets the TTL (sliding window).
- A pending interrupt suspends expiration until resolved. - A pending interrupt suspends expiration until resolved.
@@ -40,10 +60,8 @@ class SessionManager:
state = self._sessions.get(thread_id) state = self._sessions.get(thread_id)
if state is None: if state is None:
return True return True
if state.has_pending_interrupt: if state.has_pending_interrupt:
return False return False
elapsed = time.time() - state.last_activity elapsed = time.time() - state.last_activity
return elapsed > self._session_ttl return elapsed > self._session_ttl
@@ -52,7 +70,6 @@ class SessionManager:
existing = self._sessions.get(thread_id) existing = self._sessions.get(thread_id)
if existing is None: if existing is None:
return self.touch(thread_id) return self.touch(thread_id)
new_state = SessionState( new_state = SessionState(
thread_id=thread_id, thread_id=thread_id,
last_activity=existing.last_activity, last_activity=existing.last_activity,
@@ -76,3 +93,120 @@ class SessionManager:
def remove(self, thread_id: str) -> None: def remove(self, thread_id: str) -> None:
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id} self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
# Alias for explicit naming
InMemorySessionManager = SessionManager
class PgSessionManager:
"""PostgreSQL-backed session manager for multi-worker production."""
def __init__(
self,
pool: AsyncConnectionPool,
session_ttl_seconds: int = 1800,
) -> None:
self._pool = pool
self._session_ttl = session_ttl_seconds
def touch(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
async def _touch(self, thread_id: str) -> SessionState:
now = datetime.now(timezone.utc)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
VALUES (%(tid)s, %(now)s, FALSE)
ON CONFLICT (thread_id) DO UPDATE
SET last_activity = %(now)s
""",
{"tid": thread_id, "now": now},
)
return SessionState(
thread_id=thread_id,
last_activity=now.timestamp(),
has_pending_interrupt=False,
)
def is_expired(self, thread_id: str) -> bool:
state = self.get_state(thread_id)
if state is None:
return True
if state.has_pending_interrupt:
return False
elapsed = time.time() - state.last_activity
return elapsed > self._session_ttl
def extend_for_interrupt(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._set_interrupt(thread_id, True)
)
def resolve_interrupt(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._set_interrupt(thread_id, False)
)
async def _set_interrupt(
self, thread_id: str, has_interrupt: bool
) -> SessionState:
now = datetime.now(timezone.utc)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
VALUES (%(tid)s, %(now)s, %(interrupt)s)
ON CONFLICT (thread_id) DO UPDATE
SET last_activity = %(now)s,
has_pending_interrupt = %(interrupt)s
""",
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
)
return SessionState(
thread_id=thread_id,
last_activity=now.timestamp(),
has_pending_interrupt=has_interrupt,
)
def get_state(self, thread_id: str) -> SessionState | None:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._get_state(thread_id)
)
async def _get_state(self, thread_id: str) -> SessionState | None:
async with self._pool.connection() as conn:
cursor = await conn.execute(
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
{"tid": thread_id},
)
row = await cursor.fetchone()
if row is None:
return None
return SessionState(
thread_id=thread_id,
last_activity=row["last_activity"].timestamp(),
has_pending_interrupt=row["has_pending_interrupt"],
)
def remove(self, thread_id: str) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
async def _remove(self, thread_id: str) -> None:
async with self._pool.connection() as conn:
await conn.execute(
"DELETE FROM sessions WHERE thread_id = %(tid)s",
{"tid": thread_id},
)

30
backend/app/ws_context.py Normal file
View File

@@ -0,0 +1,30 @@
"""WebSocketContext -- bundles all dependencies needed by dispatch_message."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from app.analytics.event_recorder import AnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler
from app.conversation_tracker import ConversationTrackerProtocol
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
@dataclass(frozen=True)
class WebSocketContext:
"""All dependencies required for WebSocket message processing.
Replaces the previous 9-parameter function signature in dispatch_message.
"""
graph_ctx: GraphContext
session_manager: SessionManager
callback_handler: TokenUsageCallbackHandler
interrupt_manager: InterruptManager | None = None
analytics_recorder: AnalyticsRecorder | None = None
conversation_tracker: ConversationTrackerProtocol | None = None
pool: Any = None

View File

@@ -7,22 +7,19 @@ import logging
import re import re
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.types import Command from langgraph.types import Command
from app.graph import classify_intent
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi import WebSocket from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.analytics.event_recorder import AnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.conversation_tracker import ConversationTrackerProtocol from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -46,7 +43,7 @@ def _evict_stale_threads(cutoff: float) -> None:
async def handle_user_message( async def handle_user_message(
ws: WebSocket, ws: WebSocket,
graph: CompiledStateGraph, ctx: GraphContext,
session_manager: SessionManager, session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler, callback_handler: TokenUsageCallbackHandler,
thread_id: str, thread_id: str,
@@ -54,8 +51,6 @@ async def handle_user_message(
interrupt_manager: InterruptManager | None = None, interrupt_manager: InterruptManager | None = None,
) -> None: ) -> None:
"""Process a user message through the graph and stream results back.""" """Process a user message through the graph and stream results back."""
# Touch first so new sessions are created before expiry check.
# For existing sessions, touch resets the sliding window.
existing = session_manager.get_state(thread_id) existing = session_manager.get_state(thread_id)
if existing is not None and session_manager.is_expired(thread_id): if existing is not None and session_manager.is_expired(thread_id):
msg = "Session expired. Please start a new conversation." msg = "Session expired. Please start a new conversation."
@@ -64,8 +59,7 @@ async def handle_user_message(
session_manager.touch(thread_id) session_manager.touch(thread_id)
# Run intent classification if available (for logging/future multi-intent) classification = await ctx.classify_intent(content)
classification = await classify_intent(graph, content)
if classification is not None: if classification is not None:
logger.info( logger.info(
"Intent classification for thread %s: ambiguous=%s, intents=%s", "Intent classification for thread %s: ambiguous=%s, intents=%s",
@@ -74,7 +68,6 @@ async def handle_user_message(
[i.agent_name for i in classification.intents], [i.agent_name for i in classification.intents],
) )
# If ambiguous, send clarification and return
if classification.is_ambiguous and classification.clarification_question: if classification.is_ambiguous and classification.clarification_question:
await _send_json( await _send_json(
ws, ws,
@@ -89,7 +82,6 @@ async def handle_user_message(
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
# If multi-intent detected, add routing hint to the message
if classification and len(classification.intents) > 1: if classification and len(classification.intents) > 1:
agent_names = [i.agent_name for i in classification.intents] agent_names = [i.agent_name for i in classification.intents]
hint = ( hint = (
@@ -101,7 +93,7 @@ async def handle_user_message(
input_msg = {"messages": [HumanMessage(content=content)]} input_msg = {"messages": [HumanMessage(content=content)]}
try: try:
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"): async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"):
msg_chunk, metadata = chunk msg_chunk, metadata = chunk
node = metadata.get("langgraph_node", "") node = metadata.get("langgraph_node", "")
@@ -126,12 +118,11 @@ async def handle_user_message(
}, },
) )
state = await graph.aget_state(config) state = await ctx.graph.aget_state(config)
if _has_interrupt(state): if _has_interrupt(state):
interrupt_data = _extract_interrupt(state) interrupt_data = _extract_interrupt(state)
session_manager.extend_for_interrupt(thread_id) session_manager.extend_for_interrupt(thread_id)
# Register interrupt with TTL tracking
if interrupt_manager is not None: if interrupt_manager is not None:
interrupt_manager.register( interrupt_manager.register(
thread_id=thread_id, thread_id=thread_id,
@@ -158,7 +149,7 @@ async def handle_user_message(
async def handle_interrupt_response( async def handle_interrupt_response(
ws: WebSocket, ws: WebSocket,
graph: CompiledStateGraph, ctx: GraphContext,
session_manager: SessionManager, session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler, callback_handler: TokenUsageCallbackHandler,
thread_id: str, thread_id: str,
@@ -166,7 +157,6 @@ async def handle_interrupt_response(
interrupt_manager: InterruptManager | None = None, interrupt_manager: InterruptManager | None = None,
) -> None: ) -> None:
"""Resume graph execution after interrupt approval/rejection.""" """Resume graph execution after interrupt approval/rejection."""
# Check interrupt TTL before resuming
if interrupt_manager is not None: if interrupt_manager is not None:
status = interrupt_manager.check_status(thread_id) status = interrupt_manager.check_status(thread_id)
if status is not None and status.is_expired: if status is not None and status.is_expired:
@@ -184,7 +174,7 @@ async def handle_interrupt_response(
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
try: try:
async for chunk in graph.astream( async for chunk in ctx.graph.astream(
Command(resume=approved), Command(resume=approved),
config=config, config=config,
stream_mode="messages", stream_mode="messages",
@@ -212,14 +202,8 @@ async def handle_interrupt_response(
async def dispatch_message( async def dispatch_message(
ws: WebSocket, ws: WebSocket,
graph: CompiledStateGraph, ctx: WebSocketContext,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
raw_data: str, raw_data: str,
interrupt_manager: InterruptManager | None = None,
analytics_recorder: AnalyticsRecorder | None = None,
conversation_tracker: ConversationTrackerProtocol | None = None,
pool: Any = None,
) -> None: ) -> None:
"""Parse and route an incoming WebSocket message.""" """Parse and route an incoming WebSocket message."""
if len(raw_data) > MAX_MESSAGE_SIZE: if len(raw_data) > MAX_MESSAGE_SIZE:
@@ -268,14 +252,15 @@ async def dispatch_message(
_thread_timestamps[thread_id] = [*recent, now] _thread_timestamps[thread_id] = [*recent, now]
await handle_user_message( await handle_user_message(
ws, graph, session_manager, callback_handler, thread_id, content, ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
interrupt_manager=interrupt_manager, thread_id, content,
interrupt_manager=ctx.interrupt_manager,
) )
await _fire_and_forget_tracking( await _fire_and_forget_tracking(
thread_id=thread_id, thread_id=thread_id,
pool=pool, pool=ctx.pool,
analytics_recorder=analytics_recorder, analytics_recorder=ctx.analytics_recorder,
conversation_tracker=conversation_tracker, conversation_tracker=ctx.conversation_tracker,
agent_name=None, agent_name=None,
tokens=0, tokens=0,
cost=0.0, cost=0.0,
@@ -284,8 +269,9 @@ async def dispatch_message(
elif msg_type == "interrupt_response": elif msg_type == "interrupt_response":
approved = data.get("approved", False) approved = data.get("approved", False)
await handle_interrupt_response( await handle_interrupt_response(
ws, graph, session_manager, callback_handler, thread_id, approved, ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
interrupt_manager=interrupt_manager, thread_id, approved,
interrupt_manager=ctx.interrupt_manager,
) )
else: else:
@@ -294,9 +280,9 @@ async def dispatch_message(
async def _fire_and_forget_tracking( async def _fire_and_forget_tracking(
thread_id: str, thread_id: str,
pool: Any, pool: object,
analytics_recorder: Any | None, analytics_recorder: object | None,
conversation_tracker: Any | None, conversation_tracker: object | None,
agent_name: str | None, agent_name: str | None,
tokens: int, tokens: int,
cost: float, cost: float,

View File

@@ -13,10 +13,12 @@ from httpx import ASGITransport, AsyncClient
from app.analytics.api import router as analytics_router from app.analytics.api import router as analytics_router
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.openapi.review_api import _job_store, router as openapi_router from app.openapi.review_api import _job_store, router as openapi_router
from app.replay.api import router as replay_router from app.replay.api import router as replay_router
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
@@ -74,8 +76,6 @@ def make_graph(
) -> MagicMock: ) -> MagicMock:
"""Build a mock LangGraph CompiledStateGraph.""" """Build a mock LangGraph CompiledStateGraph."""
g = MagicMock() g = MagicMock()
g.intent_classifier = None
g.agent_registry = None
if state is None: if state is None:
state = make_state() state = make_state()
@@ -93,6 +93,14 @@ def make_graph(
return g return g
def make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
"""Build a GraphContext wrapping a mock graph."""
g = graph or make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Fake database pool # Fake database pool
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -148,6 +156,7 @@ def create_e2e_app(
) -> FastAPI: ) -> FastAPI:
"""Create a FastAPI app wired with mocked dependencies for E2E testing.""" """Create a FastAPI app wired with mocked dependencies for E2E testing."""
g = graph or make_graph() g = graph or make_graph()
graph_ctx = make_graph_ctx(g)
p = pool or FakePool() p = pool or FakePool()
sm = SessionManager(session_ttl_seconds=session_ttl) sm = SessionManager(session_ttl_seconds=session_ttl)
im = InterruptManager(ttl_seconds=interrupt_ttl) im = InterruptManager(ttl_seconds=interrupt_ttl)
@@ -157,7 +166,7 @@ def create_e2e_app(
app.include_router(replay_router) app.include_router(replay_router)
app.include_router(analytics_router) app.include_router(analytics_router)
app.state.graph = g app.state.graph_ctx = graph_ctx
app.state.session_manager = sm app.state.session_manager = sm
app.state.interrupt_manager = im app.state.interrupt_manager = im
app.state.pool = p app.state.pool = p
@@ -175,17 +184,16 @@ def create_e2e_app(
try: try:
while True: while True:
raw_data = await ws.receive_text() raw_data = await ws.receive_text()
await dispatch_message( ws_ctx = WebSocketContext(
ws, graph_ctx=app.state.graph_ctx,
app.state.graph, session_manager=app.state.session_manager,
app.state.session_manager, callback_handler=TokenUsageCallbackHandler(model_name="test-model"),
TokenUsageCallbackHandler(model_name="test-model"),
raw_data,
interrupt_manager=app.state.interrupt_manager, interrupt_manager=app.state.interrupt_manager,
analytics_recorder=app.state.analytics_recorder, analytics_recorder=app.state.analytics_recorder,
conversation_tracker=app.state.conversation_tracker, conversation_tracker=app.state.conversation_tracker,
pool=app.state.pool, pool=app.state.pool,
) )
await dispatch_message(ws, ws_ctx, raw_data)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass

View File

@@ -20,10 +20,12 @@ import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig, AgentRegistry from app.registry import AgentConfig, AgentRegistry
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates" TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
@@ -128,10 +130,8 @@ class TestCheckpoint1OrderQueryRouting:
mock_classifier.classify = AsyncMock(return_value=ClassificationResult( mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
)) ))
graph.intent_classifier = mock_classifier
mock_registry = MagicMock() mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=()) mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
# Graph streams order_lookup response # Graph streams order_lookup response
graph.astream = MagicMock(return_value=AsyncIterHelper([ graph.astream = MagicMock(return_value=AsyncIterHelper([
@@ -140,14 +140,21 @@ class TestCheckpoint1OrderQueryRouting:
])) ]))
graph.aget_state = AsyncMock(return_value=_state()) graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"] tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
assert any(m["tool"] == "get_order_status" for m in tool_msgs) assert any(m["tool"] == "get_order_status" for m in tool_msgs)
@@ -201,25 +208,30 @@ class TestCheckpoint2MultiIntentSequential:
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
), ),
)) ))
graph.intent_classifier = mock_classifier
mock_registry = MagicMock() mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=()) mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state()) graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({ raw = json.dumps({
"type": "message", "type": "message",
"thread_id": "t1", "thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣", "content": "取消订单 1042 并给我一个 10% 折扣",
}) })
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
# Verify the graph was called with the routing hint in the message # Verify the graph was called with the routing hint in the message
call_args = graph.astream.call_args call_args = graph.astream.call_args
@@ -267,21 +279,26 @@ class TestCheckpoint3AmbiguousClarification:
"Could you please provide more details about what you need help with?" "Could you please provide more details about what you need help with?"
), ),
)) ))
graph.intent_classifier = mock_classifier
mock_registry = MagicMock() mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=()) mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state()) graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
clarifications = [m for m in ws.sent if m["type"] == "clarification"] clarifications = [m for m in ws.sent if m["type"] == "clarification"]
assert len(clarifications) == 1 assert len(clarifications) == 1
@@ -303,20 +320,26 @@ class TestCheckpoint4InterruptTTLAutoCancel:
async def test_30min_expired_interrupt_auto_cancels(self) -> None: async def test_30min_expired_interrupt_auto_cancels(self) -> None:
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
graph = MagicMock() graph = MagicMock()
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=st) graph.aget_state = AsyncMock(return_value=st)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager(ttl_seconds=1800) # 30 minutes im = InterruptManager(ttl_seconds=1800) # 30 minutes
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
# Trigger interrupt # Trigger interrupt
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
interrupts = [m for m in ws.sent if m["type"] == "interrupt"] interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1 assert len(interrupts) == 1
@@ -333,7 +356,7 @@ class TestCheckpoint4InterruptTTLAutoCancel:
"thread_id": "t1", "thread_id": "t1",
"approved": True, "approved": True,
}) })
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
# Should get retry prompt, NOT resume the graph # Should get retry prompt, NOT resume the graph
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"] expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]

View File

@@ -18,10 +18,12 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig from app.registry import AgentConfig
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -103,36 +105,45 @@ def _make_classifier(result: ClassificationResult) -> AsyncMock:
return classifier return classifier
def _make_graph( def _make_graph_and_ctx(
classifier_result: ClassificationResult | None, classifier_result: ClassificationResult | None,
chunks: list, chunks: list,
state=None, state=None,
) -> MagicMock: ) -> tuple[MagicMock, GraphContext]:
"""Build a graph mock with optional intent classifier.""" """Build a graph mock and GraphContext with optional intent classifier."""
graph = MagicMock() graph = MagicMock()
if classifier_result is not None:
graph.intent_classifier = _make_classifier(classifier_result)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=AGENTS)
graph.agent_registry = mock_registry
else:
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks))) graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
graph.aget_state = AsyncMock(return_value=state or _state()) graph.aget_state = AsyncMock(return_value=state or _state())
return graph
if classifier_result is not None:
classifier = _make_classifier(classifier_result)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=AGENTS)
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=classifier,
)
else:
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=None,
)
return graph, graph_ctx
async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]: async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
sm = SessionManager() sm = SessionManager()
sm.touch(thread_id) sm.touch(thread_id)
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content}) raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
return ws.sent return ws.sent
@@ -151,12 +162,12 @@ class TestSingleIntentRouting:
agent_name="order_lookup", confidence=0.95, reasoning="status query", agent_name="order_lookup", confidence=0.95, reasoning="status query",
),), ),),
) )
graph = _make_graph(result, [ graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"), _tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
_chunk("Order 1042 is shipped.", "order_lookup"), _chunk("Order 1042 is shipped.", "order_lookup"),
]) ])
msgs = await _dispatch(graph, "What is the status of order 1042?") msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
tools = [m for m in msgs if m["type"] == "tool_call"] tools = [m for m in msgs if m["type"] == "tool_call"]
assert len(tools) == 1 assert len(tools) == 1
@@ -171,13 +182,13 @@ class TestSingleIntentRouting:
result = ClassificationResult( result = ClassificationResult(
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),), intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
) )
graph = _make_graph( graph, graph_ctx = _make_graph_and_ctx(
result, result,
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")], [_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}), state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
) )
msgs = await _dispatch(graph, "Cancel order 1042") msgs = await _dispatch(graph_ctx, "Cancel order 1042")
tools = [m for m in msgs if m["type"] == "tool_call"] tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "cancel_order" assert tools[0]["tool"] == "cancel_order"
@@ -191,12 +202,12 @@ class TestSingleIntentRouting:
result = ClassificationResult( result = ClassificationResult(
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),), intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
) )
graph = _make_graph(result, [ graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"), _tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"), _chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
]) ])
msgs = await _dispatch(graph, "Give me a 15% coupon") msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
tools = [m for m in msgs if m["type"] == "tool_call"] tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "generate_coupon" assert tools[0]["tool"] == "generate_coupon"
@@ -207,11 +218,11 @@ class TestSingleIntentRouting:
result = ClassificationResult( result = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),), intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
) )
graph = _make_graph(result, [ graph, graph_ctx = _make_graph_and_ctx(result, [
_chunk("I can help with order inquiries.", "fallback"), _chunk("I can help with order inquiries.", "fallback"),
]) ])
msgs = await _dispatch(graph, "What can you do?") msgs = await _dispatch(graph_ctx, "What can you do?")
tokens = [m for m in msgs if m["type"] == "token"] tokens = [m for m in msgs if m["type"] == "token"]
assert tokens[0]["agent"] == "fallback" assert tokens[0]["agent"] == "fallback"
@@ -233,7 +244,7 @@ class TestMultiIntentRouting:
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
), ),
) )
graph = _make_graph(result, [ graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"), _tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"), _tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
]) ])
@@ -243,13 +254,17 @@ class TestMultiIntentRouting:
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({ raw = json.dumps({
"type": "message", "type": "message",
"thread_id": "t1", "thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣", "content": "取消订单 1042 并给我一个 10% 折扣",
}) })
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
# Verify routing hint was injected # Verify routing hint was injected
call_args = graph.astream.call_args[0][0] call_args = graph.astream.call_args[0][0]
@@ -269,16 +284,20 @@ class TestMultiIntentRouting:
result = ClassificationResult( result = ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
) )
graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")]) graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
msg_content = graph.astream.call_args[0][0]["messages"][0].content msg_content = graph.astream.call_args[0][0]["messages"][0].content
assert "[System:" not in msg_content assert "[System:" not in msg_content
@@ -299,9 +318,9 @@ class TestAmbiguityRouting:
is_ambiguous=True, is_ambiguous=True,
clarification_question="Could you please clarify what you need?", clarification_question="Could you please clarify what you need?",
) )
graph = _make_graph(result, []) graph, graph_ctx = _make_graph_and_ctx(result, [])
msgs = await _dispatch(graph, "嗯...") msgs = await _dispatch(graph_ctx, "嗯...")
clarifications = [m for m in msgs if m["type"] == "clarification"] clarifications = [m for m in msgs if m["type"] == "clarification"]
assert len(clarifications) == 1 assert len(clarifications) == 1
@@ -339,12 +358,12 @@ class TestNoClassifierFallback:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_classifier_routes_via_supervisor(self) -> None: async def test_no_classifier_routes_via_supervisor(self) -> None:
graph = _make_graph( graph, graph_ctx = _make_graph_and_ctx(
classifier_result=None, classifier_result=None,
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")], chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
) )
msgs = await _dispatch(graph, "What is order 1042 status?") msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
tokens = [m for m in msgs if m["type"] == "token"] tokens = [m for m in msgs if m["type"] == "token"]
assert len(tokens) == 1 assert len(tokens) == 1

View File

@@ -15,8 +15,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -81,8 +83,6 @@ def _graph(
resume_chunks: list | None = None, resume_chunks: list | None = None,
) -> MagicMock: ) -> MagicMock:
g = MagicMock() g = MagicMock()
g.intent_classifier = None
g.agent_registry = None
if st is None: if st is None:
st = _state() st = _state()
@@ -100,6 +100,13 @@ def _graph(
return g return g
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
g = graph or _graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
def _setup( def _setup(
graph=None, graph=None,
session_ttl: int = 1800, session_ttl: int = 1800,
@@ -109,23 +116,28 @@ def _setup(
): ):
"""Create test dependencies. Pre-touches session by default.""" """Create test dependencies. Pre-touches session by default."""
g = graph or _graph() g = graph or _graph()
graph_ctx = _make_graph_ctx(g)
sm = SessionManager(session_ttl_seconds=session_ttl) sm = SessionManager(session_ttl_seconds=session_ttl)
im = InterruptManager(ttl_seconds=interrupt_ttl) im = InterruptManager(ttl_seconds=interrupt_ttl)
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
if touch: if touch:
sm.touch(thread_id) sm.touch(thread_id)
return g, sm, im, cb, ws return g, sm, im, cb, ws, ws_ctx
async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"): async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content}) raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True): async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved}) raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -136,10 +148,10 @@ async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
class TestWebSocketHappyPath: class TestWebSocketHappyPath:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_receives_tokens_and_complete(self) -> None: async def test_send_message_receives_tokens_and_complete(self) -> None:
g, sm, im, cb, ws = _setup( g, sm, im, cb, ws, ws_ctx = _setup(
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")]) graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
) )
await _send(ws, g, sm, im, cb, content="What is the status of order 1042?") await _send(ws, ws_ctx, content="What is the status of order 1042?")
tokens = [m for m in ws.sent if m["type"] == "token"] tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 2 assert len(tokens) == 2
@@ -153,13 +165,13 @@ class TestWebSocketHappyPath:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_call_streamed(self) -> None: async def test_tool_call_streamed(self) -> None:
g, sm, im, cb, ws = _setup( g, sm, im, cb, ws, ws_ctx = _setup(
graph=_graph(chunks=[ graph=_graph(chunks=[
_tool_chunk("get_order_status", {"order_id": "1042"}), _tool_chunk("get_order_status", {"order_id": "1042"}),
_chunk("Order shipped."), _chunk("Order shipped."),
]) ])
) )
await _send(ws, g, sm, im, cb, content="Check order 1042") await _send(ws, ws_ctx, content="Check order 1042")
tools = [m for m in ws.sent if m["type"] == "tool_call"] tools = [m for m in ws.sent if m["type"] == "tool_call"]
assert len(tools) == 1 assert len(tools) == 1
@@ -168,9 +180,9 @@ class TestWebSocketHappyPath:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_messages_same_session(self) -> None: async def test_multiple_messages_same_session(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
for i in range(3): for i in range(3):
await _send(ws, g, sm, im, cb, content=f"msg {i}") await _send(ws, ws_ctx, content=f"msg {i}")
completes = [m for m in ws.sent if m["type"] == "message_complete"] completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 3 assert len(completes) == 3
@@ -183,10 +195,10 @@ class TestWebSocketInterruptApproval:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
resume = [_chunk("Order 1042 cancelled.", "order_actions")] resume = [_chunk("Order 1042 cancelled.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume) g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws = _setup(graph=g) g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
# Send message -> triggers interrupt # Send message -> triggers interrupt
await _send(ws, g_, sm, im, cb, content="Cancel order 1042") await _send(ws, ws_ctx, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"] interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1 assert len(interrupts) == 1
@@ -196,7 +208,7 @@ class TestWebSocketInterruptApproval:
# Approve # Approve
ws.sent.clear() ws.sent.clear()
await _respond(ws, g_, sm, im, cb, approved=True) await _respond(ws, ws_ctx, approved=True)
tokens = [m for m in ws.sent if m["type"] == "token"] tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 1 assert len(tokens) == 1
@@ -211,12 +223,12 @@ class TestWebSocketInterruptApproval:
st_int = _state(interrupt=True) st_int = _state(interrupt=True)
resume = [_chunk("Order remains active.", "order_actions")] resume = [_chunk("Order remains active.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume) g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws = _setup(graph=g) g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
await _send(ws, g_, sm, im, cb, content="Cancel order 1042") await _send(ws, ws_ctx, content="Cancel order 1042")
ws.sent.clear() ws.sent.clear()
await _respond(ws, g_, sm, im, cb, approved=False) await _respond(ws, ws_ctx, approved=False)
tokens = [m for m in ws.sent if m["type"] == "token"] tokens = [m for m in ws.sent if m["type"] == "token"]
assert "remains active" in tokens[0]["content"] assert "remains active" in tokens[0]["content"]
@@ -226,28 +238,28 @@ class TestWebSocketInterruptApproval:
class TestWebSocketSessionTTL: class TestWebSocketSessionTTL:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_expired_session_returns_error(self) -> None: async def test_expired_session_returns_error(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=0) g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
# Session was touched in _setup, but TTL is 0 so it's already expired # Session was touched in _setup, but TTL is 0 so it's already expired
await _send(ws, g, sm, im, cb, content="hello") await _send(ws, ws_ctx, content="hello")
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "expired" in ws.sent[0]["message"].lower() assert "expired" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_session_not_expired(self) -> None: async def test_new_session_not_expired(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=3600) g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
await _send(ws, g, sm, im, cb, content="hello") await _send(ws, ws_ctx, content="hello")
completes = [m for m in ws.sent if m["type"] == "message_complete"] completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 1 assert len(completes) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sliding_window_resets_on_message(self) -> None: async def test_sliding_window_resets_on_message(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=3600) g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
await _send(ws, g, sm, im, cb, content="hello") await _send(ws, ws_ctx, content="hello")
first_activity = sm.get_state("t1").last_activity first_activity = sm.get_state("t1").last_activity
time.sleep(0.01) time.sleep(0.01)
await _send(ws, g, sm, im, cb, content="hello again") await _send(ws, ws_ctx, content="hello again")
second_activity = sm.get_state("t1").last_activity second_activity = sm.get_state("t1").last_activity
assert second_activity > first_activity assert second_activity > first_activity
@@ -256,9 +268,9 @@ class TestWebSocketSessionTTL:
async def test_interrupt_extends_session_ttl(self) -> None: async def test_interrupt_extends_session_ttl(self) -> None:
st_int = _state(interrupt=True) st_int = _state(interrupt=True)
g = _graph(chunks=[], st=st_int) g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600) g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
await _send(ws, g_, sm, im, cb, content="cancel order") await _send(ws, ws_ctx, content="cancel order")
state = sm.get_state("t1") state = sm.get_state("t1")
assert state is not None assert state is not None
@@ -270,53 +282,53 @@ class TestWebSocketSessionTTL:
class TestWebSocketValidation: class TestWebSocketValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_json(self) -> None: async def test_invalid_json(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im) await dispatch_message(ws, ws_ctx, "not json")
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "Invalid JSON" in ws.sent[0]["message"] assert "Invalid JSON" in ws.sent[0]["message"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_thread_id(self) -> None: async def test_missing_thread_id(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "content": "hi"}) raw = json.dumps({"type": "message", "content": "hi"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "thread_id" in ws.sent[0]["message"] assert "thread_id" in ws.sent[0]["message"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_thread_id_format(self) -> None: async def test_invalid_thread_id_format(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"}) raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_content(self) -> None: async def test_missing_content(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1"}) raw = json.dumps({"type": "message", "thread_id": "t1"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unknown_message_type(self) -> None: async def test_unknown_message_type(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "foobar", "thread_id": "t1"}) raw = json.dumps({"type": "foobar", "thread_id": "t1"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "Unknown" in ws.sent[0]["message"] assert "Unknown" in ws.sent[0]["message"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_too_large(self) -> None: async def test_message_too_large(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im) await dispatch_message(ws, ws_ctx, "x" * 40_000)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "too large" in ws.sent[0]["message"].lower() assert "too large" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_content_too_long(self) -> None: async def test_content_too_long(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "too long" in ws.sent[0]["message"].lower() assert "too long" in ws.sent[0]["message"].lower()
@@ -327,10 +339,10 @@ class TestWebSocketInterruptTTL:
async def test_expired_interrupt_sends_retry_prompt(self) -> None: async def test_expired_interrupt_sends_retry_prompt(self) -> None:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
g = _graph(chunks=[], st=st_int) g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5) g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
# Trigger interrupt # Trigger interrupt
await _send(ws, g_, sm, im, cb, content="Cancel order 1042") await _send(ws, ws_ctx, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"] interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1 assert len(interrupts) == 1
@@ -341,7 +353,7 @@ class TestWebSocketInterruptTTL:
with patch("app.interrupt_manager.time") as mock_time: with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 10 mock_time.time.return_value = record.created_at + 10
await _respond(ws, g_, sm, im, cb, approved=True) await _respond(ws, ws_ctx, approved=True)
assert ws.sent[0]["type"] == "interrupt_expired" assert ws.sent[0]["type"] == "interrupt_expired"
assert "cancel_order" in ws.sent[0]["message"] assert "cancel_order" in ws.sent[0]["message"]

View File

@@ -55,7 +55,7 @@ class TestDbModule:
from app.db import setup_app_tables from app.db import setup_app_tables
await setup_app_tables(mock_pool) await setup_app_tables(mock_pool)
assert mock_conn.execute.await_count == 4 assert mock_conn.execute.await_count == 5
def test_ddl_statements_valid(self) -> None: def test_ddl_statements_valid(self) -> None:
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL

View File

@@ -51,5 +51,5 @@ class TestAnalyticsEventsDDL:
from app.db import setup_app_tables from app.db import setup_app_tables
await setup_app_tables(mock_pool) await setup_app_tables(mock_pool)
# Now expects 4 statements: conversations, interrupts, analytics_events, migrations # Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations
assert mock_conn.execute.await_count == 4 assert mock_conn.execute.await_count == 5

View File

@@ -8,7 +8,9 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
pytestmark = pytest.mark.unit pytestmark = pytest.mark.unit
@@ -20,7 +22,7 @@ def _make_ws() -> AsyncMock:
return ws return ws
def _make_graph() -> AsyncMock: def _make_graph() -> MagicMock:
graph = AsyncMock() graph = AsyncMock()
class AsyncIterHelper: class AsyncIterHelper:
@@ -34,23 +36,32 @@ def _make_graph() -> AsyncMock:
state = MagicMock() state = MagicMock()
state.tasks = () state.tasks = ()
graph.aget_state = AsyncMock(return_value=state) graph.aget_state = AsyncMock(return_value=state)
graph.intent_classifier = None
graph.agent_registry = None
return graph return graph
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
graph = _make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
return WebSocketContext(
graph_ctx=graph_ctx,
session_manager=sm or SessionManager(),
callback_handler=TokenUsageCallbackHandler(),
)
@pytest.mark.unit @pytest.mark.unit
class TestEmptyMessageHandling: class TestEmptyMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_message_content_returns_error(self) -> None: async def test_empty_message_content_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -60,13 +71,12 @@ class TestEmptyMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_whitespace_only_message_treated_as_empty(self) -> None: async def test_whitespace_only_message_treated_as_empty(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -77,14 +87,13 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_content_over_10000_chars_returns_error(self) -> None: async def test_content_over_10000_chars_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
content = "x" * 10001 content = "x" * 10001
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -93,14 +102,13 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_content_exactly_10000_chars_is_accepted(self) -> None: async def test_content_exactly_10000_chars_is_accepted(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
content = "x" * 10000 content = "x" * 10000
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
# Should be processed, not an error about length # Should be processed, not an error about length
@@ -110,12 +118,10 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_raw_message_over_32kb_returns_error(self) -> None: async def test_raw_message_over_32kb_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
large_msg = "x" * 40_000 large_msg = "x" * 40_000
await dispatch_message(ws, graph, sm, cb, large_msg) await dispatch_message(ws, ws_ctx, large_msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -127,11 +133,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_json_returns_error(self) -> None: async def test_invalid_json_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, "not valid json {{") await dispatch_message(ws, ws_ctx, "not valid json {{")
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -140,11 +144,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_string_returns_json_error(self) -> None: async def test_empty_string_returns_json_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, "") await dispatch_message(ws, ws_ctx, "")
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -152,11 +154,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_json_array_not_object_returns_error(self) -> None: async def test_json_array_not_object_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]') await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -167,17 +167,15 @@ class TestRateLimiting:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rapid_fire_messages_rate_limited(self) -> None: async def test_rapid_fire_messages_rate_limited(self) -> None:
ws = _make_ws() ws = _make_ws()
_make_graph() # ensure graph factory works, not needed directly
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit) # Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
rate_limit_triggered = False rate_limit_triggered = False
for i in range(11): for i in range(11):
graph2 = _make_graph() # fresh graph each time ws_ctx = _make_ws_ctx(sm=sm)
await dispatch_message(ws, graph2, sm, cb, json.dumps({ await dispatch_message(ws, ws_ctx, json.dumps({
"type": "message", "type": "message",
"thread_id": "t1", "thread_id": "t1",
"content": f"message {i}", "content": f"message {i}",
@@ -193,19 +191,18 @@ class TestRateLimiting:
async def test_different_threads_have_separate_rate_limits(self) -> None: async def test_different_threads_have_separate_rate_limits(self) -> None:
ws = _make_ws() ws = _make_ws()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
sm.touch("t2") sm.touch("t2")
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited # Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
for i in range(5): for i in range(5):
graph1 = _make_graph() ws_ctx1 = _make_ws_ctx(sm=sm)
graph2 = _make_graph() ws_ctx2 = _make_ws_ctx(sm=sm)
await dispatch_message(ws, graph1, sm, cb, json.dumps({ await dispatch_message(ws, ws_ctx1, json.dumps({
"type": "message", "thread_id": "t1", "content": f"msg {i}", "type": "message", "thread_id": "t1", "content": f"msg {i}",
})) }))
await dispatch_message(ws, graph2, sm, cb, json.dumps({ await dispatch_message(ws, ws_ctx2, json.dumps({
"type": "message", "thread_id": "t2", "content": f"msg {i}", "type": "message", "thread_id": "t2", "content": f"msg {i}",
})) }))

View File

@@ -8,7 +8,8 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from app.graph import build_agent_nodes, build_graph, classify_intent from app.graph import build_agent_nodes, build_graph
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget from app.intent import ClassificationResult, IntentTarget
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -37,8 +38,9 @@ class TestBuildGraph:
mock_llm.with_structured_output = MagicMock(return_value=mock_llm) mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
graph = build_graph(sample_registry, mock_llm, checkpointer) graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
assert graph is not None assert graph_ctx is not None
assert graph_ctx.graph is not None
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None: def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock() mock_llm = MagicMock()
@@ -47,11 +49,11 @@ class TestBuildGraph:
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
mock_classifier = MagicMock() mock_classifier = MagicMock()
graph = build_graph( graph_ctx = build_graph(
sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
) )
assert graph.intent_classifier is mock_classifier assert graph_ctx.intent_classifier is mock_classifier
assert graph.agent_registry is sample_registry assert graph_ctx.registry is sample_registry
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None: def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock() mock_llm = MagicMock()
@@ -59,17 +61,18 @@ class TestBuildGraph:
mock_llm.with_structured_output = MagicMock(return_value=mock_llm) mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
graph = build_graph(sample_registry, mock_llm, checkpointer) graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
assert graph.intent_classifier is None assert graph_ctx.intent_classifier is None
@pytest.mark.unit @pytest.mark.unit
class TestClassifyIntent: class TestClassifyIntent:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_returns_none_without_classifier(self) -> None: async def test_returns_none_without_classifier(self) -> None:
graph = MagicMock() mock_registry = MagicMock()
graph.intent_classifier = None mock_registry.list_agents = MagicMock(return_value=())
result = await classify_intent(graph, "hello") graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None)
result = await graph_ctx.classify_intent("hello")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -80,11 +83,12 @@ class TestClassifyIntent:
mock_classifier = AsyncMock() mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=expected) mock_classifier.classify = AsyncMock(return_value=expected)
graph = MagicMock() mock_registry = MagicMock()
graph.intent_classifier = mock_classifier mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = MagicMock() graph_ctx = GraphContext(
graph.agent_registry.list_agents = MagicMock(return_value=()) graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier,
)
result = await classify_intent(graph, "check order") result = await graph_ctx.classify_intent("check order")
assert result is not None assert result is not None
assert result.intents[0].agent_name == "order_lookup" assert result.intents[0].agent_name == "order_lookup"

View File

@@ -13,7 +13,7 @@ class TestMainModule:
assert app.title == "Smart Support" assert app.title == "Smart Support"
def test_app_version(self) -> None: def test_app_version(self) -> None:
assert app.version == "0.5.0" assert app.version == "0.6.0"
def test_agents_yaml_path_exists(self) -> None: def test_agents_yaml_path_exists(self) -> None:
assert AGENTS_YAML.name == "agents.yaml" assert AGENTS_YAML.name == "agents.yaml"
@@ -39,4 +39,4 @@ class TestMainModule:
assert "/api/health" in routes assert "/api/health" in routes
def test_app_version_is_0_5_0(self) -> None: def test_app_version_is_0_5_0(self) -> None:
assert app.version == "0.5.0" assert app.version == "0.6.0"

View File

@@ -8,8 +8,10 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import ( from app.ws_handler import (
_extract_interrupt, _extract_interrupt,
_has_interrupt, _has_interrupt,
@@ -25,18 +27,42 @@ def _make_ws() -> AsyncMock:
return ws return ws
def _make_graph() -> AsyncMock: def _make_graph() -> MagicMock:
graph = AsyncMock() graph = AsyncMock()
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
state = MagicMock() state = MagicMock()
state.tasks = () state.tasks = ()
graph.aget_state = AsyncMock(return_value=state) graph.aget_state = AsyncMock(return_value=state)
# Phase 2: graph needs intent_classifier and agent_registry attrs
graph.intent_classifier = None
graph.agent_registry = None
return graph return graph
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
g = graph or _make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
def _make_ws_ctx(
graph_ctx: GraphContext | None = None,
sm: SessionManager | None = None,
cb: TokenUsageCallbackHandler | None = None,
interrupt_manager: InterruptManager | None = None,
analytics_recorder=None,
conversation_tracker=None,
pool=None,
) -> WebSocketContext:
return WebSocketContext(
graph_ctx=graph_ctx or _make_graph_ctx(),
session_manager=sm or SessionManager(),
callback_handler=cb or TokenUsageCallbackHandler(),
interrupt_manager=interrupt_manager,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
pool=pool,
)
class AsyncIterHelper: class AsyncIterHelper:
"""Helper to make a list behave as an async iterator.""" """Helper to make a list behave as an async iterator."""
@@ -57,11 +83,9 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_json(self) -> None: async def test_invalid_json(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, "not json") await dispatch_message(ws, ws_ctx, "not json")
ws.send_json.assert_awaited_once() ws.send_json.assert_awaited_once()
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -70,12 +94,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_thread_id(self) -> None: async def test_missing_thread_id(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "message", "content": "hello"}) msg = json.dumps({"type": "message", "content": "hello"})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "thread_id" in call_data["message"] assert "thread_id" in call_data["message"]
@@ -83,24 +105,20 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_content(self) -> None: async def test_missing_content(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "message", "thread_id": "t1"}) msg = json.dumps({"type": "message", "thread_id": "t1"})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unknown_message_type(self) -> None: async def test_unknown_message_type(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "unknown", "thread_id": "t1"}) msg = json.dumps({"type": "unknown", "thread_id": "t1"})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "Unknown" in call_data["message"] assert "Unknown" in call_data["message"]
@@ -108,12 +126,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_too_large(self) -> None: async def test_message_too_large(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
large_msg = "x" * 40_000 large_msg = "x" * 40_000
await dispatch_message(ws, graph, sm, cb, large_msg) await dispatch_message(ws, ws_ctx, large_msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "too large" in call_data["message"].lower() assert "too large" in call_data["message"].lower()
@@ -121,12 +137,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_thread_id_format(self) -> None: async def test_invalid_thread_id_format(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"}) msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "thread_id" in call_data["message"].lower() assert "thread_id" in call_data["message"].lower()
@@ -134,12 +148,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_content_too_long(self) -> None: async def test_content_too_long(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "too long" in call_data["message"].lower() assert "too long" in call_data["message"].lower()
@@ -147,14 +159,13 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dispatch_with_interrupt_manager(self) -> None: async def test_dispatch_with_interrupt_manager(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager() im = InterruptManager()
ws_ctx = _make_ws_ctx(sm=sm, interrupt_manager=im)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im) await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"
@@ -164,14 +175,14 @@ class TestHandleUserMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_expired_session(self) -> None: async def test_expired_session(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager(session_ttl_seconds=0) sm = SessionManager(session_ttl_seconds=0)
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
# First call creates the session (TTL=0) # First call creates the session (TTL=0)
await handle_user_message(ws, graph, sm, cb, "t1", "hello") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
# Second call finds it expired # Second call finds it expired
await handle_user_message(ws, graph, sm, cb, "t1", "hello again") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello again")
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "expired" in call_data["message"].lower() assert "expired" in call_data["message"].lower()
@@ -179,12 +190,12 @@ class TestHandleUserMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_successful_message(self) -> None: async def test_successful_message(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hello") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"
@@ -193,13 +204,12 @@ class TestHandleUserMessage:
ws = _make_ws() ws = _make_ws()
graph = AsyncMock() graph = AsyncMock()
graph.astream = MagicMock(side_effect=RuntimeError("boom")) graph.astream = MagicMock(side_effect=RuntimeError("boom"))
graph.intent_classifier = None graph_ctx = _make_graph_ctx(graph=graph)
graph.agent_registry = None
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hello") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -207,8 +217,6 @@ class TestHandleUserMessage:
async def test_interrupt_registered_with_manager(self) -> None: async def test_interrupt_registered_with_manager(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = AsyncMock() graph = AsyncMock()
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
# Simulate interrupt in state # Simulate interrupt in state
@@ -220,13 +228,14 @@ class TestHandleUserMessage:
state.tasks = (task,) state.tasks = (task,)
graph.aget_state = AsyncMock(return_value=state) graph.aget_state = AsyncMock(return_value=state)
graph_ctx = _make_graph_ctx(graph=graph)
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
im = InterruptManager() im = InterruptManager()
sm.touch("t1") sm.touch("t1")
await handle_user_message( await handle_user_message(
ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im, ws, graph_ctx, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
) )
# Interrupt should be registered # Interrupt should be registered
@@ -257,16 +266,17 @@ class TestHandleUserMessage:
clarification_question="What do you mean?", clarification_question="What do you mean?",
) )
) )
graph.intent_classifier = mock_classifier
mock_registry = MagicMock() mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=()) mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hmm") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hmm")
calls = [c[0][0] for c in ws.send_json.call_args_list] calls = [c[0][0] for c in ws.send_json.call_args_list]
clarification_msgs = [c for c in calls if c.get("type") == "clarification"] clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
@@ -279,13 +289,13 @@ class TestHandleInterruptResponse:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_approved_interrupt(self) -> None: async def test_approved_interrupt(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
sm.extend_for_interrupt("t1") sm.extend_for_interrupt("t1")
await handle_interrupt_response(ws, graph, sm, cb, "t1", True) await handle_interrupt_response(ws, graph_ctx, sm, cb, "t1", True)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"
@@ -294,7 +304,7 @@ class TestHandleInterruptResponse:
from unittest.mock import patch from unittest.mock import patch
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=5) im = InterruptManager(ttl_seconds=5)
@@ -307,7 +317,7 @@ class TestHandleInterruptResponse:
with patch("app.interrupt_manager.time") as mock_time: with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = im._interrupts["t1"].created_at + 10 mock_time.time.return_value = im._interrupts["t1"].created_at + 10
await handle_interrupt_response( await handle_interrupt_response(
ws, graph, sm, cb, "t1", True, interrupt_manager=im ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
) )
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
@@ -317,7 +327,7 @@ class TestHandleInterruptResponse:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_interrupt_resolves(self) -> None: async def test_valid_interrupt_resolves(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=1800) im = InterruptManager(ttl_seconds=1800)
@@ -327,7 +337,7 @@ class TestHandleInterruptResponse:
im.register("t1", "cancel_order", {}) im.register("t1", "cancel_order", {})
await handle_interrupt_response( await handle_interrupt_response(
ws, graph, sm, cb, "t1", True, interrupt_manager=im ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
) )
# Interrupt should be resolved # Interrupt should be resolved
@@ -374,19 +384,14 @@ class TestDispatchMessageWithTracking:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_conversation_tracker_called_on_message(self) -> None: async def test_conversation_tracker_called_on_message(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
tracker = AsyncMock() tracker = AsyncMock()
pool = MagicMock() pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message( await dispatch_message(ws, ws_ctx, msg)
ws, graph, sm, cb, msg,
conversation_tracker=tracker,
pool=pool,
)
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1") tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
tracker.record_turn.assert_awaited_once() tracker.record_turn.assert_awaited_once()
@@ -394,53 +399,42 @@ class TestDispatchMessageWithTracking:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_analytics_recorder_called_on_message(self) -> None: async def test_analytics_recorder_called_on_message(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
recorder = AsyncMock() recorder = AsyncMock()
pool = MagicMock() pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, analytics_recorder=recorder, pool=pool)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message( await dispatch_message(ws, ws_ctx, msg)
ws, graph, sm, cb, msg,
analytics_recorder=recorder,
pool=pool,
)
recorder.record.assert_awaited_once() recorder.record.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tracker_failure_does_not_break_chat(self) -> None: async def test_tracker_failure_does_not_break_chat(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
tracker = AsyncMock() tracker = AsyncMock()
tracker.ensure_conversation.side_effect = RuntimeError("DB down") tracker.ensure_conversation.side_effect = RuntimeError("DB down")
pool = MagicMock() pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# Should not raise despite tracker failure # Should not raise despite tracker failure
await dispatch_message( await dispatch_message(ws, ws_ctx, msg)
ws, graph, sm, cb, msg,
conversation_tracker=tracker,
pool=pool,
)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_tracker_no_error(self) -> None: async def test_no_tracker_no_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# No tracker or recorder passed -- should work fine # No tracker or recorder passed -- should work fine
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"

View File

@@ -101,3 +101,74 @@ export async function fetchReplay(
export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> { export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> {
return apiFetch<AnalyticsData>(`/api/analytics?range=${range}`); return apiFetch<AnalyticsData>(`/api/analytics?range=${range}`);
} }
// -- OpenAPI import --
export interface ImportJobResponse {
job_id: string;
status: string;
spec_url: string;
total_endpoints: number;
classified_count: number;
error_message: string | null;
generated_tools_count?: number;
}
export interface EndpointClassification {
index: number;
access_type: string;
needs_interrupt: boolean;
agent_group: string;
confidence: number;
customer_params: string[];
endpoint: {
path: string;
method: string;
operation_id: string;
summary: string;
description: string;
};
}
async function apiPost<T>(path: string, body: unknown): Promise<T> {
const res = await fetch(`${API_BASE}${path}`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(body),
});
if (!res.ok) {
throw new Error(`API error ${res.status}: ${res.statusText}`);
}
return res.json();
}
export async function startImport(url: string): Promise<ImportJobResponse> {
return apiPost<ImportJobResponse>("/api/openapi/import", { url });
}
export async function fetchImportJob(jobId: string): Promise<ImportJobResponse> {
const res = await fetch(`${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}`);
if (!res.ok) {
throw new Error(`API error ${res.status}: ${res.statusText}`);
}
return res.json();
}
export async function fetchClassifications(
jobId: string
): Promise<EndpointClassification[]> {
const res = await fetch(
`${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}/classifications`
);
if (!res.ok) {
throw new Error(`API error ${res.status}: ${res.statusText}`);
}
return res.json();
}
export async function approveJob(jobId: string): Promise<ImportJobResponse> {
return apiPost<ImportJobResponse>(
`/api/openapi/jobs/${encodeURIComponent(jobId)}/approve`,
{}
);
}

View File

@@ -13,10 +13,8 @@ import type {
ToolAction, ToolAction,
} from "../types"; } from "../types";
let msgCounter = 0;
function nextId(): string { function nextId(): string {
msgCounter += 1; return crypto.randomUUID();
return `msg-${msgCounter}`;
} }
export function ChatPage() { export function ChatPage() {
@@ -68,6 +66,48 @@ export function ChatPage() {
setIsWaiting(false); setIsWaiting(false);
break; break;
} }
case "clarification": {
setMessages((prev) => [
...prev,
{
id: nextId(),
sender: "agent",
agent: "System",
content: msg.message,
timestamp: Date.now(),
},
]);
setIsWaiting(false);
break;
}
case "interrupt_expired": {
setCurrentInterrupt(null);
setMessages((prev) => [
...prev,
{
id: nextId(),
sender: "agent",
agent: "System",
content: msg.message,
timestamp: Date.now(),
},
]);
setIsWaiting(false);
break;
}
case "tool_result": {
setToolActions((prev) => {
const last = prev[prev.length - 1];
if (last && last.tool === msg.tool && last.agent === msg.agent) {
return [
...prev.slice(0, -1),
{ ...last, result: msg.result },
];
}
return prev;
});
break;
}
case "message_complete": { case "message_complete": {
setMessages((prev) => { setMessages((prev) => {
const last = prev[prev.length - 1]; const last = prev[prev.length - 1];

View File

@@ -81,7 +81,7 @@ export function ReplayPage() {
{/* Timeline */} {/* Timeline */}
<div className="section-card" style={{ padding: "2rem" }}> <div className="section-card" style={{ padding: "2rem" }}>
<ReplayTimeline steps={steps as any} /> <ReplayTimeline steps={steps} />
</div> </div>
</div> </div>
)} )}

View File

@@ -1,12 +1,14 @@
import { useEffect, useRef, useState } from "react"; import { useEffect, useRef, useState } from "react";
import {
approveJob,
fetchClassifications,
fetchImportJob,
startImport,
type EndpointClassification,
type ImportJobResponse,
} from "../api";
interface ImportJob { interface FlatClassification {
job_id: string;
status: "pending" | "processing" | "done" | "failed";
error_message?: string;
}
interface EndpointClassification {
path: string; path: string;
method: string; method: string;
summary: string; summary: string;
@@ -14,48 +16,23 @@ interface EndpointClassification {
agent_group: string; agent_group: string;
} }
function flattenClassification(c: EndpointClassification): FlatClassification {
return {
path: c.endpoint?.path ?? "",
method: c.endpoint?.method ?? "",
summary: c.endpoint?.summary ?? "",
access_type: c.access_type ?? "read",
agent_group: c.agent_group ?? "Unassigned",
};
}
export function ReviewPage() { export function ReviewPage() {
const [url, setUrl] = useState(""); const [url, setUrl] = useState("");
const [job, setJob] = useState<ImportJob | null>(null); const [job, setJob] = useState<ImportJobResponse | null>(null);
const [submitting, setSubmitting] = useState(false); const [submitting, setSubmitting] = useState(false);
const [submitError, setSubmitError] = useState<string | null>(null); const [submitError, setSubmitError] = useState<string | null>(null);
const [classifications, setClassifications] = useState<EndpointClassification[]>([ const [approveStatus, setApproveStatus] = useState<string | null>(null);
{ const [classifications, setClassifications] = useState<FlatClassification[]>([]);
path: "/api/v1/orders/{order_id}/cancel",
method: "post",
summary: "Cancel an active Shopify order",
access_type: "write",
agent_group: "Order Specialist",
},
{
path: "/api/v1/orders/{order_id}",
method: "get",
summary: "Retrieve detailed information about an order",
access_type: "read",
agent_group: "Order Specialist",
},
{
path: "/api/v1/payments/{charge_id}/refund",
method: "post",
summary: "Issue a full or partial refund for a charge",
access_type: "write",
agent_group: "Billing Assistant",
},
{
path: "/api/v1/customers/{email}/discounts",
method: "post",
summary: "Apply a loyalty discount to a customer account",
access_type: "write",
agent_group: "Billing Assistant",
},
{
path: "/api/v1/inventory/check",
method: "get",
summary: "Query realtime stock levels across warehouses",
access_type: "read",
agent_group: "Unassigned",
}
]);
const pollRef = useRef<ReturnType<typeof setTimeout> | null>(null); const pollRef = useRef<ReturnType<typeof setTimeout> | null>(null);
useEffect(() => { useEffect(() => {
@@ -65,25 +42,13 @@ export function ReviewPage() {
}, []); }, []);
function pollJob(jobId: string) { function pollJob(jobId: string) {
fetch(`/api/openapi/jobs/${encodeURIComponent(jobId)}`) fetchImportJob(jobId)
.then((r) => r.json()) .then((j) => {
.then((data) => {
const j: ImportJob = data.data ?? data;
setJob(j); setJob(j);
if (j.status === "done") { if (j.status === "done") {
return fetch(`/api/openapi/jobs/${encodeURIComponent(jobId)}/classifications`) return fetchClassifications(jobId).then((clfs) => {
.then((r) => r.json()) setClassifications(clfs.map(flattenClassification));
.then((clfs: EndpointClassification[]) => { });
setClassifications(
clfs.map((c: any) => ({
path: c.endpoint?.path ?? c.path ?? "",
method: c.endpoint?.method ?? c.method ?? "",
summary: c.endpoint?.summary ?? c.summary ?? "",
access_type: c.access_type ?? "read",
agent_group: c.agent_group ?? "Unassigned",
}))
);
});
} else if (j.status === "failed") { } else if (j.status === "failed") {
return; return;
} else { } else {
@@ -100,17 +65,12 @@ export function ReviewPage() {
if (!url.trim()) return; if (!url.trim()) return;
setSubmitting(true); setSubmitting(true);
setSubmitError(null); setSubmitError(null);
setApproveStatus(null);
setJob(null); setJob(null);
setClassifications([]); setClassifications([]);
fetch("/api/openapi/import", { startImport(url)
method: "POST", .then((j) => {
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ url }),
})
.then((r) => r.json())
.then((data) => {
const j: ImportJob = data.data ?? data;
setJob(j); setJob(j);
if (j.job_id) pollJob(j.job_id); if (j.job_id) pollJob(j.job_id);
}) })
@@ -120,7 +80,7 @@ export function ReviewPage() {
function handleFieldChange( function handleFieldChange(
idx: number, idx: number,
field: keyof EndpointClassification, field: keyof FlatClassification,
value: string value: string
) { ) {
setClassifications((prev) => setClassifications((prev) =>
@@ -130,21 +90,26 @@ export function ReviewPage() {
function handleApprove() { function handleApprove() {
if (!job?.job_id) return; if (!job?.job_id) return;
fetch(`/api/openapi/jobs/${encodeURIComponent(job.job_id)}/approve`, { setApproveStatus(null);
method: "POST", approveJob(job.job_id)
headers: { "Content-Type": "application/json" }, .then((result) => {
body: JSON.stringify({ endpoints: classifications }), setJob(result);
}).then(() => { setApproveStatus(
alert("Approved and saved."); `Configuration saved. ${result.generated_tools_count ?? 0} tools generated.`
}); );
})
.catch((err: Error) => setApproveStatus(`Error: ${err.message}`));
} }
const groupedByAgent = classifications.reduce((acc, c, idx) => { const groupedByAgent = classifications.reduce(
const group = c.agent_group || "Unassigned"; (acc, c, idx) => {
if (!acc[group]) acc[group] = []; const group = c.agent_group || "Unassigned";
acc[group].push({ ...c, originalIdx: idx }); if (!acc[group]) acc[group] = [];
return acc; acc[group].push({ ...c, originalIdx: idx });
}, {} as Record<string, (EndpointClassification & { originalIdx: number })[]>); return acc;
},
{} as Record<string, (FlatClassification & { originalIdx: number })[]>
);
return ( return (
<div className="page-container"> <div className="page-container">
@@ -167,35 +132,105 @@ export function ReviewPage() {
</button> </button>
</form> </form>
{submitError && <div style={{ color: "var(--brand-accent)", marginBottom: "1rem" }}>Error: {submitError}</div>} {submitError && (
<div style={{ color: "var(--brand-accent)", marginBottom: "1rem" }}>
Error: {submitError}
</div>
)}
{job && ( {job && (
<div style={{ padding: "1rem", background: "var(--bg-surface)", border: "1px solid var(--border-light)", borderRadius: "var(--radius-md)", marginBottom: "1.5rem" }}> <div
style={{
padding: "1rem",
background: "var(--bg-surface)",
border: "1px solid var(--border-light)",
borderRadius: "var(--radius-md)",
marginBottom: "1.5rem",
}}
>
<strong>Job:</strong> {job.job_id} &mdash; Status:{" "} <strong>Job:</strong> {job.job_id} &mdash; Status:{" "}
<span style={{ fontWeight: 600, color: job.status === "done" ? "#10b981" : job.status === "failed" ? "var(--brand-accent)" : "#f59e0b" }}> <span
style={{
fontWeight: 600,
color:
job.status === "done" || job.status === "approved"
? "#10b981"
: job.status === "failed"
? "var(--brand-accent)"
: "#f59e0b",
}}
>
{job.status} {job.status}
</span> </span>
{job.error_message && <div style={{ marginTop: "4px", color: "var(--brand-accent)" }}>{job.error_message}</div>} {job.error_message && (
<div style={{ marginTop: "4px", color: "var(--brand-accent)" }}>
{job.error_message}
</div>
)}
</div>
)}
{approveStatus && (
<div
style={{
padding: "0.75rem 1rem",
background: approveStatus.startsWith("Error")
? "#fef2f2"
: "#f0fdf4",
border: `1px solid ${approveStatus.startsWith("Error") ? "#fecaca" : "#bbf7d0"}`,
borderRadius: "var(--radius-md)",
marginBottom: "1rem",
fontSize: "0.875rem",
}}
>
{approveStatus}
</div> </div>
)} )}
{classifications.length > 0 && ( {classifications.length > 0 && (
<> <>
<div style={{ display: "flex", justifyContent: "space-between", alignItems: "center", marginBottom: "1rem" }}> <div
<div> style={{
<h3 style={{ margin: 0, fontSize: "1.25rem", color: "var(--text-primary)" }}>Assigned Capabilities ({classifications.length})</h3> display: "flex",
<p style={{ margin: "0.25rem 0 0 0", fontSize: "0.875rem", color: "var(--text-secondary)" }}>Grouped by target Agent.</p> justifyContent: "space-between",
</div> alignItems: "center",
<button onClick={handleApprove} className="btn btn-primary"> marginBottom: "1rem",
Save Configuration }}
</button> >
<div>
<h3
style={{
margin: 0,
fontSize: "1.25rem",
color: "var(--text-primary)",
}}
>
Assigned Capabilities ({classifications.length})
</h3>
<p
style={{
margin: "0.25rem 0 0 0",
fontSize: "0.875rem",
color: "var(--text-secondary)",
}}
>
Grouped by target Agent.
</p>
</div>
<button onClick={handleApprove} className="btn btn-primary">
Save Configuration
</button>
</div> </div>
<div className="agent-grid"> <div className="agent-grid">
{Object.entries(groupedByAgent).map(([groupName, tools]) => ( {Object.entries(groupedByAgent).map(([groupName, tools]) => (
<div key={groupName} className="agent-grid-card"> <div key={groupName} className="agent-grid-card">
<div className="agent-card-header-bg"> <div className="agent-card-header-bg">
<div className="agent-avatar-lg">{groupName === "Unassigned" ? "?" : groupName.charAt(0).toUpperCase()}</div> <div className="agent-avatar-lg">
{groupName === "Unassigned"
? "?"
: groupName.charAt(0).toUpperCase()}
</div>
<div className="agent-card-meta"> <div className="agent-card-meta">
<h3>{groupName}</h3> <h3>{groupName}</h3>
<span>{tools.length} Attached Tools</span> <span>{tools.length} Attached Tools</span>
@@ -205,16 +240,36 @@ export function ReviewPage() {
{tools.map((t) => ( {tools.map((t) => (
<div key={t.originalIdx} className="tool-pill-item"> <div key={t.originalIdx} className="tool-pill-item">
<div className="tool-pill-header"> <div className="tool-pill-header">
<span className="tool-method-badge" style={{ background: t.method === "get" ? "#3b82f6" : t.method === "post" ? "#10b981" : t.method === "delete" ? "#ef4444" : "#f59e0b" }}> <span
className="tool-method-badge"
style={{
background:
t.method === "get"
? "#3b82f6"
: t.method === "post"
? "#10b981"
: t.method === "delete"
? "#ef4444"
: "#f59e0b",
}}
>
{t.method} {t.method}
</span> </span>
<span className="tool-path-text" title={t.path}>{t.path}</span> <span className="tool-path-text" title={t.path}>
{t.path}
</span>
</div> </div>
<div className="tool-summary-text">{t.summary}</div> <div className="tool-summary-text">{t.summary}</div>
<div className="tool-pill-controls"> <div className="tool-pill-controls">
<select <select
value={t.access_type} value={t.access_type}
onChange={(e) => handleFieldChange(t.originalIdx, "access_type", e.target.value)} onChange={(e) =>
handleFieldChange(
t.originalIdx,
"access_type",
e.target.value
)
}
className="tool-select" className="tool-select"
> >
<option value="read">Read Only</option> <option value="read">Read Only</option>
@@ -223,7 +278,13 @@ export function ReviewPage() {
<input <input
type="text" type="text"
value={t.agent_group} value={t.agent_group}
onChange={(e) => handleFieldChange(t.originalIdx, "agent_group", e.target.value)} onChange={(e) =>
handleFieldChange(
t.originalIdx,
"agent_group",
e.target.value
)
}
className="tool-input" className="tool-input"
placeholder="Agent Name" placeholder="Agent Name"
/> />

View File

@@ -39,13 +39,28 @@ export interface ErrorMessage {
message: string; message: string;
} }
export interface ClarificationMessage {
type: "clarification";
thread_id: string;
message: string;
}
export interface InterruptExpiredMessage {
type: "interrupt_expired";
thread_id: string;
action: string;
message: string;
}
export type ServerMessage = export type ServerMessage =
| TokenMessage | TokenMessage
| InterruptMessage | InterruptMessage
| ToolCallMessage | ToolCallMessage
| ToolResultMessage | ToolResultMessage
| MessageCompleteMessage | MessageCompleteMessage
| ErrorMessage; | ErrorMessage
| ClarificationMessage
| InterruptExpiredMessage;
// -- Client -> Server messages -- // -- Client -> Server messages --