refactor: fix architectural issues across frontend and backend
Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
@@ -4,16 +4,22 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
|
|
||||||
from app.analytics.queries import get_analytics
|
from app.analytics.queries import get_analytics
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.auth import require_admin_api_key
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/analytics", tags=["analytics"])
|
router = APIRouter(
|
||||||
|
prefix="/api/analytics",
|
||||||
|
tags=["analytics"],
|
||||||
|
dependencies=[Depends(require_admin_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
||||||
_DEFAULT_RANGE = "7d"
|
_DEFAULT_RANGE = "7d"
|
||||||
@@ -25,10 +31,6 @@ async def _get_pool(request: Request) -> AsyncConnectionPool:
|
|||||||
return request.app.state.pool
|
return request.app.state.pool
|
||||||
|
|
||||||
|
|
||||||
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
|
||||||
return {"success": success, "data": data, "error": error}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_range(range_str: str) -> int:
|
def _parse_range(range_str: str) -> int:
|
||||||
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
|
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
|
||||||
match = _RANGE_PATTERN.match(range_str)
|
match = _RANGE_PATTERN.match(range_str)
|
||||||
@@ -55,4 +57,4 @@ async def analytics(
|
|||||||
range_days = _parse_range(range)
|
range_days = _parse_range(range)
|
||||||
pool = await _get_pool(request)
|
pool = await _get_pool(request)
|
||||||
result = await get_analytics(pool, range_days=range_days)
|
result = await get_analytics(pool, range_days=range_days)
|
||||||
return _envelope(asdict(result))
|
return envelope(asdict(result))
|
||||||
|
|||||||
10
backend/app/api_utils.py
Normal file
10
backend/app/api_utils.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""Shared API response helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
||||||
|
"""Wrap API response data in a standard envelope format."""
|
||||||
|
return {"success": success, "data": data, "error": error}
|
||||||
72
backend/app/auth.py
Normal file
72
backend/app/auth.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""API key authentication for admin endpoints and WebSocket connections."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_admin_api_key(request: Request) -> str:
|
||||||
|
"""Retrieve the configured admin API key from app settings.
|
||||||
|
|
||||||
|
Returns empty string if settings are not configured (test/dev mode).
|
||||||
|
"""
|
||||||
|
settings = getattr(request.app.state, "settings", None)
|
||||||
|
if settings is None:
|
||||||
|
return ""
|
||||||
|
key = getattr(settings, "admin_api_key", "")
|
||||||
|
return key if isinstance(key, str) else ""
|
||||||
|
|
||||||
|
|
||||||
|
async def require_admin_api_key(
|
||||||
|
request: Request,
|
||||||
|
api_key: Annotated[str | None, Depends(_API_KEY_HEADER)] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Dependency that enforces API key authentication on admin endpoints.
|
||||||
|
|
||||||
|
Skips validation when no admin_api_key is configured (dev mode).
|
||||||
|
"""
|
||||||
|
expected = _get_admin_api_key(request)
|
||||||
|
if not expected:
|
||||||
|
return
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing X-API-Key header",
|
||||||
|
)
|
||||||
|
if not secrets.compare_digest(api_key, expected):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_ws_token(
|
||||||
|
ws: WebSocket,
|
||||||
|
token: str | None = Query(default=None),
|
||||||
|
) -> None:
|
||||||
|
"""Verify WebSocket connection token from query parameter.
|
||||||
|
|
||||||
|
Skips validation when no admin_api_key is configured (dev mode).
|
||||||
|
Usage: ws://host/ws?token=<api_key>
|
||||||
|
"""
|
||||||
|
settings = ws.app.state.settings
|
||||||
|
expected = settings.admin_api_key
|
||||||
|
if not expected:
|
||||||
|
return
|
||||||
|
|
||||||
|
if token is None or not secrets.compare_digest(token, expected):
|
||||||
|
await ws.close(code=4001, reason="Unauthorized")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid or missing WebSocket token",
|
||||||
|
)
|
||||||
@@ -32,6 +32,8 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
template_name: str = ""
|
template_name: str = ""
|
||||||
|
|
||||||
|
admin_api_key: str = ""
|
||||||
|
|
||||||
anthropic_api_key: str = ""
|
anthropic_api_key: str = ""
|
||||||
openai_api_key: str = ""
|
openai_api_key: str = ""
|
||||||
azure_openai_api_key: str = ""
|
azure_openai_api_key: str = ""
|
||||||
|
|||||||
@@ -51,6 +51,15 @@ CREATE TABLE IF NOT EXISTS analytics_events (
|
|||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_SESSIONS_DDL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
_CONVERSATIONS_MIGRATION_DDL = """
|
_CONVERSATIONS_MIGRATION_DDL = """
|
||||||
ALTER TABLE conversations
|
ALTER TABLE conversations
|
||||||
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
||||||
@@ -84,5 +93,6 @@ async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
|||||||
async with pool.connection() as conn:
|
async with pool.connection() as conn:
|
||||||
await conn.execute(_CONVERSATIONS_DDL)
|
await conn.execute(_CONVERSATIONS_DDL)
|
||||||
await conn.execute(_INTERRUPTS_DDL)
|
await conn.execute(_INTERRUPTS_DDL)
|
||||||
|
await conn.execute(_SESSIONS_DDL)
|
||||||
await conn.execute(_ANALYTICS_EVENTS_DDL)
|
await conn.execute(_ANALYTICS_EVENTS_DDL)
|
||||||
await conn.execute(_CONVERSATIONS_MIGRATION_DDL)
|
await conn.execute(_CONVERSATIONS_MIGRATION_DDL)
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ from langchain.agents import create_agent
|
|||||||
from langgraph_supervisor import create_supervisor
|
from langgraph_supervisor import create_supervisor
|
||||||
|
|
||||||
from app.agents import get_tools_by_names
|
from app.agents import get_tools_by_names
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
|
||||||
|
|
||||||
from app.intent import ClassificationResult, IntentClassifier
|
from app.intent import IntentClassifier
|
||||||
from app.registry import AgentRegistry
|
from app.registry import AgentRegistry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -75,12 +75,11 @@ def build_graph(
|
|||||||
llm: BaseChatModel,
|
llm: BaseChatModel,
|
||||||
checkpointer: AsyncPostgresSaver,
|
checkpointer: AsyncPostgresSaver,
|
||||||
intent_classifier: IntentClassifier | None = None,
|
intent_classifier: IntentClassifier | None = None,
|
||||||
) -> CompiledStateGraph:
|
) -> GraphContext:
|
||||||
"""Build and compile the LangGraph supervisor graph.
|
"""Build and compile the LangGraph supervisor graph.
|
||||||
|
|
||||||
If an intent_classifier is provided, the supervisor prompt is enhanced
|
Returns a GraphContext that bundles the compiled graph with its
|
||||||
with agent descriptions for better routing. The classifier is stored
|
associated registry and intent classifier.
|
||||||
for use by the routing layer (ws_handler).
|
|
||||||
"""
|
"""
|
||||||
agent_nodes = build_agent_nodes(registry, llm)
|
agent_nodes = build_agent_nodes(registry, llm)
|
||||||
agent_descriptions = _format_agent_descriptions(registry)
|
agent_descriptions = _format_agent_descriptions(registry)
|
||||||
@@ -94,28 +93,10 @@ def build_graph(
|
|||||||
output_mode="full_history",
|
output_mode="full_history",
|
||||||
)
|
)
|
||||||
|
|
||||||
graph = workflow.compile(checkpointer=checkpointer)
|
compiled = workflow.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
# Attach classifier and registry to graph for use by ws_handler
|
return GraphContext(
|
||||||
graph.intent_classifier = intent_classifier # type: ignore[attr-defined]
|
graph=compiled,
|
||||||
graph.agent_registry = registry # type: ignore[attr-defined]
|
registry=registry,
|
||||||
|
intent_classifier=intent_classifier,
|
||||||
return graph
|
)
|
||||||
|
|
||||||
|
|
||||||
async def classify_intent(
|
|
||||||
graph: CompiledStateGraph,
|
|
||||||
message: str,
|
|
||||||
) -> ClassificationResult | None:
|
|
||||||
"""Classify user intent using the graph's attached classifier.
|
|
||||||
|
|
||||||
Returns None if no classifier is configured.
|
|
||||||
"""
|
|
||||||
classifier = getattr(graph, "intent_classifier", None)
|
|
||||||
registry = getattr(graph, "agent_registry", None)
|
|
||||||
|
|
||||||
if classifier is None or registry is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
agents = registry.list_agents()
|
|
||||||
return await classifier.classify(message, agents)
|
|
||||||
|
|||||||
36
backend/app/graph_context.py
Normal file
36
backend/app/graph_context.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""GraphContext -- typed wrapper around the compiled graph and its dependencies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
|
from app.intent import ClassificationResult, IntentClassifier
|
||||||
|
from app.registry import AgentRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GraphContext:
|
||||||
|
"""Bundles the compiled LangGraph graph with its associated services.
|
||||||
|
|
||||||
|
Replaces the previous pattern of monkey-patching attributes onto the
|
||||||
|
third-party CompiledStateGraph instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph: CompiledStateGraph
|
||||||
|
registry: AgentRegistry
|
||||||
|
intent_classifier: IntentClassifier | None = None
|
||||||
|
|
||||||
|
async def classify_intent(self, message: str) -> ClassificationResult | None:
|
||||||
|
"""Classify user intent using the attached classifier.
|
||||||
|
|
||||||
|
Returns None if no classifier is configured.
|
||||||
|
"""
|
||||||
|
if self.intent_classifier is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agents = self.registry.list_agents()
|
||||||
|
return await self.intent_classifier.classify(message, agents)
|
||||||
@@ -1,10 +1,18 @@
|
|||||||
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration."""
|
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.
|
||||||
|
|
||||||
|
Provides both in-memory (InterruptManager) and PostgreSQL-backed
|
||||||
|
(PgInterruptManager) implementations behind a common Protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -28,8 +36,32 @@ class InterruptStatus:
|
|||||||
record: InterruptRecord
|
record: InterruptRecord
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptManagerProtocol(Protocol):
|
||||||
|
"""Protocol for interrupt TTL management."""
|
||||||
|
|
||||||
|
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
|
||||||
|
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
|
||||||
|
def resolve(self, thread_id: str) -> None: ...
|
||||||
|
def has_pending(self, thread_id: str) -> bool: ...
|
||||||
|
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
|
||||||
|
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||||
|
return {
|
||||||
|
"type": "interrupt_expired",
|
||||||
|
"thread_id": expired_record.thread_id,
|
||||||
|
"action": expired_record.action,
|
||||||
|
"message": (
|
||||||
|
f"The approval request for '{expired_record.action}' has expired "
|
||||||
|
f"after {expired_record.ttl_seconds // 60} minutes. "
|
||||||
|
f"Would you like to try again?"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class InterruptManager:
|
class InterruptManager:
|
||||||
"""Manages interrupt TTL with auto-expiration.
|
"""In-memory interrupt manager for single-worker development.
|
||||||
|
|
||||||
Complements SessionManager -- this tracks interrupt-specific TTL
|
Complements SessionManager -- this tracks interrupt-specific TTL
|
||||||
while SessionManager handles session-level TTL.
|
while SessionManager handles session-level TTL.
|
||||||
@@ -62,11 +94,9 @@ class InterruptManager:
|
|||||||
record = self._interrupts.get(thread_id)
|
record = self._interrupts.get(thread_id)
|
||||||
if record is None:
|
if record is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elapsed = time.time() - record.created_at
|
elapsed = time.time() - record.created_at
|
||||||
remaining = max(0.0, record.ttl_seconds - elapsed)
|
remaining = max(0.0, record.ttl_seconds - elapsed)
|
||||||
is_expired = elapsed > record.ttl_seconds
|
is_expired = elapsed > record.ttl_seconds
|
||||||
|
|
||||||
return InterruptStatus(
|
return InterruptStatus(
|
||||||
is_expired=is_expired,
|
is_expired=is_expired,
|
||||||
remaining_seconds=remaining,
|
remaining_seconds=remaining,
|
||||||
@@ -84,28 +114,17 @@ class InterruptManager:
|
|||||||
now = time.time()
|
now = time.time()
|
||||||
expired: list[InterruptRecord] = []
|
expired: list[InterruptRecord] = []
|
||||||
active: dict[str, InterruptRecord] = {}
|
active: dict[str, InterruptRecord] = {}
|
||||||
|
|
||||||
for thread_id, record in self._interrupts.items():
|
for thread_id, record in self._interrupts.items():
|
||||||
if now - record.created_at > record.ttl_seconds:
|
if now - record.created_at > record.ttl_seconds:
|
||||||
expired.append(record)
|
expired.append(record)
|
||||||
else:
|
else:
|
||||||
active[thread_id] = record
|
active[thread_id] = record
|
||||||
|
|
||||||
self._interrupts = active
|
self._interrupts = active
|
||||||
return tuple(expired)
|
return tuple(expired)
|
||||||
|
|
||||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||||
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||||
return {
|
return _build_retry_prompt(expired_record)
|
||||||
"type": "interrupt_expired",
|
|
||||||
"thread_id": expired_record.thread_id,
|
|
||||||
"action": expired_record.action,
|
|
||||||
"message": (
|
|
||||||
f"The approval request for '{expired_record.action}' has expired "
|
|
||||||
f"after {expired_record.ttl_seconds // 60} minutes. "
|
|
||||||
f"Would you like to try again?"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
def has_pending(self, thread_id: str) -> bool:
|
def has_pending(self, thread_id: str) -> bool:
|
||||||
"""Check if a thread has a pending (non-expired) interrupt."""
|
"""Check if a thread has a pending (non-expired) interrupt."""
|
||||||
@@ -113,3 +132,137 @@ class InterruptManager:
|
|||||||
if status is None:
|
if status is None:
|
||||||
return False
|
return False
|
||||||
return not status.is_expired
|
return not status.is_expired
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for explicit naming
|
||||||
|
InMemoryInterruptManager = InterruptManager
|
||||||
|
|
||||||
|
|
||||||
|
class PgInterruptManager:
|
||||||
|
"""PostgreSQL-backed interrupt manager for multi-worker production.
|
||||||
|
|
||||||
|
Uses the existing active_interrupts table defined in db.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
ttl_seconds: int = 1800,
|
||||||
|
) -> None:
|
||||||
|
self._pool = pool
|
||||||
|
self._ttl_seconds = ttl_seconds
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
action: str,
|
||||||
|
params: dict,
|
||||||
|
) -> InterruptRecord:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._register(thread_id, action, params)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _register(
|
||||||
|
self, thread_id: str, action: str, params: dict
|
||||||
|
) -> InterruptRecord:
|
||||||
|
import json
|
||||||
|
|
||||||
|
record = InterruptRecord(
|
||||||
|
interrupt_id=uuid.uuid4().hex,
|
||||||
|
thread_id=thread_id,
|
||||||
|
action=action,
|
||||||
|
params=dict(params),
|
||||||
|
created_at=time.time(),
|
||||||
|
ttl_seconds=self._ttl_seconds,
|
||||||
|
)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
|
||||||
|
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
|
||||||
|
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
|
||||||
|
DO UPDATE SET
|
||||||
|
interrupt_id = %(iid)s,
|
||||||
|
action = %(action)s,
|
||||||
|
params = %(params)s,
|
||||||
|
created_at = NOW(),
|
||||||
|
resolved_at = NULL
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"iid": record.interrupt_id,
|
||||||
|
"tid": thread_id,
|
||||||
|
"action": action,
|
||||||
|
"params": json.dumps(params),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return record
|
||||||
|
|
||||||
|
def check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._check_status(thread_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT interrupt_id, action, params, created_at
|
||||||
|
FROM active_interrupts
|
||||||
|
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||||
|
ORDER BY created_at DESC LIMIT 1
|
||||||
|
""",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
created_at = row["created_at"].timestamp()
|
||||||
|
elapsed = time.time() - created_at
|
||||||
|
remaining = max(0.0, self._ttl_seconds - elapsed)
|
||||||
|
is_expired = elapsed > self._ttl_seconds
|
||||||
|
|
||||||
|
record = InterruptRecord(
|
||||||
|
interrupt_id=row["interrupt_id"],
|
||||||
|
thread_id=thread_id,
|
||||||
|
action=row["action"],
|
||||||
|
params=row["params"] if isinstance(row["params"], dict) else {},
|
||||||
|
created_at=created_at,
|
||||||
|
ttl_seconds=self._ttl_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return InterruptStatus(
|
||||||
|
is_expired=is_expired,
|
||||||
|
remaining_seconds=remaining,
|
||||||
|
record=record,
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve(self, thread_id: str) -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
|
||||||
|
|
||||||
|
async def _resolve(self, thread_id: str) -> None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE active_interrupts
|
||||||
|
SET resolved_at = NOW(), resolution = 'resolved'
|
||||||
|
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||||
|
""",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||||
|
return _build_retry_prompt(expired_record)
|
||||||
|
|
||||||
|
def has_pending(self, thread_id: str) -> bool:
|
||||||
|
status = self.check_status(thread_id)
|
||||||
|
if status is None:
|
||||||
|
return False
|
||||||
|
return not status.is_expired
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import Depends, FastAPI, Query, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from app.analytics.api import router as analytics_router
|
from app.analytics.api import router as analytics_router
|
||||||
@@ -25,6 +25,7 @@ from app.openapi.review_api import router as openapi_router
|
|||||||
from app.registry import AgentRegistry
|
from app.registry import AgentRegistry
|
||||||
from app.replay.api import router as replay_router
|
from app.replay.api import router as replay_router
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -52,7 +53,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
|
|
||||||
llm = create_llm(settings)
|
llm = create_llm(settings)
|
||||||
intent_classifier = LLMIntentClassifier(llm)
|
intent_classifier = LLMIntentClassifier(llm)
|
||||||
graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
||||||
|
|
||||||
session_manager = SessionManager(
|
session_manager = SessionManager(
|
||||||
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
||||||
@@ -71,7 +72,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
else:
|
else:
|
||||||
escalator = NoOpEscalator()
|
escalator = NoOpEscalator()
|
||||||
|
|
||||||
app.state.graph = graph
|
app.state.graph_ctx = graph_ctx
|
||||||
app.state.session_manager = session_manager
|
app.state.session_manager = session_manager
|
||||||
app.state.interrupt_manager = interrupt_manager
|
app.state.interrupt_manager = interrupt_manager
|
||||||
app.state.escalator = escalator
|
app.state.escalator = escalator
|
||||||
@@ -93,7 +94,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
await pool.close()
|
await pool.close()
|
||||||
|
|
||||||
|
|
||||||
_VERSION = "0.5.0"
|
_VERSION = "0.6.0"
|
||||||
|
|
||||||
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
|
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
|
||||||
|
|
||||||
@@ -109,28 +110,37 @@ def health_check() -> dict:
|
|||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/ws")
|
||||||
async def websocket_endpoint(ws: WebSocket) -> None:
|
async def websocket_endpoint(
|
||||||
await ws.accept()
|
ws: WebSocket,
|
||||||
graph = app.state.graph
|
token: str | None = Query(default=None),
|
||||||
session_manager = app.state.session_manager
|
) -> None:
|
||||||
interrupt_manager = app.state.interrupt_manager
|
|
||||||
settings = app.state.settings
|
settings = app.state.settings
|
||||||
|
|
||||||
|
# Verify WebSocket token when admin_api_key is configured
|
||||||
|
if settings.admin_api_key:
|
||||||
|
import secrets as _secrets
|
||||||
|
|
||||||
|
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
|
||||||
|
await ws.close(code=4001, reason="Unauthorized")
|
||||||
|
return
|
||||||
|
|
||||||
|
await ws.accept()
|
||||||
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
||||||
|
|
||||||
analytics_recorder = app.state.analytics_recorder
|
ws_ctx = WebSocketContext(
|
||||||
conversation_tracker = app.state.conversation_tracker
|
graph_ctx=app.state.graph_ctx,
|
||||||
pool = app.state.pool
|
session_manager=app.state.session_manager,
|
||||||
|
callback_handler=callback_handler,
|
||||||
|
interrupt_manager=app.state.interrupt_manager,
|
||||||
|
analytics_recorder=app.state.analytics_recorder,
|
||||||
|
conversation_tracker=app.state.conversation_tracker,
|
||||||
|
pool=app.state.pool,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
raw_data = await ws.receive_text()
|
raw_data = await ws.receive_text()
|
||||||
await dispatch_message(
|
await dispatch_message(ws, ws_ctx, raw_data)
|
||||||
ws, graph, session_manager, callback_handler, raw_data,
|
|
||||||
interrupt_manager=interrupt_manager,
|
|
||||||
analytics_recorder=analytics_recorder,
|
|
||||||
conversation_tracker=conversation_tracker,
|
|
||||||
pool=pool,
|
|
||||||
)
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.info("WebSocket client disconnected")
|
logger.info("WebSocket client disconnected")
|
||||||
|
|
||||||
|
|||||||
@@ -15,15 +15,21 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
from app.auth import require_admin_api_key
|
||||||
|
from app.openapi.generator import generate_agent_yaml, generate_tool_code
|
||||||
from app.openapi.importer import ImportOrchestrator
|
from app.openapi.importer import ImportOrchestrator
|
||||||
from app.openapi.models import ClassificationResult, ImportJob
|
from app.openapi.models import ClassificationResult, ImportJob
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/openapi", tags=["openapi"])
|
router = APIRouter(
|
||||||
|
prefix="/api/openapi",
|
||||||
|
tags=["openapi"],
|
||||||
|
dependencies=[Depends(require_admin_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
# In-memory store: job_id -> job dict, guarded by async lock
|
# In-memory store: job_id -> job dict, guarded by async lock
|
||||||
_job_store: dict[str, dict] = {}
|
_job_store: dict[str, dict] = {}
|
||||||
@@ -235,11 +241,42 @@ async def update_classification(
|
|||||||
|
|
||||||
@router.post("/jobs/{job_id}/approve")
|
@router.post("/jobs/{job_id}/approve")
|
||||||
async def approve_job(job_id: str) -> dict:
|
async def approve_job(job_id: str) -> dict:
|
||||||
"""Approve a job's classifications and trigger tool generation."""
|
"""Approve a job's classifications and trigger tool generation.
|
||||||
|
|
||||||
|
Generates Python tool code for each classified endpoint and
|
||||||
|
produces an agent YAML configuration snippet.
|
||||||
|
"""
|
||||||
job = _job_store.get(job_id)
|
job = _job_store.get(job_id)
|
||||||
if job is None:
|
if job is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||||
|
|
||||||
updated_job = {**job, "status": "approved"}
|
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||||
|
if not classifications:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No classifications to approve. Import must complete first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
base_url = job["spec_url"].rsplit("/", 1)[0]
|
||||||
|
generated_tools = []
|
||||||
|
for clf in classifications:
|
||||||
|
tool = generate_tool_code(clf, base_url)
|
||||||
|
generated_tools.append({
|
||||||
|
"function_name": tool.function_name,
|
||||||
|
"agent_group": clf.agent_group,
|
||||||
|
"code": tool.code,
|
||||||
|
})
|
||||||
|
|
||||||
|
agent_yaml = generate_agent_yaml(tuple(classifications), base_url)
|
||||||
|
|
||||||
|
updated_job = {
|
||||||
|
**job,
|
||||||
|
"status": "approved",
|
||||||
|
"generated_tools": generated_tools,
|
||||||
|
"agent_yaml": agent_yaml,
|
||||||
|
}
|
||||||
_job_store[job_id] = updated_job
|
_job_store[job_id] = updated_job
|
||||||
return _job_to_response(updated_job)
|
|
||||||
|
response = _job_to_response(updated_job)
|
||||||
|
response["generated_tools_count"] = len(generated_tools)
|
||||||
|
return response
|
||||||
|
|||||||
@@ -3,16 +3,23 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Annotated, Any
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.auth import require_admin_api_key
|
||||||
|
|
||||||
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["replay"])
|
router = APIRouter(
|
||||||
|
prefix="/api",
|
||||||
|
tags=["replay"],
|
||||||
|
dependencies=[Depends(require_admin_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
_COUNT_CONVERSATIONS_SQL = """
|
_COUNT_CONVERSATIONS_SQL = """
|
||||||
SELECT COUNT(*) FROM conversations
|
SELECT COUNT(*) FROM conversations
|
||||||
@@ -38,10 +45,6 @@ async def get_pool(request: Request) -> AsyncConnectionPool:
|
|||||||
return request.app.state.pool
|
return request.app.state.pool
|
||||||
|
|
||||||
|
|
||||||
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
|
||||||
return {"success": success, "data": data, "error": error}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/conversations")
|
@router.get("/conversations")
|
||||||
async def list_conversations(
|
async def list_conversations(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -62,7 +65,7 @@ async def list_conversations(
|
|||||||
)
|
)
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
return _envelope({
|
return envelope({
|
||||||
"conversations": [dict(row) for row in rows],
|
"conversations": [dict(row) for row in rows],
|
||||||
"total": total,
|
"total": total,
|
||||||
"page": page,
|
"page": page,
|
||||||
@@ -119,4 +122,4 @@ async def get_replay(
|
|||||||
for s in page_steps
|
for s in page_steps
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
return _envelope(data)
|
return envelope(data)
|
||||||
|
|||||||
@@ -1,9 +1,18 @@
|
|||||||
"""Session TTL management with sliding window and interrupt extension."""
|
"""Session TTL management with sliding window and interrupt extension.
|
||||||
|
|
||||||
|
Provides both in-memory (SessionManager) and PostgreSQL-backed
|
||||||
|
(PgSessionManager) implementations behind a common Protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -13,8 +22,19 @@ class SessionState:
|
|||||||
has_pending_interrupt: bool
|
has_pending_interrupt: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManagerProtocol(Protocol):
|
||||||
|
"""Protocol for session TTL management."""
|
||||||
|
|
||||||
|
def touch(self, thread_id: str) -> SessionState: ...
|
||||||
|
def is_expired(self, thread_id: str) -> bool: ...
|
||||||
|
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
|
||||||
|
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
|
||||||
|
def get_state(self, thread_id: str) -> SessionState | None: ...
|
||||||
|
def remove(self, thread_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
"""Manages session TTL with sliding window and interrupt extensions.
|
"""In-memory session manager for single-worker development.
|
||||||
|
|
||||||
- Each message resets the TTL (sliding window).
|
- Each message resets the TTL (sliding window).
|
||||||
- A pending interrupt suspends expiration until resolved.
|
- A pending interrupt suspends expiration until resolved.
|
||||||
@@ -40,10 +60,8 @@ class SessionManager:
|
|||||||
state = self._sessions.get(thread_id)
|
state = self._sessions.get(thread_id)
|
||||||
if state is None:
|
if state is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if state.has_pending_interrupt:
|
if state.has_pending_interrupt:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
elapsed = time.time() - state.last_activity
|
elapsed = time.time() - state.last_activity
|
||||||
return elapsed > self._session_ttl
|
return elapsed > self._session_ttl
|
||||||
|
|
||||||
@@ -52,7 +70,6 @@ class SessionManager:
|
|||||||
existing = self._sessions.get(thread_id)
|
existing = self._sessions.get(thread_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
return self.touch(thread_id)
|
return self.touch(thread_id)
|
||||||
|
|
||||||
new_state = SessionState(
|
new_state = SessionState(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
last_activity=existing.last_activity,
|
last_activity=existing.last_activity,
|
||||||
@@ -76,3 +93,120 @@ class SessionManager:
|
|||||||
|
|
||||||
def remove(self, thread_id: str) -> None:
|
def remove(self, thread_id: str) -> None:
|
||||||
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for explicit naming
|
||||||
|
InMemorySessionManager = SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
class PgSessionManager:
|
||||||
|
"""PostgreSQL-backed session manager for multi-worker production."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
session_ttl_seconds: int = 1800,
|
||||||
|
) -> None:
|
||||||
|
self._pool = pool
|
||||||
|
self._session_ttl = session_ttl_seconds
|
||||||
|
|
||||||
|
def touch(self, thread_id: str) -> SessionState:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
|
||||||
|
|
||||||
|
async def _touch(self, thread_id: str) -> SessionState:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||||
|
VALUES (%(tid)s, %(now)s, FALSE)
|
||||||
|
ON CONFLICT (thread_id) DO UPDATE
|
||||||
|
SET last_activity = %(now)s
|
||||||
|
""",
|
||||||
|
{"tid": thread_id, "now": now},
|
||||||
|
)
|
||||||
|
return SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=now.timestamp(),
|
||||||
|
has_pending_interrupt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_expired(self, thread_id: str) -> bool:
|
||||||
|
state = self.get_state(thread_id)
|
||||||
|
if state is None:
|
||||||
|
return True
|
||||||
|
if state.has_pending_interrupt:
|
||||||
|
return False
|
||||||
|
elapsed = time.time() - state.last_activity
|
||||||
|
return elapsed > self._session_ttl
|
||||||
|
|
||||||
|
def extend_for_interrupt(self, thread_id: str) -> SessionState:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._set_interrupt(thread_id, True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve_interrupt(self, thread_id: str) -> SessionState:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._set_interrupt(thread_id, False)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _set_interrupt(
|
||||||
|
self, thread_id: str, has_interrupt: bool
|
||||||
|
) -> SessionState:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||||
|
VALUES (%(tid)s, %(now)s, %(interrupt)s)
|
||||||
|
ON CONFLICT (thread_id) DO UPDATE
|
||||||
|
SET last_activity = %(now)s,
|
||||||
|
has_pending_interrupt = %(interrupt)s
|
||||||
|
""",
|
||||||
|
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
|
||||||
|
)
|
||||||
|
return SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=now.timestamp(),
|
||||||
|
has_pending_interrupt=has_interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_state(self, thread_id: str) -> SessionState | None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._get_state(thread_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_state(self, thread_id: str) -> SessionState | None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(
|
||||||
|
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=row["last_activity"].timestamp(),
|
||||||
|
has_pending_interrupt=row["has_pending_interrupt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove(self, thread_id: str) -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
|
||||||
|
|
||||||
|
async def _remove(self, thread_id: str) -> None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"DELETE FROM sessions WHERE thread_id = %(tid)s",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
|||||||
30
backend/app/ws_context.py
Normal file
30
backend/app/ws_context.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""WebSocketContext -- bundles all dependencies needed by dispatch_message."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.analytics.event_recorder import AnalyticsRecorder
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.conversation_tracker import ConversationTrackerProtocol
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WebSocketContext:
|
||||||
|
"""All dependencies required for WebSocket message processing.
|
||||||
|
|
||||||
|
Replaces the previous 9-parameter function signature in dispatch_message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_ctx: GraphContext
|
||||||
|
session_manager: SessionManager
|
||||||
|
callback_handler: TokenUsageCallbackHandler
|
||||||
|
interrupt_manager: InterruptManager | None = None
|
||||||
|
analytics_recorder: AnalyticsRecorder | None = None
|
||||||
|
conversation_tracker: ConversationTrackerProtocol | None = None
|
||||||
|
pool: Any = None
|
||||||
@@ -7,22 +7,19 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from app.graph import classify_intent
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
|
||||||
|
|
||||||
from app.analytics.event_recorder import AnalyticsRecorder
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
from app.conversation_tracker import ConversationTrackerProtocol
|
from app.graph_context import GraphContext
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -46,7 +43,7 @@ def _evict_stale_threads(cutoff: float) -> None:
|
|||||||
|
|
||||||
async def handle_user_message(
|
async def handle_user_message(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
graph: CompiledStateGraph,
|
ctx: GraphContext,
|
||||||
session_manager: SessionManager,
|
session_manager: SessionManager,
|
||||||
callback_handler: TokenUsageCallbackHandler,
|
callback_handler: TokenUsageCallbackHandler,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -54,8 +51,6 @@ async def handle_user_message(
|
|||||||
interrupt_manager: InterruptManager | None = None,
|
interrupt_manager: InterruptManager | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process a user message through the graph and stream results back."""
|
"""Process a user message through the graph and stream results back."""
|
||||||
# Touch first so new sessions are created before expiry check.
|
|
||||||
# For existing sessions, touch resets the sliding window.
|
|
||||||
existing = session_manager.get_state(thread_id)
|
existing = session_manager.get_state(thread_id)
|
||||||
if existing is not None and session_manager.is_expired(thread_id):
|
if existing is not None and session_manager.is_expired(thread_id):
|
||||||
msg = "Session expired. Please start a new conversation."
|
msg = "Session expired. Please start a new conversation."
|
||||||
@@ -64,8 +59,7 @@ async def handle_user_message(
|
|||||||
|
|
||||||
session_manager.touch(thread_id)
|
session_manager.touch(thread_id)
|
||||||
|
|
||||||
# Run intent classification if available (for logging/future multi-intent)
|
classification = await ctx.classify_intent(content)
|
||||||
classification = await classify_intent(graph, content)
|
|
||||||
if classification is not None:
|
if classification is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
||||||
@@ -74,7 +68,6 @@ async def handle_user_message(
|
|||||||
[i.agent_name for i in classification.intents],
|
[i.agent_name for i in classification.intents],
|
||||||
)
|
)
|
||||||
|
|
||||||
# If ambiguous, send clarification and return
|
|
||||||
if classification.is_ambiguous and classification.clarification_question:
|
if classification.is_ambiguous and classification.clarification_question:
|
||||||
await _send_json(
|
await _send_json(
|
||||||
ws,
|
ws,
|
||||||
@@ -89,7 +82,6 @@ async def handle_user_message(
|
|||||||
|
|
||||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||||
|
|
||||||
# If multi-intent detected, add routing hint to the message
|
|
||||||
if classification and len(classification.intents) > 1:
|
if classification and len(classification.intents) > 1:
|
||||||
agent_names = [i.agent_name for i in classification.intents]
|
agent_names = [i.agent_name for i in classification.intents]
|
||||||
hint = (
|
hint = (
|
||||||
@@ -101,7 +93,7 @@ async def handle_user_message(
|
|||||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||||
msg_chunk, metadata = chunk
|
msg_chunk, metadata = chunk
|
||||||
node = metadata.get("langgraph_node", "")
|
node = metadata.get("langgraph_node", "")
|
||||||
|
|
||||||
@@ -126,12 +118,11 @@ async def handle_user_message(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
state = await graph.aget_state(config)
|
state = await ctx.graph.aget_state(config)
|
||||||
if _has_interrupt(state):
|
if _has_interrupt(state):
|
||||||
interrupt_data = _extract_interrupt(state)
|
interrupt_data = _extract_interrupt(state)
|
||||||
session_manager.extend_for_interrupt(thread_id)
|
session_manager.extend_for_interrupt(thread_id)
|
||||||
|
|
||||||
# Register interrupt with TTL tracking
|
|
||||||
if interrupt_manager is not None:
|
if interrupt_manager is not None:
|
||||||
interrupt_manager.register(
|
interrupt_manager.register(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -158,7 +149,7 @@ async def handle_user_message(
|
|||||||
|
|
||||||
async def handle_interrupt_response(
|
async def handle_interrupt_response(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
graph: CompiledStateGraph,
|
ctx: GraphContext,
|
||||||
session_manager: SessionManager,
|
session_manager: SessionManager,
|
||||||
callback_handler: TokenUsageCallbackHandler,
|
callback_handler: TokenUsageCallbackHandler,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -166,7 +157,6 @@ async def handle_interrupt_response(
|
|||||||
interrupt_manager: InterruptManager | None = None,
|
interrupt_manager: InterruptManager | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Resume graph execution after interrupt approval/rejection."""
|
"""Resume graph execution after interrupt approval/rejection."""
|
||||||
# Check interrupt TTL before resuming
|
|
||||||
if interrupt_manager is not None:
|
if interrupt_manager is not None:
|
||||||
status = interrupt_manager.check_status(thread_id)
|
status = interrupt_manager.check_status(thread_id)
|
||||||
if status is not None and status.is_expired:
|
if status is not None and status.is_expired:
|
||||||
@@ -184,7 +174,7 @@ async def handle_interrupt_response(
|
|||||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in graph.astream(
|
async for chunk in ctx.graph.astream(
|
||||||
Command(resume=approved),
|
Command(resume=approved),
|
||||||
config=config,
|
config=config,
|
||||||
stream_mode="messages",
|
stream_mode="messages",
|
||||||
@@ -212,14 +202,8 @@ async def handle_interrupt_response(
|
|||||||
|
|
||||||
async def dispatch_message(
|
async def dispatch_message(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
graph: CompiledStateGraph,
|
ctx: WebSocketContext,
|
||||||
session_manager: SessionManager,
|
|
||||||
callback_handler: TokenUsageCallbackHandler,
|
|
||||||
raw_data: str,
|
raw_data: str,
|
||||||
interrupt_manager: InterruptManager | None = None,
|
|
||||||
analytics_recorder: AnalyticsRecorder | None = None,
|
|
||||||
conversation_tracker: ConversationTrackerProtocol | None = None,
|
|
||||||
pool: Any = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Parse and route an incoming WebSocket message."""
|
"""Parse and route an incoming WebSocket message."""
|
||||||
if len(raw_data) > MAX_MESSAGE_SIZE:
|
if len(raw_data) > MAX_MESSAGE_SIZE:
|
||||||
@@ -268,14 +252,15 @@ async def dispatch_message(
|
|||||||
_thread_timestamps[thread_id] = [*recent, now]
|
_thread_timestamps[thread_id] = [*recent, now]
|
||||||
|
|
||||||
await handle_user_message(
|
await handle_user_message(
|
||||||
ws, graph, session_manager, callback_handler, thread_id, content,
|
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||||
interrupt_manager=interrupt_manager,
|
thread_id, content,
|
||||||
|
interrupt_manager=ctx.interrupt_manager,
|
||||||
)
|
)
|
||||||
await _fire_and_forget_tracking(
|
await _fire_and_forget_tracking(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
pool=pool,
|
pool=ctx.pool,
|
||||||
analytics_recorder=analytics_recorder,
|
analytics_recorder=ctx.analytics_recorder,
|
||||||
conversation_tracker=conversation_tracker,
|
conversation_tracker=ctx.conversation_tracker,
|
||||||
agent_name=None,
|
agent_name=None,
|
||||||
tokens=0,
|
tokens=0,
|
||||||
cost=0.0,
|
cost=0.0,
|
||||||
@@ -284,8 +269,9 @@ async def dispatch_message(
|
|||||||
elif msg_type == "interrupt_response":
|
elif msg_type == "interrupt_response":
|
||||||
approved = data.get("approved", False)
|
approved = data.get("approved", False)
|
||||||
await handle_interrupt_response(
|
await handle_interrupt_response(
|
||||||
ws, graph, session_manager, callback_handler, thread_id, approved,
|
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||||
interrupt_manager=interrupt_manager,
|
thread_id, approved,
|
||||||
|
interrupt_manager=ctx.interrupt_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -294,9 +280,9 @@ async def dispatch_message(
|
|||||||
|
|
||||||
async def _fire_and_forget_tracking(
|
async def _fire_and_forget_tracking(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
pool: Any,
|
pool: object,
|
||||||
analytics_recorder: Any | None,
|
analytics_recorder: object | None,
|
||||||
conversation_tracker: Any | None,
|
conversation_tracker: object | None,
|
||||||
agent_name: str | None,
|
agent_name: str | None,
|
||||||
tokens: int,
|
tokens: int,
|
||||||
cost: float,
|
cost: float,
|
||||||
|
|||||||
@@ -13,10 +13,12 @@ from httpx import ASGITransport, AsyncClient
|
|||||||
|
|
||||||
from app.analytics.api import router as analytics_router
|
from app.analytics.api import router as analytics_router
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.openapi.review_api import _job_store, router as openapi_router
|
from app.openapi.review_api import _job_store, router as openapi_router
|
||||||
from app.replay.api import router as replay_router
|
from app.replay.api import router as replay_router
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
|
|
||||||
@@ -74,8 +76,6 @@ def make_graph(
|
|||||||
) -> MagicMock:
|
) -> MagicMock:
|
||||||
"""Build a mock LangGraph CompiledStateGraph."""
|
"""Build a mock LangGraph CompiledStateGraph."""
|
||||||
g = MagicMock()
|
g = MagicMock()
|
||||||
g.intent_classifier = None
|
|
||||||
g.agent_registry = None
|
|
||||||
|
|
||||||
if state is None:
|
if state is None:
|
||||||
state = make_state()
|
state = make_state()
|
||||||
@@ -93,6 +93,14 @@ def make_graph(
|
|||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||||
|
"""Build a GraphContext wrapping a mock graph."""
|
||||||
|
g = graph or make_graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Fake database pool
|
# Fake database pool
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -148,6 +156,7 @@ def create_e2e_app(
|
|||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
"""Create a FastAPI app wired with mocked dependencies for E2E testing."""
|
"""Create a FastAPI app wired with mocked dependencies for E2E testing."""
|
||||||
g = graph or make_graph()
|
g = graph or make_graph()
|
||||||
|
graph_ctx = make_graph_ctx(g)
|
||||||
p = pool or FakePool()
|
p = pool or FakePool()
|
||||||
sm = SessionManager(session_ttl_seconds=session_ttl)
|
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||||
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||||
@@ -157,7 +166,7 @@ def create_e2e_app(
|
|||||||
app.include_router(replay_router)
|
app.include_router(replay_router)
|
||||||
app.include_router(analytics_router)
|
app.include_router(analytics_router)
|
||||||
|
|
||||||
app.state.graph = g
|
app.state.graph_ctx = graph_ctx
|
||||||
app.state.session_manager = sm
|
app.state.session_manager = sm
|
||||||
app.state.interrupt_manager = im
|
app.state.interrupt_manager = im
|
||||||
app.state.pool = p
|
app.state.pool = p
|
||||||
@@ -175,17 +184,16 @@ def create_e2e_app(
|
|||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
raw_data = await ws.receive_text()
|
raw_data = await ws.receive_text()
|
||||||
await dispatch_message(
|
ws_ctx = WebSocketContext(
|
||||||
ws,
|
graph_ctx=app.state.graph_ctx,
|
||||||
app.state.graph,
|
session_manager=app.state.session_manager,
|
||||||
app.state.session_manager,
|
callback_handler=TokenUsageCallbackHandler(model_name="test-model"),
|
||||||
TokenUsageCallbackHandler(model_name="test-model"),
|
|
||||||
raw_data,
|
|
||||||
interrupt_manager=app.state.interrupt_manager,
|
interrupt_manager=app.state.interrupt_manager,
|
||||||
analytics_recorder=app.state.analytics_recorder,
|
analytics_recorder=app.state.analytics_recorder,
|
||||||
conversation_tracker=app.state.conversation_tracker,
|
conversation_tracker=app.state.conversation_tracker,
|
||||||
pool=app.state.pool,
|
pool=app.state.pool,
|
||||||
)
|
)
|
||||||
|
await dispatch_message(ws, ws_ctx, raw_data)
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -20,10 +20,12 @@ import pytest
|
|||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
|
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.registry import AgentConfig, AgentRegistry
|
from app.registry import AgentConfig, AgentRegistry
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||||
@@ -128,10 +130,8 @@ class TestCheckpoint1OrderQueryRouting:
|
|||||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||||
))
|
))
|
||||||
graph.intent_classifier = mock_classifier
|
|
||||||
mock_registry = MagicMock()
|
mock_registry = MagicMock()
|
||||||
mock_registry.list_agents = MagicMock(return_value=())
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = mock_registry
|
|
||||||
|
|
||||||
# Graph streams order_lookup response
|
# Graph streams order_lookup response
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([
|
graph.astream = MagicMock(return_value=AsyncIterHelper([
|
||||||
@@ -140,14 +140,21 @@ class TestCheckpoint1OrderQueryRouting:
|
|||||||
]))
|
]))
|
||||||
graph.aget_state = AsyncMock(return_value=_state())
|
graph.aget_state = AsyncMock(return_value=_state())
|
||||||
|
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
|
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||||
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
|
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
|
||||||
@@ -201,25 +208,30 @@ class TestCheckpoint2MultiIntentSequential:
|
|||||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
graph.intent_classifier = mock_classifier
|
|
||||||
mock_registry = MagicMock()
|
mock_registry = MagicMock()
|
||||||
mock_registry.list_agents = MagicMock(return_value=())
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = mock_registry
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
graph.aget_state = AsyncMock(return_value=_state())
|
graph.aget_state = AsyncMock(return_value=_state())
|
||||||
|
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({
|
raw = json.dumps({
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"thread_id": "t1",
|
"thread_id": "t1",
|
||||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||||
})
|
})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
# Verify the graph was called with the routing hint in the message
|
# Verify the graph was called with the routing hint in the message
|
||||||
call_args = graph.astream.call_args
|
call_args = graph.astream.call_args
|
||||||
@@ -267,21 +279,26 @@ class TestCheckpoint3AmbiguousClarification:
|
|||||||
"Could you please provide more details about what you need help with?"
|
"Could you please provide more details about what you need help with?"
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
graph.intent_classifier = mock_classifier
|
|
||||||
mock_registry = MagicMock()
|
mock_registry = MagicMock()
|
||||||
mock_registry.list_agents = MagicMock(return_value=())
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = mock_registry
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
graph.aget_state = AsyncMock(return_value=_state())
|
graph.aget_state = AsyncMock(return_value=_state())
|
||||||
|
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
|
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
|
||||||
assert len(clarifications) == 1
|
assert len(clarifications) == 1
|
||||||
@@ -303,20 +320,26 @@ class TestCheckpoint4InterruptTTLAutoCancel:
|
|||||||
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
|
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
|
||||||
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||||
graph = MagicMock()
|
graph = MagicMock()
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
graph.aget_state = AsyncMock(return_value=st)
|
graph.aget_state = AsyncMock(return_value=st)
|
||||||
|
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
|
||||||
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager(ttl_seconds=1800) # 30 minutes
|
im = InterruptManager(ttl_seconds=1800) # 30 minutes
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
# Trigger interrupt
|
# Trigger interrupt
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||||
assert len(interrupts) == 1
|
assert len(interrupts) == 1
|
||||||
@@ -333,7 +356,7 @@ class TestCheckpoint4InterruptTTLAutoCancel:
|
|||||||
"thread_id": "t1",
|
"thread_id": "t1",
|
||||||
"approved": True,
|
"approved": True,
|
||||||
})
|
})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
# Should get retry prompt, NOT resume the graph
|
# Should get retry prompt, NOT resume the graph
|
||||||
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
|
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
|
||||||
|
|||||||
@@ -18,10 +18,12 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.registry import AgentConfig
|
from app.registry import AgentConfig
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -103,36 +105,45 @@ def _make_classifier(result: ClassificationResult) -> AsyncMock:
|
|||||||
return classifier
|
return classifier
|
||||||
|
|
||||||
|
|
||||||
def _make_graph(
|
def _make_graph_and_ctx(
|
||||||
classifier_result: ClassificationResult | None,
|
classifier_result: ClassificationResult | None,
|
||||||
chunks: list,
|
chunks: list,
|
||||||
state=None,
|
state=None,
|
||||||
) -> MagicMock:
|
) -> tuple[MagicMock, GraphContext]:
|
||||||
"""Build a graph mock with optional intent classifier."""
|
"""Build a graph mock and GraphContext with optional intent classifier."""
|
||||||
graph = MagicMock()
|
graph = MagicMock()
|
||||||
|
|
||||||
if classifier_result is not None:
|
|
||||||
graph.intent_classifier = _make_classifier(classifier_result)
|
|
||||||
mock_registry = MagicMock()
|
|
||||||
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
|
||||||
graph.agent_registry = mock_registry
|
|
||||||
else:
|
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
|
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
|
||||||
graph.aget_state = AsyncMock(return_value=state or _state())
|
graph.aget_state = AsyncMock(return_value=state or _state())
|
||||||
return graph
|
|
||||||
|
if classifier_result is not None:
|
||||||
|
classifier = _make_classifier(classifier_result)
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=classifier,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return graph, graph_ctx
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]:
|
async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch(thread_id)
|
sm.touch(thread_id)
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
|
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
return ws.sent
|
return ws.sent
|
||||||
|
|
||||||
|
|
||||||
@@ -151,12 +162,12 @@ class TestSingleIntentRouting:
|
|||||||
agent_name="order_lookup", confidence=0.95, reasoning="status query",
|
agent_name="order_lookup", confidence=0.95, reasoning="status query",
|
||||||
),),
|
),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
||||||
_chunk("Order 1042 is shipped.", "order_lookup"),
|
_chunk("Order 1042 is shipped.", "order_lookup"),
|
||||||
])
|
])
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "What is the status of order 1042?")
|
msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
|
||||||
|
|
||||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||||
assert len(tools) == 1
|
assert len(tools) == 1
|
||||||
@@ -171,13 +182,13 @@ class TestSingleIntentRouting:
|
|||||||
result = ClassificationResult(
|
result = ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
|
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(
|
graph, graph_ctx = _make_graph_and_ctx(
|
||||||
result,
|
result,
|
||||||
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
|
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
|
||||||
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
|
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
|
||||||
)
|
)
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "Cancel order 1042")
|
msgs = await _dispatch(graph_ctx, "Cancel order 1042")
|
||||||
|
|
||||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||||
assert tools[0]["tool"] == "cancel_order"
|
assert tools[0]["tool"] == "cancel_order"
|
||||||
@@ -191,12 +202,12 @@ class TestSingleIntentRouting:
|
|||||||
result = ClassificationResult(
|
result = ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
|
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
|
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
|
||||||
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
|
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
|
||||||
])
|
])
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "Give me a 15% coupon")
|
msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
|
||||||
|
|
||||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||||
assert tools[0]["tool"] == "generate_coupon"
|
assert tools[0]["tool"] == "generate_coupon"
|
||||||
@@ -207,11 +218,11 @@ class TestSingleIntentRouting:
|
|||||||
result = ClassificationResult(
|
result = ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
|
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
_chunk("I can help with order inquiries.", "fallback"),
|
_chunk("I can help with order inquiries.", "fallback"),
|
||||||
])
|
])
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "What can you do?")
|
msgs = await _dispatch(graph_ctx, "What can you do?")
|
||||||
|
|
||||||
tokens = [m for m in msgs if m["type"] == "token"]
|
tokens = [m for m in msgs if m["type"] == "token"]
|
||||||
assert tokens[0]["agent"] == "fallback"
|
assert tokens[0]["agent"] == "fallback"
|
||||||
@@ -233,7 +244,7 @@ class TestMultiIntentRouting:
|
|||||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
|
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
|
||||||
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
|
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
|
||||||
])
|
])
|
||||||
@@ -243,13 +254,17 @@ class TestMultiIntentRouting:
|
|||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({
|
raw = json.dumps({
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"thread_id": "t1",
|
"thread_id": "t1",
|
||||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||||
})
|
})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
# Verify routing hint was injected
|
# Verify routing hint was injected
|
||||||
call_args = graph.astream.call_args[0][0]
|
call_args = graph.astream.call_args[0][0]
|
||||||
@@ -269,16 +284,20 @@ class TestMultiIntentRouting:
|
|||||||
result = ClassificationResult(
|
result = ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")])
|
graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
|
||||||
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
msg_content = graph.astream.call_args[0][0]["messages"][0].content
|
msg_content = graph.astream.call_args[0][0]["messages"][0].content
|
||||||
assert "[System:" not in msg_content
|
assert "[System:" not in msg_content
|
||||||
@@ -299,9 +318,9 @@ class TestAmbiguityRouting:
|
|||||||
is_ambiguous=True,
|
is_ambiguous=True,
|
||||||
clarification_question="Could you please clarify what you need?",
|
clarification_question="Could you please clarify what you need?",
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [])
|
graph, graph_ctx = _make_graph_and_ctx(result, [])
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "嗯...")
|
msgs = await _dispatch(graph_ctx, "嗯...")
|
||||||
|
|
||||||
clarifications = [m for m in msgs if m["type"] == "clarification"]
|
clarifications = [m for m in msgs if m["type"] == "clarification"]
|
||||||
assert len(clarifications) == 1
|
assert len(clarifications) == 1
|
||||||
@@ -339,12 +358,12 @@ class TestNoClassifierFallback:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_classifier_routes_via_supervisor(self) -> None:
|
async def test_no_classifier_routes_via_supervisor(self) -> None:
|
||||||
graph = _make_graph(
|
graph, graph_ctx = _make_graph_and_ctx(
|
||||||
classifier_result=None,
|
classifier_result=None,
|
||||||
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
|
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
|
||||||
)
|
)
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "What is order 1042 status?")
|
msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
|
||||||
|
|
||||||
tokens = [m for m in msgs if m["type"] == "token"]
|
tokens = [m for m in msgs if m["type"] == "token"]
|
||||||
assert len(tokens) == 1
|
assert len(tokens) == 1
|
||||||
|
|||||||
@@ -15,8 +15,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -81,8 +83,6 @@ def _graph(
|
|||||||
resume_chunks: list | None = None,
|
resume_chunks: list | None = None,
|
||||||
) -> MagicMock:
|
) -> MagicMock:
|
||||||
g = MagicMock()
|
g = MagicMock()
|
||||||
g.intent_classifier = None
|
|
||||||
g.agent_registry = None
|
|
||||||
|
|
||||||
if st is None:
|
if st is None:
|
||||||
st = _state()
|
st = _state()
|
||||||
@@ -100,6 +100,13 @@ def _graph(
|
|||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||||
|
g = graph or _graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||||
|
|
||||||
|
|
||||||
def _setup(
|
def _setup(
|
||||||
graph=None,
|
graph=None,
|
||||||
session_ttl: int = 1800,
|
session_ttl: int = 1800,
|
||||||
@@ -109,23 +116,28 @@ def _setup(
|
|||||||
):
|
):
|
||||||
"""Create test dependencies. Pre-touches session by default."""
|
"""Create test dependencies. Pre-touches session by default."""
|
||||||
g = graph or _graph()
|
g = graph or _graph()
|
||||||
|
graph_ctx = _make_graph_ctx(g)
|
||||||
sm = SessionManager(session_ttl_seconds=session_ttl)
|
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||||
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
if touch:
|
if touch:
|
||||||
sm.touch(thread_id)
|
sm.touch(thread_id)
|
||||||
return g, sm, im, cb, ws
|
return g, sm, im, cb, ws, ws_ctx
|
||||||
|
|
||||||
|
|
||||||
async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"):
|
async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
|
||||||
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
|
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
|
||||||
async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
|
async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
|
||||||
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
|
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -136,10 +148,10 @@ async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
|
|||||||
class TestWebSocketHappyPath:
|
class TestWebSocketHappyPath:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_message_receives_tokens_and_complete(self) -> None:
|
async def test_send_message_receives_tokens_and_complete(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(
|
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||||
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
|
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
|
||||||
)
|
)
|
||||||
await _send(ws, g, sm, im, cb, content="What is the status of order 1042?")
|
await _send(ws, ws_ctx, content="What is the status of order 1042?")
|
||||||
|
|
||||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||||
assert len(tokens) == 2
|
assert len(tokens) == 2
|
||||||
@@ -153,13 +165,13 @@ class TestWebSocketHappyPath:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_call_streamed(self) -> None:
|
async def test_tool_call_streamed(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(
|
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||||
graph=_graph(chunks=[
|
graph=_graph(chunks=[
|
||||||
_tool_chunk("get_order_status", {"order_id": "1042"}),
|
_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||||
_chunk("Order shipped."),
|
_chunk("Order shipped."),
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
await _send(ws, g, sm, im, cb, content="Check order 1042")
|
await _send(ws, ws_ctx, content="Check order 1042")
|
||||||
|
|
||||||
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||||
assert len(tools) == 1
|
assert len(tools) == 1
|
||||||
@@ -168,9 +180,9 @@ class TestWebSocketHappyPath:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_multiple_messages_same_session(self) -> None:
|
async def test_multiple_messages_same_session(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
await _send(ws, g, sm, im, cb, content=f"msg {i}")
|
await _send(ws, ws_ctx, content=f"msg {i}")
|
||||||
|
|
||||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||||
assert len(completes) == 3
|
assert len(completes) == 3
|
||||||
@@ -183,10 +195,10 @@ class TestWebSocketInterruptApproval:
|
|||||||
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||||
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
|
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
|
||||||
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||||
g_, sm, im, cb, ws = _setup(graph=g)
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||||
|
|
||||||
# Send message -> triggers interrupt
|
# Send message -> triggers interrupt
|
||||||
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||||
|
|
||||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||||
assert len(interrupts) == 1
|
assert len(interrupts) == 1
|
||||||
@@ -196,7 +208,7 @@ class TestWebSocketInterruptApproval:
|
|||||||
|
|
||||||
# Approve
|
# Approve
|
||||||
ws.sent.clear()
|
ws.sent.clear()
|
||||||
await _respond(ws, g_, sm, im, cb, approved=True)
|
await _respond(ws, ws_ctx, approved=True)
|
||||||
|
|
||||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||||
assert len(tokens) == 1
|
assert len(tokens) == 1
|
||||||
@@ -211,12 +223,12 @@ class TestWebSocketInterruptApproval:
|
|||||||
st_int = _state(interrupt=True)
|
st_int = _state(interrupt=True)
|
||||||
resume = [_chunk("Order remains active.", "order_actions")]
|
resume = [_chunk("Order remains active.", "order_actions")]
|
||||||
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||||
g_, sm, im, cb, ws = _setup(graph=g)
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||||
|
|
||||||
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||||
ws.sent.clear()
|
ws.sent.clear()
|
||||||
|
|
||||||
await _respond(ws, g_, sm, im, cb, approved=False)
|
await _respond(ws, ws_ctx, approved=False)
|
||||||
|
|
||||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||||
assert "remains active" in tokens[0]["content"]
|
assert "remains active" in tokens[0]["content"]
|
||||||
@@ -226,28 +238,28 @@ class TestWebSocketInterruptApproval:
|
|||||||
class TestWebSocketSessionTTL:
|
class TestWebSocketSessionTTL:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_expired_session_returns_error(self) -> None:
|
async def test_expired_session_returns_error(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(session_ttl=0)
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
|
||||||
# Session was touched in _setup, but TTL is 0 so it's already expired
|
# Session was touched in _setup, but TTL is 0 so it's already expired
|
||||||
await _send(ws, g, sm, im, cb, content="hello")
|
await _send(ws, ws_ctx, content="hello")
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "expired" in ws.sent[0]["message"].lower()
|
assert "expired" in ws.sent[0]["message"].lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_session_not_expired(self) -> None:
|
async def test_new_session_not_expired(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(session_ttl=3600)
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||||
await _send(ws, g, sm, im, cb, content="hello")
|
await _send(ws, ws_ctx, content="hello")
|
||||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||||
assert len(completes) == 1
|
assert len(completes) == 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sliding_window_resets_on_message(self) -> None:
|
async def test_sliding_window_resets_on_message(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(session_ttl=3600)
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||||
|
|
||||||
await _send(ws, g, sm, im, cb, content="hello")
|
await _send(ws, ws_ctx, content="hello")
|
||||||
first_activity = sm.get_state("t1").last_activity
|
first_activity = sm.get_state("t1").last_activity
|
||||||
|
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
await _send(ws, g, sm, im, cb, content="hello again")
|
await _send(ws, ws_ctx, content="hello again")
|
||||||
second_activity = sm.get_state("t1").last_activity
|
second_activity = sm.get_state("t1").last_activity
|
||||||
|
|
||||||
assert second_activity > first_activity
|
assert second_activity > first_activity
|
||||||
@@ -256,9 +268,9 @@ class TestWebSocketSessionTTL:
|
|||||||
async def test_interrupt_extends_session_ttl(self) -> None:
|
async def test_interrupt_extends_session_ttl(self) -> None:
|
||||||
st_int = _state(interrupt=True)
|
st_int = _state(interrupt=True)
|
||||||
g = _graph(chunks=[], st=st_int)
|
g = _graph(chunks=[], st=st_int)
|
||||||
g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600)
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
|
||||||
|
|
||||||
await _send(ws, g_, sm, im, cb, content="cancel order")
|
await _send(ws, ws_ctx, content="cancel order")
|
||||||
|
|
||||||
state = sm.get_state("t1")
|
state = sm.get_state("t1")
|
||||||
assert state is not None
|
assert state is not None
|
||||||
@@ -270,53 +282,53 @@ class TestWebSocketSessionTTL:
|
|||||||
class TestWebSocketValidation:
|
class TestWebSocketValidation:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_json(self) -> None:
|
async def test_invalid_json(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, "not json")
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "Invalid JSON" in ws.sent[0]["message"]
|
assert "Invalid JSON" in ws.sent[0]["message"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_thread_id(self) -> None:
|
async def test_missing_thread_id(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "message", "content": "hi"})
|
raw = json.dumps({"type": "message", "content": "hi"})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "thread_id" in ws.sent[0]["message"]
|
assert "thread_id" in ws.sent[0]["message"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_thread_id_format(self) -> None:
|
async def test_invalid_thread_id_format(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
|
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_content(self) -> None:
|
async def test_missing_content(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1"})
|
raw = json.dumps({"type": "message", "thread_id": "t1"})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_message_type(self) -> None:
|
async def test_unknown_message_type(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
|
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "Unknown" in ws.sent[0]["message"]
|
assert "Unknown" in ws.sent[0]["message"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_too_large(self) -> None:
|
async def test_message_too_large(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, "x" * 40_000)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "too large" in ws.sent[0]["message"].lower()
|
assert "too large" in ws.sent[0]["message"].lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_content_too_long(self) -> None:
|
async def test_content_too_long(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "too long" in ws.sent[0]["message"].lower()
|
assert "too long" in ws.sent[0]["message"].lower()
|
||||||
|
|
||||||
@@ -327,10 +339,10 @@ class TestWebSocketInterruptTTL:
|
|||||||
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
||||||
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||||
g = _graph(chunks=[], st=st_int)
|
g = _graph(chunks=[], st=st_int)
|
||||||
g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5)
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
|
||||||
|
|
||||||
# Trigger interrupt
|
# Trigger interrupt
|
||||||
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||||
|
|
||||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||||
assert len(interrupts) == 1
|
assert len(interrupts) == 1
|
||||||
@@ -341,7 +353,7 @@ class TestWebSocketInterruptTTL:
|
|||||||
|
|
||||||
with patch("app.interrupt_manager.time") as mock_time:
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
mock_time.time.return_value = record.created_at + 10
|
mock_time.time.return_value = record.created_at + 10
|
||||||
await _respond(ws, g_, sm, im, cb, approved=True)
|
await _respond(ws, ws_ctx, approved=True)
|
||||||
|
|
||||||
assert ws.sent[0]["type"] == "interrupt_expired"
|
assert ws.sent[0]["type"] == "interrupt_expired"
|
||||||
assert "cancel_order" in ws.sent[0]["message"]
|
assert "cancel_order" in ws.sent[0]["message"]
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class TestDbModule:
|
|||||||
from app.db import setup_app_tables
|
from app.db import setup_app_tables
|
||||||
|
|
||||||
await setup_app_tables(mock_pool)
|
await setup_app_tables(mock_pool)
|
||||||
assert mock_conn.execute.await_count == 4
|
assert mock_conn.execute.await_count == 5
|
||||||
|
|
||||||
def test_ddl_statements_valid(self) -> None:
|
def test_ddl_statements_valid(self) -> None:
|
||||||
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
||||||
|
|||||||
@@ -51,5 +51,5 @@ class TestAnalyticsEventsDDL:
|
|||||||
from app.db import setup_app_tables
|
from app.db import setup_app_tables
|
||||||
|
|
||||||
await setup_app_tables(mock_pool)
|
await setup_app_tables(mock_pool)
|
||||||
# Now expects 4 statements: conversations, interrupts, analytics_events, migrations
|
# Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations
|
||||||
assert mock_conn.execute.await_count == 4
|
assert mock_conn.execute.await_count == 5
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
@@ -20,7 +22,7 @@ def _make_ws() -> AsyncMock:
|
|||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
|
||||||
def _make_graph() -> AsyncMock:
|
def _make_graph() -> MagicMock:
|
||||||
graph = AsyncMock()
|
graph = AsyncMock()
|
||||||
|
|
||||||
class AsyncIterHelper:
|
class AsyncIterHelper:
|
||||||
@@ -34,23 +36,32 @@ def _make_graph() -> AsyncMock:
|
|||||||
state = MagicMock()
|
state = MagicMock()
|
||||||
state.tasks = ()
|
state.tasks = ()
|
||||||
graph.aget_state = AsyncMock(return_value=state)
|
graph.aget_state = AsyncMock(return_value=state)
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
|
||||||
|
graph = _make_graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
|
||||||
|
return WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx,
|
||||||
|
session_manager=sm or SessionManager(),
|
||||||
|
callback_handler=TokenUsageCallbackHandler(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
class TestEmptyMessageHandling:
|
class TestEmptyMessageHandling:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_message_content_returns_error(self) -> None:
|
async def test_empty_message_content_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -60,13 +71,12 @@ class TestEmptyMessageHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_whitespace_only_message_treated_as_empty(self) -> None:
|
async def test_whitespace_only_message_treated_as_empty(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -77,14 +87,13 @@ class TestOversizedMessageHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_content_over_10000_chars_returns_error(self) -> None:
|
async def test_content_over_10000_chars_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
content = "x" * 10001
|
content = "x" * 10001
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -93,14 +102,13 @@ class TestOversizedMessageHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
|
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
content = "x" * 10000
|
content = "x" * 10000
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
# Should be processed, not an error about length
|
# Should be processed, not an error about length
|
||||||
@@ -110,12 +118,10 @@ class TestOversizedMessageHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_raw_message_over_32kb_returns_error(self) -> None:
|
async def test_raw_message_over_32kb_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
large_msg = "x" * 40_000
|
large_msg = "x" * 40_000
|
||||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
await dispatch_message(ws, ws_ctx, large_msg)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -127,11 +133,9 @@ class TestInvalidJsonHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_json_returns_error(self) -> None:
|
async def test_invalid_json_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
await dispatch_message(ws, graph, sm, cb, "not valid json {{")
|
await dispatch_message(ws, ws_ctx, "not valid json {{")
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -140,11 +144,9 @@ class TestInvalidJsonHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_string_returns_json_error(self) -> None:
|
async def test_empty_string_returns_json_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
await dispatch_message(ws, graph, sm, cb, "")
|
await dispatch_message(ws, ws_ctx, "")
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -152,11 +154,9 @@ class TestInvalidJsonHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_json_array_not_object_returns_error(self) -> None:
|
async def test_json_array_not_object_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]')
|
await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -167,17 +167,15 @@ class TestRateLimiting:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rapid_fire_messages_rate_limited(self) -> None:
|
async def test_rapid_fire_messages_rate_limited(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
_make_graph() # ensure graph factory works, not needed directly
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
|
|
||||||
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
|
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
|
||||||
rate_limit_triggered = False
|
rate_limit_triggered = False
|
||||||
for i in range(11):
|
for i in range(11):
|
||||||
graph2 = _make_graph() # fresh graph each time
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
await dispatch_message(ws, ws_ctx, json.dumps({
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"thread_id": "t1",
|
"thread_id": "t1",
|
||||||
"content": f"message {i}",
|
"content": f"message {i}",
|
||||||
@@ -193,19 +191,18 @@ class TestRateLimiting:
|
|||||||
async def test_different_threads_have_separate_rate_limits(self) -> None:
|
async def test_different_threads_have_separate_rate_limits(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
sm.touch("t2")
|
sm.touch("t2")
|
||||||
|
|
||||||
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
|
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
graph1 = _make_graph()
|
ws_ctx1 = _make_ws_ctx(sm=sm)
|
||||||
graph2 = _make_graph()
|
ws_ctx2 = _make_ws_ctx(sm=sm)
|
||||||
await dispatch_message(ws, graph1, sm, cb, json.dumps({
|
await dispatch_message(ws, ws_ctx1, json.dumps({
|
||||||
"type": "message", "thread_id": "t1", "content": f"msg {i}",
|
"type": "message", "thread_id": "t1", "content": f"msg {i}",
|
||||||
}))
|
}))
|
||||||
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
await dispatch_message(ws, ws_ctx2, json.dumps({
|
||||||
"type": "message", "thread_id": "t2", "content": f"msg {i}",
|
"type": "message", "thread_id": "t2", "content": f"msg {i}",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
from app.graph import build_agent_nodes, build_graph, classify_intent
|
from app.graph import build_agent_nodes, build_graph
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.intent import ClassificationResult, IntentTarget
|
from app.intent import ClassificationResult, IntentTarget
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -37,8 +38,9 @@ class TestBuildGraph:
|
|||||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
|
|
||||||
graph = build_graph(sample_registry, mock_llm, checkpointer)
|
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||||
assert graph is not None
|
assert graph_ctx is not None
|
||||||
|
assert graph_ctx.graph is not None
|
||||||
|
|
||||||
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
|
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
|
||||||
mock_llm = MagicMock()
|
mock_llm = MagicMock()
|
||||||
@@ -47,11 +49,11 @@ class TestBuildGraph:
|
|||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
mock_classifier = MagicMock()
|
mock_classifier = MagicMock()
|
||||||
|
|
||||||
graph = build_graph(
|
graph_ctx = build_graph(
|
||||||
sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
|
sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
|
||||||
)
|
)
|
||||||
assert graph.intent_classifier is mock_classifier
|
assert graph_ctx.intent_classifier is mock_classifier
|
||||||
assert graph.agent_registry is sample_registry
|
assert graph_ctx.registry is sample_registry
|
||||||
|
|
||||||
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
|
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
|
||||||
mock_llm = MagicMock()
|
mock_llm = MagicMock()
|
||||||
@@ -59,17 +61,18 @@ class TestBuildGraph:
|
|||||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
|
|
||||||
graph = build_graph(sample_registry, mock_llm, checkpointer)
|
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||||
assert graph.intent_classifier is None
|
assert graph_ctx.intent_classifier is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
class TestClassifyIntent:
|
class TestClassifyIntent:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_none_without_classifier(self) -> None:
|
async def test_returns_none_without_classifier(self) -> None:
|
||||||
graph = MagicMock()
|
mock_registry = MagicMock()
|
||||||
graph.intent_classifier = None
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
result = await classify_intent(graph, "hello")
|
graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None)
|
||||||
|
result = await graph_ctx.classify_intent("hello")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -80,11 +83,12 @@ class TestClassifyIntent:
|
|||||||
mock_classifier = AsyncMock()
|
mock_classifier = AsyncMock()
|
||||||
mock_classifier.classify = AsyncMock(return_value=expected)
|
mock_classifier.classify = AsyncMock(return_value=expected)
|
||||||
|
|
||||||
graph = MagicMock()
|
mock_registry = MagicMock()
|
||||||
graph.intent_classifier = mock_classifier
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = MagicMock()
|
graph_ctx = GraphContext(
|
||||||
graph.agent_registry.list_agents = MagicMock(return_value=())
|
graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
|
|
||||||
result = await classify_intent(graph, "check order")
|
result = await graph_ctx.classify_intent("check order")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.intents[0].agent_name == "order_lookup"
|
assert result.intents[0].agent_name == "order_lookup"
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class TestMainModule:
|
|||||||
assert app.title == "Smart Support"
|
assert app.title == "Smart Support"
|
||||||
|
|
||||||
def test_app_version(self) -> None:
|
def test_app_version(self) -> None:
|
||||||
assert app.version == "0.5.0"
|
assert app.version == "0.6.0"
|
||||||
|
|
||||||
def test_agents_yaml_path_exists(self) -> None:
|
def test_agents_yaml_path_exists(self) -> None:
|
||||||
assert AGENTS_YAML.name == "agents.yaml"
|
assert AGENTS_YAML.name == "agents.yaml"
|
||||||
@@ -39,4 +39,4 @@ class TestMainModule:
|
|||||||
assert "/api/health" in routes
|
assert "/api/health" in routes
|
||||||
|
|
||||||
def test_app_version_is_0_5_0(self) -> None:
|
def test_app_version_is_0_5_0(self) -> None:
|
||||||
assert app.version == "0.5.0"
|
assert app.version == "0.6.0"
|
||||||
|
|||||||
@@ -8,8 +8,10 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import (
|
from app.ws_handler import (
|
||||||
_extract_interrupt,
|
_extract_interrupt,
|
||||||
_has_interrupt,
|
_has_interrupt,
|
||||||
@@ -25,18 +27,42 @@ def _make_ws() -> AsyncMock:
|
|||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
|
||||||
def _make_graph() -> AsyncMock:
|
def _make_graph() -> MagicMock:
|
||||||
graph = AsyncMock()
|
graph = AsyncMock()
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
state = MagicMock()
|
state = MagicMock()
|
||||||
state.tasks = ()
|
state.tasks = ()
|
||||||
graph.aget_state = AsyncMock(return_value=state)
|
graph.aget_state = AsyncMock(return_value=state)
|
||||||
# Phase 2: graph needs intent_classifier and agent_registry attrs
|
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||||
|
g = graph or _make_graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws_ctx(
|
||||||
|
graph_ctx: GraphContext | None = None,
|
||||||
|
sm: SessionManager | None = None,
|
||||||
|
cb: TokenUsageCallbackHandler | None = None,
|
||||||
|
interrupt_manager: InterruptManager | None = None,
|
||||||
|
analytics_recorder=None,
|
||||||
|
conversation_tracker=None,
|
||||||
|
pool=None,
|
||||||
|
) -> WebSocketContext:
|
||||||
|
return WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx or _make_graph_ctx(),
|
||||||
|
session_manager=sm or SessionManager(),
|
||||||
|
callback_handler=cb or TokenUsageCallbackHandler(),
|
||||||
|
interrupt_manager=interrupt_manager,
|
||||||
|
analytics_recorder=analytics_recorder,
|
||||||
|
conversation_tracker=conversation_tracker,
|
||||||
|
pool=pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncIterHelper:
|
class AsyncIterHelper:
|
||||||
"""Helper to make a list behave as an async iterator."""
|
"""Helper to make a list behave as an async iterator."""
|
||||||
|
|
||||||
@@ -57,11 +83,9 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_json(self) -> None:
|
async def test_invalid_json(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
await dispatch_message(ws, graph, sm, cb, "not json")
|
await dispatch_message(ws, ws_ctx, "not json")
|
||||||
ws.send_json.assert_awaited_once()
|
ws.send_json.assert_awaited_once()
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -70,12 +94,10 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_thread_id(self) -> None:
|
async def test_missing_thread_id(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "message", "content": "hello"})
|
msg = json.dumps({"type": "message", "content": "hello"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "thread_id" in call_data["message"]
|
assert "thread_id" in call_data["message"]
|
||||||
@@ -83,24 +105,20 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_content(self) -> None:
|
async def test_missing_content(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_message_type(self) -> None:
|
async def test_unknown_message_type(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "Unknown" in call_data["message"]
|
assert "Unknown" in call_data["message"]
|
||||||
@@ -108,12 +126,10 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_too_large(self) -> None:
|
async def test_message_too_large(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
large_msg = "x" * 40_000
|
large_msg = "x" * 40_000
|
||||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
await dispatch_message(ws, ws_ctx, large_msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "too large" in call_data["message"].lower()
|
assert "too large" in call_data["message"].lower()
|
||||||
@@ -121,12 +137,10 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_thread_id_format(self) -> None:
|
async def test_invalid_thread_id_format(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "thread_id" in call_data["message"].lower()
|
assert "thread_id" in call_data["message"].lower()
|
||||||
@@ -134,12 +148,10 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_content_too_long(self) -> None:
|
async def test_content_too_long(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "too long" in call_data["message"].lower()
|
assert "too long" in call_data["message"].lower()
|
||||||
@@ -147,14 +159,13 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_dispatch_with_interrupt_manager(self) -> None:
|
async def test_dispatch_with_interrupt_manager(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm, interrupt_manager=im)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
@@ -164,14 +175,14 @@ class TestHandleUserMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_expired_session(self) -> None:
|
async def test_expired_session(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager(session_ttl_seconds=0)
|
sm = SessionManager(session_ttl_seconds=0)
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
# First call creates the session (TTL=0)
|
# First call creates the session (TTL=0)
|
||||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||||
# Second call finds it expired
|
# Second call finds it expired
|
||||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello again")
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello again")
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "expired" in call_data["message"].lower()
|
assert "expired" in call_data["message"].lower()
|
||||||
@@ -179,12 +190,12 @@ class TestHandleUserMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_successful_message(self) -> None:
|
async def test_successful_message(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
@@ -193,13 +204,12 @@ class TestHandleUserMessage:
|
|||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = AsyncMock()
|
graph = AsyncMock()
|
||||||
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
||||||
graph.intent_classifier = None
|
graph_ctx = _make_graph_ctx(graph=graph)
|
||||||
graph.agent_registry = None
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
|
|
||||||
@@ -207,8 +217,6 @@ class TestHandleUserMessage:
|
|||||||
async def test_interrupt_registered_with_manager(self) -> None:
|
async def test_interrupt_registered_with_manager(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = AsyncMock()
|
graph = AsyncMock()
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
|
|
||||||
# Simulate interrupt in state
|
# Simulate interrupt in state
|
||||||
@@ -220,13 +228,14 @@ class TestHandleUserMessage:
|
|||||||
state.tasks = (task,)
|
state.tasks = (task,)
|
||||||
graph.aget_state = AsyncMock(return_value=state)
|
graph.aget_state = AsyncMock(return_value=state)
|
||||||
|
|
||||||
|
graph_ctx = _make_graph_ctx(graph=graph)
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
await handle_user_message(
|
await handle_user_message(
|
||||||
ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
|
ws, graph_ctx, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Interrupt should be registered
|
# Interrupt should be registered
|
||||||
@@ -257,16 +266,17 @@ class TestHandleUserMessage:
|
|||||||
clarification_question="What do you mean?",
|
clarification_question="What do you mean?",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
graph.intent_classifier = mock_classifier
|
|
||||||
mock_registry = MagicMock()
|
mock_registry = MagicMock()
|
||||||
mock_registry.list_agents = MagicMock(return_value=())
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = mock_registry
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
await handle_user_message(ws, graph, sm, cb, "t1", "hmm")
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hmm")
|
||||||
|
|
||||||
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
||||||
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
|
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
|
||||||
@@ -279,13 +289,13 @@ class TestHandleInterruptResponse:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_approved_interrupt(self) -> None:
|
async def test_approved_interrupt(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
sm.extend_for_interrupt("t1")
|
sm.extend_for_interrupt("t1")
|
||||||
await handle_interrupt_response(ws, graph, sm, cb, "t1", True)
|
await handle_interrupt_response(ws, graph_ctx, sm, cb, "t1", True)
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
@@ -294,7 +304,7 @@ class TestHandleInterruptResponse:
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
im = InterruptManager(ttl_seconds=5)
|
im = InterruptManager(ttl_seconds=5)
|
||||||
@@ -307,7 +317,7 @@ class TestHandleInterruptResponse:
|
|||||||
with patch("app.interrupt_manager.time") as mock_time:
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
|
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
|
||||||
await handle_interrupt_response(
|
await handle_interrupt_response(
|
||||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
|
||||||
)
|
)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
@@ -317,7 +327,7 @@ class TestHandleInterruptResponse:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_interrupt_resolves(self) -> None:
|
async def test_valid_interrupt_resolves(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
im = InterruptManager(ttl_seconds=1800)
|
im = InterruptManager(ttl_seconds=1800)
|
||||||
@@ -327,7 +337,7 @@ class TestHandleInterruptResponse:
|
|||||||
im.register("t1", "cancel_order", {})
|
im.register("t1", "cancel_order", {})
|
||||||
|
|
||||||
await handle_interrupt_response(
|
await handle_interrupt_response(
|
||||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
|
||||||
)
|
)
|
||||||
|
|
||||||
# Interrupt should be resolved
|
# Interrupt should be resolved
|
||||||
@@ -374,19 +384,14 @@ class TestDispatchMessageWithTracking:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_conversation_tracker_called_on_message(self) -> None:
|
async def test_conversation_tracker_called_on_message(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
tracker = AsyncMock()
|
tracker = AsyncMock()
|
||||||
pool = MagicMock()
|
pool = MagicMock()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
await dispatch_message(
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
ws, graph, sm, cb, msg,
|
|
||||||
conversation_tracker=tracker,
|
|
||||||
pool=pool,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
|
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
|
||||||
tracker.record_turn.assert_awaited_once()
|
tracker.record_turn.assert_awaited_once()
|
||||||
@@ -394,53 +399,42 @@ class TestDispatchMessageWithTracking:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_analytics_recorder_called_on_message(self) -> None:
|
async def test_analytics_recorder_called_on_message(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
recorder = AsyncMock()
|
recorder = AsyncMock()
|
||||||
pool = MagicMock()
|
pool = MagicMock()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm, analytics_recorder=recorder, pool=pool)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
await dispatch_message(
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
ws, graph, sm, cb, msg,
|
|
||||||
analytics_recorder=recorder,
|
|
||||||
pool=pool,
|
|
||||||
)
|
|
||||||
|
|
||||||
recorder.record.assert_awaited_once()
|
recorder.record.assert_awaited_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tracker_failure_does_not_break_chat(self) -> None:
|
async def test_tracker_failure_does_not_break_chat(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
tracker = AsyncMock()
|
tracker = AsyncMock()
|
||||||
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
|
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
|
||||||
pool = MagicMock()
|
pool = MagicMock()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
# Should not raise despite tracker failure
|
# Should not raise despite tracker failure
|
||||||
await dispatch_message(
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
ws, graph, sm, cb, msg,
|
|
||||||
conversation_tracker=tracker,
|
|
||||||
pool=pool,
|
|
||||||
)
|
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_tracker_no_error(self) -> None:
|
async def test_no_tracker_no_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
# No tracker or recorder passed -- should work fine
|
# No tracker or recorder passed -- should work fine
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|||||||
@@ -101,3 +101,74 @@ export async function fetchReplay(
|
|||||||
export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> {
|
export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> {
|
||||||
return apiFetch<AnalyticsData>(`/api/analytics?range=${range}`);
|
return apiFetch<AnalyticsData>(`/api/analytics?range=${range}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -- OpenAPI import --
|
||||||
|
|
||||||
|
export interface ImportJobResponse {
|
||||||
|
job_id: string;
|
||||||
|
status: string;
|
||||||
|
spec_url: string;
|
||||||
|
total_endpoints: number;
|
||||||
|
classified_count: number;
|
||||||
|
error_message: string | null;
|
||||||
|
generated_tools_count?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EndpointClassification {
|
||||||
|
index: number;
|
||||||
|
access_type: string;
|
||||||
|
needs_interrupt: boolean;
|
||||||
|
agent_group: string;
|
||||||
|
confidence: number;
|
||||||
|
customer_params: string[];
|
||||||
|
endpoint: {
|
||||||
|
path: string;
|
||||||
|
method: string;
|
||||||
|
operation_id: string;
|
||||||
|
summary: string;
|
||||||
|
description: string;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async function apiPost<T>(path: string, body: unknown): Promise<T> {
|
||||||
|
const res = await fetch(`${API_BASE}${path}`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify(body),
|
||||||
|
});
|
||||||
|
if (!res.ok) {
|
||||||
|
throw new Error(`API error ${res.status}: ${res.statusText}`);
|
||||||
|
}
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function startImport(url: string): Promise<ImportJobResponse> {
|
||||||
|
return apiPost<ImportJobResponse>("/api/openapi/import", { url });
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchImportJob(jobId: string): Promise<ImportJobResponse> {
|
||||||
|
const res = await fetch(`${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}`);
|
||||||
|
if (!res.ok) {
|
||||||
|
throw new Error(`API error ${res.status}: ${res.statusText}`);
|
||||||
|
}
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchClassifications(
|
||||||
|
jobId: string
|
||||||
|
): Promise<EndpointClassification[]> {
|
||||||
|
const res = await fetch(
|
||||||
|
`${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}/classifications`
|
||||||
|
);
|
||||||
|
if (!res.ok) {
|
||||||
|
throw new Error(`API error ${res.status}: ${res.statusText}`);
|
||||||
|
}
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function approveJob(jobId: string): Promise<ImportJobResponse> {
|
||||||
|
return apiPost<ImportJobResponse>(
|
||||||
|
`/api/openapi/jobs/${encodeURIComponent(jobId)}/approve`,
|
||||||
|
{}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,10 +13,8 @@ import type {
|
|||||||
ToolAction,
|
ToolAction,
|
||||||
} from "../types";
|
} from "../types";
|
||||||
|
|
||||||
let msgCounter = 0;
|
|
||||||
function nextId(): string {
|
function nextId(): string {
|
||||||
msgCounter += 1;
|
return crypto.randomUUID();
|
||||||
return `msg-${msgCounter}`;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatPage() {
|
export function ChatPage() {
|
||||||
@@ -68,6 +66,48 @@ export function ChatPage() {
|
|||||||
setIsWaiting(false);
|
setIsWaiting(false);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case "clarification": {
|
||||||
|
setMessages((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: nextId(),
|
||||||
|
sender: "agent",
|
||||||
|
agent: "System",
|
||||||
|
content: msg.message,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
setIsWaiting(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case "interrupt_expired": {
|
||||||
|
setCurrentInterrupt(null);
|
||||||
|
setMessages((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: nextId(),
|
||||||
|
sender: "agent",
|
||||||
|
agent: "System",
|
||||||
|
content: msg.message,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
setIsWaiting(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case "tool_result": {
|
||||||
|
setToolActions((prev) => {
|
||||||
|
const last = prev[prev.length - 1];
|
||||||
|
if (last && last.tool === msg.tool && last.agent === msg.agent) {
|
||||||
|
return [
|
||||||
|
...prev.slice(0, -1),
|
||||||
|
{ ...last, result: msg.result },
|
||||||
|
];
|
||||||
|
}
|
||||||
|
return prev;
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
case "message_complete": {
|
case "message_complete": {
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
const last = prev[prev.length - 1];
|
const last = prev[prev.length - 1];
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ export function ReplayPage() {
|
|||||||
|
|
||||||
{/* Timeline */}
|
{/* Timeline */}
|
||||||
<div className="section-card" style={{ padding: "2rem" }}>
|
<div className="section-card" style={{ padding: "2rem" }}>
|
||||||
<ReplayTimeline steps={steps as any} />
|
<ReplayTimeline steps={steps} />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
|
import {
|
||||||
|
approveJob,
|
||||||
|
fetchClassifications,
|
||||||
|
fetchImportJob,
|
||||||
|
startImport,
|
||||||
|
type EndpointClassification,
|
||||||
|
type ImportJobResponse,
|
||||||
|
} from "../api";
|
||||||
|
|
||||||
interface ImportJob {
|
interface FlatClassification {
|
||||||
job_id: string;
|
|
||||||
status: "pending" | "processing" | "done" | "failed";
|
|
||||||
error_message?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface EndpointClassification {
|
|
||||||
path: string;
|
path: string;
|
||||||
method: string;
|
method: string;
|
||||||
summary: string;
|
summary: string;
|
||||||
@@ -14,48 +16,23 @@ interface EndpointClassification {
|
|||||||
agent_group: string;
|
agent_group: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function flattenClassification(c: EndpointClassification): FlatClassification {
|
||||||
|
return {
|
||||||
|
path: c.endpoint?.path ?? "",
|
||||||
|
method: c.endpoint?.method ?? "",
|
||||||
|
summary: c.endpoint?.summary ?? "",
|
||||||
|
access_type: c.access_type ?? "read",
|
||||||
|
agent_group: c.agent_group ?? "Unassigned",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export function ReviewPage() {
|
export function ReviewPage() {
|
||||||
const [url, setUrl] = useState("");
|
const [url, setUrl] = useState("");
|
||||||
const [job, setJob] = useState<ImportJob | null>(null);
|
const [job, setJob] = useState<ImportJobResponse | null>(null);
|
||||||
const [submitting, setSubmitting] = useState(false);
|
const [submitting, setSubmitting] = useState(false);
|
||||||
const [submitError, setSubmitError] = useState<string | null>(null);
|
const [submitError, setSubmitError] = useState<string | null>(null);
|
||||||
const [classifications, setClassifications] = useState<EndpointClassification[]>([
|
const [approveStatus, setApproveStatus] = useState<string | null>(null);
|
||||||
{
|
const [classifications, setClassifications] = useState<FlatClassification[]>([]);
|
||||||
path: "/api/v1/orders/{order_id}/cancel",
|
|
||||||
method: "post",
|
|
||||||
summary: "Cancel an active Shopify order",
|
|
||||||
access_type: "write",
|
|
||||||
agent_group: "Order Specialist",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
path: "/api/v1/orders/{order_id}",
|
|
||||||
method: "get",
|
|
||||||
summary: "Retrieve detailed information about an order",
|
|
||||||
access_type: "read",
|
|
||||||
agent_group: "Order Specialist",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
path: "/api/v1/payments/{charge_id}/refund",
|
|
||||||
method: "post",
|
|
||||||
summary: "Issue a full or partial refund for a charge",
|
|
||||||
access_type: "write",
|
|
||||||
agent_group: "Billing Assistant",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
path: "/api/v1/customers/{email}/discounts",
|
|
||||||
method: "post",
|
|
||||||
summary: "Apply a loyalty discount to a customer account",
|
|
||||||
access_type: "write",
|
|
||||||
agent_group: "Billing Assistant",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
path: "/api/v1/inventory/check",
|
|
||||||
method: "get",
|
|
||||||
summary: "Query realtime stock levels across warehouses",
|
|
||||||
access_type: "read",
|
|
||||||
agent_group: "Unassigned",
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
const pollRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
const pollRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -65,24 +42,12 @@ export function ReviewPage() {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
function pollJob(jobId: string) {
|
function pollJob(jobId: string) {
|
||||||
fetch(`/api/openapi/jobs/${encodeURIComponent(jobId)}`)
|
fetchImportJob(jobId)
|
||||||
.then((r) => r.json())
|
.then((j) => {
|
||||||
.then((data) => {
|
|
||||||
const j: ImportJob = data.data ?? data;
|
|
||||||
setJob(j);
|
setJob(j);
|
||||||
if (j.status === "done") {
|
if (j.status === "done") {
|
||||||
return fetch(`/api/openapi/jobs/${encodeURIComponent(jobId)}/classifications`)
|
return fetchClassifications(jobId).then((clfs) => {
|
||||||
.then((r) => r.json())
|
setClassifications(clfs.map(flattenClassification));
|
||||||
.then((clfs: EndpointClassification[]) => {
|
|
||||||
setClassifications(
|
|
||||||
clfs.map((c: any) => ({
|
|
||||||
path: c.endpoint?.path ?? c.path ?? "",
|
|
||||||
method: c.endpoint?.method ?? c.method ?? "",
|
|
||||||
summary: c.endpoint?.summary ?? c.summary ?? "",
|
|
||||||
access_type: c.access_type ?? "read",
|
|
||||||
agent_group: c.agent_group ?? "Unassigned",
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
} else if (j.status === "failed") {
|
} else if (j.status === "failed") {
|
||||||
return;
|
return;
|
||||||
@@ -100,17 +65,12 @@ export function ReviewPage() {
|
|||||||
if (!url.trim()) return;
|
if (!url.trim()) return;
|
||||||
setSubmitting(true);
|
setSubmitting(true);
|
||||||
setSubmitError(null);
|
setSubmitError(null);
|
||||||
|
setApproveStatus(null);
|
||||||
setJob(null);
|
setJob(null);
|
||||||
setClassifications([]);
|
setClassifications([]);
|
||||||
|
|
||||||
fetch("/api/openapi/import", {
|
startImport(url)
|
||||||
method: "POST",
|
.then((j) => {
|
||||||
headers: { "Content-Type": "application/json" },
|
|
||||||
body: JSON.stringify({ url }),
|
|
||||||
})
|
|
||||||
.then((r) => r.json())
|
|
||||||
.then((data) => {
|
|
||||||
const j: ImportJob = data.data ?? data;
|
|
||||||
setJob(j);
|
setJob(j);
|
||||||
if (j.job_id) pollJob(j.job_id);
|
if (j.job_id) pollJob(j.job_id);
|
||||||
})
|
})
|
||||||
@@ -120,7 +80,7 @@ export function ReviewPage() {
|
|||||||
|
|
||||||
function handleFieldChange(
|
function handleFieldChange(
|
||||||
idx: number,
|
idx: number,
|
||||||
field: keyof EndpointClassification,
|
field: keyof FlatClassification,
|
||||||
value: string
|
value: string
|
||||||
) {
|
) {
|
||||||
setClassifications((prev) =>
|
setClassifications((prev) =>
|
||||||
@@ -130,21 +90,26 @@ export function ReviewPage() {
|
|||||||
|
|
||||||
function handleApprove() {
|
function handleApprove() {
|
||||||
if (!job?.job_id) return;
|
if (!job?.job_id) return;
|
||||||
fetch(`/api/openapi/jobs/${encodeURIComponent(job.job_id)}/approve`, {
|
setApproveStatus(null);
|
||||||
method: "POST",
|
approveJob(job.job_id)
|
||||||
headers: { "Content-Type": "application/json" },
|
.then((result) => {
|
||||||
body: JSON.stringify({ endpoints: classifications }),
|
setJob(result);
|
||||||
}).then(() => {
|
setApproveStatus(
|
||||||
alert("Approved and saved.");
|
`Configuration saved. ${result.generated_tools_count ?? 0} tools generated.`
|
||||||
});
|
);
|
||||||
|
})
|
||||||
|
.catch((err: Error) => setApproveStatus(`Error: ${err.message}`));
|
||||||
}
|
}
|
||||||
|
|
||||||
const groupedByAgent = classifications.reduce((acc, c, idx) => {
|
const groupedByAgent = classifications.reduce(
|
||||||
|
(acc, c, idx) => {
|
||||||
const group = c.agent_group || "Unassigned";
|
const group = c.agent_group || "Unassigned";
|
||||||
if (!acc[group]) acc[group] = [];
|
if (!acc[group]) acc[group] = [];
|
||||||
acc[group].push({ ...c, originalIdx: idx });
|
acc[group].push({ ...c, originalIdx: idx });
|
||||||
return acc;
|
return acc;
|
||||||
}, {} as Record<string, (EndpointClassification & { originalIdx: number })[]>);
|
},
|
||||||
|
{} as Record<string, (FlatClassification & { originalIdx: number })[]>
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="page-container">
|
<div className="page-container">
|
||||||
@@ -167,24 +132,90 @@ export function ReviewPage() {
|
|||||||
</button>
|
</button>
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
{submitError && <div style={{ color: "var(--brand-accent)", marginBottom: "1rem" }}>Error: {submitError}</div>}
|
{submitError && (
|
||||||
|
<div style={{ color: "var(--brand-accent)", marginBottom: "1rem" }}>
|
||||||
|
Error: {submitError}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{job && (
|
{job && (
|
||||||
<div style={{ padding: "1rem", background: "var(--bg-surface)", border: "1px solid var(--border-light)", borderRadius: "var(--radius-md)", marginBottom: "1.5rem" }}>
|
<div
|
||||||
|
style={{
|
||||||
|
padding: "1rem",
|
||||||
|
background: "var(--bg-surface)",
|
||||||
|
border: "1px solid var(--border-light)",
|
||||||
|
borderRadius: "var(--radius-md)",
|
||||||
|
marginBottom: "1.5rem",
|
||||||
|
}}
|
||||||
|
>
|
||||||
<strong>Job:</strong> {job.job_id} — Status:{" "}
|
<strong>Job:</strong> {job.job_id} — Status:{" "}
|
||||||
<span style={{ fontWeight: 600, color: job.status === "done" ? "#10b981" : job.status === "failed" ? "var(--brand-accent)" : "#f59e0b" }}>
|
<span
|
||||||
|
style={{
|
||||||
|
fontWeight: 600,
|
||||||
|
color:
|
||||||
|
job.status === "done" || job.status === "approved"
|
||||||
|
? "#10b981"
|
||||||
|
: job.status === "failed"
|
||||||
|
? "var(--brand-accent)"
|
||||||
|
: "#f59e0b",
|
||||||
|
}}
|
||||||
|
>
|
||||||
{job.status}
|
{job.status}
|
||||||
</span>
|
</span>
|
||||||
{job.error_message && <div style={{ marginTop: "4px", color: "var(--brand-accent)" }}>{job.error_message}</div>}
|
{job.error_message && (
|
||||||
|
<div style={{ marginTop: "4px", color: "var(--brand-accent)" }}>
|
||||||
|
{job.error_message}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{approveStatus && (
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
padding: "0.75rem 1rem",
|
||||||
|
background: approveStatus.startsWith("Error")
|
||||||
|
? "#fef2f2"
|
||||||
|
: "#f0fdf4",
|
||||||
|
border: `1px solid ${approveStatus.startsWith("Error") ? "#fecaca" : "#bbf7d0"}`,
|
||||||
|
borderRadius: "var(--radius-md)",
|
||||||
|
marginBottom: "1rem",
|
||||||
|
fontSize: "0.875rem",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{approveStatus}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{classifications.length > 0 && (
|
{classifications.length > 0 && (
|
||||||
<>
|
<>
|
||||||
<div style={{ display: "flex", justifyContent: "space-between", alignItems: "center", marginBottom: "1rem" }}>
|
<div
|
||||||
|
style={{
|
||||||
|
display: "flex",
|
||||||
|
justifyContent: "space-between",
|
||||||
|
alignItems: "center",
|
||||||
|
marginBottom: "1rem",
|
||||||
|
}}
|
||||||
|
>
|
||||||
<div>
|
<div>
|
||||||
<h3 style={{ margin: 0, fontSize: "1.25rem", color: "var(--text-primary)" }}>Assigned Capabilities ({classifications.length})</h3>
|
<h3
|
||||||
<p style={{ margin: "0.25rem 0 0 0", fontSize: "0.875rem", color: "var(--text-secondary)" }}>Grouped by target Agent.</p>
|
style={{
|
||||||
|
margin: 0,
|
||||||
|
fontSize: "1.25rem",
|
||||||
|
color: "var(--text-primary)",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Assigned Capabilities ({classifications.length})
|
||||||
|
</h3>
|
||||||
|
<p
|
||||||
|
style={{
|
||||||
|
margin: "0.25rem 0 0 0",
|
||||||
|
fontSize: "0.875rem",
|
||||||
|
color: "var(--text-secondary)",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Grouped by target Agent.
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<button onClick={handleApprove} className="btn btn-primary">
|
<button onClick={handleApprove} className="btn btn-primary">
|
||||||
Save Configuration
|
Save Configuration
|
||||||
@@ -195,7 +226,11 @@ export function ReviewPage() {
|
|||||||
{Object.entries(groupedByAgent).map(([groupName, tools]) => (
|
{Object.entries(groupedByAgent).map(([groupName, tools]) => (
|
||||||
<div key={groupName} className="agent-grid-card">
|
<div key={groupName} className="agent-grid-card">
|
||||||
<div className="agent-card-header-bg">
|
<div className="agent-card-header-bg">
|
||||||
<div className="agent-avatar-lg">{groupName === "Unassigned" ? "?" : groupName.charAt(0).toUpperCase()}</div>
|
<div className="agent-avatar-lg">
|
||||||
|
{groupName === "Unassigned"
|
||||||
|
? "?"
|
||||||
|
: groupName.charAt(0).toUpperCase()}
|
||||||
|
</div>
|
||||||
<div className="agent-card-meta">
|
<div className="agent-card-meta">
|
||||||
<h3>{groupName}</h3>
|
<h3>{groupName}</h3>
|
||||||
<span>{tools.length} Attached Tools</span>
|
<span>{tools.length} Attached Tools</span>
|
||||||
@@ -205,16 +240,36 @@ export function ReviewPage() {
|
|||||||
{tools.map((t) => (
|
{tools.map((t) => (
|
||||||
<div key={t.originalIdx} className="tool-pill-item">
|
<div key={t.originalIdx} className="tool-pill-item">
|
||||||
<div className="tool-pill-header">
|
<div className="tool-pill-header">
|
||||||
<span className="tool-method-badge" style={{ background: t.method === "get" ? "#3b82f6" : t.method === "post" ? "#10b981" : t.method === "delete" ? "#ef4444" : "#f59e0b" }}>
|
<span
|
||||||
|
className="tool-method-badge"
|
||||||
|
style={{
|
||||||
|
background:
|
||||||
|
t.method === "get"
|
||||||
|
? "#3b82f6"
|
||||||
|
: t.method === "post"
|
||||||
|
? "#10b981"
|
||||||
|
: t.method === "delete"
|
||||||
|
? "#ef4444"
|
||||||
|
: "#f59e0b",
|
||||||
|
}}
|
||||||
|
>
|
||||||
{t.method}
|
{t.method}
|
||||||
</span>
|
</span>
|
||||||
<span className="tool-path-text" title={t.path}>{t.path}</span>
|
<span className="tool-path-text" title={t.path}>
|
||||||
|
{t.path}
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<div className="tool-summary-text">{t.summary}</div>
|
<div className="tool-summary-text">{t.summary}</div>
|
||||||
<div className="tool-pill-controls">
|
<div className="tool-pill-controls">
|
||||||
<select
|
<select
|
||||||
value={t.access_type}
|
value={t.access_type}
|
||||||
onChange={(e) => handleFieldChange(t.originalIdx, "access_type", e.target.value)}
|
onChange={(e) =>
|
||||||
|
handleFieldChange(
|
||||||
|
t.originalIdx,
|
||||||
|
"access_type",
|
||||||
|
e.target.value
|
||||||
|
)
|
||||||
|
}
|
||||||
className="tool-select"
|
className="tool-select"
|
||||||
>
|
>
|
||||||
<option value="read">Read Only</option>
|
<option value="read">Read Only</option>
|
||||||
@@ -223,7 +278,13 @@ export function ReviewPage() {
|
|||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
value={t.agent_group}
|
value={t.agent_group}
|
||||||
onChange={(e) => handleFieldChange(t.originalIdx, "agent_group", e.target.value)}
|
onChange={(e) =>
|
||||||
|
handleFieldChange(
|
||||||
|
t.originalIdx,
|
||||||
|
"agent_group",
|
||||||
|
e.target.value
|
||||||
|
)
|
||||||
|
}
|
||||||
className="tool-input"
|
className="tool-input"
|
||||||
placeholder="Agent Name"
|
placeholder="Agent Name"
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -39,13 +39,28 @@ export interface ErrorMessage {
|
|||||||
message: string;
|
message: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ClarificationMessage {
|
||||||
|
type: "clarification";
|
||||||
|
thread_id: string;
|
||||||
|
message: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface InterruptExpiredMessage {
|
||||||
|
type: "interrupt_expired";
|
||||||
|
thread_id: string;
|
||||||
|
action: string;
|
||||||
|
message: string;
|
||||||
|
}
|
||||||
|
|
||||||
export type ServerMessage =
|
export type ServerMessage =
|
||||||
| TokenMessage
|
| TokenMessage
|
||||||
| InterruptMessage
|
| InterruptMessage
|
||||||
| ToolCallMessage
|
| ToolCallMessage
|
||||||
| ToolResultMessage
|
| ToolResultMessage
|
||||||
| MessageCompleteMessage
|
| MessageCompleteMessage
|
||||||
| ErrorMessage;
|
| ErrorMessage
|
||||||
|
| ClarificationMessage
|
||||||
|
| InterruptExpiredMessage;
|
||||||
|
|
||||||
// -- Client -> Server messages --
|
// -- Client -> Server messages --
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user