refactor: fix architectural issues across frontend and backend

Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
Yaojia Wang
2026-04-06 15:59:14 +02:00
parent b8654aa31f
commit af53111928
29 changed files with 1183 additions and 473 deletions

View File

@@ -55,7 +55,7 @@ class TestDbModule:
from app.db import setup_app_tables
await setup_app_tables(mock_pool)
assert mock_conn.execute.await_count == 4
assert mock_conn.execute.await_count == 5
def test_ddl_statements_valid(self) -> None:
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL

View File

@@ -51,5 +51,5 @@ class TestAnalyticsEventsDDL:
from app.db import setup_app_tables
await setup_app_tables(mock_pool)
# Now expects 4 statements: conversations, interrupts, analytics_events, migrations
assert mock_conn.execute.await_count == 4
# Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations
assert mock_conn.execute.await_count == 5

View File

@@ -8,7 +8,9 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
pytestmark = pytest.mark.unit
@@ -20,7 +22,7 @@ def _make_ws() -> AsyncMock:
return ws
def _make_graph() -> AsyncMock:
def _make_graph() -> MagicMock:
graph = AsyncMock()
class AsyncIterHelper:
@@ -34,23 +36,32 @@ def _make_graph() -> AsyncMock:
state = MagicMock()
state.tasks = ()
graph.aget_state = AsyncMock(return_value=state)
graph.intent_classifier = None
graph.agent_registry = None
return graph
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
graph = _make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
return WebSocketContext(
graph_ctx=graph_ctx,
session_manager=sm or SessionManager(),
callback_handler=TokenUsageCallbackHandler(),
)
@pytest.mark.unit
class TestEmptyMessageHandling:
@pytest.mark.asyncio
async def test_empty_message_content_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -60,13 +71,12 @@ class TestEmptyMessageHandling:
@pytest.mark.asyncio
async def test_whitespace_only_message_treated_as_empty(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -77,14 +87,13 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio
async def test_content_over_10000_chars_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
content = "x" * 10001
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -93,14 +102,13 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
content = "x" * 10000
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0]
# Should be processed, not an error about length
@@ -110,12 +118,10 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio
async def test_raw_message_over_32kb_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
large_msg = "x" * 40_000
await dispatch_message(ws, graph, sm, cb, large_msg)
await dispatch_message(ws, ws_ctx, large_msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -127,11 +133,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio
async def test_invalid_json_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, graph, sm, cb, "not valid json {{")
await dispatch_message(ws, ws_ctx, "not valid json {{")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -140,11 +144,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio
async def test_empty_string_returns_json_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, graph, sm, cb, "")
await dispatch_message(ws, ws_ctx, "")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -152,11 +154,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio
async def test_json_array_not_object_returns_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]')
await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -167,17 +167,15 @@ class TestRateLimiting:
@pytest.mark.asyncio
async def test_rapid_fire_messages_rate_limited(self) -> None:
ws = _make_ws()
_make_graph() # ensure graph factory works, not needed directly
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
rate_limit_triggered = False
for i in range(11):
graph2 = _make_graph() # fresh graph each time
await dispatch_message(ws, graph2, sm, cb, json.dumps({
ws_ctx = _make_ws_ctx(sm=sm)
await dispatch_message(ws, ws_ctx, json.dumps({
"type": "message",
"thread_id": "t1",
"content": f"message {i}",
@@ -193,19 +191,18 @@ class TestRateLimiting:
async def test_different_threads_have_separate_rate_limits(self) -> None:
ws = _make_ws()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
sm.touch("t2")
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
for i in range(5):
graph1 = _make_graph()
graph2 = _make_graph()
await dispatch_message(ws, graph1, sm, cb, json.dumps({
ws_ctx1 = _make_ws_ctx(sm=sm)
ws_ctx2 = _make_ws_ctx(sm=sm)
await dispatch_message(ws, ws_ctx1, json.dumps({
"type": "message", "thread_id": "t1", "content": f"msg {i}",
}))
await dispatch_message(ws, graph2, sm, cb, json.dumps({
await dispatch_message(ws, ws_ctx2, json.dumps({
"type": "message", "thread_id": "t2", "content": f"msg {i}",
}))

View File

@@ -8,7 +8,8 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from langgraph.checkpoint.memory import InMemorySaver
from app.graph import build_agent_nodes, build_graph, classify_intent
from app.graph import build_agent_nodes, build_graph
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget
if TYPE_CHECKING:
@@ -37,8 +38,9 @@ class TestBuildGraph:
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
checkpointer = InMemorySaver()
graph = build_graph(sample_registry, mock_llm, checkpointer)
assert graph is not None
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
assert graph_ctx is not None
assert graph_ctx.graph is not None
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock()
@@ -47,11 +49,11 @@ class TestBuildGraph:
checkpointer = InMemorySaver()
mock_classifier = MagicMock()
graph = build_graph(
graph_ctx = build_graph(
sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
)
assert graph.intent_classifier is mock_classifier
assert graph.agent_registry is sample_registry
assert graph_ctx.intent_classifier is mock_classifier
assert graph_ctx.registry is sample_registry
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock()
@@ -59,17 +61,18 @@ class TestBuildGraph:
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
checkpointer = InMemorySaver()
graph = build_graph(sample_registry, mock_llm, checkpointer)
assert graph.intent_classifier is None
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
assert graph_ctx.intent_classifier is None
@pytest.mark.unit
class TestClassifyIntent:
@pytest.mark.asyncio
async def test_returns_none_without_classifier(self) -> None:
graph = MagicMock()
graph.intent_classifier = None
result = await classify_intent(graph, "hello")
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None)
result = await graph_ctx.classify_intent("hello")
assert result is None
@pytest.mark.asyncio
@@ -80,11 +83,12 @@ class TestClassifyIntent:
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=expected)
graph = MagicMock()
graph.intent_classifier = mock_classifier
graph.agent_registry = MagicMock()
graph.agent_registry.list_agents = MagicMock(return_value=())
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(
graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier,
)
result = await classify_intent(graph, "check order")
result = await graph_ctx.classify_intent("check order")
assert result is not None
assert result.intents[0].agent_name == "order_lookup"

View File

@@ -13,7 +13,7 @@ class TestMainModule:
assert app.title == "Smart Support"
def test_app_version(self) -> None:
assert app.version == "0.5.0"
assert app.version == "0.6.0"
def test_agents_yaml_path_exists(self) -> None:
assert AGENTS_YAML.name == "agents.yaml"
@@ -39,4 +39,4 @@ class TestMainModule:
assert "/api/health" in routes
def test_app_version_is_0_5_0(self) -> None:
assert app.version == "0.5.0"
assert app.version == "0.6.0"

View File

@@ -8,8 +8,10 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import (
_extract_interrupt,
_has_interrupt,
@@ -25,18 +27,42 @@ def _make_ws() -> AsyncMock:
return ws
def _make_graph() -> AsyncMock:
def _make_graph() -> MagicMock:
graph = AsyncMock()
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
state = MagicMock()
state.tasks = ()
graph.aget_state = AsyncMock(return_value=state)
# Phase 2: graph needs intent_classifier and agent_registry attrs
graph.intent_classifier = None
graph.agent_registry = None
return graph
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
g = graph or _make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
def _make_ws_ctx(
graph_ctx: GraphContext | None = None,
sm: SessionManager | None = None,
cb: TokenUsageCallbackHandler | None = None,
interrupt_manager: InterruptManager | None = None,
analytics_recorder=None,
conversation_tracker=None,
pool=None,
) -> WebSocketContext:
return WebSocketContext(
graph_ctx=graph_ctx or _make_graph_ctx(),
session_manager=sm or SessionManager(),
callback_handler=cb or TokenUsageCallbackHandler(),
interrupt_manager=interrupt_manager,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
pool=pool,
)
class AsyncIterHelper:
"""Helper to make a list behave as an async iterator."""
@@ -57,11 +83,9 @@ class TestDispatchMessage:
@pytest.mark.asyncio
async def test_invalid_json(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, graph, sm, cb, "not json")
await dispatch_message(ws, ws_ctx, "not json")
ws.send_json.assert_awaited_once()
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -70,12 +94,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio
async def test_missing_thread_id(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "message", "content": "hello"})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "thread_id" in call_data["message"]
@@ -83,24 +105,20 @@ class TestDispatchMessage:
@pytest.mark.asyncio
async def test_missing_content(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "message", "thread_id": "t1"})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.asyncio
async def test_unknown_message_type(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "Unknown" in call_data["message"]
@@ -108,12 +126,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio
async def test_message_too_large(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
large_msg = "x" * 40_000
await dispatch_message(ws, graph, sm, cb, large_msg)
await dispatch_message(ws, ws_ctx, large_msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "too large" in call_data["message"].lower()
@@ -121,12 +137,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio
async def test_invalid_thread_id_format(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "thread_id" in call_data["message"].lower()
@@ -134,12 +148,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio
async def test_content_too_long(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx()
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "too long" in call_data["message"].lower()
@@ -147,14 +159,13 @@ class TestDispatchMessage:
@pytest.mark.asyncio
async def test_dispatch_with_interrupt_manager(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager()
ws_ctx = _make_ws_ctx(sm=sm, interrupt_manager=im)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im)
await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@@ -164,14 +175,14 @@ class TestHandleUserMessage:
@pytest.mark.asyncio
async def test_expired_session(self) -> None:
ws = _make_ws()
graph = _make_graph()
graph_ctx = _make_graph_ctx()
sm = SessionManager(session_ttl_seconds=0)
cb = TokenUsageCallbackHandler()
# First call creates the session (TTL=0)
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
# Second call finds it expired
await handle_user_message(ws, graph, sm, cb, "t1", "hello again")
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello again")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "expired" in call_data["message"].lower()
@@ -179,12 +190,12 @@ class TestHandleUserMessage:
@pytest.mark.asyncio
async def test_successful_message(self) -> None:
ws = _make_ws()
graph = _make_graph()
graph_ctx = _make_graph_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@@ -193,13 +204,12 @@ class TestHandleUserMessage:
ws = _make_ws()
graph = AsyncMock()
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
graph.intent_classifier = None
graph.agent_registry = None
graph_ctx = _make_graph_ctx(graph=graph)
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@@ -207,8 +217,6 @@ class TestHandleUserMessage:
async def test_interrupt_registered_with_manager(self) -> None:
ws = _make_ws()
graph = AsyncMock()
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
# Simulate interrupt in state
@@ -220,13 +228,14 @@ class TestHandleUserMessage:
state.tasks = (task,)
graph.aget_state = AsyncMock(return_value=state)
graph_ctx = _make_graph_ctx(graph=graph)
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager()
sm.touch("t1")
await handle_user_message(
ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
ws, graph_ctx, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
)
# Interrupt should be registered
@@ -257,16 +266,17 @@ class TestHandleUserMessage:
clarification_question="What do you mean?",
)
)
graph.intent_classifier = mock_classifier
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hmm")
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hmm")
calls = [c[0][0] for c in ws.send_json.call_args_list]
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
@@ -279,13 +289,13 @@ class TestHandleInterruptResponse:
@pytest.mark.asyncio
async def test_approved_interrupt(self) -> None:
ws = _make_ws()
graph = _make_graph()
graph_ctx = _make_graph_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1")
sm.extend_for_interrupt("t1")
await handle_interrupt_response(ws, graph, sm, cb, "t1", True)
await handle_interrupt_response(ws, graph_ctx, sm, cb, "t1", True)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@@ -294,7 +304,7 @@ class TestHandleInterruptResponse:
from unittest.mock import patch
ws = _make_ws()
graph = _make_graph()
graph_ctx = _make_graph_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=5)
@@ -307,7 +317,7 @@ class TestHandleInterruptResponse:
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
await handle_interrupt_response(
ws, graph, sm, cb, "t1", True, interrupt_manager=im
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
)
call_data = ws.send_json.call_args[0][0]
@@ -317,7 +327,7 @@ class TestHandleInterruptResponse:
@pytest.mark.asyncio
async def test_valid_interrupt_resolves(self) -> None:
ws = _make_ws()
graph = _make_graph()
graph_ctx = _make_graph_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=1800)
@@ -327,7 +337,7 @@ class TestHandleInterruptResponse:
im.register("t1", "cancel_order", {})
await handle_interrupt_response(
ws, graph, sm, cb, "t1", True, interrupt_manager=im
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
)
# Interrupt should be resolved
@@ -374,19 +384,14 @@ class TestDispatchMessageWithTracking:
@pytest.mark.asyncio
async def test_conversation_tracker_called_on_message(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
tracker = AsyncMock()
pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(
ws, graph, sm, cb, msg,
conversation_tracker=tracker,
pool=pool,
)
await dispatch_message(ws, ws_ctx, msg)
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
tracker.record_turn.assert_awaited_once()
@@ -394,53 +399,42 @@ class TestDispatchMessageWithTracking:
@pytest.mark.asyncio
async def test_analytics_recorder_called_on_message(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
recorder = AsyncMock()
pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, analytics_recorder=recorder, pool=pool)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(
ws, graph, sm, cb, msg,
analytics_recorder=recorder,
pool=pool,
)
await dispatch_message(ws, ws_ctx, msg)
recorder.record.assert_awaited_once()
@pytest.mark.asyncio
async def test_tracker_failure_does_not_break_chat(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
tracker = AsyncMock()
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# Should not raise despite tracker failure
await dispatch_message(
ws, graph, sm, cb, msg,
conversation_tracker=tracker,
pool=pool,
)
await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"
@pytest.mark.asyncio
async def test_no_tracker_no_error(self) -> None:
ws = _make_ws()
graph = _make_graph()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# No tracker or recorder passed -- should work fine
await dispatch_message(ws, graph, sm, cb, msg)
await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete"