Backend: - ConversationTracker: Protocol + PostgresConversationTracker for lifecycle tracking - Error handler: ErrorCategory enum, classify_error(), with_retry() exponential backoff - Wire PostgresAnalyticsRecorder + ConversationTracker into ws_handler - Rate limiting (10 msg/10s per thread), edge case hardening - Health endpoint GET /api/health, version 0.5.0 - Demo seed data script + sample OpenAPI spec Frontend (all new): - React Router with NavBar (Chat / Replay / Dashboard / Review) - ReplayListPage + ReplayPage with ReplayTimeline component - DashboardPage with MetricCard, range selector, zero-state - ReviewPage for OpenAPI classification review - ErrorBanner for WebSocket disconnect handling - API client (api.ts) with typed fetch wrappers Infrastructure: - Frontend Dockerfile (multi-stage node -> nginx) - nginx.conf with SPA routing + API/WS proxy - docker-compose.yml with frontend service + healthchecks - .env.example files (root + backend) Documentation: - README.md with quick start and architecture - Agent configuration guide - OpenAPI import guide - Deployment guide - Demo script 48 new tests, 449 total passing, 92.87% coverage
214 lines
6.8 KiB
Python
214 lines
6.8 KiB
Python
"""Edge case tests for ws_handler input validation and rate limiting."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.callbacks import TokenUsageCallbackHandler
|
|
from app.session_manager import SessionManager
|
|
from app.ws_handler import dispatch_message
|
|
|
|
pytestmark = pytest.mark.unit
|
|
|
|
|
|
def _make_ws() -> AsyncMock:
|
|
ws = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
return ws
|
|
|
|
|
|
def _make_graph() -> AsyncMock:
|
|
graph = AsyncMock()
|
|
|
|
class AsyncIterHelper:
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
raise StopAsyncIteration
|
|
|
|
graph.astream = MagicMock(return_value=AsyncIterHelper())
|
|
state = MagicMock()
|
|
state.tasks = ()
|
|
graph.aget_state = AsyncMock(return_value=state)
|
|
graph.intent_classifier = None
|
|
graph.agent_registry = None
|
|
return graph
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestEmptyMessageHandling:
|
|
@pytest.mark.asyncio
|
|
async def test_empty_message_content_returns_error(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
|
|
await dispatch_message(ws, graph, sm, cb, msg)
|
|
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
msg_lower = call_data["message"].lower()
|
|
assert "content" in msg_lower or "missing" in msg_lower
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_whitespace_only_message_treated_as_empty(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
|
|
await dispatch_message(ws, graph, sm, cb, msg)
|
|
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestOversizedMessageHandling:
|
|
@pytest.mark.asyncio
|
|
async def test_content_over_10000_chars_returns_error(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
content = "x" * 10001
|
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
|
await dispatch_message(ws, graph, sm, cb, msg)
|
|
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "too long" in call_data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
content = "x" * 10000
|
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
|
await dispatch_message(ws, graph, sm, cb, msg)
|
|
|
|
last_call = ws.send_json.call_args[0][0]
|
|
# Should be processed, not an error about length
|
|
msg_text = last_call.get("message", "").lower()
|
|
assert last_call["type"] != "error" or "too long" not in msg_text
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raw_message_over_32kb_returns_error(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
large_msg = "x" * 40_000
|
|
await dispatch_message(ws, graph, sm, cb, large_msg)
|
|
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "too large" in call_data["message"].lower()
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestInvalidJsonHandling:
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_json_returns_error(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
await dispatch_message(ws, graph, sm, cb, "not valid json {{")
|
|
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "invalid json" in call_data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_string_returns_json_error(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
await dispatch_message(ws, graph, sm, cb, "")
|
|
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_json_array_not_object_returns_error(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]')
|
|
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestRateLimiting:
|
|
@pytest.mark.asyncio
|
|
async def test_rapid_fire_messages_rate_limited(self) -> None:
|
|
ws = _make_ws()
|
|
_make_graph() # ensure graph factory works, not needed directly
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
|
|
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
|
|
rate_limit_triggered = False
|
|
for i in range(11):
|
|
graph2 = _make_graph() # fresh graph each time
|
|
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
|
"type": "message",
|
|
"thread_id": "t1",
|
|
"content": f"message {i}",
|
|
}))
|
|
last_call = ws.send_json.call_args[0][0]
|
|
if last_call["type"] == "error" and "rate" in last_call.get("message", "").lower():
|
|
rate_limit_triggered = True
|
|
break
|
|
|
|
assert rate_limit_triggered, "Rate limiting should trigger after 10 rapid messages"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_different_threads_have_separate_rate_limits(self) -> None:
|
|
ws = _make_ws()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
sm.touch("t2")
|
|
|
|
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
|
|
for i in range(5):
|
|
graph1 = _make_graph()
|
|
graph2 = _make_graph()
|
|
await dispatch_message(ws, graph1, sm, cb, json.dumps({
|
|
"type": "message", "thread_id": "t1", "content": f"msg {i}",
|
|
}))
|
|
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
|
"type": "message", "thread_id": "t2", "content": f"msg {i}",
|
|
}))
|
|
|
|
last_call = ws.send_json.call_args[0][0]
|
|
assert "rate" not in last_call.get("message", "").lower()
|