"""Tests for services/pr_dedup.py. Written FIRST (TDD RED phase). find_unprocessed_prs queries agent_threads to find which PRs have not yet been processed (no existing thread for that repo+pr_id combination). """ import pytest from release_agent.models.pr import PRInfo from release_agent.services.pr_dedup import find_unprocessed_prs # --------------------------------------------------------------------------- # Helpers — fake async pool # --------------------------------------------------------------------------- def _make_pr(pr_id: str, repo_name: str = "my-repo") -> PRInfo: return PRInfo( pr_id=pr_id, pr_url=f"https://dev.azure.com/org/proj/_git/{repo_name}/pullrequest/{pr_id}", repo_name=repo_name, branch="refs/heads/feature/ALLPOST-100-fix", pr_title=f"PR {pr_id}", pr_status="active", ) def _make_pool(existing_rows: list[tuple[str, str]]): """Return a fake async connection pool. existing_rows: list of (pr_id, repo_name) tuples representing already-processed PRs. """ class FakeCursor: def __init__(self, rows): self._rows = rows async def execute(self, sql, params=None): pass async def fetchall(self): return self._rows async def __aenter__(self): return self async def __aexit__(self, *args): pass class FakeConn: def __init__(self, rows): self._rows = rows def cursor(self): return FakeCursor(self._rows) async def __aenter__(self): return self async def __aexit__(self, *args): pass class FakePool: def __init__(self, rows): self._rows = rows def connection(self): return FakeConn(self._rows) return FakePool(existing_rows) # --------------------------------------------------------------------------- # find_unprocessed_prs tests # --------------------------------------------------------------------------- class TestFindUnprocessedPrs: async def test_returns_all_when_none_processed(self) -> None: prs = [_make_pr("10"), _make_pr("20")] pool = _make_pool([]) result = await find_unprocessed_prs(pool, prs) assert len(result) == 2 async def test_returns_empty_when_all_processed(self) -> None: prs = [_make_pr("10"), _make_pr("20")] # existing rows: (pr_id, repo_name) pool = _make_pool([("10", "my-repo"), ("20", "my-repo")]) result = await find_unprocessed_prs(pool, prs) assert result == [] async def test_returns_only_unprocessed(self) -> None: prs = [_make_pr("10"), _make_pr("20"), _make_pr("30")] pool = _make_pool([("10", "my-repo")]) result = await find_unprocessed_prs(pool, prs) pr_ids = [p.pr_id for p in result] assert "10" not in pr_ids assert "20" in pr_ids assert "30" in pr_ids async def test_empty_input_returns_empty(self) -> None: pool = _make_pool([]) result = await find_unprocessed_prs(pool, []) assert result == [] async def test_different_repos_not_confused(self) -> None: pr_repo_a = _make_pr("10", repo_name="repo-a") pr_repo_b = _make_pr("10", repo_name="repo-b") # Only repo-a/10 is processed pool = _make_pool([("10", "repo-a")]) result = await find_unprocessed_prs(pool, [pr_repo_a, pr_repo_b]) # repo-b/10 should still be returned (different repo) assert len(result) == 1 assert result[0].repo_name == "repo-b" async def test_returns_list_of_pr_info(self) -> None: prs = [_make_pr("42")] pool = _make_pool([]) result = await find_unprocessed_prs(pool, prs) assert all(isinstance(p, PRInfo) for p in result) async def test_preserves_pr_info_objects(self) -> None: pr = _make_pr("77") pool = _make_pool([]) result = await find_unprocessed_prs(pool, [pr]) assert result[0].pr_id == "77" assert result[0].repo_name == "my-repo"