feat(ui): implement premium beige design system and ux refinements

This commit is contained in:
Yaojia Wang
2026-04-05 22:35:48 +02:00
parent d2b4610df9
commit 189a0fad34
30 changed files with 3651 additions and 801 deletions

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from psycopg.types.json import Json
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
@@ -89,7 +91,7 @@ class PostgresAnalyticsRecorder:
"duration_ms": duration_ms,
"success": success,
"error_message": error_message,
"metadata": metadata or {},
"metadata": Json(metadata or {}),
}
async with self._pool.connection() as conn:
await conn.execute(_INSERT_SQL, params)

View File

@@ -17,7 +17,7 @@ class Settings(BaseSettings):
database_url: str
llm_provider: Literal["anthropic", "openai", "google"] = "anthropic"
llm_provider: Literal["anthropic", "openai", "azure_openai", "google"] = "anthropic"
llm_model: str = "claude-sonnet-4-6"
session_ttl_minutes: int = 30
@@ -34,6 +34,10 @@ class Settings(BaseSettings):
anthropic_api_key: str = ""
openai_api_key: str = ""
azure_openai_api_key: str = ""
azure_openai_endpoint: str = ""
azure_openai_api_version: str = "2024-12-01-preview"
azure_openai_deployment: str = ""
google_api_key: str = ""
@model_validator(mode="after")
@@ -41,6 +45,7 @@ class Settings(BaseSettings):
key_map = {
"anthropic": self.anthropic_api_key,
"openai": self.openai_api_key,
"azure_openai": self.azure_openai_api_key,
"google": self.google_api_key,
}
key = key_map.get(self.llm_provider, "")
@@ -49,4 +54,13 @@ class Settings(BaseSettings):
f"API key for provider '{self.llm_provider}' is required. "
f"Set the corresponding environment variable."
)
if self.llm_provider == "azure_openai":
if not self.azure_openai_endpoint:
raise ValueError(
"AZURE_OPENAI_ENDPOINT is required for azure_openai provider."
)
if not self.azure_openai_deployment:
raise ValueError(
"AZURE_OPENAI_DEPLOYMENT is required for azure_openai provider."
)
return self

View File

@@ -9,7 +9,7 @@ if TYPE_CHECKING:
_ENSURE_SQL = """
INSERT INTO conversations
(thread_id, started_at, last_activity)
(thread_id, created_at, last_activity)
VALUES
(%(thread_id)s, NOW(), NOW())
ON CONFLICT (thread_id) DO NOTHING

View File

@@ -31,6 +31,16 @@ def create_llm(settings: Settings) -> BaseChatModel:
api_key=settings.openai_api_key,
)
if provider == "azure_openai":
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(
azure_deployment=settings.azure_openai_deployment,
azure_endpoint=settings.azure_openai_endpoint,
api_key=settings.azure_openai_api_key,
api_version=settings.azure_openai_api_version,
)
if provider == "google":
from langchain_google_genai import ChatGoogleGenerativeAI
@@ -39,4 +49,7 @@ def create_llm(settings: Settings) -> BaseChatModel:
google_api_key=settings.google_api_key,
)
raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.")
raise ValueError(
f"Unknown LLM provider: '{provider}'. "
"Use 'anthropic', 'openai', 'azure_openai', or 'google'."
)

View File

@@ -54,7 +54,10 @@ async def handle_user_message(
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Process a user message through the graph and stream results back."""
if session_manager.is_expired(thread_id):
# Touch first so new sessions are created before expiry check.
# For existing sessions, touch resets the sliding window.
existing = session_manager.get_state(thread_id)
if existing is not None and session_manager.is_expired(thread_id):
msg = "Session expired. Please start a new conversation."
await _send_json(ws, {"type": "error", "message": msg})
return

View File

@@ -0,0 +1,219 @@
"""E2E test fixtures -- full FastAPI app with mocked LLM and database."""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from httpx import ASGITransport, AsyncClient
from app.analytics.api import router as analytics_router
from app.callbacks import TokenUsageCallbackHandler
from app.interrupt_manager import InterruptManager
from app.openapi.review_api import _job_store, router as openapi_router
from app.replay.api import router as replay_router
from app.session_manager import SessionManager
from app.ws_handler import dispatch_message
# ---------------------------------------------------------------------------
# Graph helpers -- simulate LangGraph streaming behaviour
# ---------------------------------------------------------------------------
class AsyncIterHelper:
"""Make a list behave as an async iterator."""
def __init__(self, items: list) -> None:
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
def make_chunk(content: str, node: str = "order_lookup") -> tuple:
c = MagicMock()
c.content = content
c.tool_calls = []
return (c, {"langgraph_node": node})
def make_tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
c = MagicMock()
c.content = ""
c.tool_calls = [{"name": name, "args": args}]
return (c, {"langgraph_node": node})
def make_state(*, interrupt: bool = False, data: dict | None = None) -> Any:
s = MagicMock()
if interrupt:
obj = MagicMock()
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
t = MagicMock()
t.interrupts = (obj,)
s.tasks = (t,)
else:
s.tasks = ()
return s
def make_graph(
chunks: list | None = None,
state: Any = None,
resume_chunks: list | None = None,
) -> MagicMock:
"""Build a mock LangGraph CompiledStateGraph."""
g = MagicMock()
g.intent_classifier = None
g.agent_registry = None
if state is None:
state = make_state()
streams = [chunks or [], resume_chunks or []]
idx = {"n": 0}
def astream_side_effect(*a, **kw):
i = min(idx["n"], len(streams) - 1)
idx["n"] += 1
return AsyncIterHelper(list(streams[i]))
g.astream = MagicMock(side_effect=astream_side_effect)
g.aget_state = AsyncMock(return_value=state)
return g
# ---------------------------------------------------------------------------
# Fake database pool
# ---------------------------------------------------------------------------
class FakeCursor:
"""Minimal async cursor returning pre-configured rows."""
def __init__(self, rows: list[dict]) -> None:
self._rows = rows
async def fetchall(self) -> list[dict]:
return self._rows
class FakeConnection:
"""Fake async connection that returns a FakeCursor."""
def __init__(self, rows: list[dict]) -> None:
self._rows = rows
async def execute(self, query: str, params: dict | None = None) -> FakeCursor:
return FakeCursor(self._rows)
class FakePool:
"""Minimal pool that yields a fake connection."""
def __init__(self, rows: list[dict] | None = None) -> None:
self._rows = rows or []
@asynccontextmanager
async def connection(self):
yield FakeConnection(self._rows)
async def close(self) -> None:
pass
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_e2e_app(
graph: MagicMock | None = None,
pool: FakePool | None = None,
session_ttl: int = 3600,
interrupt_ttl: int = 1800,
) -> FastAPI:
"""Create a FastAPI app wired with mocked dependencies for E2E testing."""
g = graph or make_graph()
p = pool or FakePool()
sm = SessionManager(session_ttl_seconds=session_ttl)
im = InterruptManager(ttl_seconds=interrupt_ttl)
app = FastAPI(title="Smart Support E2E Test")
app.include_router(openapi_router)
app.include_router(replay_router)
app.include_router(analytics_router)
app.state.graph = g
app.state.session_manager = sm
app.state.interrupt_manager = im
app.state.pool = p
app.state.settings = MagicMock(llm_model="test-model")
app.state.analytics_recorder = AsyncMock()
app.state.conversation_tracker = AsyncMock()
@app.get("/api/health")
def health_check() -> dict:
return {"status": "ok", "version": "test"}
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket) -> None:
await ws.accept()
try:
while True:
raw_data = await ws.receive_text()
await dispatch_message(
ws,
app.state.graph,
app.state.session_manager,
TokenUsageCallbackHandler(model_name="test-model"),
raw_data,
interrupt_manager=app.state.interrupt_manager,
analytics_recorder=app.state.analytics_recorder,
conversation_tracker=app.state.conversation_tracker,
pool=app.state.pool,
)
except WebSocketDisconnect:
pass
return app
@pytest.fixture
def e2e_graph():
"""Default graph fixture -- returns tokens and message_complete."""
return make_graph(
chunks=[make_chunk("Order 1042 is "), make_chunk("shipped.")]
)
@pytest.fixture
def e2e_app(e2e_graph):
"""Default E2E app fixture."""
return create_e2e_app(graph=e2e_graph)
@pytest.fixture
async def e2e_client(e2e_app):
"""Async HTTP client for E2E tests."""
transport = ASGITransport(app=e2e_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
@pytest.fixture(autouse=True)
def clear_openapi_job_store():
"""Clear the in-memory job store between tests."""
_job_store.clear()
yield
_job_store.clear()

View File

@@ -0,0 +1,384 @@
"""E2E tests for critical chat user flows (flows 1-4).
Flow 1: Happy path -- query order, get answer
Flow 2: Approval flow -- write operation, interrupt, approve, execute
Flow 3: Rejection flow -- write operation, interrupt, reject, no execution
Flow 4: Multi-turn context -- sequential messages in same session
"""
from __future__ import annotations
import json
import pytest
from starlette.testclient import TestClient
from tests.e2e.conftest import (
create_e2e_app,
make_chunk,
make_graph,
make_state,
make_tool_chunk,
)
pytestmark = pytest.mark.e2e
class TestFlow1HappyPath:
"""Flow 1: query order -> get answer with streaming tokens."""
def test_websocket_happy_path_order_query(self) -> None:
graph = make_graph(
chunks=[
make_tool_chunk("get_order_status", {"order_id": "1042"}),
make_chunk("Order 1042 has been shipped and is on its way."),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-happy-1",
"content": "What is the status of order 1042?",
})
messages = []
while True:
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
tool_calls = [m for m in messages if m["type"] == "tool_call"]
assert len(tool_calls) == 1
assert tool_calls[0]["tool"] == "get_order_status"
assert tool_calls[0]["args"] == {"order_id": "1042"}
tokens = [m for m in messages if m["type"] == "token"]
assert len(tokens) == 1
assert "shipped" in tokens[0]["content"]
completes = [m for m in messages if m["type"] == "message_complete"]
assert len(completes) == 1
assert completes[0]["thread_id"] == "e2e-happy-1"
def test_websocket_multiple_token_stream(self) -> None:
"""Verify streaming returns multiple token chunks."""
graph = make_graph(
chunks=[
make_chunk("Your order "),
make_chunk("1042 "),
make_chunk("was delivered "),
make_chunk("yesterday."),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-stream-1",
"content": "Where is my order?",
})
messages = _collect_until_complete(ws)
tokens = [m for m in messages if m["type"] == "token"]
assert len(tokens) == 4
full_text = "".join(t["content"] for t in tokens)
assert "1042" in full_text
assert "delivered" in full_text
class TestFlow2ApprovalFlow:
"""Flow 2: write operation -> interrupt -> approve -> execute."""
def test_interrupt_approve_executes_action(self) -> None:
interrupt_state = make_state(
interrupt=True,
data={"action": "cancel_order", "order_id": "1042"},
)
graph = make_graph(
chunks=[],
state=interrupt_state,
resume_chunks=[
make_chunk("Order 1042 has been cancelled successfully.", "order_actions"),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Step 1: Send cancel request
ws.send_json({
"type": "message",
"thread_id": "e2e-approve-1",
"content": "Cancel order 1042",
})
messages = _collect_until_type(ws, "interrupt")
interrupts = [m for m in messages if m["type"] == "interrupt"]
assert len(interrupts) == 1
assert interrupts[0]["action"] == "cancel_order"
assert interrupts[0]["thread_id"] == "e2e-approve-1"
# Step 2: Approve the interrupt
ws.send_json({
"type": "interrupt_response",
"thread_id": "e2e-approve-1",
"approved": True,
})
resume_messages = _collect_until_complete(ws)
tokens = [m for m in resume_messages if m["type"] == "token"]
assert len(tokens) == 1
assert "cancelled" in tokens[0]["content"]
assert tokens[0]["agent"] == "order_actions"
completes = [m for m in resume_messages if m["type"] == "message_complete"]
assert len(completes) == 1
class TestFlow3RejectionFlow:
"""Flow 3: write operation -> interrupt -> reject -> no execution."""
def test_interrupt_reject_does_not_execute(self) -> None:
interrupt_state = make_state(
interrupt=True,
data={"action": "cancel_order", "order_id": "1042"},
)
graph = make_graph(
chunks=[],
state=interrupt_state,
resume_chunks=[
make_chunk("Understood. Order 1042 will remain active.", "order_actions"),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Step 1: Trigger interrupt
ws.send_json({
"type": "message",
"thread_id": "e2e-reject-1",
"content": "Cancel order 1042",
})
messages = _collect_until_type(ws, "interrupt")
assert any(m["type"] == "interrupt" for m in messages)
# Step 2: Reject
ws.send_json({
"type": "interrupt_response",
"thread_id": "e2e-reject-1",
"approved": False,
})
resume_messages = _collect_until_complete(ws)
tokens = [m for m in resume_messages if m["type"] == "token"]
assert len(tokens) == 1
assert "remain active" in tokens[0]["content"]
# Verify graph.astream was called with resume=False
resume_call = graph.astream.call_args_list[-1]
command = resume_call[0][0]
assert command.resume is False
class TestFlow4MultiTurnContext:
"""Flow 4: multi-turn conversation in the same session."""
def test_multi_turn_messages_share_session(self) -> None:
"""Multiple messages in the same thread_id maintain session context."""
graph = make_graph(
chunks=[make_chunk("Order 1042 status: shipped.")],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Turn 1: Query order
ws.send_json({
"type": "message",
"thread_id": "e2e-multi-1",
"content": "What is the status of order 1042?",
})
turn1 = _collect_until_complete(ws)
assert any(m["type"] == "message_complete" for m in turn1)
# Turn 2: Follow-up in same thread
ws.send_json({
"type": "message",
"thread_id": "e2e-multi-1",
"content": "When will it arrive?",
})
turn2 = _collect_until_complete(ws)
assert any(m["type"] == "message_complete" for m in turn2)
# Turn 3: Another follow-up
ws.send_json({
"type": "message",
"thread_id": "e2e-multi-1",
"content": "Can you track it?",
})
turn3 = _collect_until_complete(ws)
assert any(m["type"] == "message_complete" for m in turn3)
# Verify all turns used the same thread_id in graph calls
for call in graph.astream.call_args_list:
config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {})
assert config["configurable"]["thread_id"] == "e2e-multi-1"
def test_separate_threads_are_independent(self) -> None:
"""Different thread_ids have independent sessions."""
graph = make_graph(
chunks=[make_chunk("Response.")],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Thread A
ws.send_json({
"type": "message",
"thread_id": "e2e-thread-a",
"content": "Hello from thread A",
})
_collect_until_complete(ws)
# Thread B
ws.send_json({
"type": "message",
"thread_id": "e2e-thread-b",
"content": "Hello from thread B",
})
_collect_until_complete(ws)
# Both threads should exist as separate sessions
sm = app.state.session_manager
assert sm.get_state("e2e-thread-a") is not None
assert sm.get_state("e2e-thread-b") is not None
class TestChatEdgeCases:
"""Edge cases and error handling for the chat WebSocket."""
def test_invalid_json_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_text("not valid json")
msg = ws.receive_json()
assert msg["type"] == "error"
assert "Invalid JSON" in msg["message"]
def test_missing_thread_id_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "message", "content": "hello"})
msg = ws.receive_json()
assert msg["type"] == "error"
assert "thread_id" in msg["message"]
def test_empty_content_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-err-1",
"content": "",
})
msg = ws.receive_json()
assert msg["type"] == "error"
def test_expired_session_returns_error(self) -> None:
graph = make_graph(chunks=[make_chunk("Response.")])
app = create_e2e_app(graph=graph, session_ttl=0)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# First message creates the session (TTL=0)
ws.send_json({
"type": "message",
"thread_id": "e2e-expired-1",
"content": "hello",
})
_collect_until_complete_or_error(ws)
# Second message finds the session expired (TTL=0)
ws.send_json({
"type": "message",
"thread_id": "e2e-expired-1",
"content": "hello again",
})
messages = _collect_until_complete_or_error(ws)
errors = [m for m in messages if m["type"] == "error"]
assert len(errors) >= 1
assert "expired" in errors[0]["message"].lower()
def test_oversized_message_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_text("x" * 40_000)
msg = ws.receive_json()
assert msg["type"] == "error"
assert "too large" in msg["message"].lower()
def test_health_endpoint(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/health")
assert resp.status_code == 200
assert resp.json()["status"] == "ok"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]:
"""Receive WebSocket messages until message_complete or error."""
messages = []
for _ in range(max_messages):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
return messages
def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]:
"""Receive until a specific message type is received."""
messages = []
for _ in range(max_messages):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] == msg_type:
break
return messages
def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]:
"""Receive until message_complete or error."""
messages = []
for _ in range(max_messages):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
return messages

View File

@@ -0,0 +1,201 @@
"""E2E tests for OpenAPI import flow (flow 5).
Flow 5: paste OpenAPI spec URL -> import job -> classify endpoints ->
review classifications -> approve -> tool generation.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from starlette.testclient import TestClient
from app.openapi.models import ClassificationResult, EndpointInfo
from app.openapi.review_api import _job_store
from tests.e2e.conftest import create_e2e_app
pytestmark = pytest.mark.e2e
def _fake_endpoint(
path: str = "/orders/{id}",
method: str = "GET",
operation_id: str = "getOrder",
summary: str = "Get order details",
) -> EndpointInfo:
return EndpointInfo(
path=path,
method=method,
operation_id=operation_id,
summary=summary,
description="",
parameters=(),
request_body_schema=None,
response_schema=None,
)
def _fake_classification(
endpoint: EndpointInfo | None = None,
access_type: str = "read",
needs_interrupt: bool = False,
agent_group: str = "order_lookup",
) -> ClassificationResult:
return ClassificationResult(
endpoint=endpoint or _fake_endpoint(),
access_type=access_type,
customer_params=["order_id"],
agent_group=agent_group,
confidence=0.95,
needs_interrupt=needs_interrupt,
)
class TestFlow5OpenAPIImport:
"""Flow 5: full OpenAPI import lifecycle."""
def test_import_job_lifecycle(self) -> None:
"""Start import -> check status -> review classifications -> approve."""
app = create_e2e_app()
with TestClient(app) as client:
# Step 1: Start import job
resp = client.post(
"/api/openapi/import",
json={"url": "https://api.example.com/openapi.json"},
)
assert resp.status_code == 202
body = resp.json()
assert body["status"] == "pending"
job_id = body["job_id"]
# Step 2: Check job status (still pending since background task hasn't run)
resp = client.get(f"/api/openapi/jobs/{job_id}")
assert resp.status_code == 200
assert resp.json()["job_id"] == job_id
def test_import_job_with_classifications(self) -> None:
"""Simulate completed import and review classified endpoints."""
app = create_e2e_app()
# Seed a completed job directly
ep_read = _fake_endpoint("/orders/{id}", "GET", "getOrder", "Get order")
ep_write = _fake_endpoint("/orders/{id}/cancel", "POST", "cancelOrder", "Cancel order")
clf_read = _fake_classification(ep_read, "read", False, "order_lookup")
clf_write = _fake_classification(ep_write, "write", True, "order_actions")
job_id = "test-job-001"
_job_store[job_id] = {
"job_id": job_id,
"status": "done",
"spec_url": "https://api.example.com/openapi.json",
"total_endpoints": 2,
"classified_count": 2,
"error_message": None,
"classifications": [clf_read, clf_write],
}
with TestClient(app) as client:
# Step 1: Get classifications
resp = client.get(f"/api/openapi/jobs/{job_id}/classifications")
assert resp.status_code == 200
classifications = resp.json()
assert len(classifications) == 2
# Verify read endpoint
read_clf = classifications[0]
assert read_clf["access_type"] == "read"
assert read_clf["needs_interrupt"] is False
assert read_clf["endpoint"]["path"] == "/orders/{id}"
# Verify write endpoint
write_clf = classifications[1]
assert write_clf["access_type"] == "write"
assert write_clf["needs_interrupt"] is True
assert write_clf["endpoint"]["path"] == "/orders/{id}/cancel"
# Step 2: Update a classification
resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/0",
json={
"access_type": "write",
"needs_interrupt": True,
"agent_group": "order_actions",
},
)
assert resp.status_code == 200
updated = resp.json()
assert updated["access_type"] == "write"
assert updated["needs_interrupt"] is True
assert updated["agent_group"] == "order_actions"
# Step 3: Approve the job
resp = client.post(f"/api/openapi/jobs/{job_id}/approve")
assert resp.status_code == 200
assert resp.json()["status"] == "approved"
def test_import_nonexistent_job_returns_404(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/openapi/jobs/nonexistent")
assert resp.status_code == 404
def test_import_invalid_url_returns_422(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.post("/api/openapi/import", json={"url": "not-a-url"})
assert resp.status_code == 422
def test_classification_index_out_of_range(self) -> None:
app = create_e2e_app()
job_id = "test-job-range"
_job_store[job_id] = {
"job_id": job_id,
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 1,
"classified_count": 1,
"error_message": None,
"classifications": [_fake_classification()],
}
with TestClient(app) as client:
resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/99",
json={
"access_type": "read",
"needs_interrupt": False,
"agent_group": "order_lookup",
},
)
assert resp.status_code == 404
def test_update_classification_invalid_agent_group(self) -> None:
app = create_e2e_app()
job_id = "test-job-invalid"
_job_store[job_id] = {
"job_id": job_id,
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 1,
"classified_count": 1,
"error_message": None,
"classifications": [_fake_classification()],
}
with TestClient(app) as client:
resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/0",
json={
"access_type": "read",
"needs_interrupt": False,
"agent_group": "invalid group!", # spaces and special chars
},
)
assert resp.status_code == 422

View File

@@ -0,0 +1,214 @@
"""E2E tests for replay and analytics flows (flow 6).
Flow 6: list conversations -> select one -> step-by-step replay.
Also tests the analytics dashboard endpoint.
"""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from starlette.testclient import TestClient
from tests.e2e.conftest import FakePool, create_e2e_app
pytestmark = pytest.mark.e2e
# ---------------------------------------------------------------------------
# Custom pool that returns specific data per query
# ---------------------------------------------------------------------------
class ReplayPool(FakePool):
"""Pool that returns different data depending on the SQL query."""
def __init__(
self,
conversations: list[dict] | None = None,
checkpoints: list[dict] | None = None,
analytics_rows: list[dict] | None = None,
) -> None:
super().__init__()
self._conversations = conversations or []
self._checkpoints = checkpoints or []
self._analytics = analytics_rows or []
class _Conn:
def __init__(self, convos, checkpoints, analytics):
self._convos = convos
self._checkpoints = checkpoints
self._analytics = analytics
async def execute(self, query: str, params=None):
from tests.e2e.conftest import FakeCursor
if "conversations" in query and "SELECT" in query:
return FakeCursor(self._convos)
if "checkpoints" in query:
return FakeCursor(self._checkpoints)
# Analytics queries
return FakeCursor(self._analytics)
def connection(self):
from contextlib import asynccontextmanager
conn = self._Conn(self._conversations, self._checkpoints, self._analytics)
@asynccontextmanager
async def _ctx():
yield conn
return _ctx()
class TestFlow6ReplayConversation:
"""Flow 6: list conversations -> select one -> step replay."""
def test_list_conversations(self) -> None:
now = datetime.now(tz=timezone.utc).isoformat()
conversations = [
{
"thread_id": "conv-001",
"created_at": now,
"last_activity": now,
"status": "active",
"total_tokens": 150,
"total_cost_usd": 0.003,
},
{
"thread_id": "conv-002",
"created_at": now,
"last_activity": now,
"status": "completed",
"total_tokens": 300,
"total_cost_usd": 0.006,
},
]
pool = ReplayPool(conversations=conversations)
app = create_e2e_app(pool=pool)
with TestClient(app) as client:
resp = client.get("/api/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert len(body["data"]) == 2
assert body["data"][0]["thread_id"] == "conv-001"
assert body["data"][1]["thread_id"] == "conv-002"
def test_list_conversations_pagination(self) -> None:
conversations = [
{
"thread_id": f"conv-{i:03d}",
"created_at": "2026-04-01T00:00:00Z",
"last_activity": "2026-04-01T00:00:00Z",
"status": "active",
"total_tokens": 100,
"total_cost_usd": 0.001,
}
for i in range(5)
]
pool = ReplayPool(conversations=conversations)
app = create_e2e_app(pool=pool)
with TestClient(app) as client:
resp = client.get("/api/conversations", params={"page": 1, "per_page": 2})
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
def test_replay_thread_not_found(self) -> None:
pool = ReplayPool(checkpoints=[])
app = create_e2e_app(pool=pool)
with TestClient(app) as client:
resp = client.get("/api/replay/nonexistent-thread")
assert resp.status_code == 404
def test_replay_invalid_thread_id_format(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
# Thread ID with special chars fails regex validation
resp = client.get("/api/replay/invalid%20thread%21%40")
assert resp.status_code == 400
class TestAnalyticsDashboard:
"""Analytics endpoint tests."""
def test_analytics_invalid_range_format(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "invalid"})
assert resp.status_code == 400
def test_analytics_range_too_large(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "999d"})
assert resp.status_code == 400
def test_analytics_range_zero_rejected(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "0d"})
assert resp.status_code == 400
class TestFullUserJourney:
"""End-to-end journey: chat -> then check replay list shows the conversation."""
def test_chat_then_check_conversations_endpoint(self) -> None:
"""After chatting via WebSocket, the conversations endpoint is reachable."""
from tests.e2e.conftest import make_chunk, make_graph
graph = make_graph(chunks=[make_chunk("Your order is shipped.")])
now = datetime.now(tz=timezone.utc).isoformat()
pool = ReplayPool(
conversations=[
{
"thread_id": "e2e-journey-1",
"created_at": now,
"last_activity": now,
"status": "active",
"total_tokens": 50,
"total_cost_usd": 0.001,
},
],
)
app = create_e2e_app(graph=graph, pool=pool)
with TestClient(app) as client:
# Step 1: Chat via WebSocket
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-journey-1",
"content": "Where is my order?",
})
messages = []
for _ in range(20):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
assert any(m["type"] == "message_complete" for m in messages)
# Step 2: Check conversations endpoint
resp = client.get("/api/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert any(
c["thread_id"] == "e2e-journey-1" for c in body["data"]
)
# Step 3: Health check still works
resp = client.get("/api/health")
assert resp.status_code == 200

View File

@@ -168,7 +168,10 @@ class TestHandleUserMessage:
sm = SessionManager(session_ttl_seconds=0)
cb = TokenUsageCallbackHandler()
# First call creates the session (TTL=0)
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
# Second call finds it expired
await handle_user_message(ws, graph, sm, cb, "t1", "hello again")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "expired" in call_data["message"].lower()