feat: complete phase 1 -- core framework with chat loop, agents, and React UI
Backend: - FastAPI WebSocket /ws endpoint with streaming via LangGraph astream - LangGraph Supervisor connecting 3 mock agents (order_lookup, order_actions, fallback) - YAML Agent Registry with Pydantic validation and immutable configs - PostgresSaver checkpoint persistence via langgraph-checkpoint-postgres - Session TTL with 30-min sliding window and interrupt extension - LLM provider abstraction (Anthropic/OpenAI/Google) - Token usage + cost tracking callback handler - Input validation: message size cap, thread_id format, content length - Security: no hardcoded defaults, startup API key validation, no input reflection Frontend: - React 19 + TypeScript + Vite chat UI - WebSocket hook with reconnect + exponential backoff - Streaming token display with agent attribution - Interrupt approval/reject UI for write operations - Collapsible tool call viewer Testing: - 87 unit tests, 87% coverage (exceeds 80% requirement) - Ruff lint + format clean Infrastructure: - Docker Compose (PostgreSQL 16 + backend) - pyproject.toml with full dependency management
This commit is contained in:
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
69
backend/tests/conftest.py
Normal file
69
backend/tests/conftest.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Shared test fixtures and marker registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from app.config import Settings
|
||||
from app.registry import AgentRegistry
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings() -> Settings:
|
||||
return Settings(
|
||||
database_url="postgresql://test:test@localhost:5432/test_db",
|
||||
llm_provider="anthropic",
|
||||
llm_model="claude-sonnet-4-6",
|
||||
anthropic_api_key="test-key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_yaml_path(tmp_path: Path) -> Path:
|
||||
data = {
|
||||
"agents": [
|
||||
{
|
||||
"name": "test_reader",
|
||||
"description": "A test read agent",
|
||||
"permission": "read",
|
||||
"tools": ["get_order_status"],
|
||||
},
|
||||
{
|
||||
"name": "test_writer",
|
||||
"description": "A test write agent",
|
||||
"permission": "write",
|
||||
"personality": {
|
||||
"tone": "formal",
|
||||
"greeting": "Greetings.",
|
||||
"escalation_message": "Escalating now.",
|
||||
},
|
||||
"tools": ["cancel_order"],
|
||||
},
|
||||
{
|
||||
"name": "test_fallback",
|
||||
"description": "A fallback agent",
|
||||
"permission": "read",
|
||||
"tools": ["fallback_respond"],
|
||||
},
|
||||
]
|
||||
}
|
||||
path = tmp_path / "test_agents.yaml"
|
||||
path.write_text(yaml.dump(data), encoding="utf-8")
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_registry(sample_yaml_path: Path) -> AgentRegistry:
|
||||
return AgentRegistry.load(sample_yaml_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_manager() -> SessionManager:
|
||||
return SessionManager(session_ttl_seconds=60)
|
||||
0
backend/tests/e2e/__init__.py
Normal file
0
backend/tests/e2e/__init__.py
Normal file
0
backend/tests/integration/__init__.py
Normal file
0
backend/tests/integration/__init__.py
Normal file
0
backend/tests/unit/__init__.py
Normal file
0
backend/tests/unit/__init__.py
Normal file
82
backend/tests/unit/test_agents.py
Normal file
82
backend/tests/unit/test_agents.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Tests for agent tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.fallback import fallback_respond
|
||||
from app.agents.order_lookup import get_order_status, get_tracking_info
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOrderLookup:
|
||||
def test_get_order_status_existing(self) -> None:
|
||||
result = get_order_status.invoke({"order_id": "1042"})
|
||||
assert result["order_id"] == "1042"
|
||||
assert result["status"] == "shipped"
|
||||
|
||||
def test_get_order_status_not_found(self) -> None:
|
||||
result = get_order_status.invoke({"order_id": "9999"})
|
||||
assert "error" in result
|
||||
assert "9999" in result["error"]
|
||||
|
||||
def test_get_tracking_info_existing(self) -> None:
|
||||
result = get_tracking_info.invoke({"order_id": "1042"})
|
||||
assert result["carrier"] == "FedEx"
|
||||
assert result["tracking_number"] == "FX-9876543210"
|
||||
|
||||
def test_get_tracking_info_not_found(self) -> None:
|
||||
result = get_tracking_info.invoke({"order_id": "1043"})
|
||||
assert "error" in result
|
||||
|
||||
def test_all_mock_orders_have_required_fields(self) -> None:
|
||||
from app.agents.order_lookup import MOCK_ORDERS
|
||||
|
||||
for oid, order in MOCK_ORDERS.items():
|
||||
assert "order_id" in order
|
||||
assert "status" in order
|
||||
assert order["order_id"] == oid
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFallback:
|
||||
def test_fallback_respond_returns_help(self) -> None:
|
||||
result = fallback_respond.invoke({"query": "random question"})
|
||||
assert "order" in result.lower()
|
||||
assert "help" in result.lower() or "can do" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOrderActions:
|
||||
def test_cancel_order_approved(self) -> None:
|
||||
with patch("app.agents.order_actions.interrupt", return_value=True):
|
||||
from app.agents.order_actions import cancel_order
|
||||
|
||||
result = cancel_order.invoke({"order_id": "1042"})
|
||||
assert result["status"] == "cancelled"
|
||||
assert "1042" in result["message"]
|
||||
|
||||
def test_cancel_order_rejected(self) -> None:
|
||||
with patch("app.agents.order_actions.interrupt", return_value=False):
|
||||
from app.agents.order_actions import cancel_order
|
||||
|
||||
result = cancel_order.invoke({"order_id": "1042"})
|
||||
assert result["status"] == "kept"
|
||||
assert "declined" in result["message"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolBridge:
|
||||
def test_get_tools_by_names(self) -> None:
|
||||
from app.agents import get_tools_by_names
|
||||
|
||||
tools = get_tools_by_names(["get_order_status", "cancel_order"])
|
||||
assert len(tools) == 2
|
||||
|
||||
def test_unknown_tool_raises(self) -> None:
|
||||
from app.agents import get_tools_by_names
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown tool"):
|
||||
get_tools_by_names(["nonexistent_tool"])
|
||||
102
backend/tests/unit/test_callbacks.py
Normal file
102
backend/tests/unit/test_callbacks.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Tests for app.callbacks module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTokenUsageCallbackHandler:
|
||||
def test_initial_state(self) -> None:
|
||||
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||
usage = handler.get_usage()
|
||||
assert usage.prompt_tokens == 0
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 0
|
||||
assert usage.total_cost_usd == 0.0
|
||||
|
||||
def test_accumulates_tokens(self) -> None:
|
||||
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||
|
||||
class FakeResult:
|
||||
llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}}
|
||||
|
||||
handler.on_llm_end(FakeResult())
|
||||
usage = handler.get_usage()
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
|
||||
def test_accumulates_across_calls(self) -> None:
|
||||
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||
|
||||
class FakeResult:
|
||||
llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}}
|
||||
|
||||
handler.on_llm_end(FakeResult())
|
||||
handler.on_llm_end(FakeResult())
|
||||
usage = handler.get_usage()
|
||||
assert usage.prompt_tokens == 200
|
||||
assert usage.completion_tokens == 100
|
||||
assert usage.total_tokens == 300
|
||||
|
||||
def test_cost_calculation(self) -> None:
|
||||
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||
|
||||
class FakeResult:
|
||||
llm_output = {"token_usage": {"prompt_tokens": 1000, "completion_tokens": 1000}}
|
||||
|
||||
handler.on_llm_end(FakeResult())
|
||||
usage = handler.get_usage()
|
||||
# claude-sonnet-4-6: prompt $0.003/1K, completion $0.015/1K
|
||||
expected_cost = 1000 * 0.003 / 1000 + 1000 * 0.015 / 1000
|
||||
assert usage.total_cost_usd == pytest.approx(expected_cost)
|
||||
|
||||
def test_reset(self) -> None:
|
||||
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
|
||||
|
||||
class FakeResult:
|
||||
llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}}
|
||||
|
||||
handler.on_llm_end(FakeResult())
|
||||
handler.reset()
|
||||
usage = handler.get_usage()
|
||||
assert usage.total_tokens == 0
|
||||
|
||||
def test_usage_is_immutable(self) -> None:
|
||||
handler = TokenUsageCallbackHandler()
|
||||
usage = handler.get_usage()
|
||||
with pytest.raises(Exception):
|
||||
usage.prompt_tokens = 999
|
||||
|
||||
def test_unknown_model_uses_default_cost(self) -> None:
|
||||
handler = TokenUsageCallbackHandler(model_name="unknown-model")
|
||||
|
||||
class FakeResult:
|
||||
llm_output = {"token_usage": {"prompt_tokens": 1000, "completion_tokens": 1000}}
|
||||
|
||||
handler.on_llm_end(FakeResult())
|
||||
usage = handler.get_usage()
|
||||
assert usage.total_cost_usd > 0
|
||||
|
||||
def test_handles_missing_token_usage(self) -> None:
|
||||
handler = TokenUsageCallbackHandler()
|
||||
|
||||
class FakeResult:
|
||||
llm_output = {}
|
||||
|
||||
handler.on_llm_end(FakeResult())
|
||||
usage = handler.get_usage()
|
||||
assert usage.total_tokens == 0
|
||||
|
||||
def test_handles_none_llm_output(self) -> None:
|
||||
handler = TokenUsageCallbackHandler()
|
||||
|
||||
class FakeResult:
|
||||
llm_output = None
|
||||
|
||||
handler.on_llm_end(FakeResult())
|
||||
usage = handler.get_usage()
|
||||
assert usage.total_tokens == 0
|
||||
60
backend/tests/unit/test_config.py
Normal file
60
backend/tests/unit/test_config.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Tests for app.config module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSettings:
|
||||
def test_default_values(self) -> None:
|
||||
settings = Settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
anthropic_api_key="key",
|
||||
)
|
||||
assert settings.llm_provider == "anthropic"
|
||||
assert settings.llm_model == "claude-sonnet-4-6"
|
||||
assert settings.session_ttl_minutes == 30
|
||||
assert settings.interrupt_ttl_minutes == 30
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
settings = Settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="openai",
|
||||
llm_model="gpt-4o",
|
||||
session_ttl_minutes=15,
|
||||
openai_api_key="sk-test",
|
||||
)
|
||||
assert settings.llm_provider == "openai"
|
||||
assert settings.llm_model == "gpt-4o"
|
||||
assert settings.session_ttl_minutes == 15
|
||||
|
||||
def test_invalid_provider_rejected(self) -> None:
|
||||
with pytest.raises(Exception):
|
||||
Settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="invalid",
|
||||
)
|
||||
|
||||
def test_missing_database_url_rejected(self) -> None:
|
||||
with pytest.raises(Exception):
|
||||
Settings(anthropic_api_key="key")
|
||||
|
||||
def test_empty_api_key_for_provider_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="API key"):
|
||||
Settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="anthropic",
|
||||
anthropic_api_key="",
|
||||
)
|
||||
|
||||
def test_wrong_provider_key_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="API key"):
|
||||
Settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="openai",
|
||||
anthropic_api_key="key",
|
||||
openai_api_key="",
|
||||
)
|
||||
64
backend/tests/unit/test_db.py
Normal file
64
backend/tests/unit/test_db.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Tests for app.db module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import _CONVERSATIONS_DDL, _INTERRUPTS_DDL
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDbModule:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pool_sets_correct_params(self) -> None:
|
||||
settings = Settings(
|
||||
database_url="postgresql://user:pass@localhost:5432/testdb",
|
||||
anthropic_api_key="key",
|
||||
)
|
||||
with patch("app.db.AsyncConnectionPool") as MockPool:
|
||||
mock_pool = AsyncMock()
|
||||
MockPool.return_value = mock_pool
|
||||
|
||||
from app.db import create_pool
|
||||
|
||||
await create_pool(settings)
|
||||
MockPool.assert_called_once()
|
||||
call_kwargs = MockPool.call_args
|
||||
assert "postgresql://user:pass@localhost:5432/testdb" in str(call_kwargs)
|
||||
mock_pool.open.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkpointer_calls_setup(self) -> None:
|
||||
mock_pool = AsyncMock()
|
||||
with patch("app.db.AsyncPostgresSaver") as MockSaver:
|
||||
mock_saver = AsyncMock()
|
||||
MockSaver.return_value = mock_saver
|
||||
|
||||
from app.db import create_checkpointer
|
||||
|
||||
await create_checkpointer(mock_pool)
|
||||
MockSaver.assert_called_once_with(conn=mock_pool)
|
||||
mock_saver.setup.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_app_tables_executes_ddl(self) -> None:
|
||||
mock_conn = AsyncMock()
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.connection.return_value = mock_ctx
|
||||
|
||||
from app.db import setup_app_tables
|
||||
|
||||
await setup_app_tables(mock_pool)
|
||||
assert mock_conn.execute.await_count == 2
|
||||
|
||||
def test_ddl_statements_valid(self) -> None:
|
||||
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
||||
assert "CREATE TABLE IF NOT EXISTS active_interrupts" in _INTERRUPTS_DDL
|
||||
assert "thread_id" in _CONVERSATIONS_DDL
|
||||
assert "interrupt_id" in _INTERRUPTS_DDL
|
||||
44
backend/tests/unit/test_graph.py
Normal file
44
backend/tests/unit/test_graph.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for app.graph module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.graph import SUPERVISOR_PROMPT, build_agent_nodes, build_graph
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.registry import AgentRegistry
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildAgentNodes:
|
||||
def test_creates_correct_number_of_nodes(self, sample_registry: AgentRegistry) -> None:
|
||||
mock_llm = MagicMock()
|
||||
nodes = build_agent_nodes(sample_registry, mock_llm)
|
||||
assert len(nodes) == 3
|
||||
|
||||
def test_nodes_are_runnable(self, sample_registry: AgentRegistry) -> None:
|
||||
mock_llm = MagicMock()
|
||||
nodes = build_agent_nodes(sample_registry, mock_llm)
|
||||
for node in nodes:
|
||||
assert hasattr(node, "invoke") or hasattr(node, "ainvoke")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildGraph:
|
||||
def test_graph_compiles_with_mock_checkpointer(self, sample_registry: AgentRegistry) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||
mock_checkpointer = AsyncMock()
|
||||
|
||||
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
|
||||
assert graph is not None
|
||||
|
||||
def test_supervisor_prompt_contains_routing_info(self) -> None:
|
||||
assert "order_lookup" in SUPERVISOR_PROMPT
|
||||
assert "order_actions" in SUPERVISOR_PROMPT
|
||||
assert "fallback" in SUPERVISOR_PROMPT
|
||||
41
backend/tests/unit/test_llm.py
Normal file
41
backend/tests/unit/test_llm.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Tests for app.llm module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import Settings
|
||||
from app.llm import create_llm
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCreateLlm:
|
||||
def test_anthropic_provider(self) -> None:
|
||||
settings = Settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="anthropic",
|
||||
llm_model="claude-sonnet-4-6",
|
||||
anthropic_api_key="test-key",
|
||||
)
|
||||
llm = create_llm(settings)
|
||||
assert type(llm).__name__ == "ChatAnthropic"
|
||||
|
||||
def test_openai_provider(self) -> None:
|
||||
settings = Settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="openai",
|
||||
llm_model="gpt-4o",
|
||||
openai_api_key="sk-test",
|
||||
)
|
||||
llm = create_llm(settings)
|
||||
assert type(llm).__name__ == "ChatOpenAI"
|
||||
|
||||
def test_google_provider(self) -> None:
|
||||
settings = Settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="google",
|
||||
llm_model="gemini-pro",
|
||||
google_api_key="test-key",
|
||||
)
|
||||
llm = create_llm(settings)
|
||||
assert type(llm).__name__ == "ChatGoogleGenerativeAI"
|
||||
27
backend/tests/unit/test_main.py
Normal file
27
backend/tests/unit/test_main.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Tests for app.main module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.main import AGENTS_YAML, FRONTEND_DIST, app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMainModule:
|
||||
def test_app_title(self) -> None:
|
||||
assert app.title == "Smart Support"
|
||||
|
||||
def test_app_version(self) -> None:
|
||||
assert app.version == "0.1.0"
|
||||
|
||||
def test_agents_yaml_path_exists(self) -> None:
|
||||
assert AGENTS_YAML.name == "agents.yaml"
|
||||
|
||||
def test_frontend_dist_path(self) -> None:
|
||||
assert "frontend" in str(FRONTEND_DIST)
|
||||
assert "dist" in str(FRONTEND_DIST)
|
||||
|
||||
def test_websocket_route_registered(self) -> None:
|
||||
routes = [r.path for r in app.routes if hasattr(r, "path")]
|
||||
assert "/ws" in routes
|
||||
147
backend/tests/unit/test_registry.py
Normal file
147
backend/tests/unit/test_registry.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Tests for app.registry module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from app.registry import AgentConfig, AgentRegistry, PersonalityConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPersonalityConfig:
|
||||
def test_defaults(self) -> None:
|
||||
p = PersonalityConfig()
|
||||
assert p.tone == "professional and helpful"
|
||||
assert "Hello" in p.greeting
|
||||
assert "human agent" in p.escalation_message
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
p = PersonalityConfig(tone="casual", greeting="Hey!", escalation_message="Hold on.")
|
||||
assert p.tone == "casual"
|
||||
|
||||
def test_immutable(self) -> None:
|
||||
p = PersonalityConfig()
|
||||
with pytest.raises(Exception):
|
||||
p.tone = "new tone"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentConfig:
|
||||
def test_valid_config(self) -> None:
|
||||
ac = AgentConfig(
|
||||
name="test",
|
||||
description="A test agent",
|
||||
permission="read",
|
||||
tools=["tool1"],
|
||||
)
|
||||
assert ac.name == "test"
|
||||
assert ac.permission == "read"
|
||||
|
||||
def test_empty_name_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="must not be empty"):
|
||||
AgentConfig(name=" ", description="d", permission="read", tools=["t"])
|
||||
|
||||
def test_empty_tools_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="at least one tool"):
|
||||
AgentConfig(name="x", description="d", permission="read", tools=[])
|
||||
|
||||
def test_invalid_permission_rejected(self) -> None:
|
||||
with pytest.raises(Exception):
|
||||
AgentConfig(name="x", description="d", permission="admin", tools=["t"])
|
||||
|
||||
def test_name_stripped(self) -> None:
|
||||
ac = AgentConfig(name=" test ", description="d", permission="read", tools=["t"])
|
||||
assert ac.name == "test"
|
||||
|
||||
def test_immutable(self) -> None:
|
||||
ac = AgentConfig(name="test", description="d", permission="read", tools=["t"])
|
||||
with pytest.raises(Exception):
|
||||
ac.name = "new"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentRegistry:
|
||||
def test_load_valid_yaml(self, sample_yaml_path: Path) -> None:
|
||||
registry = AgentRegistry.load(sample_yaml_path)
|
||||
assert len(registry) == 3
|
||||
assert registry.get_agent("test_reader").permission == "read"
|
||||
assert registry.get_agent("test_writer").permission == "write"
|
||||
|
||||
def test_list_agents(self, sample_registry: AgentRegistry) -> None:
|
||||
agents = sample_registry.list_agents()
|
||||
assert len(agents) == 3
|
||||
names = {a.name for a in agents}
|
||||
assert names == {"test_reader", "test_writer", "test_fallback"}
|
||||
|
||||
def test_get_agents_by_permission(self, sample_registry: AgentRegistry) -> None:
|
||||
readers = sample_registry.get_agents_by_permission("read")
|
||||
assert len(readers) == 2
|
||||
writers = sample_registry.get_agents_by_permission("write")
|
||||
assert len(writers) == 1
|
||||
|
||||
def test_get_nonexistent_agent(self, sample_registry: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError, match="not found"):
|
||||
sample_registry.get_agent("nonexistent")
|
||||
|
||||
def test_personality_defaults_applied(self, sample_registry: AgentRegistry) -> None:
|
||||
agent = sample_registry.get_agent("test_reader")
|
||||
assert agent.personality.tone == "professional and helpful"
|
||||
|
||||
def test_personality_custom_applied(self, sample_registry: AgentRegistry) -> None:
|
||||
agent = sample_registry.get_agent("test_writer")
|
||||
assert agent.personality.tone == "formal"
|
||||
assert agent.personality.greeting == "Greetings."
|
||||
|
||||
def test_file_not_found(self) -> None:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
AgentRegistry.load("/nonexistent/path.yaml")
|
||||
|
||||
def test_empty_file(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "empty.yaml"
|
||||
path.write_text("", encoding="utf-8")
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
AgentRegistry.load(path)
|
||||
|
||||
def test_invalid_yaml_syntax(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "bad.yaml"
|
||||
path.write_text("agents:\n - name: [invalid\n", encoding="utf-8")
|
||||
with pytest.raises(ValueError, match="Invalid YAML"):
|
||||
AgentRegistry.load(path)
|
||||
|
||||
def test_missing_agents_key(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "no_agents.yaml"
|
||||
path.write_text(yaml.dump({"items": []}), encoding="utf-8")
|
||||
with pytest.raises(ValueError, match="agents"):
|
||||
AgentRegistry.load(path)
|
||||
|
||||
def test_duplicate_agent_names(self, tmp_path: Path) -> None:
|
||||
data = {
|
||||
"agents": [
|
||||
{"name": "dup", "description": "a", "permission": "read", "tools": ["t1"]},
|
||||
{"name": "dup", "description": "b", "permission": "read", "tools": ["t2"]},
|
||||
]
|
||||
}
|
||||
path = tmp_path / "dups.yaml"
|
||||
path.write_text(yaml.dump(data), encoding="utf-8")
|
||||
with pytest.raises(ValueError, match="Duplicate"):
|
||||
AgentRegistry.load(path)
|
||||
|
||||
def test_missing_required_fields(self, tmp_path: Path) -> None:
|
||||
data = {"agents": [{"name": "x"}]}
|
||||
path = tmp_path / "missing.yaml"
|
||||
path.write_text(yaml.dump(data), encoding="utf-8")
|
||||
with pytest.raises(ValueError, match="Invalid agent config"):
|
||||
AgentRegistry.load(path)
|
||||
|
||||
def test_empty_agents_list(self, tmp_path: Path) -> None:
|
||||
data = {"agents": []}
|
||||
path = tmp_path / "empty_list.yaml"
|
||||
path.write_text(yaml.dump(data), encoding="utf-8")
|
||||
with pytest.raises(ValueError, match="non-empty"):
|
||||
AgentRegistry.load(path)
|
||||
70
backend/tests/unit/test_session_manager.py
Normal file
70
backend/tests/unit/test_session_manager.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Tests for app.session_manager module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSessionManager:
|
||||
def test_new_session_not_expired(self, session_manager: SessionManager) -> None:
|
||||
session_manager.touch("thread-1")
|
||||
assert not session_manager.is_expired("thread-1")
|
||||
|
||||
def test_unknown_session_is_expired(self, session_manager: SessionManager) -> None:
|
||||
assert session_manager.is_expired("unknown")
|
||||
|
||||
def test_session_expires_after_ttl(self) -> None:
|
||||
mgr = SessionManager(session_ttl_seconds=1)
|
||||
mgr.touch("t1")
|
||||
with patch("app.session_manager.time") as mock_time:
|
||||
mock_time.time.return_value = time.time() + 2
|
||||
assert mgr.is_expired("t1")
|
||||
|
||||
def test_touch_resets_ttl(self) -> None:
|
||||
mgr = SessionManager(session_ttl_seconds=5)
|
||||
mgr.touch("t1")
|
||||
initial_state = mgr.get_state("t1")
|
||||
# Touch again after some time
|
||||
with patch("app.session_manager.time") as mock_time:
|
||||
mock_time.time.return_value = time.time() + 3
|
||||
mgr.touch("t1")
|
||||
new_state = mgr.get_state("t1")
|
||||
assert new_state.last_activity > initial_state.last_activity
|
||||
|
||||
def test_interrupt_suspends_expiration(self) -> None:
|
||||
mgr = SessionManager(session_ttl_seconds=1)
|
||||
mgr.touch("t1")
|
||||
mgr.extend_for_interrupt("t1")
|
||||
with patch("app.session_manager.time") as mock_time:
|
||||
mock_time.time.return_value = time.time() + 100
|
||||
assert not mgr.is_expired("t1")
|
||||
|
||||
def test_resolve_interrupt_resumes_ttl(self) -> None:
|
||||
mgr = SessionManager(session_ttl_seconds=1)
|
||||
mgr.touch("t1")
|
||||
mgr.extend_for_interrupt("t1")
|
||||
mgr.resolve_interrupt("t1")
|
||||
state = mgr.get_state("t1")
|
||||
assert not state.has_pending_interrupt
|
||||
|
||||
def test_extend_for_nonexistent_creates_session(self) -> None:
|
||||
mgr = SessionManager()
|
||||
mgr.extend_for_interrupt("new-thread")
|
||||
state = mgr.get_state("new-thread")
|
||||
assert state is not None
|
||||
|
||||
def test_remove_session(self, session_manager: SessionManager) -> None:
|
||||
session_manager.touch("t1")
|
||||
session_manager.remove("t1")
|
||||
assert session_manager.get_state("t1") is None
|
||||
|
||||
def test_session_state_is_immutable(self, session_manager: SessionManager) -> None:
|
||||
state = session_manager.touch("t1")
|
||||
with pytest.raises(Exception):
|
||||
state.thread_id = "new"
|
||||
233
backend/tests/unit/test_ws_handler.py
Normal file
233
backend/tests/unit/test_ws_handler.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Tests for app.ws_handler module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_handler import (
|
||||
_extract_interrupt,
|
||||
_has_interrupt,
|
||||
dispatch_message,
|
||||
handle_interrupt_response,
|
||||
handle_user_message,
|
||||
)
|
||||
|
||||
|
||||
def _make_ws() -> AsyncMock:
|
||||
ws = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
def _make_graph() -> AsyncMock:
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
return graph
|
||||
|
||||
|
||||
class AsyncIterHelper:
|
||||
"""Helper to make a list behave as an async iterator."""
|
||||
|
||||
def __init__(self, items):
|
||||
self._items = items
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._items:
|
||||
raise StopAsyncIteration
|
||||
return self._items.pop(0)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDispatchMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
await dispatch_message(ws, graph, sm, cb, "not json")
|
||||
ws.send_json.assert_awaited_once()
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "Invalid JSON" in call_data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_thread_id(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "message", "content": "hello"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "thread_id" in call_data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_content(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_message_type(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "Unknown" in call_data["message"]
|
||||
# Verify raw input is NOT reflected back
|
||||
assert "unknown" not in call_data["message"].lower().replace("unknown message type", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_too_large(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
large_msg = "x" * 40_000
|
||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too large" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_thread_id_format(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "thread_id" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_too_long(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 9000})
|
||||
await dispatch_message(ws, graph, sm, cb, msg)
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too long" in call_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleUserMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager(session_ttl_seconds=0)
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "expired" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
# Should end with message_complete
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_error_sends_error_message(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = AsyncMock()
|
||||
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleInterruptResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_approved_interrupt(self) -> None:
|
||||
ws = _make_ws()
|
||||
graph = _make_graph()
|
||||
sm = SessionManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
|
||||
sm.touch("t1")
|
||||
sm.extend_for_interrupt("t1")
|
||||
await handle_interrupt_response(ws, graph, sm, cb, "t1", True)
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert last_call["type"] == "message_complete"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInterruptHelpers:
|
||||
def test_has_interrupt_false_for_empty_tasks(self) -> None:
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
assert not _has_interrupt(state)
|
||||
|
||||
def test_has_interrupt_true(self) -> None:
|
||||
interrupt_obj = MagicMock()
|
||||
interrupt_obj.value = {"action": "cancel"}
|
||||
task = MagicMock()
|
||||
task.interrupts = (interrupt_obj,)
|
||||
state = MagicMock()
|
||||
state.tasks = (task,)
|
||||
assert _has_interrupt(state)
|
||||
|
||||
def test_extract_interrupt_data(self) -> None:
|
||||
interrupt_obj = MagicMock()
|
||||
interrupt_obj.value = {"action": "cancel_order", "order_id": "1042"}
|
||||
task = MagicMock()
|
||||
task.interrupts = (interrupt_obj,)
|
||||
state = MagicMock()
|
||||
state.tasks = (task,)
|
||||
data = _extract_interrupt(state)
|
||||
assert data["action"] == "cancel_order"
|
||||
|
||||
def test_extract_interrupt_empty(self) -> None:
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
data = _extract_interrupt(state)
|
||||
assert data["action"] == "unknown"
|
||||
Reference in New Issue
Block a user