385 lines
14 KiB
Python
385 lines
14 KiB
Python
"""E2E tests for critical chat user flows (flows 1-4).
|
|
|
|
Flow 1: Happy path -- query order, get answer
|
|
Flow 2: Approval flow -- write operation, interrupt, approve, execute
|
|
Flow 3: Rejection flow -- write operation, interrupt, reject, no execution
|
|
Flow 4: Multi-turn context -- sequential messages in same session
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
import pytest
|
|
from starlette.testclient import TestClient
|
|
|
|
from tests.e2e.conftest import (
|
|
create_e2e_app,
|
|
make_chunk,
|
|
make_graph,
|
|
make_state,
|
|
make_tool_chunk,
|
|
)
|
|
|
|
pytestmark = pytest.mark.e2e
|
|
|
|
|
|
class TestFlow1HappyPath:
|
|
"""Flow 1: query order -> get answer with streaming tokens."""
|
|
|
|
def test_websocket_happy_path_order_query(self) -> None:
|
|
graph = make_graph(
|
|
chunks=[
|
|
make_tool_chunk("get_order_status", {"order_id": "1042"}),
|
|
make_chunk("Order 1042 has been shipped and is on its way."),
|
|
],
|
|
)
|
|
app = create_e2e_app(graph=graph)
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-happy-1",
|
|
"content": "What is the status of order 1042?",
|
|
})
|
|
|
|
messages = []
|
|
while True:
|
|
msg = ws.receive_json()
|
|
messages.append(msg)
|
|
if msg["type"] in ("message_complete", "error"):
|
|
break
|
|
|
|
tool_calls = [m for m in messages if m["type"] == "tool_call"]
|
|
assert len(tool_calls) == 1
|
|
assert tool_calls[0]["tool"] == "get_order_status"
|
|
assert tool_calls[0]["args"] == {"order_id": "1042"}
|
|
|
|
tokens = [m for m in messages if m["type"] == "token"]
|
|
assert len(tokens) == 1
|
|
assert "shipped" in tokens[0]["content"]
|
|
|
|
completes = [m for m in messages if m["type"] == "message_complete"]
|
|
assert len(completes) == 1
|
|
assert completes[0]["thread_id"] == "e2e-happy-1"
|
|
|
|
def test_websocket_multiple_token_stream(self) -> None:
|
|
"""Verify streaming returns multiple token chunks."""
|
|
graph = make_graph(
|
|
chunks=[
|
|
make_chunk("Your order "),
|
|
make_chunk("1042 "),
|
|
make_chunk("was delivered "),
|
|
make_chunk("yesterday."),
|
|
],
|
|
)
|
|
app = create_e2e_app(graph=graph)
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-stream-1",
|
|
"content": "Where is my order?",
|
|
})
|
|
|
|
messages = _collect_until_complete(ws)
|
|
|
|
tokens = [m for m in messages if m["type"] == "token"]
|
|
assert len(tokens) == 4
|
|
full_text = "".join(t["content"] for t in tokens)
|
|
assert "1042" in full_text
|
|
assert "delivered" in full_text
|
|
|
|
|
|
class TestFlow2ApprovalFlow:
|
|
"""Flow 2: write operation -> interrupt -> approve -> execute."""
|
|
|
|
def test_interrupt_approve_executes_action(self) -> None:
|
|
interrupt_state = make_state(
|
|
interrupt=True,
|
|
data={"action": "cancel_order", "order_id": "1042"},
|
|
)
|
|
graph = make_graph(
|
|
chunks=[],
|
|
state=interrupt_state,
|
|
resume_chunks=[
|
|
make_chunk("Order 1042 has been cancelled successfully.", "order_actions"),
|
|
],
|
|
)
|
|
app = create_e2e_app(graph=graph)
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
# Step 1: Send cancel request
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-approve-1",
|
|
"content": "Cancel order 1042",
|
|
})
|
|
|
|
messages = _collect_until_type(ws, "interrupt")
|
|
|
|
interrupts = [m for m in messages if m["type"] == "interrupt"]
|
|
assert len(interrupts) == 1
|
|
assert interrupts[0]["action"] == "cancel_order"
|
|
assert interrupts[0]["thread_id"] == "e2e-approve-1"
|
|
|
|
# Step 2: Approve the interrupt
|
|
ws.send_json({
|
|
"type": "interrupt_response",
|
|
"thread_id": "e2e-approve-1",
|
|
"approved": True,
|
|
})
|
|
|
|
resume_messages = _collect_until_complete(ws)
|
|
|
|
tokens = [m for m in resume_messages if m["type"] == "token"]
|
|
assert len(tokens) == 1
|
|
assert "cancelled" in tokens[0]["content"]
|
|
assert tokens[0]["agent"] == "order_actions"
|
|
|
|
completes = [m for m in resume_messages if m["type"] == "message_complete"]
|
|
assert len(completes) == 1
|
|
|
|
|
|
class TestFlow3RejectionFlow:
|
|
"""Flow 3: write operation -> interrupt -> reject -> no execution."""
|
|
|
|
def test_interrupt_reject_does_not_execute(self) -> None:
|
|
interrupt_state = make_state(
|
|
interrupt=True,
|
|
data={"action": "cancel_order", "order_id": "1042"},
|
|
)
|
|
graph = make_graph(
|
|
chunks=[],
|
|
state=interrupt_state,
|
|
resume_chunks=[
|
|
make_chunk("Understood. Order 1042 will remain active.", "order_actions"),
|
|
],
|
|
)
|
|
app = create_e2e_app(graph=graph)
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
# Step 1: Trigger interrupt
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-reject-1",
|
|
"content": "Cancel order 1042",
|
|
})
|
|
|
|
messages = _collect_until_type(ws, "interrupt")
|
|
assert any(m["type"] == "interrupt" for m in messages)
|
|
|
|
# Step 2: Reject
|
|
ws.send_json({
|
|
"type": "interrupt_response",
|
|
"thread_id": "e2e-reject-1",
|
|
"approved": False,
|
|
})
|
|
|
|
resume_messages = _collect_until_complete(ws)
|
|
|
|
tokens = [m for m in resume_messages if m["type"] == "token"]
|
|
assert len(tokens) == 1
|
|
assert "remain active" in tokens[0]["content"]
|
|
|
|
# Verify graph.astream was called with resume=False
|
|
resume_call = graph.astream.call_args_list[-1]
|
|
command = resume_call[0][0]
|
|
assert command.resume is False
|
|
|
|
|
|
class TestFlow4MultiTurnContext:
|
|
"""Flow 4: multi-turn conversation in the same session."""
|
|
|
|
def test_multi_turn_messages_share_session(self) -> None:
|
|
"""Multiple messages in the same thread_id maintain session context."""
|
|
graph = make_graph(
|
|
chunks=[make_chunk("Order 1042 status: shipped.")],
|
|
)
|
|
app = create_e2e_app(graph=graph)
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
# Turn 1: Query order
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-multi-1",
|
|
"content": "What is the status of order 1042?",
|
|
})
|
|
turn1 = _collect_until_complete(ws)
|
|
assert any(m["type"] == "message_complete" for m in turn1)
|
|
|
|
# Turn 2: Follow-up in same thread
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-multi-1",
|
|
"content": "When will it arrive?",
|
|
})
|
|
turn2 = _collect_until_complete(ws)
|
|
assert any(m["type"] == "message_complete" for m in turn2)
|
|
|
|
# Turn 3: Another follow-up
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-multi-1",
|
|
"content": "Can you track it?",
|
|
})
|
|
turn3 = _collect_until_complete(ws)
|
|
assert any(m["type"] == "message_complete" for m in turn3)
|
|
|
|
# Verify all turns used the same thread_id in graph calls
|
|
for call in graph.astream.call_args_list:
|
|
config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {})
|
|
assert config["configurable"]["thread_id"] == "e2e-multi-1"
|
|
|
|
def test_separate_threads_are_independent(self) -> None:
|
|
"""Different thread_ids have independent sessions."""
|
|
graph = make_graph(
|
|
chunks=[make_chunk("Response.")],
|
|
)
|
|
app = create_e2e_app(graph=graph)
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
# Thread A
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-thread-a",
|
|
"content": "Hello from thread A",
|
|
})
|
|
_collect_until_complete(ws)
|
|
|
|
# Thread B
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-thread-b",
|
|
"content": "Hello from thread B",
|
|
})
|
|
_collect_until_complete(ws)
|
|
|
|
# Both threads should exist as separate sessions
|
|
sm = app.state.session_manager
|
|
assert sm.get_state("e2e-thread-a") is not None
|
|
assert sm.get_state("e2e-thread-b") is not None
|
|
|
|
|
|
class TestChatEdgeCases:
|
|
"""Edge cases and error handling for the chat WebSocket."""
|
|
|
|
def test_invalid_json_returns_error(self) -> None:
|
|
app = create_e2e_app()
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_text("not valid json")
|
|
msg = ws.receive_json()
|
|
assert msg["type"] == "error"
|
|
assert "Invalid JSON" in msg["message"]
|
|
|
|
def test_missing_thread_id_returns_error(self) -> None:
|
|
app = create_e2e_app()
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_json({"type": "message", "content": "hello"})
|
|
msg = ws.receive_json()
|
|
assert msg["type"] == "error"
|
|
assert "thread_id" in msg["message"]
|
|
|
|
def test_empty_content_returns_error(self) -> None:
|
|
app = create_e2e_app()
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-err-1",
|
|
"content": "",
|
|
})
|
|
msg = ws.receive_json()
|
|
assert msg["type"] == "error"
|
|
|
|
def test_expired_session_returns_error(self) -> None:
|
|
graph = make_graph(chunks=[make_chunk("Response.")])
|
|
app = create_e2e_app(graph=graph, session_ttl=0)
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
# First message creates the session (TTL=0)
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-expired-1",
|
|
"content": "hello",
|
|
})
|
|
_collect_until_complete_or_error(ws)
|
|
|
|
# Second message finds the session expired (TTL=0)
|
|
ws.send_json({
|
|
"type": "message",
|
|
"thread_id": "e2e-expired-1",
|
|
"content": "hello again",
|
|
})
|
|
messages = _collect_until_complete_or_error(ws)
|
|
errors = [m for m in messages if m["type"] == "error"]
|
|
assert len(errors) >= 1
|
|
assert "expired" in errors[0]["message"].lower()
|
|
|
|
def test_oversized_message_returns_error(self) -> None:
|
|
app = create_e2e_app()
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_text("x" * 40_000)
|
|
msg = ws.receive_json()
|
|
assert msg["type"] == "error"
|
|
assert "too large" in msg["message"].lower()
|
|
|
|
def test_health_endpoint(self) -> None:
|
|
app = create_e2e_app()
|
|
with TestClient(app) as client:
|
|
resp = client.get("/api/health")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "ok"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]:
|
|
"""Receive WebSocket messages until message_complete or error."""
|
|
messages = []
|
|
for _ in range(max_messages):
|
|
msg = ws.receive_json()
|
|
messages.append(msg)
|
|
if msg["type"] in ("message_complete", "error"):
|
|
break
|
|
return messages
|
|
|
|
|
|
def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]:
|
|
"""Receive until a specific message type is received."""
|
|
messages = []
|
|
for _ in range(max_messages):
|
|
msg = ws.receive_json()
|
|
messages.append(msg)
|
|
if msg["type"] == msg_type:
|
|
break
|
|
return messages
|
|
|
|
|
|
def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]:
|
|
"""Receive until message_complete or error."""
|
|
messages = []
|
|
for _ in range(max_messages):
|
|
msg = ws.receive_json()
|
|
messages.append(msg)
|
|
if msg["type"] in ("message_complete", "error"):
|
|
break
|
|
return messages
|