Files
smart-support/backend/tests/e2e/test_chat_flows.py
Yaojia Wang f0699436c5 refactor: engineering improvements -- API versioning, structured logging, Alembic, error standardization, test coverage
- API versioning: all REST endpoints prefixed with /api/v1/
- Structured logging: replaced stdlib logging with structlog (console/JSON modes)
- Alembic migrations: versioned DB schema with initial migration
- Error standardization: global exception handlers for consistent envelope format
- Interrupt cleanup: asyncio background task for expired interrupt removal
- Integration tests: +30 tests (analytics, replay, openapi, error, session APIs)
- Frontend tests: +57 tests (all components, pages, useWebSocket hook)
- Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
2026-04-06 23:19:29 +02:00

385 lines
14 KiB
Python

"""E2E tests for critical chat user flows (flows 1-4).
Flow 1: Happy path -- query order, get answer
Flow 2: Approval flow -- write operation, interrupt, approve, execute
Flow 3: Rejection flow -- write operation, interrupt, reject, no execution
Flow 4: Multi-turn context -- sequential messages in same session
"""
from __future__ import annotations
import json
import pytest
from starlette.testclient import TestClient
from tests.e2e.conftest import (
create_e2e_app,
make_chunk,
make_graph,
make_state,
make_tool_chunk,
)
pytestmark = pytest.mark.e2e
class TestFlow1HappyPath:
"""Flow 1: query order -> get answer with streaming tokens."""
def test_websocket_happy_path_order_query(self) -> None:
graph = make_graph(
chunks=[
make_tool_chunk("get_order_status", {"order_id": "1042"}),
make_chunk("Order 1042 has been shipped and is on its way."),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-happy-1",
"content": "What is the status of order 1042?",
})
messages = []
while True:
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
tool_calls = [m for m in messages if m["type"] == "tool_call"]
assert len(tool_calls) == 1
assert tool_calls[0]["tool"] == "get_order_status"
assert tool_calls[0]["args"] == {"order_id": "1042"}
tokens = [m for m in messages if m["type"] == "token"]
assert len(tokens) == 1
assert "shipped" in tokens[0]["content"]
completes = [m for m in messages if m["type"] == "message_complete"]
assert len(completes) == 1
assert completes[0]["thread_id"] == "e2e-happy-1"
def test_websocket_multiple_token_stream(self) -> None:
"""Verify streaming returns multiple token chunks."""
graph = make_graph(
chunks=[
make_chunk("Your order "),
make_chunk("1042 "),
make_chunk("was delivered "),
make_chunk("yesterday."),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-stream-1",
"content": "Where is my order?",
})
messages = _collect_until_complete(ws)
tokens = [m for m in messages if m["type"] == "token"]
assert len(tokens) == 4
full_text = "".join(t["content"] for t in tokens)
assert "1042" in full_text
assert "delivered" in full_text
class TestFlow2ApprovalFlow:
"""Flow 2: write operation -> interrupt -> approve -> execute."""
def test_interrupt_approve_executes_action(self) -> None:
interrupt_state = make_state(
interrupt=True,
data={"action": "cancel_order", "order_id": "1042"},
)
graph = make_graph(
chunks=[],
state=interrupt_state,
resume_chunks=[
make_chunk("Order 1042 has been cancelled successfully.", "order_actions"),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Step 1: Send cancel request
ws.send_json({
"type": "message",
"thread_id": "e2e-approve-1",
"content": "Cancel order 1042",
})
messages = _collect_until_type(ws, "interrupt")
interrupts = [m for m in messages if m["type"] == "interrupt"]
assert len(interrupts) == 1
assert interrupts[0]["action"] == "cancel_order"
assert interrupts[0]["thread_id"] == "e2e-approve-1"
# Step 2: Approve the interrupt
ws.send_json({
"type": "interrupt_response",
"thread_id": "e2e-approve-1",
"approved": True,
})
resume_messages = _collect_until_complete(ws)
tokens = [m for m in resume_messages if m["type"] == "token"]
assert len(tokens) == 1
assert "cancelled" in tokens[0]["content"]
assert tokens[0]["agent"] == "order_actions"
completes = [m for m in resume_messages if m["type"] == "message_complete"]
assert len(completes) == 1
class TestFlow3RejectionFlow:
"""Flow 3: write operation -> interrupt -> reject -> no execution."""
def test_interrupt_reject_does_not_execute(self) -> None:
interrupt_state = make_state(
interrupt=True,
data={"action": "cancel_order", "order_id": "1042"},
)
graph = make_graph(
chunks=[],
state=interrupt_state,
resume_chunks=[
make_chunk("Understood. Order 1042 will remain active.", "order_actions"),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Step 1: Trigger interrupt
ws.send_json({
"type": "message",
"thread_id": "e2e-reject-1",
"content": "Cancel order 1042",
})
messages = _collect_until_type(ws, "interrupt")
assert any(m["type"] == "interrupt" for m in messages)
# Step 2: Reject
ws.send_json({
"type": "interrupt_response",
"thread_id": "e2e-reject-1",
"approved": False,
})
resume_messages = _collect_until_complete(ws)
tokens = [m for m in resume_messages if m["type"] == "token"]
assert len(tokens) == 1
assert "remain active" in tokens[0]["content"]
# Verify graph.astream was called with resume=False
resume_call = graph.astream.call_args_list[-1]
command = resume_call[0][0]
assert command.resume is False
class TestFlow4MultiTurnContext:
"""Flow 4: multi-turn conversation in the same session."""
def test_multi_turn_messages_share_session(self) -> None:
"""Multiple messages in the same thread_id maintain session context."""
graph = make_graph(
chunks=[make_chunk("Order 1042 status: shipped.")],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Turn 1: Query order
ws.send_json({
"type": "message",
"thread_id": "e2e-multi-1",
"content": "What is the status of order 1042?",
})
turn1 = _collect_until_complete(ws)
assert any(m["type"] == "message_complete" for m in turn1)
# Turn 2: Follow-up in same thread
ws.send_json({
"type": "message",
"thread_id": "e2e-multi-1",
"content": "When will it arrive?",
})
turn2 = _collect_until_complete(ws)
assert any(m["type"] == "message_complete" for m in turn2)
# Turn 3: Another follow-up
ws.send_json({
"type": "message",
"thread_id": "e2e-multi-1",
"content": "Can you track it?",
})
turn3 = _collect_until_complete(ws)
assert any(m["type"] == "message_complete" for m in turn3)
# Verify all turns used the same thread_id in graph calls
for call in graph.astream.call_args_list:
config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {})
assert config["configurable"]["thread_id"] == "e2e-multi-1"
def test_separate_threads_are_independent(self) -> None:
"""Different thread_ids have independent sessions."""
graph = make_graph(
chunks=[make_chunk("Response.")],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Thread A
ws.send_json({
"type": "message",
"thread_id": "e2e-thread-a",
"content": "Hello from thread A",
})
_collect_until_complete(ws)
# Thread B
ws.send_json({
"type": "message",
"thread_id": "e2e-thread-b",
"content": "Hello from thread B",
})
_collect_until_complete(ws)
# Both threads should exist as separate sessions
sm = app.state.session_manager
assert sm.get_state("e2e-thread-a") is not None
assert sm.get_state("e2e-thread-b") is not None
class TestChatEdgeCases:
"""Edge cases and error handling for the chat WebSocket."""
def test_invalid_json_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_text("not valid json")
msg = ws.receive_json()
assert msg["type"] == "error"
assert "Invalid JSON" in msg["message"]
def test_missing_thread_id_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "message", "content": "hello"})
msg = ws.receive_json()
assert msg["type"] == "error"
assert "thread_id" in msg["message"]
def test_empty_content_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-err-1",
"content": "",
})
msg = ws.receive_json()
assert msg["type"] == "error"
def test_expired_session_returns_error(self) -> None:
graph = make_graph(chunks=[make_chunk("Response.")])
app = create_e2e_app(graph=graph, session_ttl=0)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# First message creates the session (TTL=0)
ws.send_json({
"type": "message",
"thread_id": "e2e-expired-1",
"content": "hello",
})
_collect_until_complete_or_error(ws)
# Second message finds the session expired (TTL=0)
ws.send_json({
"type": "message",
"thread_id": "e2e-expired-1",
"content": "hello again",
})
messages = _collect_until_complete_or_error(ws)
errors = [m for m in messages if m["type"] == "error"]
assert len(errors) >= 1
assert "expired" in errors[0]["message"].lower()
def test_oversized_message_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_text("x" * 40_000)
msg = ws.receive_json()
assert msg["type"] == "error"
assert "too large" in msg["message"].lower()
def test_health_endpoint(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/v1/health")
assert resp.status_code == 200
assert resp.json()["status"] == "ok"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]:
"""Receive WebSocket messages until message_complete or error."""
messages = []
for _ in range(max_messages):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
return messages
def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]:
"""Receive until a specific message type is received."""
messages = []
for _ in range(max_messages):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] == msg_type:
break
return messages
def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]:
"""Receive until message_complete or error."""
messages = []
for _ in range(max_messages):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
return messages