"""Tests for graph/polling.py — poll_until async utility. Written FIRST (TDD RED phase). All tests inject a fake_sleep_fn that returns immediately to avoid real waits. """ import asyncio from unittest.mock import AsyncMock, call import pytest from release_agent.graph.polling import poll_until # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- async def _immediate_sleep(seconds: float) -> None: """Drop-in replacement for asyncio.sleep that returns immediately.""" return # --------------------------------------------------------------------------- # Success path tests # --------------------------------------------------------------------------- class TestPollUntilSuccess: """Tests for the happy path where poll_fn succeeds before timeout.""" async def test_returns_tuple_of_result_and_completed_true(self) -> None: calls = iter(["running", "running", "completed"]) async def poll_fn(): return next(calls) result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "completed", interval_seconds=1, max_wait_seconds=60, sleep_fn=_immediate_sleep, ) assert result == "completed" assert completed is True async def test_returns_immediately_when_already_done(self) -> None: async def poll_fn(): return "completed" result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "completed", interval_seconds=1, max_wait_seconds=60, sleep_fn=_immediate_sleep, ) assert result == "completed" assert completed is True async def test_polls_multiple_times_before_done(self) -> None: call_count = 0 async def poll_fn(): nonlocal call_count call_count += 1 return "done" if call_count >= 3 else "pending" result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", interval_seconds=1, max_wait_seconds=60, sleep_fn=_immediate_sleep, ) assert result == "done" assert completed is True assert call_count == 3 async def test_sleep_called_between_polls(self) -> None: call_count = 0 sleep_calls: list[float] = [] async def poll_fn(): nonlocal call_count call_count += 1 return "done" if call_count >= 2 else "pending" async def tracking_sleep(seconds: float) -> None: sleep_calls.append(seconds) await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", interval_seconds=15, max_wait_seconds=60, sleep_fn=tracking_sleep, ) assert len(sleep_calls) >= 1 assert all(s == 15 for s in sleep_calls) async def test_no_sleep_on_first_successful_poll(self) -> None: sleep_calls: list[float] = [] async def poll_fn(): return "done" async def tracking_sleep(seconds: float) -> None: sleep_calls.append(seconds) await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", interval_seconds=10, max_wait_seconds=60, sleep_fn=tracking_sleep, ) assert sleep_calls == [] async def test_works_with_dict_results(self) -> None: responses = iter([ {"status": "inProgress"}, {"status": "completed", "result": "succeeded"}, ]) async def poll_fn(): return next(responses) result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: r["status"] == "completed", interval_seconds=1, max_wait_seconds=60, sleep_fn=_immediate_sleep, ) assert result["result"] == "succeeded" assert completed is True # --------------------------------------------------------------------------- # Timeout tests # --------------------------------------------------------------------------- class TestPollUntilTimeout: """Tests for timeout behavior.""" async def test_returns_last_result_and_completed_false_on_timeout(self) -> None: async def poll_fn(): return "still_running" # With interval=10, max_wait=5, it should time out after one poll result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", interval_seconds=10, max_wait_seconds=5, sleep_fn=_immediate_sleep, ) assert result == "still_running" assert completed is False async def test_at_least_one_poll_happens_before_timeout(self) -> None: call_count = 0 async def poll_fn(): nonlocal call_count call_count += 1 return "running" await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", interval_seconds=100, max_wait_seconds=1, sleep_fn=_immediate_sleep, ) assert call_count >= 1 async def test_max_polls_bounded_by_max_wait_over_interval(self) -> None: call_count = 0 async def poll_fn(): nonlocal call_count call_count += 1 return "running" await poll_until( poll_fn=poll_fn, is_done=lambda r: False, interval_seconds=10, max_wait_seconds=30, sleep_fn=_immediate_sleep, ) # With interval=10, max_wait=30: should poll at most ceil(30/10)+1 = 4 times assert call_count <= 5 # --------------------------------------------------------------------------- # Error handling tests # --------------------------------------------------------------------------- class TestPollUntilErrorHandling: """Tests for error/exception handling in poll_until.""" async def test_continues_after_transient_exception(self) -> None: call_count = 0 async def poll_fn(): nonlocal call_count call_count += 1 if call_count < 3: raise RuntimeError("Transient error") return "done" result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", interval_seconds=1, max_wait_seconds=60, sleep_fn=_immediate_sleep, ) assert result == "done" assert completed is True async def test_aborts_after_three_consecutive_failures(self) -> None: call_count = 0 async def poll_fn(): nonlocal call_count call_count += 1 raise RuntimeError("Persistent error") result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: True, interval_seconds=1, max_wait_seconds=60, sleep_fn=_immediate_sleep, ) # Should abort after 3 consecutive failures assert call_count == 3 assert completed is False assert result is None async def test_resets_consecutive_failure_count_on_success(self) -> None: call_count = 0 async def poll_fn(): nonlocal call_count call_count += 1 # Fail twice, succeed once, fail twice, succeed (done) if call_count in (1, 2): raise RuntimeError("fail") if call_count == 3: return "running" if call_count in (4, 5): raise RuntimeError("fail again") return "done" result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", interval_seconds=1, max_wait_seconds=120, sleep_fn=_immediate_sleep, ) assert completed is True assert result == "done" async def test_single_exception_does_not_abort(self) -> None: call_count = 0 async def poll_fn(): nonlocal call_count call_count += 1 if call_count == 1: raise ValueError("one error") return "done" result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", interval_seconds=1, max_wait_seconds=60, sleep_fn=_immediate_sleep, ) assert completed is True assert result == "done" async def test_two_consecutive_failures_do_not_abort(self) -> None: call_count = 0 async def poll_fn(): nonlocal call_count call_count += 1 if call_count <= 2: raise ConnectionError("two errors") return "done" result, completed = await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", interval_seconds=1, max_wait_seconds=60, sleep_fn=_immediate_sleep, ) assert completed is True assert result == "done" # --------------------------------------------------------------------------- # Default parameter tests # --------------------------------------------------------------------------- class TestPollUntilDefaults: """Tests that default parameters match the spec.""" async def test_default_interval_is_30_seconds(self) -> None: sleep_calls: list[float] = [] async def poll_fn(): return "done" if len(sleep_calls) >= 1 else "running" async def tracking_sleep(seconds: float) -> None: sleep_calls.append(seconds) await poll_until( poll_fn=poll_fn, is_done=lambda r: r == "done", sleep_fn=tracking_sleep, ) if sleep_calls: assert sleep_calls[0] == 30 async def test_poll_fn_and_is_done_are_keyword_only(self) -> None: """poll_fn and is_done must be passed as keyword arguments.""" async def poll_fn(): return "done" with pytest.raises(TypeError): await poll_until(poll_fn, lambda r: r == "done") # type: ignore[call-arg]