refactor: fix architectural issues across frontend and backend
Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
This commit is contained in:
@@ -13,10 +13,12 @@ from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.analytics.api import router as analytics_router
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.openapi.review_api import _job_store, router as openapi_router
|
||||
from app.replay.api import router as replay_router
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
|
||||
@@ -74,8 +76,6 @@ def make_graph(
|
||||
) -> MagicMock:
|
||||
"""Build a mock LangGraph CompiledStateGraph."""
|
||||
g = MagicMock()
|
||||
g.intent_classifier = None
|
||||
g.agent_registry = None
|
||||
|
||||
if state is None:
|
||||
state = make_state()
|
||||
@@ -93,6 +93,14 @@ def make_graph(
|
||||
return g
|
||||
|
||||
|
||||
def make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||
"""Build a GraphContext wrapping a mock graph."""
|
||||
g = graph or make_graph()
|
||||
registry = MagicMock()
|
||||
registry.list_agents = MagicMock(return_value=())
|
||||
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake database pool
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -148,6 +156,7 @@ def create_e2e_app(
|
||||
) -> FastAPI:
|
||||
"""Create a FastAPI app wired with mocked dependencies for E2E testing."""
|
||||
g = graph or make_graph()
|
||||
graph_ctx = make_graph_ctx(g)
|
||||
p = pool or FakePool()
|
||||
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||
@@ -157,7 +166,7 @@ def create_e2e_app(
|
||||
app.include_router(replay_router)
|
||||
app.include_router(analytics_router)
|
||||
|
||||
app.state.graph = g
|
||||
app.state.graph_ctx = graph_ctx
|
||||
app.state.session_manager = sm
|
||||
app.state.interrupt_manager = im
|
||||
app.state.pool = p
|
||||
@@ -175,17 +184,16 @@ def create_e2e_app(
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_text()
|
||||
await dispatch_message(
|
||||
ws,
|
||||
app.state.graph,
|
||||
app.state.session_manager,
|
||||
TokenUsageCallbackHandler(model_name="test-model"),
|
||||
raw_data,
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=app.state.graph_ctx,
|
||||
session_manager=app.state.session_manager,
|
||||
callback_handler=TokenUsageCallbackHandler(model_name="test-model"),
|
||||
interrupt_manager=app.state.interrupt_manager,
|
||||
analytics_recorder=app.state.analytics_recorder,
|
||||
conversation_tracker=app.state.conversation_tracker,
|
||||
pool=app.state.pool,
|
||||
)
|
||||
await dispatch_message(ws, ws_ctx, raw_data)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
@@ -20,10 +20,12 @@ import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
|
||||
from app.graph_context import GraphContext
|
||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.registry import AgentConfig, AgentRegistry
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||
@@ -128,10 +130,8 @@ class TestCheckpoint1OrderQueryRouting:
|
||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||
))
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
|
||||
# Graph streams order_lookup response
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([
|
||||
@@ -140,14 +140,21 @@ class TestCheckpoint1OrderQueryRouting:
|
||||
]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
|
||||
@@ -201,25 +208,30 @@ class TestCheckpoint2MultiIntentSequential:
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
))
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||
})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
# Verify the graph was called with the routing hint in the message
|
||||
call_args = graph.astream.call_args
|
||||
@@ -267,21 +279,26 @@ class TestCheckpoint3AmbiguousClarification:
|
||||
"Could you please provide more details about what you need help with?"
|
||||
),
|
||||
))
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
|
||||
assert len(clarifications) == 1
|
||||
@@ -303,20 +320,26 @@ class TestCheckpoint4InterruptTTLAutoCancel:
|
||||
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
|
||||
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||
graph = MagicMock()
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=st)
|
||||
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager(ttl_seconds=1800) # 30 minutes
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
# Trigger interrupt
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
@@ -333,7 +356,7 @@ class TestCheckpoint4InterruptTTLAutoCancel:
|
||||
"thread_id": "t1",
|
||||
"approved": True,
|
||||
})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
# Should get retry prompt, NOT resume the graph
|
||||
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
|
||||
|
||||
@@ -18,10 +18,12 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.registry import AgentConfig
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -103,36 +105,45 @@ def _make_classifier(result: ClassificationResult) -> AsyncMock:
|
||||
return classifier
|
||||
|
||||
|
||||
def _make_graph(
|
||||
def _make_graph_and_ctx(
|
||||
classifier_result: ClassificationResult | None,
|
||||
chunks: list,
|
||||
state=None,
|
||||
) -> MagicMock:
|
||||
"""Build a graph mock with optional intent classifier."""
|
||||
) -> tuple[MagicMock, GraphContext]:
|
||||
"""Build a graph mock and GraphContext with optional intent classifier."""
|
||||
graph = MagicMock()
|
||||
|
||||
if classifier_result is not None:
|
||||
graph.intent_classifier = _make_classifier(classifier_result)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
||||
graph.agent_registry = mock_registry
|
||||
else:
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
|
||||
graph.aget_state = AsyncMock(return_value=state or _state())
|
||||
return graph
|
||||
|
||||
if classifier_result is not None:
|
||||
classifier = _make_classifier(classifier_result)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=classifier,
|
||||
)
|
||||
else:
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=None,
|
||||
)
|
||||
|
||||
return graph, graph_ctx
|
||||
|
||||
|
||||
async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]:
|
||||
async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
|
||||
sm = SessionManager()
|
||||
sm.touch(thread_id)
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
return ws.sent
|
||||
|
||||
|
||||
@@ -151,12 +162,12 @@ class TestSingleIntentRouting:
|
||||
agent_name="order_lookup", confidence=0.95, reasoning="status query",
|
||||
),),
|
||||
)
|
||||
graph = _make_graph(result, [
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
||||
_chunk("Order 1042 is shipped.", "order_lookup"),
|
||||
])
|
||||
|
||||
msgs = await _dispatch(graph, "What is the status of order 1042?")
|
||||
msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
|
||||
|
||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||
assert len(tools) == 1
|
||||
@@ -171,13 +182,13 @@ class TestSingleIntentRouting:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
|
||||
)
|
||||
graph = _make_graph(
|
||||
graph, graph_ctx = _make_graph_and_ctx(
|
||||
result,
|
||||
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
|
||||
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
|
||||
)
|
||||
|
||||
msgs = await _dispatch(graph, "Cancel order 1042")
|
||||
msgs = await _dispatch(graph_ctx, "Cancel order 1042")
|
||||
|
||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||
assert tools[0]["tool"] == "cancel_order"
|
||||
@@ -191,12 +202,12 @@ class TestSingleIntentRouting:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
|
||||
)
|
||||
graph = _make_graph(result, [
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
|
||||
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
|
||||
])
|
||||
|
||||
msgs = await _dispatch(graph, "Give me a 15% coupon")
|
||||
msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
|
||||
|
||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||
assert tools[0]["tool"] == "generate_coupon"
|
||||
@@ -207,11 +218,11 @@ class TestSingleIntentRouting:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
|
||||
)
|
||||
graph = _make_graph(result, [
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_chunk("I can help with order inquiries.", "fallback"),
|
||||
])
|
||||
|
||||
msgs = await _dispatch(graph, "What can you do?")
|
||||
msgs = await _dispatch(graph_ctx, "What can you do?")
|
||||
|
||||
tokens = [m for m in msgs if m["type"] == "token"]
|
||||
assert tokens[0]["agent"] == "fallback"
|
||||
@@ -233,7 +244,7 @@ class TestMultiIntentRouting:
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
)
|
||||
graph = _make_graph(result, [
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
|
||||
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
|
||||
])
|
||||
@@ -243,13 +254,17 @@ class TestMultiIntentRouting:
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||
})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
# Verify routing hint was injected
|
||||
call_args = graph.astream.call_args[0][0]
|
||||
@@ -269,16 +284,20 @@ class TestMultiIntentRouting:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||
)
|
||||
graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")])
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
msg_content = graph.astream.call_args[0][0]["messages"][0].content
|
||||
assert "[System:" not in msg_content
|
||||
@@ -299,9 +318,9 @@ class TestAmbiguityRouting:
|
||||
is_ambiguous=True,
|
||||
clarification_question="Could you please clarify what you need?",
|
||||
)
|
||||
graph = _make_graph(result, [])
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [])
|
||||
|
||||
msgs = await _dispatch(graph, "嗯...")
|
||||
msgs = await _dispatch(graph_ctx, "嗯...")
|
||||
|
||||
clarifications = [m for m in msgs if m["type"] == "clarification"]
|
||||
assert len(clarifications) == 1
|
||||
@@ -339,12 +358,12 @@ class TestNoClassifierFallback:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_classifier_routes_via_supervisor(self) -> None:
|
||||
graph = _make_graph(
|
||||
graph, graph_ctx = _make_graph_and_ctx(
|
||||
classifier_result=None,
|
||||
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
|
||||
)
|
||||
|
||||
msgs = await _dispatch(graph, "What is order 1042 status?")
|
||||
msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
|
||||
|
||||
tokens = [m for m in msgs if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
|
||||
@@ -15,8 +15,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -81,8 +83,6 @@ def _graph(
|
||||
resume_chunks: list | None = None,
|
||||
) -> MagicMock:
|
||||
g = MagicMock()
|
||||
g.intent_classifier = None
|
||||
g.agent_registry = None
|
||||
|
||||
if st is None:
|
||||
st = _state()
|
||||
@@ -100,6 +100,13 @@ def _graph(
|
||||
return g
|
||||
|
||||
|
||||
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||
g = graph or _graph()
|
||||
registry = MagicMock()
|
||||
registry.list_agents = MagicMock(return_value=())
|
||||
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||
|
||||
|
||||
def _setup(
|
||||
graph=None,
|
||||
session_ttl: int = 1800,
|
||||
@@ -109,23 +116,28 @@ def _setup(
|
||||
):
|
||||
"""Create test dependencies. Pre-touches session by default."""
|
||||
g = graph or _graph()
|
||||
graph_ctx = _make_graph_ctx(g)
|
||||
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
if touch:
|
||||
sm.touch(thread_id)
|
||||
return g, sm, im, cb, ws
|
||||
return g, sm, im, cb, ws, ws_ctx
|
||||
|
||||
|
||||
async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"):
|
||||
async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
|
||||
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
|
||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
|
||||
async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
|
||||
async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
|
||||
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
|
||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -136,10 +148,10 @@ async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
|
||||
class TestWebSocketHappyPath:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_receives_tokens_and_complete(self) -> None:
|
||||
g, sm, im, cb, ws = _setup(
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
|
||||
)
|
||||
await _send(ws, g, sm, im, cb, content="What is the status of order 1042?")
|
||||
await _send(ws, ws_ctx, content="What is the status of order 1042?")
|
||||
|
||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||
assert len(tokens) == 2
|
||||
@@ -153,13 +165,13 @@ class TestWebSocketHappyPath:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_streamed(self) -> None:
|
||||
g, sm, im, cb, ws = _setup(
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||
graph=_graph(chunks=[
|
||||
_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||
_chunk("Order shipped."),
|
||||
])
|
||||
)
|
||||
await _send(ws, g, sm, im, cb, content="Check order 1042")
|
||||
await _send(ws, ws_ctx, content="Check order 1042")
|
||||
|
||||
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||
assert len(tools) == 1
|
||||
@@ -168,9 +180,9 @@ class TestWebSocketHappyPath:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_messages_same_session(self) -> None:
|
||||
g, sm, im, cb, ws = _setup()
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
for i in range(3):
|
||||
await _send(ws, g, sm, im, cb, content=f"msg {i}")
|
||||
await _send(ws, ws_ctx, content=f"msg {i}")
|
||||
|
||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||
assert len(completes) == 3
|
||||
@@ -183,10 +195,10 @@ class TestWebSocketInterruptApproval:
|
||||
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
|
||||
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||
g_, sm, im, cb, ws = _setup(graph=g)
|
||||
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||
|
||||
# Send message -> triggers interrupt
|
||||
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
|
||||
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||
|
||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
@@ -196,7 +208,7 @@ class TestWebSocketInterruptApproval:
|
||||
|
||||
# Approve
|
||||
ws.sent.clear()
|
||||
await _respond(ws, g_, sm, im, cb, approved=True)
|
||||
await _respond(ws, ws_ctx, approved=True)
|
||||
|
||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
@@ -211,12 +223,12 @@ class TestWebSocketInterruptApproval:
|
||||
st_int = _state(interrupt=True)
|
||||
resume = [_chunk("Order remains active.", "order_actions")]
|
||||
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||
g_, sm, im, cb, ws = _setup(graph=g)
|
||||
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||
|
||||
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
|
||||
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||
ws.sent.clear()
|
||||
|
||||
await _respond(ws, g_, sm, im, cb, approved=False)
|
||||
await _respond(ws, ws_ctx, approved=False)
|
||||
|
||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||
assert "remains active" in tokens[0]["content"]
|
||||
@@ -226,28 +238,28 @@ class TestWebSocketInterruptApproval:
|
||||
class TestWebSocketSessionTTL:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session_returns_error(self) -> None:
|
||||
g, sm, im, cb, ws = _setup(session_ttl=0)
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
|
||||
# Session was touched in _setup, but TTL is 0 so it's already expired
|
||||
await _send(ws, g, sm, im, cb, content="hello")
|
||||
await _send(ws, ws_ctx, content="hello")
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "expired" in ws.sent[0]["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_not_expired(self) -> None:
|
||||
g, sm, im, cb, ws = _setup(session_ttl=3600)
|
||||
await _send(ws, g, sm, im, cb, content="hello")
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||
await _send(ws, ws_ctx, content="hello")
|
||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sliding_window_resets_on_message(self) -> None:
|
||||
g, sm, im, cb, ws = _setup(session_ttl=3600)
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||
|
||||
await _send(ws, g, sm, im, cb, content="hello")
|
||||
await _send(ws, ws_ctx, content="hello")
|
||||
first_activity = sm.get_state("t1").last_activity
|
||||
|
||||
time.sleep(0.01)
|
||||
await _send(ws, g, sm, im, cb, content="hello again")
|
||||
await _send(ws, ws_ctx, content="hello again")
|
||||
second_activity = sm.get_state("t1").last_activity
|
||||
|
||||
assert second_activity > first_activity
|
||||
@@ -256,9 +268,9 @@ class TestWebSocketSessionTTL:
|
||||
async def test_interrupt_extends_session_ttl(self) -> None:
|
||||
st_int = _state(interrupt=True)
|
||||
g = _graph(chunks=[], st=st_int)
|
||||
g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600)
|
||||
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
|
||||
|
||||
await _send(ws, g_, sm, im, cb, content="cancel order")
|
||||
await _send(ws, ws_ctx, content="cancel order")
|
||||
|
||||
state = sm.get_state("t1")
|
||||
assert state is not None
|
||||
@@ -270,53 +282,53 @@ class TestWebSocketSessionTTL:
|
||||
class TestWebSocketValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json(self) -> None:
|
||||
g, sm, im, cb, ws = _setup()
|
||||
await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im)
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
await dispatch_message(ws, ws_ctx, "not json")
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "Invalid JSON" in ws.sent[0]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_thread_id(self) -> None:
|
||||
g, sm, im, cb, ws = _setup()
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "message", "content": "hi"})
|
||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "thread_id" in ws.sent[0]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_thread_id_format(self) -> None:
|
||||
g, sm, im, cb, ws = _setup()
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
|
||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_content(self) -> None:
|
||||
g, sm, im, cb, ws = _setup()
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1"})
|
||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_message_type(self) -> None:
|
||||
g, sm, im, cb, ws = _setup()
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
|
||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "Unknown" in ws.sent[0]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_too_large(self) -> None:
|
||||
g, sm, im, cb, ws = _setup()
|
||||
await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im)
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
await dispatch_message(ws, ws_ctx, "x" * 40_000)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "too large" in ws.sent[0]["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_too_long(self) -> None:
|
||||
g, sm, im, cb, ws = _setup()
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "too long" in ws.sent[0]["message"].lower()
|
||||
|
||||
@@ -327,10 +339,10 @@ class TestWebSocketInterruptTTL:
|
||||
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
||||
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||
g = _graph(chunks=[], st=st_int)
|
||||
g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5)
|
||||
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
|
||||
|
||||
# Trigger interrupt
|
||||
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
|
||||
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||
|
||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
@@ -341,7 +353,7 @@ class TestWebSocketInterruptTTL:
|
||||
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = record.created_at + 10
|
||||
await _respond(ws, g_, sm, im, cb, approved=True)
|
||||
await _respond(ws, ws_ctx, approved=True)
|
||||
|
||||
assert ws.sent[0]["type"] == "interrupt_expired"
|
||||
assert "cancel_order" in ws.sent[0]["message"]
|
||||
|
||||
@@ -55,7 +55,7 @@ class TestDbModule:
|
||||
from app.db import setup_app_tables
|
||||
|
||||
await setup_app_tables(mock_pool)
|
||||
assert mock_conn.execute.await_count == 4
|
||||
assert mock_conn.execute.await_count == 5
|
||||
|
||||
def test_ddl_statements_valid(self) -> None:
|
||||
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
||||
|
||||
@@ -51,5 +51,5 @@ class TestAnalyticsEventsDDL:
|
||||
from app.db import setup_app_tables
|
||||
|
||||
await setup_app_tables(mock_pool)
|
||||
# Now expects 4 statements: conversations, interrupts, analytics_events, migrations
|
||||
assert mock_conn.execute.await_count == 4
|
||||
# Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations
|
||||
assert mock_conn.execute.await_count == 5
|
||||
|
||||
@@ -8,7 +8,9 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
@@ -20,7 +22,7 @@ def _make_ws() -> AsyncMock:
|
||||
return ws
|
||||
|
||||
|
||||
def _make_graph() -> AsyncMock:
|
||||
def _make_graph() -> MagicMock:
|
||||
graph = AsyncMock()
|
||||
|
||||
class AsyncIterHelper:
|
||||
@@ -34,23 +36,32 @@ def _make_graph() -> AsyncMock:
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
return graph
|
||||
|
||||
|
||||
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
|
||||
graph = _make_graph()
|
||||
registry = MagicMock()
|
||||
registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
|
||||
return WebSocketContext(
|
||||
graph_ctx=graph_ctx,
|
||||
session_manager=sm or SessionManager(),
|
||||
callback_handler=TokenUsageCallbackHandler(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEmptyMessageHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_message_content_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -60,13 +71,12 @@ class TestEmptyMessageHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_message_treated_as_empty(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -77,14 +87,13 @@ class TestOversizedMessageHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_over_10000_chars_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
content = "x" * 10001
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -93,14 +102,13 @@ class TestOversizedMessageHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
content = "x" * 10000
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
# Should be processed, not an error about length
|
||||
@@ -110,12 +118,10 @@ class TestOversizedMessageHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_message_over_32kb_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
large_msg = "x" * 40_000
|
||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
||||
await dispatch_message(ws, ws_ctx, large_msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -127,11 +133,9 @@ class TestInvalidJsonHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, "not valid json {{")
|
||||
await dispatch_message(ws, ws_ctx, "not valid json {{")
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -140,11 +144,9 @@ class TestInvalidJsonHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_returns_json_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, "")
|
||||
await dispatch_message(ws, ws_ctx, "")
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -152,11 +154,9 @@ class TestInvalidJsonHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_array_not_object_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]')
|
||||
await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -167,17 +167,15 @@ class TestRateLimiting:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_fire_messages_rate_limited(self) -> None:
|
||||
ws = _make_ws()
|
||||
_make_graph() # ensure graph factory works, not needed directly
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
|
||||
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
|
||||
rate_limit_triggered = False
|
||||
for i in range(11):
|
||||
graph2 = _make_graph() # fresh graph each time
|
||||
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
await dispatch_message(ws, ws_ctx, json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": f"message {i}",
|
||||
@@ -193,19 +191,18 @@ class TestRateLimiting:
|
||||
async def test_different_threads_have_separate_rate_limits(self) -> None:
|
||||
ws = _make_ws()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
sm.touch("t2")
|
||||
|
||||
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
|
||||
for i in range(5):
|
||||
graph1 = _make_graph()
|
||||
graph2 = _make_graph()
|
||||
await dispatch_message(ws, graph1, sm, cb, json.dumps({
|
||||
ws_ctx1 = _make_ws_ctx(sm=sm)
|
||||
ws_ctx2 = _make_ws_ctx(sm=sm)
|
||||
await dispatch_message(ws, ws_ctx1, json.dumps({
|
||||
"type": "message", "thread_id": "t1", "content": f"msg {i}",
|
||||
}))
|
||||
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
||||
await dispatch_message(ws, ws_ctx2, json.dumps({
|
||||
"type": "message", "thread_id": "t2", "content": f"msg {i}",
|
||||
}))
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.graph import build_agent_nodes, build_graph, classify_intent
|
||||
from app.graph import build_agent_nodes, build_graph
|
||||
from app.graph_context import GraphContext
|
||||
from app.intent import ClassificationResult, IntentTarget
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -37,8 +38,9 @@ class TestBuildGraph:
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
graph = build_graph(sample_registry, mock_llm, checkpointer)
|
||||
assert graph is not None
|
||||
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||
assert graph_ctx is not None
|
||||
assert graph_ctx.graph is not None
|
||||
|
||||
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
|
||||
mock_llm = MagicMock()
|
||||
@@ -47,11 +49,11 @@ class TestBuildGraph:
|
||||
checkpointer = InMemorySaver()
|
||||
mock_classifier = MagicMock()
|
||||
|
||||
graph = build_graph(
|
||||
graph_ctx = build_graph(
|
||||
sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
|
||||
)
|
||||
assert graph.intent_classifier is mock_classifier
|
||||
assert graph.agent_registry is sample_registry
|
||||
assert graph_ctx.intent_classifier is mock_classifier
|
||||
assert graph_ctx.registry is sample_registry
|
||||
|
||||
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
|
||||
mock_llm = MagicMock()
|
||||
@@ -59,17 +61,18 @@ class TestBuildGraph:
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
graph = build_graph(sample_registry, mock_llm, checkpointer)
|
||||
assert graph.intent_classifier is None
|
||||
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||
assert graph_ctx.intent_classifier is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestClassifyIntent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_without_classifier(self) -> None:
|
||||
graph = MagicMock()
|
||||
graph.intent_classifier = None
|
||||
result = await classify_intent(graph, "hello")
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None)
|
||||
result = await graph_ctx.classify_intent("hello")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -80,11 +83,12 @@ class TestClassifyIntent:
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(return_value=expected)
|
||||
|
||||
graph = MagicMock()
|
||||
graph.intent_classifier = mock_classifier
|
||||
graph.agent_registry = MagicMock()
|
||||
graph.agent_registry.list_agents = MagicMock(return_value=())
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(
|
||||
graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
|
||||
result = await classify_intent(graph, "check order")
|
||||
result = await graph_ctx.classify_intent("check order")
|
||||
assert result is not None
|
||||
assert result.intents[0].agent_name == "order_lookup"
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestMainModule:
|
||||
assert app.title == "Smart Support"
|
||||
|
||||
def test_app_version(self) -> None:
|
||||
assert app.version == "0.5.0"
|
||||
assert app.version == "0.6.0"
|
||||
|
||||
def test_agents_yaml_path_exists(self) -> None:
|
||||
assert AGENTS_YAML.name == "agents.yaml"
|
||||
@@ -39,4 +39,4 @@ class TestMainModule:
|
||||
assert "/api/health" in routes
|
||||
|
||||
def test_app_version_is_0_5_0(self) -> None:
|
||||
assert app.version == "0.5.0"
|
||||
assert app.version == "0.6.0"
|
||||
|
||||
@@ -8,8 +8,10 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import (
|
||||
_extract_interrupt,
|
||||
_has_interrupt,
|
||||
@@ -25,18 +27,42 @@ def _make_ws() -> AsyncMock:
|
||||
return ws
|
||||
|
||||
|
||||
def _make_graph() -> AsyncMock:
|
||||
def _make_graph() -> MagicMock:
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
# Phase 2: graph needs intent_classifier and agent_registry attrs
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
return graph
|
||||
|
||||
|
||||
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||
g = graph or _make_graph()
|
||||
registry = MagicMock()
|
||||
registry.list_agents = MagicMock(return_value=())
|
||||
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||
|
||||
|
||||
def _make_ws_ctx(
|
||||
graph_ctx: GraphContext | None = None,
|
||||
sm: SessionManager | None = None,
|
||||
cb: TokenUsageCallbackHandler | None = None,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
analytics_recorder=None,
|
||||
conversation_tracker=None,
|
||||
pool=None,
|
||||
) -> WebSocketContext:
|
||||
return WebSocketContext(
|
||||
graph_ctx=graph_ctx or _make_graph_ctx(),
|
||||
session_manager=sm or SessionManager(),
|
||||
callback_handler=cb or TokenUsageCallbackHandler(),
|
||||
interrupt_manager=interrupt_manager,
|
||||
analytics_recorder=analytics_recorder,
|
||||
conversation_tracker=conversation_tracker,
|
||||
pool=pool,
|
||||
)
|
||||
|
||||
|
||||
class AsyncIterHelper:
|
||||
"""Helper to make a list behave as an async iterator."""
|
||||
|
||||
@@ -57,11 +83,9 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, "not json")
|
||||
await dispatch_message(ws, ws_ctx, "not json")
|
||||
ws.send_json.assert_awaited_once()
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
@@ -70,12 +94,10 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_thread_id(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "message", "content": "hello"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "thread_id" in call_data["message"]
|
||||
@@ -83,24 +105,20 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_content(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_message_type(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "Unknown" in call_data["message"]
|
||||
@@ -108,12 +126,10 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_too_large(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
large_msg = "x" * 40_000
|
||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
||||
await dispatch_message(ws, ws_ctx, large_msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too large" in call_data["message"].lower()
|
||||
@@ -121,12 +137,10 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_thread_id_format(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "thread_id" in call_data["message"].lower()
|
||||
@@ -134,12 +148,10 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_too_long(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too long" in call_data["message"].lower()
|
||||
@@ -147,14 +159,13 @@ class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_with_interrupt_manager(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager()
|
||||
ws_ctx = _make_ws_ctx(sm=sm, interrupt_manager=im)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@@ -164,14 +175,14 @@ class TestHandleUserMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager(session_ttl_seconds=0)
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
# First call creates the session (TTL=0)
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||
# Second call finds it expired
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello again")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello again")
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "expired" in call_data["message"].lower()
|
||||
@@ -179,12 +190,12 @@ class TestHandleUserMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@@ -193,13 +204,12 @@ class TestHandleUserMessage:
|
||||
ws = _make_ws()
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
graph_ctx = _make_graph_ctx(graph=graph)
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
@@ -207,8 +217,6 @@ class TestHandleUserMessage:
|
||||
async def test_interrupt_registered_with_manager(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = AsyncMock()
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
|
||||
# Simulate interrupt in state
|
||||
@@ -220,13 +228,14 @@ class TestHandleUserMessage:
|
||||
state.tasks = (task,)
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
|
||||
graph_ctx = _make_graph_ctx(graph=graph)
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(
|
||||
ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
|
||||
ws, graph_ctx, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
|
||||
)
|
||||
|
||||
# Interrupt should be registered
|
||||
@@ -257,16 +266,17 @@ class TestHandleUserMessage:
|
||||
clarification_question="What do you mean?",
|
||||
)
|
||||
)
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hmm")
|
||||
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hmm")
|
||||
|
||||
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
||||
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
|
||||
@@ -279,13 +289,13 @@ class TestHandleInterruptResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_approved_interrupt(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
sm.extend_for_interrupt("t1")
|
||||
await handle_interrupt_response(ws, graph, sm, cb, "t1", True)
|
||||
await handle_interrupt_response(ws, graph_ctx, sm, cb, "t1", True)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@@ -294,7 +304,7 @@ class TestHandleInterruptResponse:
|
||||
from unittest.mock import patch
|
||||
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager(ttl_seconds=5)
|
||||
@@ -307,7 +317,7 @@ class TestHandleInterruptResponse:
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
|
||||
await handle_interrupt_response(
|
||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
||||
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
|
||||
)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
@@ -317,7 +327,7 @@ class TestHandleInterruptResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_interrupt_resolves(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
graph_ctx = _make_graph_ctx()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager(ttl_seconds=1800)
|
||||
@@ -327,7 +337,7 @@ class TestHandleInterruptResponse:
|
||||
im.register("t1", "cancel_order", {})
|
||||
|
||||
await handle_interrupt_response(
|
||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
||||
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
|
||||
)
|
||||
|
||||
# Interrupt should be resolved
|
||||
@@ -374,19 +384,14 @@ class TestDispatchMessageWithTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_tracker_called_on_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
tracker = AsyncMock()
|
||||
pool = MagicMock()
|
||||
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
await dispatch_message(
|
||||
ws, graph, sm, cb, msg,
|
||||
conversation_tracker=tracker,
|
||||
pool=pool,
|
||||
)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
|
||||
tracker.record_turn.assert_awaited_once()
|
||||
@@ -394,53 +399,42 @@ class TestDispatchMessageWithTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_analytics_recorder_called_on_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
recorder = AsyncMock()
|
||||
pool = MagicMock()
|
||||
ws_ctx = _make_ws_ctx(sm=sm, analytics_recorder=recorder, pool=pool)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
await dispatch_message(
|
||||
ws, graph, sm, cb, msg,
|
||||
analytics_recorder=recorder,
|
||||
pool=pool,
|
||||
)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
recorder.record.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracker_failure_does_not_break_chat(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
tracker = AsyncMock()
|
||||
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
|
||||
pool = MagicMock()
|
||||
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
# Should not raise despite tracker failure
|
||||
await dispatch_message(
|
||||
ws, graph, sm, cb, msg,
|
||||
conversation_tracker=tracker,
|
||||
pool=pool,
|
||||
)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tracker_no_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
# No tracker or recorder passed -- should work fine
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
Reference in New Issue
Block a user