"""Tests for app.ws_handler module.""" from __future__ import annotations import json from unittest.mock import AsyncMock, MagicMock import pytest from app.callbacks import TokenUsageCallbackHandler from app.interrupt_manager import InterruptManager from app.session_manager import SessionManager from app.ws_handler import ( _extract_interrupt, _has_interrupt, dispatch_message, handle_interrupt_response, handle_user_message, ) def _make_ws() -> AsyncMock: ws = AsyncMock() ws.send_json = AsyncMock() return ws def _make_graph() -> AsyncMock: graph = AsyncMock() graph.astream = MagicMock(return_value=AsyncIterHelper([])) state = MagicMock() state.tasks = () graph.aget_state = AsyncMock(return_value=state) # Phase 2: graph needs intent_classifier and agent_registry attrs graph.intent_classifier = None graph.agent_registry = None return graph class AsyncIterHelper: """Helper to make a list behave as an async iterator.""" def __init__(self, items): self._items = items def __aiter__(self): return self async def __anext__(self): if not self._items: raise StopAsyncIteration return self._items.pop(0) @pytest.mark.unit class TestDispatchMessage: @pytest.mark.asyncio async def test_invalid_json(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() await dispatch_message(ws, graph, sm, cb, "not json") ws.send_json.assert_awaited_once() call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "Invalid JSON" in call_data["message"] @pytest.mark.asyncio async def test_missing_thread_id(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "message", "content": "hello"}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "thread_id" in call_data["message"] @pytest.mark.asyncio async def test_missing_content(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "message", "thread_id": "t1"}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @pytest.mark.asyncio async def test_unknown_message_type(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "unknown", "thread_id": "t1"}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "Unknown" in call_data["message"] @pytest.mark.asyncio async def test_message_too_large(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() large_msg = "x" * 40_000 await dispatch_message(ws, graph, sm, cb, large_msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "too large" in call_data["message"].lower() @pytest.mark.asyncio async def test_invalid_thread_id_format(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "thread_id" in call_data["message"].lower() @pytest.mark.asyncio async def test_content_too_long(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "too long" in call_data["message"].lower() @pytest.mark.asyncio async def test_dispatch_with_interrupt_manager(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() im = InterruptManager() sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im) last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @pytest.mark.unit class TestHandleUserMessage: @pytest.mark.asyncio async def test_expired_session(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager(session_ttl_seconds=0) cb = TokenUsageCallbackHandler() await handle_user_message(ws, graph, sm, cb, "t1", "hello") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "expired" in call_data["message"].lower() @pytest.mark.asyncio async def test_successful_message(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") await handle_user_message(ws, graph, sm, cb, "t1", "hello") last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @pytest.mark.asyncio async def test_graph_error_sends_error_message(self) -> None: ws = _make_ws() graph = AsyncMock() graph.astream = MagicMock(side_effect=RuntimeError("boom")) graph.intent_classifier = None graph.agent_registry = None sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") await handle_user_message(ws, graph, sm, cb, "t1", "hello") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @pytest.mark.asyncio async def test_interrupt_registered_with_manager(self) -> None: ws = _make_ws() graph = AsyncMock() graph.intent_classifier = None graph.agent_registry = None graph.astream = MagicMock(return_value=AsyncIterHelper([])) # Simulate interrupt in state interrupt_obj = MagicMock() interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"} task = MagicMock() task.interrupts = (interrupt_obj,) state = MagicMock() state.tasks = (task,) graph.aget_state = AsyncMock(return_value=state) sm = SessionManager() cb = TokenUsageCallbackHandler() im = InterruptManager() sm.touch("t1") await handle_user_message( ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im, ) # Interrupt should be registered assert im.has_pending("t1") # Should have sent interrupt message calls = [c[0][0] for c in ws.send_json.call_args_list] interrupt_msgs = [c for c in calls if c.get("type") == "interrupt"] assert len(interrupt_msgs) == 1 @pytest.mark.asyncio async def test_ambiguous_intent_sends_clarification(self) -> None: from app.intent import ClassificationResult ws = _make_ws() graph = AsyncMock() graph.astream = MagicMock(return_value=AsyncIterHelper([])) state = MagicMock() state.tasks = () graph.aget_state = AsyncMock(return_value=state) # Set up intent classifier that returns ambiguous mock_classifier = AsyncMock() mock_classifier.classify = AsyncMock( return_value=ClassificationResult( intents=(), is_ambiguous=True, clarification_question="What do you mean?", ) ) graph.intent_classifier = mock_classifier mock_registry = MagicMock() mock_registry.list_agents = MagicMock(return_value=()) graph.agent_registry = mock_registry sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") await handle_user_message(ws, graph, sm, cb, "t1", "hmm") calls = [c[0][0] for c in ws.send_json.call_args_list] clarification_msgs = [c for c in calls if c.get("type") == "clarification"] assert len(clarification_msgs) == 1 assert clarification_msgs[0]["message"] == "What do you mean?" @pytest.mark.unit class TestHandleInterruptResponse: @pytest.mark.asyncio async def test_approved_interrupt(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") sm.extend_for_interrupt("t1") await handle_interrupt_response(ws, graph, sm, cb, "t1", True) last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @pytest.mark.asyncio async def test_expired_interrupt_sends_retry_prompt(self) -> None: from unittest.mock import patch ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() im = InterruptManager(ttl_seconds=5) sm.touch("t1") sm.extend_for_interrupt("t1") im.register("t1", "cancel_order", {"order_id": "1042"}) # Expire the interrupt with patch("app.interrupt_manager.time") as mock_time: mock_time.time.return_value = im._interrupts["t1"].created_at + 10 await handle_interrupt_response( ws, graph, sm, cb, "t1", True, interrupt_manager=im ) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "interrupt_expired" assert "cancel_order" in call_data["message"] @pytest.mark.asyncio async def test_valid_interrupt_resolves(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() im = InterruptManager(ttl_seconds=1800) sm.touch("t1") sm.extend_for_interrupt("t1") im.register("t1", "cancel_order", {}) await handle_interrupt_response( ws, graph, sm, cb, "t1", True, interrupt_manager=im ) # Interrupt should be resolved assert not im.has_pending("t1") last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @pytest.mark.unit class TestInterruptHelpers: def test_has_interrupt_false_for_empty_tasks(self) -> None: state = MagicMock() state.tasks = () assert not _has_interrupt(state) def test_has_interrupt_true(self) -> None: interrupt_obj = MagicMock() interrupt_obj.value = {"action": "cancel"} task = MagicMock() task.interrupts = (interrupt_obj,) state = MagicMock() state.tasks = (task,) assert _has_interrupt(state) def test_extract_interrupt_data(self) -> None: interrupt_obj = MagicMock() interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"} task = MagicMock() task.interrupts = (interrupt_obj,) state = MagicMock() state.tasks = (task,) data = _extract_interrupt(state) assert data["action"] == "cancel_order" def test_extract_interrupt_empty(self) -> None: state = MagicMock() state.tasks = () data = _extract_interrupt(state) assert data["action"] == "unknown" @pytest.mark.unit class TestDispatchMessageWithTracking: @pytest.mark.asyncio async def test_conversation_tracker_called_on_message(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() tracker = AsyncMock() pool = MagicMock() sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) await dispatch_message( ws, graph, sm, cb, msg, conversation_tracker=tracker, pool=pool, ) tracker.ensure_conversation.assert_awaited_once_with(pool, "t1") tracker.record_turn.assert_awaited_once() @pytest.mark.asyncio async def test_analytics_recorder_called_on_message(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() recorder = AsyncMock() pool = MagicMock() sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) await dispatch_message( ws, graph, sm, cb, msg, analytics_recorder=recorder, pool=pool, ) recorder.record.assert_awaited_once() @pytest.mark.asyncio async def test_tracker_failure_does_not_break_chat(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() tracker = AsyncMock() tracker.ensure_conversation.side_effect = RuntimeError("DB down") pool = MagicMock() sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) # Should not raise despite tracker failure await dispatch_message( ws, graph, sm, cb, msg, conversation_tracker=tracker, pool=pool, ) last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @pytest.mark.asyncio async def test_no_tracker_no_error(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) # No tracker or recorder passed -- should work fine await dispatch_message(ws, graph, sm, cb, msg) last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete"