"""E2E tests for critical chat user flows (flows 1-4). Flow 1: Happy path -- query order, get answer Flow 2: Approval flow -- write operation, interrupt, approve, execute Flow 3: Rejection flow -- write operation, interrupt, reject, no execution Flow 4: Multi-turn context -- sequential messages in same session """ from __future__ import annotations import json import pytest from starlette.testclient import TestClient from tests.e2e.conftest import ( create_e2e_app, make_chunk, make_graph, make_state, make_tool_chunk, ) pytestmark = pytest.mark.e2e class TestFlow1HappyPath: """Flow 1: query order -> get answer with streaming tokens.""" def test_websocket_happy_path_order_query(self) -> None: graph = make_graph( chunks=[ make_tool_chunk("get_order_status", {"order_id": "1042"}), make_chunk("Order 1042 has been shipped and is on its way."), ], ) app = create_e2e_app(graph=graph) with TestClient(app) as client: with client.websocket_connect("/ws") as ws: ws.send_json({ "type": "message", "thread_id": "e2e-happy-1", "content": "What is the status of order 1042?", }) messages = [] while True: msg = ws.receive_json() messages.append(msg) if msg["type"] in ("message_complete", "error"): break tool_calls = [m for m in messages if m["type"] == "tool_call"] assert len(tool_calls) == 1 assert tool_calls[0]["tool"] == "get_order_status" assert tool_calls[0]["args"] == {"order_id": "1042"} tokens = [m for m in messages if m["type"] == "token"] assert len(tokens) == 1 assert "shipped" in tokens[0]["content"] completes = [m for m in messages if m["type"] == "message_complete"] assert len(completes) == 1 assert completes[0]["thread_id"] == "e2e-happy-1" def test_websocket_multiple_token_stream(self) -> None: """Verify streaming returns multiple token chunks.""" graph = make_graph( chunks=[ make_chunk("Your order "), make_chunk("1042 "), make_chunk("was delivered "), make_chunk("yesterday."), ], ) app = create_e2e_app(graph=graph) with TestClient(app) as client: with client.websocket_connect("/ws") as ws: ws.send_json({ "type": "message", "thread_id": "e2e-stream-1", "content": "Where is my order?", }) messages = _collect_until_complete(ws) tokens = [m for m in messages if m["type"] == "token"] assert len(tokens) == 4 full_text = "".join(t["content"] for t in tokens) assert "1042" in full_text assert "delivered" in full_text class TestFlow2ApprovalFlow: """Flow 2: write operation -> interrupt -> approve -> execute.""" def test_interrupt_approve_executes_action(self) -> None: interrupt_state = make_state( interrupt=True, data={"action": "cancel_order", "order_id": "1042"}, ) graph = make_graph( chunks=[], state=interrupt_state, resume_chunks=[ make_chunk("Order 1042 has been cancelled successfully.", "order_actions"), ], ) app = create_e2e_app(graph=graph) with TestClient(app) as client: with client.websocket_connect("/ws") as ws: # Step 1: Send cancel request ws.send_json({ "type": "message", "thread_id": "e2e-approve-1", "content": "Cancel order 1042", }) messages = _collect_until_type(ws, "interrupt") interrupts = [m for m in messages if m["type"] == "interrupt"] assert len(interrupts) == 1 assert interrupts[0]["action"] == "cancel_order" assert interrupts[0]["thread_id"] == "e2e-approve-1" # Step 2: Approve the interrupt ws.send_json({ "type": "interrupt_response", "thread_id": "e2e-approve-1", "approved": True, }) resume_messages = _collect_until_complete(ws) tokens = [m for m in resume_messages if m["type"] == "token"] assert len(tokens) == 1 assert "cancelled" in tokens[0]["content"] assert tokens[0]["agent"] == "order_actions" completes = [m for m in resume_messages if m["type"] == "message_complete"] assert len(completes) == 1 class TestFlow3RejectionFlow: """Flow 3: write operation -> interrupt -> reject -> no execution.""" def test_interrupt_reject_does_not_execute(self) -> None: interrupt_state = make_state( interrupt=True, data={"action": "cancel_order", "order_id": "1042"}, ) graph = make_graph( chunks=[], state=interrupt_state, resume_chunks=[ make_chunk("Understood. Order 1042 will remain active.", "order_actions"), ], ) app = create_e2e_app(graph=graph) with TestClient(app) as client: with client.websocket_connect("/ws") as ws: # Step 1: Trigger interrupt ws.send_json({ "type": "message", "thread_id": "e2e-reject-1", "content": "Cancel order 1042", }) messages = _collect_until_type(ws, "interrupt") assert any(m["type"] == "interrupt" for m in messages) # Step 2: Reject ws.send_json({ "type": "interrupt_response", "thread_id": "e2e-reject-1", "approved": False, }) resume_messages = _collect_until_complete(ws) tokens = [m for m in resume_messages if m["type"] == "token"] assert len(tokens) == 1 assert "remain active" in tokens[0]["content"] # Verify graph.astream was called with resume=False resume_call = graph.astream.call_args_list[-1] command = resume_call[0][0] assert command.resume is False class TestFlow4MultiTurnContext: """Flow 4: multi-turn conversation in the same session.""" def test_multi_turn_messages_share_session(self) -> None: """Multiple messages in the same thread_id maintain session context.""" graph = make_graph( chunks=[make_chunk("Order 1042 status: shipped.")], ) app = create_e2e_app(graph=graph) with TestClient(app) as client: with client.websocket_connect("/ws") as ws: # Turn 1: Query order ws.send_json({ "type": "message", "thread_id": "e2e-multi-1", "content": "What is the status of order 1042?", }) turn1 = _collect_until_complete(ws) assert any(m["type"] == "message_complete" for m in turn1) # Turn 2: Follow-up in same thread ws.send_json({ "type": "message", "thread_id": "e2e-multi-1", "content": "When will it arrive?", }) turn2 = _collect_until_complete(ws) assert any(m["type"] == "message_complete" for m in turn2) # Turn 3: Another follow-up ws.send_json({ "type": "message", "thread_id": "e2e-multi-1", "content": "Can you track it?", }) turn3 = _collect_until_complete(ws) assert any(m["type"] == "message_complete" for m in turn3) # Verify all turns used the same thread_id in graph calls for call in graph.astream.call_args_list: config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {}) assert config["configurable"]["thread_id"] == "e2e-multi-1" def test_separate_threads_are_independent(self) -> None: """Different thread_ids have independent sessions.""" graph = make_graph( chunks=[make_chunk("Response.")], ) app = create_e2e_app(graph=graph) with TestClient(app) as client: with client.websocket_connect("/ws") as ws: # Thread A ws.send_json({ "type": "message", "thread_id": "e2e-thread-a", "content": "Hello from thread A", }) _collect_until_complete(ws) # Thread B ws.send_json({ "type": "message", "thread_id": "e2e-thread-b", "content": "Hello from thread B", }) _collect_until_complete(ws) # Both threads should exist as separate sessions sm = app.state.session_manager assert sm.get_state("e2e-thread-a") is not None assert sm.get_state("e2e-thread-b") is not None class TestChatEdgeCases: """Edge cases and error handling for the chat WebSocket.""" def test_invalid_json_returns_error(self) -> None: app = create_e2e_app() with TestClient(app) as client: with client.websocket_connect("/ws") as ws: ws.send_text("not valid json") msg = ws.receive_json() assert msg["type"] == "error" assert "Invalid JSON" in msg["message"] def test_missing_thread_id_returns_error(self) -> None: app = create_e2e_app() with TestClient(app) as client: with client.websocket_connect("/ws") as ws: ws.send_json({"type": "message", "content": "hello"}) msg = ws.receive_json() assert msg["type"] == "error" assert "thread_id" in msg["message"] def test_empty_content_returns_error(self) -> None: app = create_e2e_app() with TestClient(app) as client: with client.websocket_connect("/ws") as ws: ws.send_json({ "type": "message", "thread_id": "e2e-err-1", "content": "", }) msg = ws.receive_json() assert msg["type"] == "error" def test_expired_session_returns_error(self) -> None: graph = make_graph(chunks=[make_chunk("Response.")]) app = create_e2e_app(graph=graph, session_ttl=0) with TestClient(app) as client: with client.websocket_connect("/ws") as ws: # First message creates the session (TTL=0) ws.send_json({ "type": "message", "thread_id": "e2e-expired-1", "content": "hello", }) _collect_until_complete_or_error(ws) # Second message finds the session expired (TTL=0) ws.send_json({ "type": "message", "thread_id": "e2e-expired-1", "content": "hello again", }) messages = _collect_until_complete_or_error(ws) errors = [m for m in messages if m["type"] == "error"] assert len(errors) >= 1 assert "expired" in errors[0]["message"].lower() def test_oversized_message_returns_error(self) -> None: app = create_e2e_app() with TestClient(app) as client: with client.websocket_connect("/ws") as ws: ws.send_text("x" * 40_000) msg = ws.receive_json() assert msg["type"] == "error" assert "too large" in msg["message"].lower() def test_health_endpoint(self) -> None: app = create_e2e_app() with TestClient(app) as client: resp = client.get("/api/health") assert resp.status_code == 200 assert resp.json()["status"] == "ok" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]: """Receive WebSocket messages until message_complete or error.""" messages = [] for _ in range(max_messages): msg = ws.receive_json() messages.append(msg) if msg["type"] in ("message_complete", "error"): break return messages def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]: """Receive until a specific message type is received.""" messages = [] for _ in range(max_messages): msg = ws.receive_json() messages.append(msg) if msg["type"] == msg_type: break return messages def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]: """Receive until message_complete or error.""" messages = [] for _ in range(max_messages): msg = ws.receive_json() messages.append(msg) if msg["type"] in ("message_complete", "error"): break return messages