feat: complete phase 1 -- core framework with chat loop, agents, and React UI
Backend: - FastAPI WebSocket /ws endpoint with streaming via LangGraph astream - LangGraph Supervisor connecting 3 mock agents (order_lookup, order_actions, fallback) - YAML Agent Registry with Pydantic validation and immutable configs - PostgresSaver checkpoint persistence via langgraph-checkpoint-postgres - Session TTL with 30-min sliding window and interrupt extension - LLM provider abstraction (Anthropic/OpenAI/Google) - Token usage + cost tracking callback handler - Input validation: message size cap, thread_id format, content length - Security: no hardcoded defaults, startup API key validation, no input reflection Frontend: - React 19 + TypeScript + Vite chat UI - WebSocket hook with reconnect + exponential backoff - Streaming token display with agent attribution - Interrupt approval/reject UI for write operations - Collapsible tool call viewer Testing: - 87 unit tests, 87% coverage (exceeds 80% requirement) - Ruff lint + format clean Infrastructure: - Docker Compose (PostgreSQL 16 + backend) - pyproject.toml with full dependency management
This commit is contained in:
233
backend/tests/unit/test_ws_handler.py
Normal file
233
backend/tests/unit/test_ws_handler.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Tests for app.ws_handler module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_handler import (
|
||||
_extract_interrupt,
|
||||
_has_interrupt,
|
||||
dispatch_message,
|
||||
handle_interrupt_response,
|
||||
handle_user_message,
|
||||
)
|
||||
|
||||
|
||||
def _make_ws() -> AsyncMock:
|
||||
ws = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
def _make_graph() -> AsyncMock:
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
return graph
|
||||
|
||||
|
||||
class AsyncIterHelper:
|
||||
"""Helper to make a list behave as an async iterator."""
|
||||
|
||||
def __init__(self, items):
|
||||
self._items = items
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._items:
|
||||
raise StopAsyncIteration
|
||||
return self._items.pop(0)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, "not json")
|
||||
ws.send_json.assert_awaited_once()
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "Invalid JSON" in call_data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_thread_id(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "message", "content": "hello"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "thread_id" in call_data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_content(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_message_type(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "Unknown" in call_data["message"]
|
||||
# Verify raw input is NOT reflected back
|
||||
assert "unknown" not in call_data["message"].lower().replace("unknown message type", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_too_large(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
large_msg = "x" * 40_000
|
||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too large" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_thread_id_format(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "thread_id" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_too_long(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too long" in call_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleUserMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager(session_ttl_seconds=0)
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "expired" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
# Should end with message_complete
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_error_sends_error_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleInterruptResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_approved_interrupt(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
sm.extend_for_interrupt("t1")
|
||||
await handle_interrupt_response(ws, graph, sm, cb, "t1", True)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInterruptHelpers:
|
||||
def test_has_interrupt_false_for_empty_tasks(self) -> None:
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
assert not _has_interrupt(state)
|
||||
|
||||
def test_has_interrupt_true(self) -> None:
|
||||
interrupt_obj = MagicMock()
|
||||
interrupt_obj.value = {"action": "cancel"}
|
||||
task = MagicMock()
|
||||
task.interrupts = (interrupt_obj,)
|
||||
state = MagicMock()
|
||||
state.tasks = (task,)
|
||||
assert _has_interrupt(state)
|
||||
|
||||
def test_extract_interrupt_data(self) -> None:
|
||||
interrupt_obj = MagicMock()
|
||||
interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"}
|
||||
task = MagicMock()
|
||||
task.interrupts = (interrupt_obj,)
|
||||
state = MagicMock()
|
||||
state.tasks = (task,)
|
||||
data = _extract_interrupt(state)
|
||||
assert data["action"] == "cancel_order"
|
||||
|
||||
def test_extract_interrupt_empty(self) -> None:
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
data = _extract_interrupt(state)
|
||||
assert data["action"] == "unknown"
|
||||
Reference in New Issue
Block a user