feat: complete phase 2 -- multi-agent routing, interrupt TTL, escalation, templates
- Intent classification with LLM structured output (single/multi/ambiguous) - Discount agent with apply_discount and generate_coupon tools - Interrupt manager with 30-min TTL auto-expiration and retry prompts - Webhook escalation module with exponential backoff retry (max 3) - Three vertical industry templates (e-commerce, SaaS, fintech) - Template loading in AgentRegistry - Enhanced supervisor prompt with dynamic agent descriptions - 153 tests passing, 90.18% coverage
This commit is contained in:
79
backend/tests/unit/test_discount.py
Normal file
79
backend/tests/unit/test_discount.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for app.agents.discount module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.discount import apply_discount, generate_coupon
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestApplyDiscount:
|
||||
def test_invalid_discount_zero(self) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 0})
|
||||
assert result["status"] == "error"
|
||||
assert "Invalid" in result["message"]
|
||||
|
||||
def test_invalid_discount_over_100(self) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 101})
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_invalid_discount_negative(self) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": -5})
|
||||
assert result["status"] == "error"
|
||||
|
||||
@patch("app.agents.discount.interrupt", return_value=True)
|
||||
def test_approved_discount(self, mock_interrupt) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
|
||||
assert result["status"] == "applied"
|
||||
assert result["discount_percent"] == 10
|
||||
assert "1042" in result["message"]
|
||||
|
||||
@patch("app.agents.discount.interrupt", return_value=False)
|
||||
def test_rejected_discount(self, mock_interrupt) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
|
||||
assert result["status"] == "declined"
|
||||
|
||||
@patch("app.agents.discount.interrupt", return_value={"approved": True})
|
||||
def test_approved_via_dict(self, mock_interrupt) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
|
||||
assert result["status"] == "applied"
|
||||
|
||||
@patch("app.agents.discount.interrupt", return_value={"approved": False})
|
||||
def test_rejected_via_dict(self, mock_interrupt) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
|
||||
assert result["status"] == "declined"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateCoupon:
|
||||
def test_valid_coupon(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 15, "expiry_days": 7})
|
||||
assert result["status"] == "generated"
|
||||
assert result["discount_percent"] == 15
|
||||
assert result["expiry_days"] == 7
|
||||
assert result["coupon_code"].startswith("SAVE15-")
|
||||
|
||||
def test_default_expiry(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 20})
|
||||
assert result["status"] == "generated"
|
||||
assert result["expiry_days"] == 30
|
||||
|
||||
def test_invalid_discount_zero(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 0})
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_invalid_discount_over_100(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 101})
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_invalid_expiry(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 10, "expiry_days": 0})
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_coupon_codes_unique(self) -> None:
|
||||
r1 = generate_coupon.invoke({"discount_percent": 10})
|
||||
r2 = generate_coupon.invoke({"discount_percent": 10})
|
||||
assert r1["coupon_code"] != r2["coupon_code"]
|
||||
169
backend/tests/unit/test_escalation.py
Normal file
169
backend/tests/unit/test_escalation.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Tests for app.escalation module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.escalation import (
|
||||
EscalationPayload,
|
||||
EscalationResult,
|
||||
NoOpEscalator,
|
||||
WebhookEscalator,
|
||||
)
|
||||
|
||||
|
||||
def _make_payload(**kwargs) -> EscalationPayload:
|
||||
defaults = {
|
||||
"thread_id": "t1",
|
||||
"reason": "Agent cannot resolve",
|
||||
"conversation_summary": "User asked about refund policy",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return EscalationPayload(**defaults)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEscalationPayload:
|
||||
def test_frozen(self) -> None:
|
||||
payload = _make_payload()
|
||||
with pytest.raises(Exception):
|
||||
payload.thread_id = "t2" # type: ignore[misc]
|
||||
|
||||
def test_default_metadata(self) -> None:
|
||||
payload = _make_payload()
|
||||
assert payload.metadata == {}
|
||||
|
||||
def test_model_dump(self) -> None:
|
||||
payload = _make_payload(metadata={"key": "val"})
|
||||
data = payload.model_dump()
|
||||
assert data["thread_id"] == "t1"
|
||||
assert data["metadata"] == {"key": "val"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEscalationResult:
|
||||
def test_frozen(self) -> None:
|
||||
result = EscalationResult(success=True, status_code=200, attempts=1, error=None)
|
||||
assert result.success
|
||||
assert result.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWebhookEscalator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_url_returns_failure(self) -> None:
|
||||
escalator = WebhookEscalator(url="", max_retries=3)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
assert not result.success
|
||||
assert result.attempts == 0
|
||||
assert "not configured" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_post(self) -> None:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook")
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert result.success
|
||||
assert result.status_code == 200
|
||||
assert result.attempts == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_on_server_error(self) -> None:
|
||||
fail_response = AsyncMock()
|
||||
fail_response.status_code = 500
|
||||
success_response = AsyncMock()
|
||||
success_response.status_code = 200
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=[fail_response, fail_response, success_response])
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert result.success
|
||||
assert result.attempts == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_retries_exhausted(self) -> None:
|
||||
fail_response = AsyncMock()
|
||||
fail_response.status_code = 500
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=fail_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert not result.success
|
||||
assert result.attempts == 3
|
||||
assert "500" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_error(self) -> None:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("timeout"))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=2)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert not result.success
|
||||
assert "timed out" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_error(self) -> None:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(
|
||||
side_effect=httpx.RequestError("connection refused")
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=1)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert not result.success
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNoOpEscalator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_disabled(self) -> None:
|
||||
escalator = NoOpEscalator()
|
||||
result = await escalator.escalate(_make_payload())
|
||||
assert not result.success
|
||||
assert result.attempts == 0
|
||||
assert "disabled" in result.error.lower()
|
||||
@@ -7,7 +7,8 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.graph import SUPERVISOR_PROMPT, build_agent_nodes, build_graph
|
||||
from app.graph import build_agent_nodes, build_graph, classify_intent
|
||||
from app.intent import ClassificationResult, IntentTarget
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.registry import AgentRegistry
|
||||
@@ -38,7 +39,51 @@ class TestBuildGraph:
|
||||
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
|
||||
assert graph is not None
|
||||
|
||||
def test_supervisor_prompt_contains_routing_info(self) -> None:
|
||||
assert "order_lookup" in SUPERVISOR_PROMPT
|
||||
assert "order_actions" in SUPERVISOR_PROMPT
|
||||
assert "fallback" in SUPERVISOR_PROMPT
|
||||
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||
mock_checkpointer = AsyncMock()
|
||||
mock_classifier = MagicMock()
|
||||
|
||||
graph = build_graph(
|
||||
sample_registry, mock_llm, mock_checkpointer, intent_classifier=mock_classifier
|
||||
)
|
||||
assert graph.intent_classifier is mock_classifier
|
||||
assert graph.agent_registry is sample_registry
|
||||
|
||||
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||
mock_checkpointer = AsyncMock()
|
||||
|
||||
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
|
||||
assert graph.intent_classifier is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestClassifyIntent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_without_classifier(self) -> None:
|
||||
graph = MagicMock()
|
||||
graph.intent_classifier = None
|
||||
result = await classify_intent(graph, "hello")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_classifier(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="test"),),
|
||||
)
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(return_value=expected)
|
||||
|
||||
graph = MagicMock()
|
||||
graph.intent_classifier = mock_classifier
|
||||
graph.agent_registry = MagicMock()
|
||||
graph.agent_registry.list_agents = MagicMock(return_value=())
|
||||
|
||||
result = await classify_intent(graph, "check order")
|
||||
assert result is not None
|
||||
assert result.intents[0].agent_name == "order_lookup"
|
||||
|
||||
175
backend/tests/unit/test_intent.py
Normal file
175
backend/tests/unit/test_intent.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Tests for app.intent module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.intent import (
|
||||
AMBIGUITY_THRESHOLD,
|
||||
ClassificationResult,
|
||||
IntentTarget,
|
||||
LLMIntentClassifier,
|
||||
_build_agent_list,
|
||||
)
|
||||
from app.registry import AgentConfig
|
||||
|
||||
|
||||
def _make_agent(name: str, desc: str = "test", permission: str = "read") -> AgentConfig:
|
||||
return AgentConfig(name=name, description=desc, permission=permission, tools=["fallback_respond"])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestIntentModels:
|
||||
def test_intent_target_frozen(self) -> None:
|
||||
target = IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="order query")
|
||||
with pytest.raises(Exception):
|
||||
target.agent_name = "other" # type: ignore[misc]
|
||||
|
||||
def test_classification_result_frozen(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="a", confidence=0.9, reasoning="r"),),
|
||||
)
|
||||
assert not result.is_ambiguous
|
||||
assert result.clarification_question is None
|
||||
|
||||
def test_classification_result_ambiguous(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question="What do you mean?",
|
||||
)
|
||||
assert result.is_ambiguous
|
||||
|
||||
def test_multi_intent(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_actions", confidence=0.85, reasoning="cancel"),
|
||||
IntentTarget(agent_name="discount", confidence=0.8, reasoning="discount"),
|
||||
),
|
||||
)
|
||||
assert len(result.intents) == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildAgentList:
|
||||
def test_formats_agents(self) -> None:
|
||||
agents = (
|
||||
_make_agent("order_lookup", "Looks up orders", "read"),
|
||||
_make_agent("order_actions", "Modifies orders", "write"),
|
||||
)
|
||||
text = _build_agent_list(agents)
|
||||
assert "order_lookup" in text
|
||||
assert "order_actions" in text
|
||||
assert "read" in text
|
||||
assert "write" in text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMIntentClassifier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_intent_classification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
|
||||
|
||||
result = await classifier.classify("What is order 1042 status?", agents)
|
||||
assert len(result.intents) == 1
|
||||
assert result.intents[0].agent_name == "order_lookup"
|
||||
assert not result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_intent_classification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("order_actions"), _make_agent("discount"), _make_agent("fallback"))
|
||||
|
||||
result = await classifier.classify("Cancel order 1042 and give me 10% off", agents)
|
||||
assert len(result.intents) == 2
|
||||
assert not result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_classification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
|
||||
is_ambiguous=False,
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
|
||||
|
||||
result = await classifier.classify("hmm", agents)
|
||||
# Low confidence triggers ambiguity
|
||||
assert result.is_ambiguous
|
||||
assert result.clarification_question is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_error_returns_ambiguous(self) -> None:
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(side_effect=RuntimeError("LLM error"))
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("fallback"),)
|
||||
|
||||
result = await classifier.classify("test", agents)
|
||||
assert result.is_ambiguous
|
||||
assert result.clarification_question is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_result_type_returns_ambiguous(self) -> None:
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value="not a ClassificationResult")
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("fallback"),)
|
||||
|
||||
result = await classifier.classify("test", agents)
|
||||
assert result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_confidence_not_ambiguous(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(
|
||||
agent_name="order_lookup",
|
||||
confidence=AMBIGUITY_THRESHOLD + 0.1,
|
||||
reasoning="clear",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("order_lookup"),)
|
||||
|
||||
result = await classifier.classify("order status 1042", agents)
|
||||
assert not result.is_ambiguous
|
||||
132
backend/tests/unit/test_interrupt_manager.py
Normal file
132
backend/tests/unit/test_interrupt_manager.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for app.interrupt_manager module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.interrupt_manager import InterruptManager
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInterruptManagerRegister:
|
||||
def test_register_creates_record(self) -> None:
|
||||
mgr = InterruptManager(ttl_seconds=1800)
|
||||
record = mgr.register("t1", "cancel_order", {"order_id": "1042"})
|
||||
assert record.thread_id == "t1"
|
||||
assert record.action == "cancel_order"
|
||||
assert record.ttl_seconds == 1800
|
||||
assert record.interrupt_id
|
||||
|
||||
def test_register_overwrites_previous(self) -> None:
|
||||
mgr = InterruptManager()
|
||||
r1 = mgr.register("t1", "cancel_order", {})
|
||||
r2 = mgr.register("t1", "apply_discount", {})
|
||||
assert r1.interrupt_id != r2.interrupt_id
|
||||
status = mgr.check_status("t1")
|
||||
assert status is not None
|
||||
assert status.record.action == "apply_discount"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInterruptManagerCheckStatus:
|
||||
def test_no_interrupt_returns_none(self) -> None:
|
||||
mgr = InterruptManager()
|
||||
assert mgr.check_status("t1") is None
|
||||
|
||||
def test_fresh_interrupt_not_expired(self) -> None:
|
||||
mgr = InterruptManager(ttl_seconds=1800)
|
||||
mgr.register("t1", "cancel_order", {})
|
||||
status = mgr.check_status("t1")
|
||||
assert status is not None
|
||||
assert not status.is_expired
|
||||
assert status.remaining_seconds > 0
|
||||
|
||||
def test_expired_interrupt(self) -> None:
|
||||
mgr = InterruptManager(ttl_seconds=10)
|
||||
mgr.register("t1", "cancel_order", {})
|
||||
# Move time forward
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = mgr._interrupts["t1"].created_at + 11
|
||||
status = mgr.check_status("t1")
|
||||
assert status is not None
|
||||
assert status.is_expired
|
||||
assert status.remaining_seconds == 0.0
|
||||
|
||||
def test_boundary_not_expired(self) -> None:
|
||||
mgr = InterruptManager(ttl_seconds=10)
|
||||
mgr.register("t1", "cancel_order", {})
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = mgr._interrupts["t1"].created_at + 9
|
||||
status = mgr.check_status("t1")
|
||||
assert status is not None
|
||||
assert not status.is_expired
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInterruptManagerResolve:
|
||||
def test_resolve_removes_record(self) -> None:
|
||||
mgr = InterruptManager()
|
||||
mgr.register("t1", "cancel_order", {})
|
||||
mgr.resolve("t1")
|
||||
assert mgr.check_status("t1") is None
|
||||
|
||||
def test_resolve_nonexistent_is_safe(self) -> None:
|
||||
mgr = InterruptManager()
|
||||
mgr.resolve("nonexistent") # Should not raise
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInterruptManagerCleanup:
|
||||
def test_cleanup_removes_expired(self) -> None:
|
||||
mgr = InterruptManager(ttl_seconds=5)
|
||||
mgr.register("t1", "cancel_order", {})
|
||||
mgr.register("t2", "apply_discount", {})
|
||||
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = mgr._interrupts["t1"].created_at + 6
|
||||
expired = mgr.cleanup_expired()
|
||||
assert len(expired) == 2
|
||||
assert mgr.check_status("t1") is None
|
||||
assert mgr.check_status("t2") is None
|
||||
|
||||
def test_cleanup_keeps_active(self) -> None:
|
||||
mgr = InterruptManager(ttl_seconds=100)
|
||||
mgr.register("t1", "cancel_order", {})
|
||||
|
||||
expired = mgr.cleanup_expired()
|
||||
assert len(expired) == 0
|
||||
assert mgr.check_status("t1") is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInterruptManagerRetryPrompt:
|
||||
def test_generates_correct_prompt(self) -> None:
|
||||
mgr = InterruptManager(ttl_seconds=1800)
|
||||
record = mgr.register("t1", "cancel_order", {"order_id": "1042"})
|
||||
prompt = mgr.generate_retry_prompt(record)
|
||||
assert prompt["type"] == "interrupt_expired"
|
||||
assert prompt["thread_id"] == "t1"
|
||||
assert prompt["action"] == "cancel_order"
|
||||
assert "30 minutes" in prompt["message"]
|
||||
assert "cancel_order" in prompt["message"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInterruptManagerHasPending:
|
||||
def test_no_interrupt(self) -> None:
|
||||
mgr = InterruptManager()
|
||||
assert not mgr.has_pending("t1")
|
||||
|
||||
def test_has_active_interrupt(self) -> None:
|
||||
mgr = InterruptManager(ttl_seconds=1800)
|
||||
mgr.register("t1", "cancel_order", {})
|
||||
assert mgr.has_pending("t1")
|
||||
|
||||
def test_expired_interrupt_not_pending(self) -> None:
|
||||
mgr = InterruptManager(ttl_seconds=5)
|
||||
mgr.register("t1", "cancel_order", {})
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = mgr._interrupts["t1"].created_at + 6
|
||||
assert not mgr.has_pending("t1")
|
||||
@@ -13,7 +13,7 @@ class TestMainModule:
|
||||
assert app.title == "Smart Support"
|
||||
|
||||
def test_app_version(self) -> None:
|
||||
assert app.version == "0.1.0"
|
||||
assert app.version == "0.2.0"
|
||||
|
||||
def test_agents_yaml_path_exists(self) -> None:
|
||||
assert AGENTS_YAML.name == "agents.yaml"
|
||||
|
||||
70
backend/tests/unit/test_templates.py
Normal file
70
backend/tests/unit/test_templates.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Tests for template loading in app.registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.registry import AgentRegistry
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestListTemplates:
|
||||
def test_lists_all_templates(self) -> None:
|
||||
templates = AgentRegistry.list_templates(TEMPLATES_DIR)
|
||||
assert "e-commerce" in templates
|
||||
assert "saas" in templates
|
||||
assert "fintech" in templates
|
||||
|
||||
def test_returns_sorted(self) -> None:
|
||||
templates = AgentRegistry.list_templates(TEMPLATES_DIR)
|
||||
assert templates == tuple(sorted(templates))
|
||||
|
||||
def test_empty_dir_returns_empty(self, tmp_path: Path) -> None:
|
||||
templates = AgentRegistry.list_templates(tmp_path)
|
||||
assert templates == ()
|
||||
|
||||
def test_nonexistent_dir_returns_empty(self) -> None:
|
||||
templates = AgentRegistry.list_templates("/nonexistent/path")
|
||||
assert templates == ()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLoadTemplate:
|
||||
def test_load_ecommerce(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
assert len(registry) == 4
|
||||
agents = registry.list_agents()
|
||||
names = {a.name for a in agents}
|
||||
assert "order_lookup" in names
|
||||
assert "discount" in names
|
||||
assert "fallback" in names
|
||||
|
||||
def test_load_saas(self) -> None:
|
||||
registry = AgentRegistry.load_template("saas", TEMPLATES_DIR)
|
||||
assert len(registry) == 3
|
||||
agents = registry.list_agents()
|
||||
names = {a.name for a in agents}
|
||||
assert "account_lookup" in names
|
||||
assert "subscription_management" in names
|
||||
|
||||
def test_load_fintech(self) -> None:
|
||||
registry = AgentRegistry.load_template("fintech", TEMPLATES_DIR)
|
||||
assert len(registry) == 3
|
||||
agents = registry.list_agents()
|
||||
names = {a.name for a in agents}
|
||||
assert "transaction_lookup" in names
|
||||
assert "dispute_handler" in names
|
||||
|
||||
def test_nonexistent_template_raises(self) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="not found"):
|
||||
AgentRegistry.load_template("nonexistent", TEMPLATES_DIR)
|
||||
|
||||
def test_error_message_lists_available(self) -> None:
|
||||
try:
|
||||
AgentRegistry.load_template("nonexistent", TEMPLATES_DIR)
|
||||
except FileNotFoundError as exc:
|
||||
assert "e-commerce" in str(exc)
|
||||
@@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_handler import (
|
||||
_extract_interrupt,
|
||||
@@ -30,6 +31,9 @@ def _make_graph() -> AsyncMock:
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
# Phase 2: graph needs intent_classifier and agent_registry attrs
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
return graph
|
||||
|
||||
|
||||
@@ -100,8 +104,6 @@ class TestDispatchMessage:
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "Unknown" in call_data["message"]
|
||||
# Verify raw input is NOT reflected back
|
||||
assert "unknown" not in call_data["message"].lower().replace("unknown message type", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_too_large(self) -> None:
|
||||
@@ -142,6 +144,20 @@ class TestDispatchMessage:
|
||||
assert call_data["type"] == "error"
|
||||
assert "too long" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_with_interrupt_manager(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager()
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleUserMessage:
|
||||
@@ -166,7 +182,6 @@ class TestHandleUserMessage:
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
# Should end with message_complete
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@@ -175,6 +190,8 @@ class TestHandleUserMessage:
|
||||
ws = _make_ws()
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
@@ -183,6 +200,74 @@ class TestHandleUserMessage:
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interrupt_registered_with_manager(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = AsyncMock()
|
||||
graph.intent_classifier = None
|
||||
graph.agent_registry = None
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
|
||||
# Simulate interrupt in state
|
||||
interrupt_obj = MagicMock()
|
||||
interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"}
|
||||
task = MagicMock()
|
||||
task.interrupts = (interrupt_obj,)
|
||||
state = MagicMock()
|
||||
state.tasks = (task,)
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im)
|
||||
|
||||
# Interrupt should be registered
|
||||
assert im.has_pending("t1")
|
||||
|
||||
# Should have sent interrupt message
|
||||
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
||||
interrupt_msgs = [c for c in calls if c.get("type") == "interrupt"]
|
||||
assert len(interrupt_msgs) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_intent_sends_clarification(self) -> None:
|
||||
from app.intent import ClassificationResult
|
||||
|
||||
ws = _make_ws()
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
|
||||
# Set up intent classifier that returns ambiguous
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(
|
||||
return_value=ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question="What do you mean?",
|
||||
)
|
||||
)
|
||||
graph.intent_classifier = mock_classifier
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.agent_registry = mock_registry
|
||||
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hmm")
|
||||
|
||||
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
||||
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
|
||||
assert len(clarification_msgs) == 1
|
||||
assert clarification_msgs[0]["message"] == "What do you mean?"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleInterruptResponse:
|
||||
@@ -199,6 +284,52 @@ class TestHandleInterruptResponse:
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
||||
from unittest.mock import patch
|
||||
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager(ttl_seconds=5)
|
||||
|
||||
sm.touch("t1")
|
||||
sm.extend_for_interrupt("t1")
|
||||
im.register("t1", "cancel_order", {"order_id": "1042"})
|
||||
|
||||
# Expire the interrupt
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
|
||||
await handle_interrupt_response(
|
||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
||||
)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "interrupt_expired"
|
||||
assert "cancel_order" in call_data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_interrupt_resolves(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
im = InterruptManager(ttl_seconds=1800)
|
||||
|
||||
sm.touch("t1")
|
||||
sm.extend_for_interrupt("t1")
|
||||
im.register("t1", "cancel_order", {})
|
||||
|
||||
await handle_interrupt_response(
|
||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
||||
)
|
||||
|
||||
# Interrupt should be resolved
|
||||
assert not im.has_pending("t1")
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInterruptHelpers:
|
||||
|
||||
Reference in New Issue
Block a user