"""Edge case tests for ws_handler input validation and rate limiting.""" from __future__ import annotations import json from unittest.mock import AsyncMock, MagicMock import pytest from app.callbacks import TokenUsageCallbackHandler from app.graph_context import GraphContext from app.session_manager import SessionManager from app.ws_context import WebSocketContext from app.ws_handler import dispatch_message pytestmark = pytest.mark.unit def _make_ws() -> AsyncMock: ws = AsyncMock() ws.send_json = AsyncMock() return ws def _make_graph() -> MagicMock: graph = AsyncMock() class AsyncIterHelper: def __aiter__(self): return self async def __anext__(self): raise StopAsyncIteration graph.astream = MagicMock(return_value=AsyncIterHelper()) state = MagicMock() state.tasks = () graph.aget_state = AsyncMock(return_value=state) return graph def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext: graph = _make_graph() registry = MagicMock() registry.list_agents = MagicMock(return_value=()) graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None) return WebSocketContext( graph_ctx=graph_ctx, session_manager=sm or SessionManager(), callback_handler=TokenUsageCallbackHandler(), ) @pytest.mark.unit class TestEmptyMessageHandling: @pytest.mark.asyncio async def test_empty_message_content_returns_error(self) -> None: ws = _make_ws() sm = SessionManager() ws_ctx = _make_ws_ctx(sm=sm) sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""}) await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" msg_lower = call_data["message"].lower() assert "content" in msg_lower or "missing" in msg_lower @pytest.mark.asyncio async def test_whitespace_only_message_treated_as_empty(self) -> None: ws = _make_ws() sm = SessionManager() ws_ctx = _make_ws_ctx(sm=sm) sm.touch("t1") msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "}) await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @pytest.mark.unit class TestOversizedMessageHandling: @pytest.mark.asyncio async def test_content_over_10000_chars_returns_error(self) -> None: ws = _make_ws() sm = SessionManager() ws_ctx = _make_ws_ctx(sm=sm) sm.touch("t1") content = "x" * 10001 msg = json.dumps({"type": "message", "thread_id": "t1", "content": content}) await dispatch_message(ws, ws_ctx, msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "too long" in call_data["message"].lower() @pytest.mark.asyncio async def test_content_exactly_10000_chars_is_accepted(self) -> None: ws = _make_ws() sm = SessionManager() ws_ctx = _make_ws_ctx(sm=sm) sm.touch("t1") content = "x" * 10000 msg = json.dumps({"type": "message", "thread_id": "t1", "content": content}) await dispatch_message(ws, ws_ctx, msg) last_call = ws.send_json.call_args[0][0] # Should be processed, not an error about length msg_text = last_call.get("message", "").lower() assert last_call["type"] != "error" or "too long" not in msg_text @pytest.mark.asyncio async def test_raw_message_over_32kb_returns_error(self) -> None: ws = _make_ws() ws_ctx = _make_ws_ctx() large_msg = "x" * 40_000 await dispatch_message(ws, ws_ctx, large_msg) call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "too large" in call_data["message"].lower() @pytest.mark.unit class TestInvalidJsonHandling: @pytest.mark.asyncio async def test_invalid_json_returns_error(self) -> None: ws = _make_ws() ws_ctx = _make_ws_ctx() await dispatch_message(ws, ws_ctx, "not valid json {{") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" assert "invalid json" in call_data["message"].lower() @pytest.mark.asyncio async def test_empty_string_returns_json_error(self) -> None: ws = _make_ws() ws_ctx = _make_ws_ctx() await dispatch_message(ws, ws_ctx, "") call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @pytest.mark.asyncio async def test_json_array_not_object_returns_error(self) -> None: ws = _make_ws() ws_ctx = _make_ws_ctx() await dispatch_message(ws, ws_ctx, '["not", "an", "object"]') call_data = ws.send_json.call_args[0][0] assert call_data["type"] == "error" @pytest.mark.unit class TestRateLimiting: @pytest.mark.asyncio async def test_rapid_fire_messages_rate_limited(self) -> None: ws = _make_ws() sm = SessionManager() sm.touch("t1") # Simulate 11 rapid messages (exceeds 10 per 10 seconds limit) rate_limit_triggered = False for i in range(11): ws_ctx = _make_ws_ctx(sm=sm) await dispatch_message(ws, ws_ctx, json.dumps({ "type": "message", "thread_id": "t1", "content": f"message {i}", })) last_call = ws.send_json.call_args[0][0] if last_call["type"] == "error" and "rate" in last_call.get("message", "").lower(): rate_limit_triggered = True break assert rate_limit_triggered, "Rate limiting should trigger after 10 rapid messages" @pytest.mark.asyncio async def test_different_threads_have_separate_rate_limits(self) -> None: ws = _make_ws() sm = SessionManager() sm.touch("t1") sm.touch("t2") # Send 5 messages on t1 and 5 on t2 -- neither should be rate limited for i in range(5): ws_ctx1 = _make_ws_ctx(sm=sm) ws_ctx2 = _make_ws_ctx(sm=sm) await dispatch_message(ws, ws_ctx1, json.dumps({ "type": "message", "thread_id": "t1", "content": f"msg {i}", })) await dispatch_message(ws, ws_ctx2, json.dumps({ "type": "message", "thread_id": "t2", "content": f"msg {i}", })) last_call = ws.send_json.call_args[0][0] assert "rate" not in last_call.get("message", "").lower()