feat(ui): implement premium beige design system and ux refinements
This commit is contained in:
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from psycopg.types.json import Json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
@@ -89,7 +91,7 @@ class PostgresAnalyticsRecorder:
|
||||
"duration_ms": duration_ms,
|
||||
"success": success,
|
||||
"error_message": error_message,
|
||||
"metadata": metadata or {},
|
||||
"metadata": Json(metadata or {}),
|
||||
}
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(_INSERT_SQL, params)
|
||||
|
||||
@@ -17,7 +17,7 @@ class Settings(BaseSettings):
|
||||
|
||||
database_url: str
|
||||
|
||||
llm_provider: Literal["anthropic", "openai", "google"] = "anthropic"
|
||||
llm_provider: Literal["anthropic", "openai", "azure_openai", "google"] = "anthropic"
|
||||
llm_model: str = "claude-sonnet-4-6"
|
||||
|
||||
session_ttl_minutes: int = 30
|
||||
@@ -34,6 +34,10 @@ class Settings(BaseSettings):
|
||||
|
||||
anthropic_api_key: str = ""
|
||||
openai_api_key: str = ""
|
||||
azure_openai_api_key: str = ""
|
||||
azure_openai_endpoint: str = ""
|
||||
azure_openai_api_version: str = "2024-12-01-preview"
|
||||
azure_openai_deployment: str = ""
|
||||
google_api_key: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -41,6 +45,7 @@ class Settings(BaseSettings):
|
||||
key_map = {
|
||||
"anthropic": self.anthropic_api_key,
|
||||
"openai": self.openai_api_key,
|
||||
"azure_openai": self.azure_openai_api_key,
|
||||
"google": self.google_api_key,
|
||||
}
|
||||
key = key_map.get(self.llm_provider, "")
|
||||
@@ -49,4 +54,13 @@ class Settings(BaseSettings):
|
||||
f"API key for provider '{self.llm_provider}' is required. "
|
||||
f"Set the corresponding environment variable."
|
||||
)
|
||||
if self.llm_provider == "azure_openai":
|
||||
if not self.azure_openai_endpoint:
|
||||
raise ValueError(
|
||||
"AZURE_OPENAI_ENDPOINT is required for azure_openai provider."
|
||||
)
|
||||
if not self.azure_openai_deployment:
|
||||
raise ValueError(
|
||||
"AZURE_OPENAI_DEPLOYMENT is required for azure_openai provider."
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
||||
|
||||
_ENSURE_SQL = """
|
||||
INSERT INTO conversations
|
||||
(thread_id, started_at, last_activity)
|
||||
(thread_id, created_at, last_activity)
|
||||
VALUES
|
||||
(%(thread_id)s, NOW(), NOW())
|
||||
ON CONFLICT (thread_id) DO NOTHING
|
||||
|
||||
@@ -31,6 +31,16 @@ def create_llm(settings: Settings) -> BaseChatModel:
|
||||
api_key=settings.openai_api_key,
|
||||
)
|
||||
|
||||
if provider == "azure_openai":
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
return AzureChatOpenAI(
|
||||
azure_deployment=settings.azure_openai_deployment,
|
||||
azure_endpoint=settings.azure_openai_endpoint,
|
||||
api_key=settings.azure_openai_api_key,
|
||||
api_version=settings.azure_openai_api_version,
|
||||
)
|
||||
|
||||
if provider == "google":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
@@ -39,4 +49,7 @@ def create_llm(settings: Settings) -> BaseChatModel:
|
||||
google_api_key=settings.google_api_key,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.")
|
||||
raise ValueError(
|
||||
f"Unknown LLM provider: '{provider}'. "
|
||||
"Use 'anthropic', 'openai', 'azure_openai', or 'google'."
|
||||
)
|
||||
|
||||
@@ -54,7 +54,10 @@ async def handle_user_message(
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Process a user message through the graph and stream results back."""
|
||||
if session_manager.is_expired(thread_id):
|
||||
# Touch first so new sessions are created before expiry check.
|
||||
# For existing sessions, touch resets the sliding window.
|
||||
existing = session_manager.get_state(thread_id)
|
||||
if existing is not None and session_manager.is_expired(thread_id):
|
||||
msg = "Session expired. Please start a new conversation."
|
||||
await _send_json(ws, {"type": "error", "message": msg})
|
||||
return
|
||||
|
||||
219
backend/tests/e2e/conftest.py
Normal file
219
backend/tests/e2e/conftest.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""E2E test fixtures -- full FastAPI app with mocked LLM and database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.analytics.api import router as analytics_router
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.openapi.review_api import _job_store, router as openapi_router
|
||||
from app.replay.api import router as replay_router
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph helpers -- simulate LangGraph streaming behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AsyncIterHelper:
|
||||
"""Make a list behave as an async iterator."""
|
||||
|
||||
def __init__(self, items: list) -> None:
|
||||
self._items = list(items)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._items:
|
||||
raise StopAsyncIteration
|
||||
return self._items.pop(0)
|
||||
|
||||
|
||||
def make_chunk(content: str, node: str = "order_lookup") -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = content
|
||||
c.tool_calls = []
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def make_tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = ""
|
||||
c.tool_calls = [{"name": name, "args": args}]
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def make_state(*, interrupt: bool = False, data: dict | None = None) -> Any:
|
||||
s = MagicMock()
|
||||
if interrupt:
|
||||
obj = MagicMock()
|
||||
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
||||
t = MagicMock()
|
||||
t.interrupts = (obj,)
|
||||
s.tasks = (t,)
|
||||
else:
|
||||
s.tasks = ()
|
||||
return s
|
||||
|
||||
|
||||
def make_graph(
|
||||
chunks: list | None = None,
|
||||
state: Any = None,
|
||||
resume_chunks: list | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a mock LangGraph CompiledStateGraph."""
|
||||
g = MagicMock()
|
||||
g.intent_classifier = None
|
||||
g.agent_registry = None
|
||||
|
||||
if state is None:
|
||||
state = make_state()
|
||||
|
||||
streams = [chunks or [], resume_chunks or []]
|
||||
idx = {"n": 0}
|
||||
|
||||
def astream_side_effect(*a, **kw):
|
||||
i = min(idx["n"], len(streams) - 1)
|
||||
idx["n"] += 1
|
||||
return AsyncIterHelper(list(streams[i]))
|
||||
|
||||
g.astream = MagicMock(side_effect=astream_side_effect)
|
||||
g.aget_state = AsyncMock(return_value=state)
|
||||
return g
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake database pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakeCursor:
|
||||
"""Minimal async cursor returning pre-configured rows."""
|
||||
|
||||
def __init__(self, rows: list[dict]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
async def fetchall(self) -> list[dict]:
|
||||
return self._rows
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
"""Fake async connection that returns a FakeCursor."""
|
||||
|
||||
def __init__(self, rows: list[dict]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
async def execute(self, query: str, params: dict | None = None) -> FakeCursor:
|
||||
return FakeCursor(self._rows)
|
||||
|
||||
|
||||
class FakePool:
|
||||
"""Minimal pool that yields a fake connection."""
|
||||
|
||||
def __init__(self, rows: list[dict] | None = None) -> None:
|
||||
self._rows = rows or []
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self):
|
||||
yield FakeConnection(self._rows)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_e2e_app(
|
||||
graph: MagicMock | None = None,
|
||||
pool: FakePool | None = None,
|
||||
session_ttl: int = 3600,
|
||||
interrupt_ttl: int = 1800,
|
||||
) -> FastAPI:
|
||||
"""Create a FastAPI app wired with mocked dependencies for E2E testing."""
|
||||
g = graph or make_graph()
|
||||
p = pool or FakePool()
|
||||
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||
|
||||
app = FastAPI(title="Smart Support E2E Test")
|
||||
app.include_router(openapi_router)
|
||||
app.include_router(replay_router)
|
||||
app.include_router(analytics_router)
|
||||
|
||||
app.state.graph = g
|
||||
app.state.session_manager = sm
|
||||
app.state.interrupt_manager = im
|
||||
app.state.pool = p
|
||||
app.state.settings = MagicMock(llm_model="test-model")
|
||||
app.state.analytics_recorder = AsyncMock()
|
||||
app.state.conversation_tracker = AsyncMock()
|
||||
|
||||
@app.get("/api/health")
|
||||
def health_check() -> dict:
|
||||
return {"status": "ok", "version": "test"}
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket) -> None:
|
||||
await ws.accept()
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_text()
|
||||
await dispatch_message(
|
||||
ws,
|
||||
app.state.graph,
|
||||
app.state.session_manager,
|
||||
TokenUsageCallbackHandler(model_name="test-model"),
|
||||
raw_data,
|
||||
interrupt_manager=app.state.interrupt_manager,
|
||||
analytics_recorder=app.state.analytics_recorder,
|
||||
conversation_tracker=app.state.conversation_tracker,
|
||||
pool=app.state.pool,
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_graph():
|
||||
"""Default graph fixture -- returns tokens and message_complete."""
|
||||
return make_graph(
|
||||
chunks=[make_chunk("Order 1042 is "), make_chunk("shipped.")]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_app(e2e_graph):
|
||||
"""Default E2E app fixture."""
|
||||
return create_e2e_app(graph=e2e_graph)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def e2e_client(e2e_app):
|
||||
"""Async HTTP client for E2E tests."""
|
||||
transport = ASGITransport(app=e2e_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_openapi_job_store():
|
||||
"""Clear the in-memory job store between tests."""
|
||||
_job_store.clear()
|
||||
yield
|
||||
_job_store.clear()
|
||||
384
backend/tests/e2e/test_chat_flows.py
Normal file
384
backend/tests/e2e/test_chat_flows.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""E2E tests for critical chat user flows (flows 1-4).
|
||||
|
||||
Flow 1: Happy path -- query order, get answer
|
||||
Flow 2: Approval flow -- write operation, interrupt, approve, execute
|
||||
Flow 3: Rejection flow -- write operation, interrupt, reject, no execution
|
||||
Flow 4: Multi-turn context -- sequential messages in same session
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from tests.e2e.conftest import (
|
||||
create_e2e_app,
|
||||
make_chunk,
|
||||
make_graph,
|
||||
make_state,
|
||||
make_tool_chunk,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
class TestFlow1HappyPath:
|
||||
"""Flow 1: query order -> get answer with streaming tokens."""
|
||||
|
||||
def test_websocket_happy_path_order_query(self) -> None:
|
||||
graph = make_graph(
|
||||
chunks=[
|
||||
make_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||
make_chunk("Order 1042 has been shipped and is on its way."),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-happy-1",
|
||||
"content": "What is the status of order 1042?",
|
||||
})
|
||||
|
||||
messages = []
|
||||
while True:
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
|
||||
tool_calls = [m for m in messages if m["type"] == "tool_call"]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0]["tool"] == "get_order_status"
|
||||
assert tool_calls[0]["args"] == {"order_id": "1042"}
|
||||
|
||||
tokens = [m for m in messages if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "shipped" in tokens[0]["content"]
|
||||
|
||||
completes = [m for m in messages if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
assert completes[0]["thread_id"] == "e2e-happy-1"
|
||||
|
||||
def test_websocket_multiple_token_stream(self) -> None:
|
||||
"""Verify streaming returns multiple token chunks."""
|
||||
graph = make_graph(
|
||||
chunks=[
|
||||
make_chunk("Your order "),
|
||||
make_chunk("1042 "),
|
||||
make_chunk("was delivered "),
|
||||
make_chunk("yesterday."),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-stream-1",
|
||||
"content": "Where is my order?",
|
||||
})
|
||||
|
||||
messages = _collect_until_complete(ws)
|
||||
|
||||
tokens = [m for m in messages if m["type"] == "token"]
|
||||
assert len(tokens) == 4
|
||||
full_text = "".join(t["content"] for t in tokens)
|
||||
assert "1042" in full_text
|
||||
assert "delivered" in full_text
|
||||
|
||||
|
||||
class TestFlow2ApprovalFlow:
|
||||
"""Flow 2: write operation -> interrupt -> approve -> execute."""
|
||||
|
||||
def test_interrupt_approve_executes_action(self) -> None:
|
||||
interrupt_state = make_state(
|
||||
interrupt=True,
|
||||
data={"action": "cancel_order", "order_id": "1042"},
|
||||
)
|
||||
graph = make_graph(
|
||||
chunks=[],
|
||||
state=interrupt_state,
|
||||
resume_chunks=[
|
||||
make_chunk("Order 1042 has been cancelled successfully.", "order_actions"),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Step 1: Send cancel request
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-approve-1",
|
||||
"content": "Cancel order 1042",
|
||||
})
|
||||
|
||||
messages = _collect_until_type(ws, "interrupt")
|
||||
|
||||
interrupts = [m for m in messages if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
assert interrupts[0]["action"] == "cancel_order"
|
||||
assert interrupts[0]["thread_id"] == "e2e-approve-1"
|
||||
|
||||
# Step 2: Approve the interrupt
|
||||
ws.send_json({
|
||||
"type": "interrupt_response",
|
||||
"thread_id": "e2e-approve-1",
|
||||
"approved": True,
|
||||
})
|
||||
|
||||
resume_messages = _collect_until_complete(ws)
|
||||
|
||||
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "cancelled" in tokens[0]["content"]
|
||||
assert tokens[0]["agent"] == "order_actions"
|
||||
|
||||
completes = [m for m in resume_messages if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
|
||||
|
||||
class TestFlow3RejectionFlow:
|
||||
"""Flow 3: write operation -> interrupt -> reject -> no execution."""
|
||||
|
||||
def test_interrupt_reject_does_not_execute(self) -> None:
|
||||
interrupt_state = make_state(
|
||||
interrupt=True,
|
||||
data={"action": "cancel_order", "order_id": "1042"},
|
||||
)
|
||||
graph = make_graph(
|
||||
chunks=[],
|
||||
state=interrupt_state,
|
||||
resume_chunks=[
|
||||
make_chunk("Understood. Order 1042 will remain active.", "order_actions"),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Step 1: Trigger interrupt
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-reject-1",
|
||||
"content": "Cancel order 1042",
|
||||
})
|
||||
|
||||
messages = _collect_until_type(ws, "interrupt")
|
||||
assert any(m["type"] == "interrupt" for m in messages)
|
||||
|
||||
# Step 2: Reject
|
||||
ws.send_json({
|
||||
"type": "interrupt_response",
|
||||
"thread_id": "e2e-reject-1",
|
||||
"approved": False,
|
||||
})
|
||||
|
||||
resume_messages = _collect_until_complete(ws)
|
||||
|
||||
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "remain active" in tokens[0]["content"]
|
||||
|
||||
# Verify graph.astream was called with resume=False
|
||||
resume_call = graph.astream.call_args_list[-1]
|
||||
command = resume_call[0][0]
|
||||
assert command.resume is False
|
||||
|
||||
|
||||
class TestFlow4MultiTurnContext:
|
||||
"""Flow 4: multi-turn conversation in the same session."""
|
||||
|
||||
def test_multi_turn_messages_share_session(self) -> None:
|
||||
"""Multiple messages in the same thread_id maintain session context."""
|
||||
graph = make_graph(
|
||||
chunks=[make_chunk("Order 1042 status: shipped.")],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Turn 1: Query order
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-multi-1",
|
||||
"content": "What is the status of order 1042?",
|
||||
})
|
||||
turn1 = _collect_until_complete(ws)
|
||||
assert any(m["type"] == "message_complete" for m in turn1)
|
||||
|
||||
# Turn 2: Follow-up in same thread
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-multi-1",
|
||||
"content": "When will it arrive?",
|
||||
})
|
||||
turn2 = _collect_until_complete(ws)
|
||||
assert any(m["type"] == "message_complete" for m in turn2)
|
||||
|
||||
# Turn 3: Another follow-up
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-multi-1",
|
||||
"content": "Can you track it?",
|
||||
})
|
||||
turn3 = _collect_until_complete(ws)
|
||||
assert any(m["type"] == "message_complete" for m in turn3)
|
||||
|
||||
# Verify all turns used the same thread_id in graph calls
|
||||
for call in graph.astream.call_args_list:
|
||||
config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {})
|
||||
assert config["configurable"]["thread_id"] == "e2e-multi-1"
|
||||
|
||||
def test_separate_threads_are_independent(self) -> None:
|
||||
"""Different thread_ids have independent sessions."""
|
||||
graph = make_graph(
|
||||
chunks=[make_chunk("Response.")],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Thread A
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-thread-a",
|
||||
"content": "Hello from thread A",
|
||||
})
|
||||
_collect_until_complete(ws)
|
||||
|
||||
# Thread B
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-thread-b",
|
||||
"content": "Hello from thread B",
|
||||
})
|
||||
_collect_until_complete(ws)
|
||||
|
||||
# Both threads should exist as separate sessions
|
||||
sm = app.state.session_manager
|
||||
assert sm.get_state("e2e-thread-a") is not None
|
||||
assert sm.get_state("e2e-thread-b") is not None
|
||||
|
||||
|
||||
class TestChatEdgeCases:
|
||||
"""Edge cases and error handling for the chat WebSocket."""
|
||||
|
||||
def test_invalid_json_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_text("not valid json")
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert "Invalid JSON" in msg["message"]
|
||||
|
||||
def test_missing_thread_id_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({"type": "message", "content": "hello"})
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert "thread_id" in msg["message"]
|
||||
|
||||
def test_empty_content_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-err-1",
|
||||
"content": "",
|
||||
})
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
|
||||
def test_expired_session_returns_error(self) -> None:
|
||||
graph = make_graph(chunks=[make_chunk("Response.")])
|
||||
app = create_e2e_app(graph=graph, session_ttl=0)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# First message creates the session (TTL=0)
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-expired-1",
|
||||
"content": "hello",
|
||||
})
|
||||
_collect_until_complete_or_error(ws)
|
||||
|
||||
# Second message finds the session expired (TTL=0)
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-expired-1",
|
||||
"content": "hello again",
|
||||
})
|
||||
messages = _collect_until_complete_or_error(ws)
|
||||
errors = [m for m in messages if m["type"] == "error"]
|
||||
assert len(errors) >= 1
|
||||
assert "expired" in errors[0]["message"].lower()
|
||||
|
||||
def test_oversized_message_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_text("x" * 40_000)
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert "too large" in msg["message"].lower()
|
||||
|
||||
def test_health_endpoint(self) -> None:
|
||||
app = create_e2e_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]:
|
||||
"""Receive WebSocket messages until message_complete or error."""
|
||||
messages = []
|
||||
for _ in range(max_messages):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
return messages
|
||||
|
||||
|
||||
def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]:
|
||||
"""Receive until a specific message type is received."""
|
||||
messages = []
|
||||
for _ in range(max_messages):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] == msg_type:
|
||||
break
|
||||
return messages
|
||||
|
||||
|
||||
def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]:
|
||||
"""Receive until message_complete or error."""
|
||||
messages = []
|
||||
for _ in range(max_messages):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
return messages
|
||||
201
backend/tests/e2e/test_openapi_import.py
Normal file
201
backend/tests/e2e/test_openapi_import.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""E2E tests for OpenAPI import flow (flow 5).
|
||||
|
||||
Flow 5: paste OpenAPI spec URL -> import job -> classify endpoints ->
|
||||
review classifications -> approve -> tool generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||
from app.openapi.review_api import _job_store
|
||||
from tests.e2e.conftest import create_e2e_app
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
def _fake_endpoint(
|
||||
path: str = "/orders/{id}",
|
||||
method: str = "GET",
|
||||
operation_id: str = "getOrder",
|
||||
summary: str = "Get order details",
|
||||
) -> EndpointInfo:
|
||||
return EndpointInfo(
|
||||
path=path,
|
||||
method=method,
|
||||
operation_id=operation_id,
|
||||
summary=summary,
|
||||
description="",
|
||||
parameters=(),
|
||||
request_body_schema=None,
|
||||
response_schema=None,
|
||||
)
|
||||
|
||||
|
||||
def _fake_classification(
|
||||
endpoint: EndpointInfo | None = None,
|
||||
access_type: str = "read",
|
||||
needs_interrupt: bool = False,
|
||||
agent_group: str = "order_lookup",
|
||||
) -> ClassificationResult:
|
||||
return ClassificationResult(
|
||||
endpoint=endpoint or _fake_endpoint(),
|
||||
access_type=access_type,
|
||||
customer_params=["order_id"],
|
||||
agent_group=agent_group,
|
||||
confidence=0.95,
|
||||
needs_interrupt=needs_interrupt,
|
||||
)
|
||||
|
||||
|
||||
class TestFlow5OpenAPIImport:
|
||||
"""Flow 5: full OpenAPI import lifecycle."""
|
||||
|
||||
def test_import_job_lifecycle(self) -> None:
|
||||
"""Start import -> check status -> review classifications -> approve."""
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
# Step 1: Start import job
|
||||
resp = client.post(
|
||||
"/api/openapi/import",
|
||||
json={"url": "https://api.example.com/openapi.json"},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
body = resp.json()
|
||||
assert body["status"] == "pending"
|
||||
job_id = body["job_id"]
|
||||
|
||||
# Step 2: Check job status (still pending since background task hasn't run)
|
||||
resp = client.get(f"/api/openapi/jobs/{job_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["job_id"] == job_id
|
||||
|
||||
def test_import_job_with_classifications(self) -> None:
|
||||
"""Simulate completed import and review classified endpoints."""
|
||||
app = create_e2e_app()
|
||||
|
||||
# Seed a completed job directly
|
||||
ep_read = _fake_endpoint("/orders/{id}", "GET", "getOrder", "Get order")
|
||||
ep_write = _fake_endpoint("/orders/{id}/cancel", "POST", "cancelOrder", "Cancel order")
|
||||
|
||||
clf_read = _fake_classification(ep_read, "read", False, "order_lookup")
|
||||
clf_write = _fake_classification(ep_write, "write", True, "order_actions")
|
||||
|
||||
job_id = "test-job-001"
|
||||
_job_store[job_id] = {
|
||||
"job_id": job_id,
|
||||
"status": "done",
|
||||
"spec_url": "https://api.example.com/openapi.json",
|
||||
"total_endpoints": 2,
|
||||
"classified_count": 2,
|
||||
"error_message": None,
|
||||
"classifications": [clf_read, clf_write],
|
||||
}
|
||||
|
||||
with TestClient(app) as client:
|
||||
# Step 1: Get classifications
|
||||
resp = client.get(f"/api/openapi/jobs/{job_id}/classifications")
|
||||
assert resp.status_code == 200
|
||||
classifications = resp.json()
|
||||
assert len(classifications) == 2
|
||||
|
||||
# Verify read endpoint
|
||||
read_clf = classifications[0]
|
||||
assert read_clf["access_type"] == "read"
|
||||
assert read_clf["needs_interrupt"] is False
|
||||
assert read_clf["endpoint"]["path"] == "/orders/{id}"
|
||||
|
||||
# Verify write endpoint
|
||||
write_clf = classifications[1]
|
||||
assert write_clf["access_type"] == "write"
|
||||
assert write_clf["needs_interrupt"] is True
|
||||
assert write_clf["endpoint"]["path"] == "/orders/{id}/cancel"
|
||||
|
||||
# Step 2: Update a classification
|
||||
resp = client.put(
|
||||
f"/api/openapi/jobs/{job_id}/classifications/0",
|
||||
json={
|
||||
"access_type": "write",
|
||||
"needs_interrupt": True,
|
||||
"agent_group": "order_actions",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
updated = resp.json()
|
||||
assert updated["access_type"] == "write"
|
||||
assert updated["needs_interrupt"] is True
|
||||
assert updated["agent_group"] == "order_actions"
|
||||
|
||||
# Step 3: Approve the job
|
||||
resp = client.post(f"/api/openapi/jobs/{job_id}/approve")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "approved"
|
||||
|
||||
def test_import_nonexistent_job_returns_404(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/openapi/jobs/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_import_invalid_url_returns_422(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/api/openapi/import", json={"url": "not-a-url"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_classification_index_out_of_range(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
job_id = "test-job-range"
|
||||
_job_store[job_id] = {
|
||||
"job_id": job_id,
|
||||
"status": "done",
|
||||
"spec_url": "https://example.com/spec.json",
|
||||
"total_endpoints": 1,
|
||||
"classified_count": 1,
|
||||
"error_message": None,
|
||||
"classifications": [_fake_classification()],
|
||||
}
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.put(
|
||||
f"/api/openapi/jobs/{job_id}/classifications/99",
|
||||
json={
|
||||
"access_type": "read",
|
||||
"needs_interrupt": False,
|
||||
"agent_group": "order_lookup",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_update_classification_invalid_agent_group(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
job_id = "test-job-invalid"
|
||||
_job_store[job_id] = {
|
||||
"job_id": job_id,
|
||||
"status": "done",
|
||||
"spec_url": "https://example.com/spec.json",
|
||||
"total_endpoints": 1,
|
||||
"classified_count": 1,
|
||||
"error_message": None,
|
||||
"classifications": [_fake_classification()],
|
||||
}
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.put(
|
||||
f"/api/openapi/jobs/{job_id}/classifications/0",
|
||||
json={
|
||||
"access_type": "read",
|
||||
"needs_interrupt": False,
|
||||
"agent_group": "invalid group!", # spaces and special chars
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
214
backend/tests/e2e/test_replay_analytics.py
Normal file
214
backend/tests/e2e/test_replay_analytics.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""E2E tests for replay and analytics flows (flow 6).
|
||||
|
||||
Flow 6: list conversations -> select one -> step-by-step replay.
|
||||
Also tests the analytics dashboard endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from tests.e2e.conftest import FakePool, create_e2e_app
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom pool that returns specific data per query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ReplayPool(FakePool):
|
||||
"""Pool that returns different data depending on the SQL query."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversations: list[dict] | None = None,
|
||||
checkpoints: list[dict] | None = None,
|
||||
analytics_rows: list[dict] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._conversations = conversations or []
|
||||
self._checkpoints = checkpoints or []
|
||||
self._analytics = analytics_rows or []
|
||||
|
||||
class _Conn:
|
||||
def __init__(self, convos, checkpoints, analytics):
|
||||
self._convos = convos
|
||||
self._checkpoints = checkpoints
|
||||
self._analytics = analytics
|
||||
|
||||
async def execute(self, query: str, params=None):
|
||||
from tests.e2e.conftest import FakeCursor
|
||||
|
||||
if "conversations" in query and "SELECT" in query:
|
||||
return FakeCursor(self._convos)
|
||||
if "checkpoints" in query:
|
||||
return FakeCursor(self._checkpoints)
|
||||
# Analytics queries
|
||||
return FakeCursor(self._analytics)
|
||||
|
||||
def connection(self):
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
conn = self._Conn(self._conversations, self._checkpoints, self._analytics)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _ctx():
|
||||
yield conn
|
||||
|
||||
return _ctx()
|
||||
|
||||
|
||||
class TestFlow6ReplayConversation:
|
||||
"""Flow 6: list conversations -> select one -> step replay."""
|
||||
|
||||
def test_list_conversations(self) -> None:
|
||||
now = datetime.now(tz=timezone.utc).isoformat()
|
||||
conversations = [
|
||||
{
|
||||
"thread_id": "conv-001",
|
||||
"created_at": now,
|
||||
"last_activity": now,
|
||||
"status": "active",
|
||||
"total_tokens": 150,
|
||||
"total_cost_usd": 0.003,
|
||||
},
|
||||
{
|
||||
"thread_id": "conv-002",
|
||||
"created_at": now,
|
||||
"last_activity": now,
|
||||
"status": "completed",
|
||||
"total_tokens": 300,
|
||||
"total_cost_usd": 0.006,
|
||||
},
|
||||
]
|
||||
pool = ReplayPool(conversations=conversations)
|
||||
app = create_e2e_app(pool=pool)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/conversations")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
assert len(body["data"]) == 2
|
||||
assert body["data"][0]["thread_id"] == "conv-001"
|
||||
assert body["data"][1]["thread_id"] == "conv-002"
|
||||
|
||||
def test_list_conversations_pagination(self) -> None:
|
||||
conversations = [
|
||||
{
|
||||
"thread_id": f"conv-{i:03d}",
|
||||
"created_at": "2026-04-01T00:00:00Z",
|
||||
"last_activity": "2026-04-01T00:00:00Z",
|
||||
"status": "active",
|
||||
"total_tokens": 100,
|
||||
"total_cost_usd": 0.001,
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
pool = ReplayPool(conversations=conversations)
|
||||
app = create_e2e_app(pool=pool)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/conversations", params={"page": 1, "per_page": 2})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
|
||||
def test_replay_thread_not_found(self) -> None:
|
||||
pool = ReplayPool(checkpoints=[])
|
||||
app = create_e2e_app(pool=pool)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/replay/nonexistent-thread")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_replay_invalid_thread_id_format(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
# Thread ID with special chars fails regex validation
|
||||
resp = client.get("/api/replay/invalid%20thread%21%40")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestAnalyticsDashboard:
|
||||
"""Analytics endpoint tests."""
|
||||
|
||||
def test_analytics_invalid_range_format(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/analytics", params={"range": "invalid"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_analytics_range_too_large(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/analytics", params={"range": "999d"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_analytics_range_zero_rejected(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/analytics", params={"range": "0d"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestFullUserJourney:
|
||||
"""End-to-end journey: chat -> then check replay list shows the conversation."""
|
||||
|
||||
def test_chat_then_check_conversations_endpoint(self) -> None:
|
||||
"""After chatting via WebSocket, the conversations endpoint is reachable."""
|
||||
from tests.e2e.conftest import make_chunk, make_graph
|
||||
|
||||
graph = make_graph(chunks=[make_chunk("Your order is shipped.")])
|
||||
now = datetime.now(tz=timezone.utc).isoformat()
|
||||
pool = ReplayPool(
|
||||
conversations=[
|
||||
{
|
||||
"thread_id": "e2e-journey-1",
|
||||
"created_at": now,
|
||||
"last_activity": now,
|
||||
"status": "active",
|
||||
"total_tokens": 50,
|
||||
"total_cost_usd": 0.001,
|
||||
},
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph, pool=pool)
|
||||
|
||||
with TestClient(app) as client:
|
||||
# Step 1: Chat via WebSocket
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-journey-1",
|
||||
"content": "Where is my order?",
|
||||
})
|
||||
messages = []
|
||||
for _ in range(20):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
assert any(m["type"] == "message_complete" for m in messages)
|
||||
|
||||
# Step 2: Check conversations endpoint
|
||||
resp = client.get("/api/conversations")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
assert any(
|
||||
c["thread_id"] == "e2e-journey-1" for c in body["data"]
|
||||
)
|
||||
|
||||
# Step 3: Health check still works
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
@@ -168,7 +168,10 @@ class TestHandleUserMessage:
|
||||
sm = SessionManager(session_ttl_seconds=0)
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
# First call creates the session (TTL=0)
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
# Second call finds it expired
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello again")
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "expired" in call_data["message"].lower()
|
||||
|
||||
Reference in New Issue
Block a user