feat(ui): implement premium beige design system and ux refinements
This commit is contained in:
384
backend/tests/e2e/test_chat_flows.py
Normal file
384
backend/tests/e2e/test_chat_flows.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""E2E tests for critical chat user flows (flows 1-4).
|
||||
|
||||
Flow 1: Happy path -- query order, get answer
|
||||
Flow 2: Approval flow -- write operation, interrupt, approve, execute
|
||||
Flow 3: Rejection flow -- write operation, interrupt, reject, no execution
|
||||
Flow 4: Multi-turn context -- sequential messages in same session
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from tests.e2e.conftest import (
|
||||
create_e2e_app,
|
||||
make_chunk,
|
||||
make_graph,
|
||||
make_state,
|
||||
make_tool_chunk,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
class TestFlow1HappyPath:
|
||||
"""Flow 1: query order -> get answer with streaming tokens."""
|
||||
|
||||
def test_websocket_happy_path_order_query(self) -> None:
|
||||
graph = make_graph(
|
||||
chunks=[
|
||||
make_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||
make_chunk("Order 1042 has been shipped and is on its way."),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-happy-1",
|
||||
"content": "What is the status of order 1042?",
|
||||
})
|
||||
|
||||
messages = []
|
||||
while True:
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
|
||||
tool_calls = [m for m in messages if m["type"] == "tool_call"]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0]["tool"] == "get_order_status"
|
||||
assert tool_calls[0]["args"] == {"order_id": "1042"}
|
||||
|
||||
tokens = [m for m in messages if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "shipped" in tokens[0]["content"]
|
||||
|
||||
completes = [m for m in messages if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
assert completes[0]["thread_id"] == "e2e-happy-1"
|
||||
|
||||
def test_websocket_multiple_token_stream(self) -> None:
|
||||
"""Verify streaming returns multiple token chunks."""
|
||||
graph = make_graph(
|
||||
chunks=[
|
||||
make_chunk("Your order "),
|
||||
make_chunk("1042 "),
|
||||
make_chunk("was delivered "),
|
||||
make_chunk("yesterday."),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-stream-1",
|
||||
"content": "Where is my order?",
|
||||
})
|
||||
|
||||
messages = _collect_until_complete(ws)
|
||||
|
||||
tokens = [m for m in messages if m["type"] == "token"]
|
||||
assert len(tokens) == 4
|
||||
full_text = "".join(t["content"] for t in tokens)
|
||||
assert "1042" in full_text
|
||||
assert "delivered" in full_text
|
||||
|
||||
|
||||
class TestFlow2ApprovalFlow:
|
||||
"""Flow 2: write operation -> interrupt -> approve -> execute."""
|
||||
|
||||
def test_interrupt_approve_executes_action(self) -> None:
|
||||
interrupt_state = make_state(
|
||||
interrupt=True,
|
||||
data={"action": "cancel_order", "order_id": "1042"},
|
||||
)
|
||||
graph = make_graph(
|
||||
chunks=[],
|
||||
state=interrupt_state,
|
||||
resume_chunks=[
|
||||
make_chunk("Order 1042 has been cancelled successfully.", "order_actions"),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Step 1: Send cancel request
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-approve-1",
|
||||
"content": "Cancel order 1042",
|
||||
})
|
||||
|
||||
messages = _collect_until_type(ws, "interrupt")
|
||||
|
||||
interrupts = [m for m in messages if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
assert interrupts[0]["action"] == "cancel_order"
|
||||
assert interrupts[0]["thread_id"] == "e2e-approve-1"
|
||||
|
||||
# Step 2: Approve the interrupt
|
||||
ws.send_json({
|
||||
"type": "interrupt_response",
|
||||
"thread_id": "e2e-approve-1",
|
||||
"approved": True,
|
||||
})
|
||||
|
||||
resume_messages = _collect_until_complete(ws)
|
||||
|
||||
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "cancelled" in tokens[0]["content"]
|
||||
assert tokens[0]["agent"] == "order_actions"
|
||||
|
||||
completes = [m for m in resume_messages if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
|
||||
|
||||
class TestFlow3RejectionFlow:
|
||||
"""Flow 3: write operation -> interrupt -> reject -> no execution."""
|
||||
|
||||
def test_interrupt_reject_does_not_execute(self) -> None:
|
||||
interrupt_state = make_state(
|
||||
interrupt=True,
|
||||
data={"action": "cancel_order", "order_id": "1042"},
|
||||
)
|
||||
graph = make_graph(
|
||||
chunks=[],
|
||||
state=interrupt_state,
|
||||
resume_chunks=[
|
||||
make_chunk("Understood. Order 1042 will remain active.", "order_actions"),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Step 1: Trigger interrupt
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-reject-1",
|
||||
"content": "Cancel order 1042",
|
||||
})
|
||||
|
||||
messages = _collect_until_type(ws, "interrupt")
|
||||
assert any(m["type"] == "interrupt" for m in messages)
|
||||
|
||||
# Step 2: Reject
|
||||
ws.send_json({
|
||||
"type": "interrupt_response",
|
||||
"thread_id": "e2e-reject-1",
|
||||
"approved": False,
|
||||
})
|
||||
|
||||
resume_messages = _collect_until_complete(ws)
|
||||
|
||||
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "remain active" in tokens[0]["content"]
|
||||
|
||||
# Verify graph.astream was called with resume=False
|
||||
resume_call = graph.astream.call_args_list[-1]
|
||||
command = resume_call[0][0]
|
||||
assert command.resume is False
|
||||
|
||||
|
||||
class TestFlow4MultiTurnContext:
|
||||
"""Flow 4: multi-turn conversation in the same session."""
|
||||
|
||||
def test_multi_turn_messages_share_session(self) -> None:
|
||||
"""Multiple messages in the same thread_id maintain session context."""
|
||||
graph = make_graph(
|
||||
chunks=[make_chunk("Order 1042 status: shipped.")],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Turn 1: Query order
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-multi-1",
|
||||
"content": "What is the status of order 1042?",
|
||||
})
|
||||
turn1 = _collect_until_complete(ws)
|
||||
assert any(m["type"] == "message_complete" for m in turn1)
|
||||
|
||||
# Turn 2: Follow-up in same thread
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-multi-1",
|
||||
"content": "When will it arrive?",
|
||||
})
|
||||
turn2 = _collect_until_complete(ws)
|
||||
assert any(m["type"] == "message_complete" for m in turn2)
|
||||
|
||||
# Turn 3: Another follow-up
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-multi-1",
|
||||
"content": "Can you track it?",
|
||||
})
|
||||
turn3 = _collect_until_complete(ws)
|
||||
assert any(m["type"] == "message_complete" for m in turn3)
|
||||
|
||||
# Verify all turns used the same thread_id in graph calls
|
||||
for call in graph.astream.call_args_list:
|
||||
config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {})
|
||||
assert config["configurable"]["thread_id"] == "e2e-multi-1"
|
||||
|
||||
def test_separate_threads_are_independent(self) -> None:
|
||||
"""Different thread_ids have independent sessions."""
|
||||
graph = make_graph(
|
||||
chunks=[make_chunk("Response.")],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Thread A
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-thread-a",
|
||||
"content": "Hello from thread A",
|
||||
})
|
||||
_collect_until_complete(ws)
|
||||
|
||||
# Thread B
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-thread-b",
|
||||
"content": "Hello from thread B",
|
||||
})
|
||||
_collect_until_complete(ws)
|
||||
|
||||
# Both threads should exist as separate sessions
|
||||
sm = app.state.session_manager
|
||||
assert sm.get_state("e2e-thread-a") is not None
|
||||
assert sm.get_state("e2e-thread-b") is not None
|
||||
|
||||
|
||||
class TestChatEdgeCases:
|
||||
"""Edge cases and error handling for the chat WebSocket."""
|
||||
|
||||
def test_invalid_json_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_text("not valid json")
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert "Invalid JSON" in msg["message"]
|
||||
|
||||
def test_missing_thread_id_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({"type": "message", "content": "hello"})
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert "thread_id" in msg["message"]
|
||||
|
||||
def test_empty_content_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-err-1",
|
||||
"content": "",
|
||||
})
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
|
||||
def test_expired_session_returns_error(self) -> None:
|
||||
graph = make_graph(chunks=[make_chunk("Response.")])
|
||||
app = create_e2e_app(graph=graph, session_ttl=0)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# First message creates the session (TTL=0)
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-expired-1",
|
||||
"content": "hello",
|
||||
})
|
||||
_collect_until_complete_or_error(ws)
|
||||
|
||||
# Second message finds the session expired (TTL=0)
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-expired-1",
|
||||
"content": "hello again",
|
||||
})
|
||||
messages = _collect_until_complete_or_error(ws)
|
||||
errors = [m for m in messages if m["type"] == "error"]
|
||||
assert len(errors) >= 1
|
||||
assert "expired" in errors[0]["message"].lower()
|
||||
|
||||
def test_oversized_message_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_text("x" * 40_000)
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert "too large" in msg["message"].lower()
|
||||
|
||||
def test_health_endpoint(self) -> None:
|
||||
app = create_e2e_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]:
|
||||
"""Receive WebSocket messages until message_complete or error."""
|
||||
messages = []
|
||||
for _ in range(max_messages):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
return messages
|
||||
|
||||
|
||||
def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]:
|
||||
"""Receive until a specific message type is received."""
|
||||
messages = []
|
||||
for _ in range(max_messages):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] == msg_type:
|
||||
break
|
||||
return messages
|
||||
|
||||
|
||||
def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]:
|
||||
"""Receive until message_complete or error."""
|
||||
messages = []
|
||||
for _ in range(max_messages):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
return messages
|
||||
Reference in New Issue
Block a user