- Intent classification with LLM structured output (single/multi/ambiguous) - Discount agent with apply_discount and generate_coupon tools - Interrupt manager with 30-min TTL auto-expiration and retry prompts - Webhook escalation module with exponential backoff retry (max 3) - Three vertical industry templates (e-commerce, SaaS, fintech) - Template loading in AgentRegistry - Enhanced supervisor prompt with dynamic agent descriptions - 153 tests passing, 90.18% coverage
365 lines
12 KiB
Python
365 lines
12 KiB
Python
"""Tests for app.ws_handler module."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.callbacks import TokenUsageCallbackHandler
|
|
from app.interrupt_manager import InterruptManager
|
|
from app.session_manager import SessionManager
|
|
from app.ws_handler import (
|
|
_extract_interrupt,
|
|
_has_interrupt,
|
|
dispatch_message,
|
|
handle_interrupt_response,
|
|
handle_user_message,
|
|
)
|
|
|
|
|
|
def _make_ws() -> AsyncMock:
|
|
ws = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
return ws
|
|
|
|
|
|
def _make_graph() -> AsyncMock:
|
|
graph = AsyncMock()
|
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
|
state = MagicMock()
|
|
state.tasks = ()
|
|
graph.aget_state = AsyncMock(return_value=state)
|
|
# Phase 2: graph needs intent_classifier and agent_registry attrs
|
|
graph.intent_classifier = None
|
|
graph.agent_registry = None
|
|
return graph
|
|
|
|
|
|
class AsyncIterHelper:
|
|
"""Helper to make a list behave as an async iterator."""
|
|
|
|
def __init__(self, items):
|
|
self._items = items
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if not self._items:
|
|
raise StopAsyncIteration
|
|
return self._items.pop(0)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestDispatchMessage:
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_json(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
await dispatch_message(ws, graph, sm, cb, "not json")
|
|
ws.send_json.assert_awaited_once()
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "Invalid JSON" in call_data["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_thread_id(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
msg = json.dumps({"type": "message", "content": "hello"})
|
|
await dispatch_message(ws, graph, sm, cb, msg)
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "thread_id" in call_data["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_content(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
|
await dispatch_message(ws, graph, sm, cb, msg)
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unknown_message_type(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
|
await dispatch_message(ws, graph, sm, cb, msg)
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "Unknown" in call_data["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_too_large(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
large_msg = "x" * 40_000
|
|
await dispatch_message(ws, graph, sm, cb, large_msg)
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "too large" in call_data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_thread_id_format(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
|
await dispatch_message(ws, graph, sm, cb, msg)
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "thread_id" in call_data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_content_too_long(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000})
|
|
await dispatch_message(ws, graph, sm, cb, msg)
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "too long" in call_data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dispatch_with_interrupt_manager(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
im = InterruptManager()
|
|
|
|
sm.touch("t1")
|
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
|
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im)
|
|
last_call = ws.send_json.call_args[0][0]
|
|
assert last_call["type"] == "message_complete"
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestHandleUserMessage:
|
|
@pytest.mark.asyncio
|
|
async def test_expired_session(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager(session_ttl_seconds=0)
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
assert "expired" in call_data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_message(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
|
last_call = ws.send_json.call_args[0][0]
|
|
assert last_call["type"] == "message_complete"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_graph_error_sends_error_message(self) -> None:
|
|
ws = _make_ws()
|
|
graph = AsyncMock()
|
|
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
|
graph.intent_classifier = None
|
|
graph.agent_registry = None
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_interrupt_registered_with_manager(self) -> None:
|
|
ws = _make_ws()
|
|
graph = AsyncMock()
|
|
graph.intent_classifier = None
|
|
graph.agent_registry = None
|
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
|
|
|
# Simulate interrupt in state
|
|
interrupt_obj = MagicMock()
|
|
interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"}
|
|
task = MagicMock()
|
|
task.interrupts = (interrupt_obj,)
|
|
state = MagicMock()
|
|
state.tasks = (task,)
|
|
graph.aget_state = AsyncMock(return_value=state)
|
|
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
im = InterruptManager()
|
|
|
|
sm.touch("t1")
|
|
await handle_user_message(ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im)
|
|
|
|
# Interrupt should be registered
|
|
assert im.has_pending("t1")
|
|
|
|
# Should have sent interrupt message
|
|
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
|
interrupt_msgs = [c for c in calls if c.get("type") == "interrupt"]
|
|
assert len(interrupt_msgs) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ambiguous_intent_sends_clarification(self) -> None:
|
|
from app.intent import ClassificationResult
|
|
|
|
ws = _make_ws()
|
|
graph = AsyncMock()
|
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
|
state = MagicMock()
|
|
state.tasks = ()
|
|
graph.aget_state = AsyncMock(return_value=state)
|
|
|
|
# Set up intent classifier that returns ambiguous
|
|
mock_classifier = AsyncMock()
|
|
mock_classifier.classify = AsyncMock(
|
|
return_value=ClassificationResult(
|
|
intents=(),
|
|
is_ambiguous=True,
|
|
clarification_question="What do you mean?",
|
|
)
|
|
)
|
|
graph.intent_classifier = mock_classifier
|
|
mock_registry = MagicMock()
|
|
mock_registry.list_agents = MagicMock(return_value=())
|
|
graph.agent_registry = mock_registry
|
|
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
await handle_user_message(ws, graph, sm, cb, "t1", "hmm")
|
|
|
|
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
|
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
|
|
assert len(clarification_msgs) == 1
|
|
assert clarification_msgs[0]["message"] == "What do you mean?"
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestHandleInterruptResponse:
|
|
@pytest.mark.asyncio
|
|
async def test_approved_interrupt(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
|
|
sm.touch("t1")
|
|
sm.extend_for_interrupt("t1")
|
|
await handle_interrupt_response(ws, graph, sm, cb, "t1", True)
|
|
last_call = ws.send_json.call_args[0][0]
|
|
assert last_call["type"] == "message_complete"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
|
from unittest.mock import patch
|
|
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
im = InterruptManager(ttl_seconds=5)
|
|
|
|
sm.touch("t1")
|
|
sm.extend_for_interrupt("t1")
|
|
im.register("t1", "cancel_order", {"order_id": "1042"})
|
|
|
|
# Expire the interrupt
|
|
with patch("app.interrupt_manager.time") as mock_time:
|
|
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
|
|
await handle_interrupt_response(
|
|
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
|
)
|
|
|
|
call_data = ws.send_json.call_args[0][0]
|
|
assert call_data["type"] == "interrupt_expired"
|
|
assert "cancel_order" in call_data["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_valid_interrupt_resolves(self) -> None:
|
|
ws = _make_ws()
|
|
graph = _make_graph()
|
|
sm = SessionManager()
|
|
cb = TokenUsageCallbackHandler()
|
|
im = InterruptManager(ttl_seconds=1800)
|
|
|
|
sm.touch("t1")
|
|
sm.extend_for_interrupt("t1")
|
|
im.register("t1", "cancel_order", {})
|
|
|
|
await handle_interrupt_response(
|
|
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
|
)
|
|
|
|
# Interrupt should be resolved
|
|
assert not im.has_pending("t1")
|
|
last_call = ws.send_json.call_args[0][0]
|
|
assert last_call["type"] == "message_complete"
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestInterruptHelpers:
|
|
def test_has_interrupt_false_for_empty_tasks(self) -> None:
|
|
state = MagicMock()
|
|
state.tasks = ()
|
|
assert not _has_interrupt(state)
|
|
|
|
def test_has_interrupt_true(self) -> None:
|
|
interrupt_obj = MagicMock()
|
|
interrupt_obj.value = {"action": "cancel"}
|
|
task = MagicMock()
|
|
task.interrupts = (interrupt_obj,)
|
|
state = MagicMock()
|
|
state.tasks = (task,)
|
|
assert _has_interrupt(state)
|
|
|
|
def test_extract_interrupt_data(self) -> None:
|
|
interrupt_obj = MagicMock()
|
|
interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"}
|
|
task = MagicMock()
|
|
task.interrupts = (interrupt_obj,)
|
|
state = MagicMock()
|
|
state.tasks = (task,)
|
|
data = _extract_interrupt(state)
|
|
assert data["action"] == "cancel_order"
|
|
|
|
def test_extract_interrupt_empty(self) -> None:
|
|
state = MagicMock()
|
|
state.tasks = ()
|
|
data = _extract_interrupt(state)
|
|
assert data["action"] == "unknown"
|