"""Tests for async retry decorator. Written FIRST (TDD RED phase).""" import asyncio import pytest from release_agent.exceptions import ( NotFoundError, RateLimitError, ServiceError, ServiceUnavailableError, ) from release_agent.tools._retry import with_retry # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_failing_then_succeeding(failures: int, exc_factory, result="ok"): """Return an async callable that fails `failures` times then returns `result`.""" call_count = {"n": 0} async def fn(): call_count["n"] += 1 if call_count["n"] <= failures: raise exc_factory() return result return fn # --------------------------------------------------------------------------- # with_retry tests # --------------------------------------------------------------------------- class TestWithRetry: """Tests for the with_retry decorator.""" async def test_success_on_first_attempt(self) -> None: call_count = {"n": 0} @with_retry(max_attempts=3) async def fn(): call_count["n"] += 1 return "done" result = await fn() assert result == "done" assert call_count["n"] == 1 async def test_retries_on_rate_limit_error(self) -> None: call_count = {"n": 0} @with_retry(max_attempts=3, base_delay=0.0) async def fn(): call_count["n"] += 1 if call_count["n"] < 3: raise RateLimitError(service="jira", retry_after=None) return "ok" result = await fn() assert result == "ok" assert call_count["n"] == 3 async def test_retries_on_service_unavailable_error(self) -> None: call_count = {"n": 0} @with_retry(max_attempts=3, base_delay=0.0) async def fn(): call_count["n"] += 1 if call_count["n"] < 2: raise ServiceUnavailableError(service="azdo") return "ok" result = await fn() assert result == "ok" assert call_count["n"] == 2 async def test_does_not_retry_on_not_found_error(self) -> None: call_count = {"n": 0} @with_retry(max_attempts=3, base_delay=0.0) async def fn(): call_count["n"] += 1 raise NotFoundError(service="azdo", detail="not found") with pytest.raises(NotFoundError): await fn() assert call_count["n"] == 1 async def test_does_not_retry_on_generic_service_error(self) -> None: call_count = {"n": 0} @with_retry(max_attempts=3, base_delay=0.0) async def fn(): call_count["n"] += 1 raise ServiceError(service="azdo", status_code=400, detail="bad request") with pytest.raises(ServiceError): await fn() assert call_count["n"] == 1 async def test_raises_after_max_attempts_exceeded(self) -> None: call_count = {"n": 0} @with_retry(max_attempts=3, base_delay=0.0) async def fn(): call_count["n"] += 1 raise RateLimitError(service="jira", retry_after=None) with pytest.raises(RateLimitError): await fn() assert call_count["n"] == 3 async def test_max_attempts_one_means_no_retry(self) -> None: call_count = {"n": 0} @with_retry(max_attempts=1, base_delay=0.0) async def fn(): call_count["n"] += 1 raise RateLimitError(service="jira", retry_after=None) with pytest.raises(RateLimitError): await fn() assert call_count["n"] == 1 async def test_does_not_retry_on_non_release_agent_error(self) -> None: call_count = {"n": 0} @with_retry(max_attempts=3, base_delay=0.0) async def fn(): call_count["n"] += 1 raise ValueError("unexpected") with pytest.raises(ValueError): await fn() assert call_count["n"] == 1 async def test_respects_retry_after_from_rate_limit_error(self) -> None: """When retry_after is set, the decorator must wait at least that long.""" delays: list[float] = [] async def fake_sleep(seconds: float) -> None: delays.append(seconds) call_count = {"n": 0} @with_retry(max_attempts=2, base_delay=0.0, sleep_fn=fake_sleep) async def fn(): call_count["n"] += 1 if call_count["n"] < 2: raise RateLimitError(service="jira", retry_after=5) return "ok" result = await fn() assert result == "ok" assert len(delays) == 1 assert delays[0] >= 5.0 async def test_exponential_backoff_grows(self) -> None: """Verify delays grow between retries (exponential).""" delays: list[float] = [] async def fake_sleep(seconds: float) -> None: delays.append(seconds) call_count = {"n": 0} @with_retry(max_attempts=4, base_delay=1.0, sleep_fn=fake_sleep) async def fn(): call_count["n"] += 1 if call_count["n"] < 4: raise ServiceUnavailableError(service="azdo") return "ok" await fn() assert len(delays) == 3 # Each subsequent delay must not be less than the previous assert delays[1] >= delays[0] assert delays[2] >= delays[1] async def test_preserves_return_value(self) -> None: @with_retry(max_attempts=2, base_delay=0.0) async def fn(): return {"key": "value"} result = await fn() assert result == {"key": "value"} async def test_works_without_decorator_args_defaults(self) -> None: """Decorator used with defaults should still work.""" @with_retry() async def fn(): return 42 result = await fn() assert result == 42