"""Tests for app.ws_handler module.""" from __future__ import annotations import json from unittest.mock import AsyncMock, MagicMock import pytest from app.callbacks import TokenUsageCallbackHandler from app.session_manager import SessionManager from app.ws_handler import ( _extract_interrupt, _has_interrupt, dispatch_message, handle_interrupt_response, handle_user_message, ) def _make_ws() -> AsyncMock: ws = AsyncMock() ws.send_json = AsyncMock() return ws def _make_graph() -> AsyncMock: graph = AsyncMock() graph.astream = MagicMock(return_value=AsyncIterHelper([])) state = MagicMock() state.tasks = () graph.aget_state = AsyncMock(return_value=state) return graph class AsyncIterHelper: """Helper to make a list behave as an async iterator.""" def __init__(self, items): self._items = items def __aiter__(self): return self async def __anext__(self): if not self._items: raise StopAsyncIteration return self._items.pop(0) @pytest.mark.unit class TestDispatchMessage: @pytest.mark.asyncio async def test_invalid_json(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() await dispatch_message(ws, graph, sm, cb, "not json") ws.send_json.assert_awaited_once() call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "Invalid JSON" in call_data["message"] @pytest.mark.asyncio async def test_missing_thread_id(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "message", "content": "hello"}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "thread_id" in call_data["message"] @pytest.mark.asyncio async def test_missing_content(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "message", "thread_id": "t1"}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @pytest.mark.asyncio async def test_unknown_message_type(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "unknown", "thread_id": "t1"}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "Unknown" in call_data["message"] # Verify raw input is NOT reflected back assert "unknown" not in call_data["message"].lower().replace("unknown message type", "") @pytest.mark.asyncio async def test_message_too_large(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() large_msg = "x" * 40_000 await dispatch_message(ws, graph, sm, cb, large_msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "too large" in call_data["message"].lower() @pytest.mark.asyncio async def test_invalid_thread_id_format(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "thread_id" in call_data["message"].lower() @pytest.mark.asyncio async def test_content_too_long(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000}) await dispatch_message(ws, graph, sm, cb, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "too long" in call_data["message"].lower() @pytest.mark.unit class TestHandleUserMessage: @pytest.mark.asyncio async def test_expired_session(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager(session_ttl_seconds=0) cb = TokenUsageCallbackHandler() await handle_user_message(ws, graph, sm, cb, "t1", "hello") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "expired" in call_data["message"].lower() @pytest.mark.asyncio async def test_successful_message(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") await handle_user_message(ws, graph, sm, cb, "t1", "hello") # Should end with message_complete last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @pytest.mark.asyncio async def test_graph_error_sends_error_message(self) -> None: ws = _make_ws() graph = AsyncMock() graph.astream = MagicMock(side_effect=RuntimeError("boom")) sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") await handle_user_message(ws, graph, sm, cb, "t1", "hello") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @pytest.mark.unit class TestHandleInterruptResponse: @pytest.mark.asyncio async def test_approved_interrupt(self) -> None: ws = _make_ws() graph = _make_graph() sm = SessionManager() cb = TokenUsageCallbackHandler() sm.touch("t1") sm.extend_for_interrupt("t1") await handle_interrupt_response(ws, graph, sm, cb, "t1", True) last_call = ws.send_json.call_args[0][0] assert last_call["type"] == "message_complete" @pytest.mark.unit class TestInterruptHelpers: def test_has_interrupt_false_for_empty_tasks(self) -> None: state = MagicMock() state.tasks = () assert not _has_interrupt(state) def test_has_interrupt_true(self) -> None: interrupt_obj = MagicMock() interrupt_obj.value = {"action": "cancel"} task = MagicMock() task.interrupts = (interrupt_obj,) state = MagicMock() state.tasks = (task,) assert _has_interrupt(state) def test_extract_interrupt_data(self) -> None: interrupt_obj = MagicMock() interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"} task = MagicMock() task.interrupts = (interrupt_obj,) state = MagicMock() state.tasks = (task,) data = _extract_interrupt(state) assert data["action"] == "cancel_order" def test_extract_interrupt_empty(self) -> None: state = MagicMock() state.tasks = () data = _extract_interrupt(state) assert data["action"] == "unknown"