"""Tests for app.safety module -- confirmation rules and MCP error taxonomy.""" from __future__ import annotations import pytest from app.safety import ( classify_mcp_error, is_retryable, max_retries, requires_confirmation, ) pytestmark = pytest.mark.unit class TestRequiresConfirmation: def test_read_agent_no_override(self) -> None: result = requires_confirmation(agent_permission="read") assert result.requires_confirmation is False def test_write_agent_no_override(self) -> None: result = requires_confirmation(agent_permission="write") assert result.requires_confirmation is True def test_interrupt_override_true(self) -> None: result = requires_confirmation( agent_permission="read", needs_interrupt=True, ) assert result.requires_confirmation is True def test_interrupt_override_false(self) -> None: result = requires_confirmation( agent_permission="write", needs_interrupt=False, ) assert result.requires_confirmation is False class TestClassifyMcpError: @pytest.mark.parametrize("code", [408, 429, 500, 502, 503, 504]) def test_transient_status_codes(self, code: int) -> None: assert classify_mcp_error(status_code=code) == "transient" @pytest.mark.parametrize("code", [401, 403]) def test_auth_status_codes(self, code: int) -> None: assert classify_mcp_error(status_code=code) == "auth" @pytest.mark.parametrize("code", [400, 404, 422]) def test_validation_status_codes(self, code: int) -> None: assert classify_mcp_error(status_code=code) == "validation" def test_unknown_status_code(self) -> None: assert classify_mcp_error(status_code=200) == "unknown" def test_timeout_message(self) -> None: assert classify_mcp_error(error_message="Connection timed out") == "transient" def test_rate_limit_message(self) -> None: assert classify_mcp_error(error_message="Rate limit exceeded") == "transient" def test_unauthorized_message(self) -> None: assert classify_mcp_error(error_message="Unauthorized access") == "auth" def test_invalid_message(self) -> None: assert classify_mcp_error(error_message="Invalid parameter") == "validation" def test_unknown_message(self) -> None: assert classify_mcp_error(error_message="Something happened") == "unknown" class TestRetryPolicy: def test_transient_is_retryable(self) -> None: assert is_retryable("transient") is True def test_validation_not_retryable(self) -> None: assert is_retryable("validation") is False def test_auth_not_retryable(self) -> None: assert is_retryable("auth") is False def test_unknown_not_retryable(self) -> None: assert is_retryable("unknown") is False def test_max_retries_value(self) -> None: assert max_retries() == 3