feat: initial commit — Billo Release Agent (LangGraph)

LangGraph-based release automation agent with:
- PR discovery (webhook + polling)
- AI code review via Claude Code CLI (subscription-based)
- Auto-create Jira tickets for PRs without ticket ID
- Jira ticket lifecycle management (code review -> staging -> done)
- CI/CD pipeline trigger, polling, and approval gates
- Slack interactive messages with approval buttons
- Per-repo semantic versioning
- PostgreSQL persistence (threads, staging, releases)
- FastAPI API (webhooks, approvals, status, manual triggers)
- Docker Compose deployment

1069 tests, 95%+ coverage.
This commit is contained in:
Yaojia Wang
2026-03-24 17:38:23 +01:00
commit f5c2733cfb
104 changed files with 19721 additions and 0 deletions

0
tests/graph/__init__.py Normal file
View File

44
tests/graph/conftest.py Normal file
View File

@@ -0,0 +1,44 @@
"""Shared fixtures for graph tests.
Provides build_mock_clients() to create ToolClients with AsyncMock fields
so individual node functions can be tested without compiling the full graph.
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from release_agent.graph.dependencies import ToolClients
def build_mock_clients() -> ToolClients:
"""Return a ToolClients instance whose fields are all AsyncMock/MagicMock."""
azdo = AsyncMock()
jira = AsyncMock()
slack = AsyncMock()
reviewer = AsyncMock()
return ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
def build_config(clients: ToolClients | None = None, staging_store=None) -> dict:
"""Return a LangGraph-style config dict with clients and staging_store."""
if clients is None:
clients = build_mock_clients()
return {
"configurable": {
"clients": clients,
"staging_store": staging_store,
}
}
@pytest.fixture()
def mock_clients() -> ToolClients:
"""Pytest fixture returning fresh mock ToolClients."""
return build_mock_clients()
@pytest.fixture()
def config(mock_clients: ToolClients):
"""Pytest fixture returning a config dict with mock clients."""
return build_config(mock_clients)

View File

@@ -0,0 +1,294 @@
"""Tests for graph/ci_nodes.py.
Written FIRST (TDD RED phase).
All external calls (azdo, slack, poll_until) are mocked.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from release_agent.graph.ci_nodes import notify_ci_result, poll_ci_build, trigger_ci_build
from release_agent.models.build import BuildStatus
from release_agent.models.pipeline import PipelineInfo
from tests.graph.conftest import build_config, build_mock_clients
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_pipeline(pipeline_id: int = 10, name: str = "CI-build") -> dict:
return {"id": pipeline_id, "name": name, "repo": "my-repo"}
# ---------------------------------------------------------------------------
# trigger_ci_build
# ---------------------------------------------------------------------------
class TestTriggerCiBuild:
"""Tests for trigger_ci_build node."""
async def test_triggers_pipeline_on_branch(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 555, "state": "inProgress"}
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
clients.azdo.trigger_pipeline.assert_called_once()
assert "ci_build_id" in result
assert result["ci_build_id"] == 555
async def test_returns_ci_build_id(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=20, name="build-and-test", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 999}
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v2.0.0"}
result = await trigger_ci_build(state, config)
assert result["ci_build_id"] == 999
async def test_appends_error_when_no_pipelines_found(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = []
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
assert "errors" in result
assert len(result["errors"]) >= 1
async def test_appends_error_on_trigger_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.side_effect = ServiceError(
service="azdo", status_code=500, detail="Internal error"
)
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
assert "errors" in result
async def test_uses_main_branch_when_no_version(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 1}
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await trigger_ci_build(state, config)
call_kwargs = clients.azdo.trigger_pipeline.call_args[1]
branch = call_kwargs.get("branch", "")
assert "main" in branch or "refs/heads" in branch
async def test_appends_message_on_success(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 123}
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
assert "messages" in result
assert len(result["messages"]) >= 1
# ---------------------------------------------------------------------------
# poll_ci_build
# ---------------------------------------------------------------------------
class TestPollCiBuild:
"""Tests for poll_ci_build node."""
async def test_returns_ci_build_status_and_result_on_completion(self) -> None:
clients = build_mock_clients()
completed_status = BuildStatus(status="completed", result="succeeded", build_url="https://build/1")
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(completed_status, True),
):
result = await poll_ci_build(state, config)
assert result["ci_build_status"] == "completed"
assert result["ci_build_result"] == "succeeded"
async def test_returns_build_url(self) -> None:
clients = build_mock_clients()
completed_status = BuildStatus(
status="completed",
result="succeeded",
build_url="https://dev.azure.com/build/42",
)
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(completed_status, True),
):
result = await poll_ci_build(state, config)
assert result.get("ci_build_url") == "https://dev.azure.com/build/42"
async def test_appends_error_on_timeout(self) -> None:
clients = build_mock_clients()
running_status = BuildStatus(status="inProgress", result=None, build_url=None)
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(running_status, False),
):
result = await poll_ci_build(state, config)
assert "errors" in result
async def test_appends_error_when_build_id_missing(self) -> None:
clients = build_mock_clients()
config = build_config(clients)
state = {"repo_name": "my-repo"} # no ci_build_id
result = await poll_ci_build(state, config)
assert "errors" in result
async def test_passes_correct_build_id_to_poll_fn(self) -> None:
clients = build_mock_clients()
clients.azdo.get_build_status.return_value = BuildStatus(
status="completed", result="succeeded", build_url=None
)
config = build_config(clients)
state = {"ci_build_id": 77, "repo_name": "my-repo"}
async def fake_poll_until(*, poll_fn, is_done, interval_seconds, max_wait_seconds, sleep_fn=None):
result = await poll_fn()
return result, True
with patch("release_agent.graph.ci_nodes.poll_until", side_effect=fake_poll_until):
await poll_ci_build(state, config)
clients.azdo.get_build_status.assert_called_once_with(build_id=77)
async def test_result_none_when_poll_returns_none(self) -> None:
clients = build_mock_clients()
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(None, False),
):
result = await poll_ci_build(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# notify_ci_result
# ---------------------------------------------------------------------------
class TestNotifyCiResult:
"""Tests for notify_ci_result node."""
async def test_sends_notification_on_success(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_status": "completed",
"ci_build_result": "succeeded",
"ci_build_url": "https://build/99",
}
result = await notify_ci_result(state, config)
clients.slack.send_notification.assert_called_once()
assert "messages" in result
async def test_sends_notification_on_failure(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_status": "completed",
"ci_build_result": "failed",
"ci_build_url": None,
}
result = await notify_ci_result(state, config)
clients.slack.send_notification.assert_called_once()
async def test_handles_slack_error_gracefully(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.slack.send_notification.side_effect = ServiceError(
service="slack", status_code=500, detail="Slack error"
)
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_result": "succeeded",
"ci_build_url": None,
}
result = await notify_ci_result(state, config)
# Should not re-raise; should append error
assert "errors" in result
async def test_includes_repo_name_in_message(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "super-service",
"ci_build_result": "succeeded",
"ci_build_url": None,
}
await notify_ci_result(state, config)
call_kwargs = clients.slack.send_notification.call_args[1]
text_or_blocks = str(call_kwargs)
assert "super-service" in text_or_blocks
async def test_returns_empty_dict_when_state_has_no_data(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {}
result = await notify_ci_result(state, config)
# Should not crash; may return messages or empty dict
assert isinstance(result, dict)

View File

@@ -0,0 +1,283 @@
"""Tests for graph/dependencies.py. Written FIRST (TDD RED phase).
Covers:
- ToolClients frozen dataclass
- StagingStore Protocol (structural check)
- JsonFileStagingStore file I/O operations
"""
import json
from datetime import date
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients
from release_agent.models.release import ArchivedRelease, StagingRelease
from release_agent.models.ticket import TicketEntry
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease:
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=[],
)
# ---------------------------------------------------------------------------
# ToolClients tests
# ---------------------------------------------------------------------------
class TestToolClients:
"""Tests for the ToolClients frozen dataclass."""
def test_can_be_constructed_with_all_fields(self) -> None:
azdo = AsyncMock()
jira = AsyncMock()
slack = AsyncMock()
reviewer = AsyncMock()
clients = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
assert clients.azdo is azdo
assert clients.jira is jira
assert clients.slack is slack
assert clients.reviewer is reviewer
def test_is_frozen_cannot_reassign_field(self) -> None:
clients = ToolClients(
azdo=AsyncMock(), jira=AsyncMock(), slack=AsyncMock(), reviewer=AsyncMock()
)
with pytest.raises((AttributeError, TypeError)):
clients.azdo = AsyncMock() # type: ignore[misc]
def test_fields_are_accessible_by_name(self) -> None:
azdo = object()
clients = ToolClients(
azdo=azdo, jira=object(), slack=object(), reviewer=object()
)
assert clients.azdo is azdo
def test_equality_for_same_instances(self) -> None:
azdo = AsyncMock()
jira = AsyncMock()
slack = AsyncMock()
reviewer = AsyncMock()
c1 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
c2 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
assert c1 == c2
# ---------------------------------------------------------------------------
# StagingStore Protocol structural tests
# ---------------------------------------------------------------------------
class TestStagingStoreProtocol:
"""Verify that the Protocol is structurally correct."""
def test_json_file_store_satisfies_protocol(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
# runtime_checkable would need @runtime_checkable; check duck-typing instead
assert hasattr(store, "load")
assert hasattr(store, "save")
assert hasattr(store, "archive")
assert hasattr(store, "list_versions")
def test_protocol_is_importable(self) -> None:
# Just import-level check
assert StagingStore is not None
# ---------------------------------------------------------------------------
# JsonFileStagingStore tests
# ---------------------------------------------------------------------------
class TestJsonFileStagingStore:
"""Tests for JsonFileStagingStore using tmp_path for file I/O."""
# ------------------------------------------------------------------
# load
# ------------------------------------------------------------------
async def test_load_returns_none_when_file_missing(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
result = await store.load("nonexistent-repo")
assert result is None
async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.0"
assert loaded.repo == "my-repo"
async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging().add_ticket(_make_ticket("BILL-10"))
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert len(loaded.tickets) == 1
assert loaded.tickets[0].id == "BILL-10"
async def test_load_is_read_only_does_not_mutate_stored(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded1 = await store.load("my-repo")
loaded2 = await store.load("my-repo")
assert loaded1 is not loaded2 # fresh objects each time
# ------------------------------------------------------------------
# save
# ------------------------------------------------------------------
async def test_save_creates_file_in_directory(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="api-service")
await store.save(staging)
expected_path = tmp_path / "api-service.json"
assert expected_path.exists()
async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging_v1 = _make_staging(version="v1.0.0")
staging_v2 = _make_staging(version="v1.0.1")
await store.save(staging_v1)
await store.save(staging_v2)
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.1"
async def test_save_writes_valid_json(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
raw = (tmp_path / "my-repo.json").read_text()
data = json.loads(raw)
assert data["version"] == "v1.0.0"
assert data["repo"] == "my-repo"
async def test_save_does_not_mutate_staging_release(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
original_tickets = list(staging.tickets)
await store.save(staging)
assert list(staging.tickets) == original_tickets
# ------------------------------------------------------------------
# archive
# ------------------------------------------------------------------
async def test_archive_removes_staging_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
assert await store.load("my-repo") is None
async def test_archive_creates_archive_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="my-repo", version="v1.0.0")
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
assert archive_path.exists()
async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
release_date = date(2025, 6, 1)
await store.archive(staging, release_date)
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
data = json.loads(archive_path.read_text())
assert data["released_at"] == "2025-06-01"
async def test_archive_without_prior_save_creates_archive(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
assert archive_path.exists()
# ------------------------------------------------------------------
# list_versions
# ------------------------------------------------------------------
async def test_list_versions_empty_directory(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
versions = await store.list_versions("my-repo")
assert versions == []
async def test_list_versions_returns_version_from_staging_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v2.1.0"))
versions = await store.list_versions("my-repo")
assert "v2.1.0" in versions
async def test_list_versions_includes_archived_versions(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(version="v1.5.0")
await store.save(staging)
await store.archive(staging, date(2025, 3, 1))
# Now save a new staging for the same repo
await store.save(_make_staging(version="v1.6.0"))
versions = await store.list_versions("my-repo")
assert "v1.5.0" in versions
assert "v1.6.0" in versions
async def test_list_versions_only_returns_versions_for_given_repo(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(repo="repo-a", version="v1.0.0"))
await store.save(_make_staging(repo="repo-b", version="v2.0.0"))
versions_a = await store.list_versions("repo-a")
assert "v1.0.0" in versions_a
# repo-b version should not appear in repo-a's list
assert "v2.0.0" not in versions_a
async def test_list_versions_no_duplicates(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v1.0.0"))
versions = await store.list_versions("my-repo")
assert len(versions) == len(set(versions))
async def test_list_versions_multiple_archives(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
for i in range(3):
staging = _make_staging(version=f"v1.0.{i}")
await store.archive(staging, date(2025, 1, i + 1))
versions = await store.list_versions("my-repo")
assert len(versions) == 3
assert "v1.0.0" in versions
assert "v1.0.1" in versions
assert "v1.0.2" in versions
# ------------------------------------------------------------------
# directory creation
# ------------------------------------------------------------------
def test_store_creates_directory_if_not_exists(self, tmp_path: Path) -> None:
new_dir = tmp_path / "staging_data"
assert not new_dir.exists()
JsonFileStagingStore(directory=new_dir)
assert new_dir.exists()

View File

@@ -0,0 +1,177 @@
"""Tests for async StagingStore protocol and async JsonFileStagingStore.
Phase 5 - Step 1: All StagingStore methods become async def.
Written FIRST (TDD RED phase).
"""
import json
from datetime import date
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients
from release_agent.models.release import StagingRelease
from release_agent.models.ticket import TicketEntry
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease:
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=[],
)
# ---------------------------------------------------------------------------
# Protocol: all methods must be async
# ---------------------------------------------------------------------------
class TestStagingStoreProtocolIsAsync:
"""Verify that StagingStore protocol methods are async-compatible."""
def test_protocol_has_load_method(self) -> None:
assert hasattr(StagingStore, "load")
def test_protocol_has_save_method(self) -> None:
assert hasattr(StagingStore, "save")
def test_protocol_has_archive_method(self) -> None:
assert hasattr(StagingStore, "archive")
def test_protocol_has_list_versions_method(self) -> None:
assert hasattr(StagingStore, "list_versions")
# ---------------------------------------------------------------------------
# JsonFileStagingStore async interface
# ---------------------------------------------------------------------------
class TestJsonFileStagingStoreAsync:
"""Verify that JsonFileStagingStore methods are awaitable (async def)."""
async def test_load_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
result = await store.load("nonexistent-repo")
assert result is None
async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.0"
assert loaded.repo == "my-repo"
async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging().add_ticket(_make_ticket("BILL-10"))
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert len(loaded.tickets) == 1
assert loaded.tickets[0].id == "BILL-10"
async def test_load_returns_fresh_objects(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded1 = await store.load("my-repo")
loaded2 = await store.load("my-repo")
assert loaded1 is not loaded2
async def test_save_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="api-service")
await store.save(staging)
expected_path = tmp_path / "api-service.json"
assert expected_path.exists()
async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v1.0.0"))
await store.save(_make_staging(version="v1.0.1"))
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.1"
async def test_save_writes_valid_json(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
raw = (tmp_path / "my-repo.json").read_text()
data = json.loads(raw)
assert data["version"] == "v1.0.0"
assert data["repo"] == "my-repo"
async def test_archive_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
assert await store.load("my-repo") is None
async def test_archive_creates_archive_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="my-repo", version="v1.0.0")
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
assert archive_path.exists()
async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
data = json.loads(archive_path.read_text())
assert data["released_at"] == "2025-06-01"
async def test_list_versions_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
versions = await store.list_versions("my-repo")
assert versions == []
async def test_list_versions_returns_staging_version(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v2.1.0"))
versions = await store.list_versions("my-repo")
assert "v2.1.0" in versions
async def test_list_versions_includes_archived(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(version="v1.5.0")
await store.save(staging)
await store.archive(staging, date(2025, 3, 1))
await store.save(_make_staging(version="v1.6.0"))
versions = await store.list_versions("my-repo")
assert "v1.5.0" in versions
assert "v1.6.0" in versions
async def test_list_versions_only_for_given_repo(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(repo="repo-a", version="v1.0.0"))
await store.save(_make_staging(repo="repo-b", version="v2.0.0"))
versions_a = await store.list_versions("repo-a")
assert "v1.0.0" in versions_a
assert "v2.0.0" not in versions_a

View File

@@ -0,0 +1,53 @@
"""Tests for graph/full_cycle.py.
Tests that the full cycle graph composes pr_completed and release subgraphs
correctly, and that the routing conditional edge works as expected.
"""
from release_agent.graph.full_cycle import build_full_cycle_graph
from release_agent.graph.routing import should_continue_to_release
class TestBuildFullCycleGraph:
def test_returns_compiled_graph(self) -> None:
graph = build_full_cycle_graph()
assert graph is not None
def test_graph_can_be_built_multiple_times(self) -> None:
graph1 = build_full_cycle_graph()
graph2 = build_full_cycle_graph()
assert graph1 is not None
assert graph2 is not None
def test_graph_has_get_graph_method(self) -> None:
graph = build_full_cycle_graph()
assert hasattr(graph, "get_graph") or hasattr(graph, "nodes")
class TestFullCycleRouting:
"""Test that the routing function used by full_cycle correctly
determines whether to continue to the release subgraph."""
def test_continue_when_flag_true_and_no_errors(self) -> None:
state = {"continue_to_release": True, "errors": []}
assert should_continue_to_release(state) == "yes"
def test_stop_when_flag_false(self) -> None:
state = {"continue_to_release": False}
assert should_continue_to_release(state) == "no"
def test_stop_when_flag_missing(self) -> None:
state = {}
assert should_continue_to_release(state) == "no"
def test_stop_when_errors_present(self) -> None:
state = {"continue_to_release": True, "errors": ["some error"]}
assert should_continue_to_release(state) == "no"
def test_stop_when_flag_true_but_errors_present(self) -> None:
state = {"continue_to_release": True, "errors": ["critical failure"]}
assert should_continue_to_release(state) == "no"
def test_continue_when_errors_empty_list(self) -> None:
state = {"continue_to_release": True, "errors": []}
assert should_continue_to_release(state) == "yes"

356
tests/graph/test_polling.py Normal file
View File

@@ -0,0 +1,356 @@
"""Tests for graph/polling.py — poll_until async utility.
Written FIRST (TDD RED phase).
All tests inject a fake_sleep_fn that returns immediately to avoid real waits.
"""
import asyncio
from unittest.mock import AsyncMock, call
import pytest
from release_agent.graph.polling import poll_until
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _immediate_sleep(seconds: float) -> None:
"""Drop-in replacement for asyncio.sleep that returns immediately."""
return
# ---------------------------------------------------------------------------
# Success path tests
# ---------------------------------------------------------------------------
class TestPollUntilSuccess:
"""Tests for the happy path where poll_fn succeeds before timeout."""
async def test_returns_tuple_of_result_and_completed_true(self) -> None:
calls = iter(["running", "running", "completed"])
async def poll_fn():
return next(calls)
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "completed",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "completed"
assert completed is True
async def test_returns_immediately_when_already_done(self) -> None:
async def poll_fn():
return "completed"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "completed",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "completed"
assert completed is True
async def test_polls_multiple_times_before_done(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
return "done" if call_count >= 3 else "pending"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "done"
assert completed is True
assert call_count == 3
async def test_sleep_called_between_polls(self) -> None:
call_count = 0
sleep_calls: list[float] = []
async def poll_fn():
nonlocal call_count
call_count += 1
return "done" if call_count >= 2 else "pending"
async def tracking_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=15,
max_wait_seconds=60,
sleep_fn=tracking_sleep,
)
assert len(sleep_calls) >= 1
assert all(s == 15 for s in sleep_calls)
async def test_no_sleep_on_first_successful_poll(self) -> None:
sleep_calls: list[float] = []
async def poll_fn():
return "done"
async def tracking_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=10,
max_wait_seconds=60,
sleep_fn=tracking_sleep,
)
assert sleep_calls == []
async def test_works_with_dict_results(self) -> None:
responses = iter([
{"status": "inProgress"},
{"status": "completed", "result": "succeeded"},
])
async def poll_fn():
return next(responses)
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r["status"] == "completed",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result["result"] == "succeeded"
assert completed is True
# ---------------------------------------------------------------------------
# Timeout tests
# ---------------------------------------------------------------------------
class TestPollUntilTimeout:
"""Tests for timeout behavior."""
async def test_returns_last_result_and_completed_false_on_timeout(self) -> None:
async def poll_fn():
return "still_running"
# With interval=10, max_wait=5, it should time out after one poll
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=10,
max_wait_seconds=5,
sleep_fn=_immediate_sleep,
)
assert result == "still_running"
assert completed is False
async def test_at_least_one_poll_happens_before_timeout(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
return "running"
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=100,
max_wait_seconds=1,
sleep_fn=_immediate_sleep,
)
assert call_count >= 1
async def test_max_polls_bounded_by_max_wait_over_interval(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
return "running"
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: False,
interval_seconds=10,
max_wait_seconds=30,
sleep_fn=_immediate_sleep,
)
# With interval=10, max_wait=30: should poll at most ceil(30/10)+1 = 4 times
assert call_count <= 5
# ---------------------------------------------------------------------------
# Error handling tests
# ---------------------------------------------------------------------------
class TestPollUntilErrorHandling:
"""Tests for error/exception handling in poll_until."""
async def test_continues_after_transient_exception(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("Transient error")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "done"
assert completed is True
async def test_aborts_after_three_consecutive_failures(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
raise RuntimeError("Persistent error")
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: True,
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
# Should abort after 3 consecutive failures
assert call_count == 3
assert completed is False
assert result is None
async def test_resets_consecutive_failure_count_on_success(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
# Fail twice, succeed once, fail twice, succeed (done)
if call_count in (1, 2):
raise RuntimeError("fail")
if call_count == 3:
return "running"
if call_count in (4, 5):
raise RuntimeError("fail again")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=120,
sleep_fn=_immediate_sleep,
)
assert completed is True
assert result == "done"
async def test_single_exception_does_not_abort(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValueError("one error")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert completed is True
assert result == "done"
async def test_two_consecutive_failures_do_not_abort(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
if call_count <= 2:
raise ConnectionError("two errors")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert completed is True
assert result == "done"
# ---------------------------------------------------------------------------
# Default parameter tests
# ---------------------------------------------------------------------------
class TestPollUntilDefaults:
"""Tests that default parameters match the spec."""
async def test_default_interval_is_30_seconds(self) -> None:
sleep_calls: list[float] = []
async def poll_fn():
return "done" if len(sleep_calls) >= 1 else "running"
async def tracking_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
sleep_fn=tracking_sleep,
)
if sleep_calls:
assert sleep_calls[0] == 30
async def test_poll_fn_and_is_done_are_keyword_only(self) -> None:
"""poll_fn and is_done must be passed as keyword arguments."""
async def poll_fn():
return "done"
with pytest.raises(TypeError):
await poll_until(poll_fn, lambda r: r == "done") # type: ignore[call-arg]

View File

@@ -0,0 +1,414 @@
"""Tests for PostgresStagingStore.
Phase 5 - Step 2: PostgreSQL-backed StagingStore using async pool.
Written FIRST (TDD RED phase).
All tests use FakeAsyncPool — no real PostgreSQL required.
"""
import json
from datetime import date
from unittest.mock import AsyncMock, MagicMock
import pytest
from release_agent.graph.postgres_staging_store import PostgresStagingStore
from release_agent.models.release import ArchivedRelease, StagingRelease
from release_agent.models.ticket import TicketEntry
# ---------------------------------------------------------------------------
# Fake pool infrastructure
# ---------------------------------------------------------------------------
class FakeAsyncCursor:
"""Records SQL calls and returns configured results."""
def __init__(self) -> None:
self.executed: list[tuple[str, tuple]] = []
self._fetchone_result: tuple | None = None
self._fetchall_result: list[tuple] = []
def set_fetchone(self, row: tuple | None) -> None:
self._fetchone_result = row
def set_fetchall(self, rows: list[tuple]) -> None:
self._fetchall_result = rows
async def execute(self, sql: str, params: tuple = ()) -> None:
self.executed.append((sql, params))
async def fetchone(self) -> tuple | None:
return self._fetchone_result
async def fetchall(self) -> list[tuple]:
return self._fetchall_result
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeAsyncTransaction:
"""Fake async transaction context manager (no-op)."""
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeAsyncConnection:
"""Async context manager returning a FakeAsyncCursor."""
def __init__(self, cursor: FakeAsyncCursor) -> None:
self._cursor = cursor
def cursor(self):
return self._cursor
def transaction(self):
return FakeAsyncTransaction()
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeAsyncPool:
"""Records all SQL executed through it."""
def __init__(self, cursor: FakeAsyncCursor) -> None:
self._cursor = cursor
self._conn = FakeAsyncConnection(cursor)
def connection(self):
return self._conn
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(
repo: str = "my-repo",
version: str = "v1.0.0",
tickets: list | None = None,
) -> StagingRelease:
t = tickets if tickets is not None else []
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=t,
)
def _staging_row(staging: StagingRelease) -> tuple:
"""Return (repo, version, started_at, tickets_json) as DB would store it."""
return (
staging.repo,
staging.version,
staging.started_at.isoformat(),
json.dumps([t.model_dump(mode="json") for t in staging.tickets]),
)
# ---------------------------------------------------------------------------
# load()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreLoad:
async def test_load_returns_none_when_no_row(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
result = await store.load("nonexistent-repo")
assert result is None
async def test_load_returns_staging_release_when_row_exists(self) -> None:
staging = _make_staging(repo="api-service", version="v2.0.0")
cursor = FakeAsyncCursor()
cursor.set_fetchone((
staging.repo,
staging.version,
staging.started_at.isoformat(),
json.dumps([]),
))
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
result = await store.load("api-service")
assert result is not None
assert isinstance(result, StagingRelease)
assert result.repo == "api-service"
assert result.version == "v2.0.0"
async def test_load_returns_staging_with_tickets(self) -> None:
ticket = _make_ticket("BILL-42")
staging = _make_staging(tickets=[ticket])
cursor = FakeAsyncCursor()
cursor.set_fetchone((
staging.repo,
staging.version,
staging.started_at.isoformat(),
json.dumps([ticket.model_dump(mode="json")]),
))
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
result = await store.load("my-repo")
assert result is not None
assert len(result.tickets) == 1
assert result.tickets[0].id == "BILL-42"
async def test_load_executes_select_with_correct_repo(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
await store.load("target-repo")
assert len(cursor.executed) >= 1
sql, params = cursor.executed[-1]
assert "SELECT" in sql.upper()
assert "target-repo" in params
async def test_load_queries_staging_releases_table(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
await store.load("my-repo")
sql, _ = cursor.executed[-1]
assert "staging_releases" in sql
# ---------------------------------------------------------------------------
# save()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreSave:
async def test_save_executes_upsert(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
await store.save(staging)
assert len(cursor.executed) >= 1
sql, _ = cursor.executed[-1]
# Should be an INSERT ... ON CONFLICT ... or UPSERT
assert "INSERT" in sql.upper() or "UPSERT" in sql.upper()
async def test_save_passes_repo_to_upsert(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(repo="payment-service")
await store.save(staging)
_, params = cursor.executed[-1]
assert "payment-service" in params
async def test_save_passes_version_to_upsert(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(version="v3.1.0")
await store.save(staging)
_, params = cursor.executed[-1]
assert "v3.1.0" in params
async def test_save_targets_staging_releases_table(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
await store.save(staging)
sql, _ = cursor.executed[-1]
assert "staging_releases" in sql
async def test_save_serializes_tickets_as_json(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(tickets=[_make_ticket("ALLPOST-99")])
await store.save(staging)
_, params = cursor.executed[-1]
# tickets param should be a JSON string containing the ticket id
tickets_json = next(p for p in params if isinstance(p, str) and "ALLPOST-99" in p)
parsed = json.loads(tickets_json)
assert parsed[0]["id"] == "ALLPOST-99"
# ---------------------------------------------------------------------------
# archive()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreArchive:
async def test_archive_inserts_into_archived_releases(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
await store.archive(staging, date(2025, 6, 1))
sql_statements = [sql for sql, _ in cursor.executed]
assert any("archived_releases" in sql for sql in sql_statements)
async def test_archive_deletes_from_staging_releases(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(repo="my-repo")
await store.archive(staging, date(2025, 6, 1))
sql_statements = [sql for sql, _ in cursor.executed]
assert any("DELETE" in sql.upper() and "staging_releases" in sql for sql in sql_statements)
async def test_archive_passes_released_at_date(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
release_date = date(2025, 12, 31)
await store.archive(staging, release_date)
all_params = [params for _, params in cursor.executed]
all_values = [v for params in all_params for v in params]
assert "2025-12-31" in all_values or release_date.isoformat() in all_values
async def test_archive_passes_repo_to_delete(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(repo="payment-service")
await store.archive(staging, date(2025, 6, 1))
delete_calls = [(sql, params) for sql, params in cursor.executed if "DELETE" in sql.upper()]
assert len(delete_calls) >= 1
_, params = delete_calls[0]
assert "payment-service" in params
# ---------------------------------------------------------------------------
# list_versions()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreListVersions:
async def test_list_versions_returns_empty_when_no_data(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
cursor.set_fetchall([])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert versions == []
async def test_list_versions_includes_staging_version(self) -> None:
cursor = FakeAsyncCursor()
# fetchone returns staging row
cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]"))
# fetchall returns archived rows
cursor.set_fetchall([])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert "v1.0.0" in versions
async def test_list_versions_includes_archived_versions(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
cursor.set_fetchall([
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-06-01"),
("my-repo", "v1.1.0", "2025-02-01", "[]", "2025-07-01"),
])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert "v1.0.0" in versions
assert "v1.1.0" in versions
async def test_list_versions_combines_staging_and_archived(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(("my-repo", "v2.0.0", "2025-03-01", "[]"))
cursor.set_fetchall([
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"),
])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert "v2.0.0" in versions
assert "v1.0.0" in versions
async def test_list_versions_no_duplicates(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]"))
cursor.set_fetchall([
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"),
])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert len(versions) == len(set(versions))
async def test_list_versions_executes_queries_for_correct_repo(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
cursor.set_fetchall([])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
await store.list_versions("target-repo")
all_params = [params for _, params in cursor.executed]
all_values = [v for params in all_params for v in params]
assert "target-repo" in all_values

View File

@@ -0,0 +1,956 @@
"""Tests for graph/pr_completed.py node functions. Written FIRST (TDD RED phase).
Each node is an async function (state, config) -> dict.
Tests call nodes directly with a state dict and config dict — no graph compilation.
"""
from datetime import date, datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients
from release_agent.graph.pr_completed import (
_post_review_to_pr,
add_jira_pr_link,
auto_create_ticket,
calculate_version,
evaluate_review,
fetch_pr_details,
interrupt_confirm_merge,
merge_pr_node,
move_jira_code_review,
move_jira_ready_for_stage,
notify_request_changes,
parse_webhook,
run_code_review,
update_staging,
build_pr_completed_graph,
)
from release_agent.models.review import ReviewIssue
from release_agent.models.jira import JiraIssue
from release_agent.models.pr import PRInfo
from release_agent.models.review import ReviewResult
from tests.graph.conftest import build_config, build_mock_clients
# ---------------------------------------------------------------------------
# Webhook payload fixtures
# ---------------------------------------------------------------------------
def _make_webhook_payload(
*,
repo_name: str = "my-repo",
pr_id: int = 42,
source_ref: str = "refs/heads/feature/ALLPOST-100_fix-bug",
target_ref: str = "refs/heads/main",
status: str = "completed",
title: str = "Fix: bug",
closed_date: str | None = "2025-01-15T10:00:00Z",
) -> dict:
# Uses snake_case keys to match WebhookPayload Pydantic model field names
return {
"subscription_id": "sub-1",
"event_type": "git.pullrequest.merged",
"resource": {
"repository": {
"id": "repo-id-1",
"name": repo_name,
"web_url": "https://dev.azure.com/org/proj/_git/my-repo",
},
"pull_request_id": pr_id,
"title": title,
"source_ref_name": source_ref,
"target_ref_name": target_ref,
"status": status,
"closed_date": closed_date,
},
}
def _make_pr_info(
*,
pr_id: str = "42",
repo_name: str = "my-repo",
branch: str = "refs/heads/feature/ALLPOST-100-fix-bug",
status: str = "completed",
) -> PRInfo:
return PRInfo(
pr_id=pr_id,
pr_url="https://dev.azure.com/org/proj/_git/my-repo/pullrequest/42",
repo_name=repo_name,
branch=branch,
pr_title="Fix: bug",
pr_status=status,
)
def _make_approve_review() -> dict:
return {
"verdict": "approve",
"summary": "Looks good",
"issues": [],
"has_blockers": False,
}
def _make_request_changes_review() -> dict:
return {
"verdict": "request_changes",
"summary": "Needs work",
"issues": [{"severity": "blocker", "description": "Missing tests"}],
"has_blockers": True,
}
# ---------------------------------------------------------------------------
# parse_webhook
# ---------------------------------------------------------------------------
class TestParseWebhook:
async def test_extracts_pr_info_from_payload(self) -> None:
state = {"webhook_payload": _make_webhook_payload()}
config = build_config()
result = await parse_webhook(state, config)
assert "pr_info" in result
pr = result["pr_info"]
assert pr["pr_id"] == "42"
assert pr["repo_name"] == "my-repo"
async def test_extracts_ticket_from_branch(self) -> None:
state = {"webhook_payload": _make_webhook_payload(
source_ref="refs/heads/feature/ALLPOST-100_fix-bug"
)}
config = build_config()
result = await parse_webhook(state, config)
assert result["ticket_id"] == "ALLPOST-100"
assert result["has_ticket"] is True
async def test_no_ticket_when_branch_has_none(self) -> None:
state = {"webhook_payload": _make_webhook_payload(
source_ref="refs/heads/bugfix/generic_fix"
)}
config = build_config()
result = await parse_webhook(state, config)
assert result["has_ticket"] is False
assert result["ticket_id"] is None
async def test_sets_repo_name(self) -> None:
state = {"webhook_payload": _make_webhook_payload(repo_name="backend-api")}
config = build_config()
result = await parse_webhook(state, config)
assert result["repo_name"] == "backend-api"
async def test_sets_pr_id_as_string(self) -> None:
state = {"webhook_payload": _make_webhook_payload(pr_id=99)}
config = build_config()
result = await parse_webhook(state, config)
assert result["pr_info"]["pr_id"] == "99"
async def test_invalid_payload_adds_error(self) -> None:
state = {"webhook_payload": {"bad": "data"}}
config = build_config()
result = await parse_webhook(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
# ---------------------------------------------------------------------------
# fetch_pr_details
# ---------------------------------------------------------------------------
class TestFetchPrDetails:
async def test_fetches_pr_and_sets_pr_already_merged_false(self) -> None:
clients = build_mock_clients()
pr = _make_pr_info(status="active")
clients.azdo.get_pr = AsyncMock(return_value=pr)
clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py")
config = build_config(clients)
state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "active"}}
result = await fetch_pr_details(state, config)
assert result["pr_already_merged"] is False
assert result["pr_diff"] == "edit: main.py"
async def test_sets_pr_already_merged_true_when_completed(self) -> None:
clients = build_mock_clients()
pr = _make_pr_info(status="completed")
clients.azdo.get_pr = AsyncMock(return_value=pr)
clients.azdo.get_pr_diff = AsyncMock(return_value="")
config = build_config(clients)
state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "completed"}}
result = await fetch_pr_details(state, config)
assert result["pr_already_merged"] is True
async def test_stores_last_merge_source_commit(self) -> None:
clients = build_mock_clients()
pr = _make_pr_info(status="active")
clients.azdo.get_pr = AsyncMock(return_value=pr)
clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py")
config = build_config(clients)
state = {"pr_id": "42", "pr_info": {"pr_id": "42"}}
result = await fetch_pr_details(state, config)
# last_merge_source_commit may be None if pr doesn't have it, but key must be present
assert "last_merge_source_commit" in result
async def test_adds_error_on_service_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.get_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Server error"
))
config = build_config(clients)
state = {"pr_id": "42"}
result = await fetch_pr_details(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
# ---------------------------------------------------------------------------
# move_jira_code_review
# ---------------------------------------------------------------------------
class TestMoveJiraCodeReview:
async def test_transitions_ticket_when_has_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_code_review(state, config)
clients.jira.transition_issue.assert_called_once_with("ALLPOST-100", "code review")
assert result == {} or "messages" in result
async def test_skips_when_no_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
state = {"has_ticket": False, "ticket_id": None}
result = await move_jira_code_review(state, config)
clients.jira.transition_issue.assert_not_called()
async def test_appends_error_on_jira_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Jira down"
))
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_code_review(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
# ---------------------------------------------------------------------------
# run_code_review
# ---------------------------------------------------------------------------
class TestRunCodeReview:
async def test_calls_reviewer_with_diff(self) -> None:
clients = build_mock_clients()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
clients.reviewer.review_pr = AsyncMock(return_value=review)
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix: bug", "repo_name": "my-repo"},
}
result = await run_code_review(state, config)
clients.reviewer.review_pr.assert_called_once()
assert "review_result" in result
async def test_stores_review_result_as_dict(self) -> None:
clients = build_mock_clients()
review = ReviewResult(verdict="approve", summary="Clean code", issues=())
clients.reviewer.review_pr = AsyncMock(return_value=review)
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix", "repo_name": "repo"},
}
result = await run_code_review(state, config)
assert result["review_result"]["verdict"] == "approve"
async def test_adds_error_on_reviewer_failure(self) -> None:
clients = build_mock_clients()
clients.reviewer.review_pr = AsyncMock(side_effect=Exception("API error"))
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix", "repo_name": "repo"},
}
result = await run_code_review(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# _post_review_to_pr
# ---------------------------------------------------------------------------
class TestPostReviewToPr:
async def test_posts_summary_comment(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_comment.assert_called_once()
call_kwargs = clients.azdo.add_pr_comment.call_args
assert "APPROVE" in call_kwargs.kwargs["content"]
async def test_posts_inline_comment_for_issue_with_file_and_line(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issue = ReviewIssue(
severity="error", description="Null check missing",
file_path="src/Foo.cs", line_start=42, suggestion="Add null guard",
)
review = ReviewResult(verdict="request_changes", summary="Issues", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_inline_comment.assert_called_once()
call_kwargs = clients.azdo.add_pr_inline_comment.call_args.kwargs
assert call_kwargs["file_path"] == "src/Foo.cs"
assert call_kwargs["line_start"] == 42
assert "Null check missing" in call_kwargs["content"]
assert "Add null guard" in call_kwargs["content"]
async def test_skips_inline_for_issue_without_line(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issue = ReviewIssue(severity="warning", description="Style issue", file_path="src/Foo.cs")
review = ReviewResult(verdict="approve", summary="OK", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_inline_comment.assert_not_called()
async def test_skips_inline_for_issue_without_file(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issue = ReviewIssue(severity="info", description="General note", line_start=10)
review = ReviewResult(verdict="approve", summary="OK", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_inline_comment.assert_not_called()
async def test_inline_failure_does_not_prevent_summary(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock(side_effect=Exception("API error"))
issue = ReviewIssue(
severity="blocker", description="Critical", file_path="a.cs", line_start=1
)
review = ReviewResult(verdict="request_changes", summary="Bad", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
# Summary should still be posted even though inline failed
clients.azdo.add_pr_comment.assert_called_once()
async def test_summary_failure_does_not_raise(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock(side_effect=Exception("Network error"))
clients.azdo.add_pr_inline_comment = AsyncMock()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
# Should not raise
await _post_review_to_pr(clients, "my-repo", 42, review)
async def test_summary_contains_issue_count(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issues = (
ReviewIssue(severity="warning", description="Issue 1"),
ReviewIssue(severity="error", description="Issue 2"),
)
review = ReviewResult(verdict="request_changes", summary="Problems", issues=issues)
await _post_review_to_pr(clients, "my-repo", 42, review)
content = clients.azdo.add_pr_comment.call_args.kwargs["content"]
assert "2 issue(s)" in content
async def test_run_code_review_calls_post_review(self) -> None:
"""Integration: run_code_review posts comments when pr_id and repo_name present."""
clients = build_mock_clients()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
clients.reviewer.review_pr = AsyncMock(return_value=review)
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "my-repo"},
}
await run_code_review(state, config)
clients.azdo.add_pr_comment.assert_called_once()
# ---------------------------------------------------------------------------
# evaluate_review
# ---------------------------------------------------------------------------
class TestEvaluateReview:
async def test_sets_review_approved_true_for_approve_verdict(self) -> None:
config = build_config()
state = {"review_result": _make_approve_review()}
result = await evaluate_review(state, config)
assert result["review_approved"] is True
async def test_sets_review_approved_false_for_request_changes(self) -> None:
config = build_config()
state = {"review_result": _make_request_changes_review()}
result = await evaluate_review(state, config)
assert result["review_approved"] is False
async def test_sets_false_when_review_result_missing(self) -> None:
config = build_config()
state = {}
result = await evaluate_review(state, config)
assert result["review_approved"] is False
async def test_sets_false_when_has_blockers(self) -> None:
config = build_config()
state = {
"review_result": {
"verdict": "approve",
"summary": "Approve with blocker?",
"issues": [{"severity": "blocker", "description": "Problem"}],
"has_blockers": True,
}
}
result = await evaluate_review(state, config)
assert result["review_approved"] is False
# ---------------------------------------------------------------------------
# interrupt_confirm_merge
# ---------------------------------------------------------------------------
class TestInterruptConfirmMerge:
async def test_calls_interrupt_with_summary_string(self) -> None:
config = build_config()
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"},
"review_result": {"summary": "LGTM"},
}
with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_merge(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
assert len(call_arg) > 0
async def test_interrupt_value_contains_pr_info(self) -> None:
config = build_config()
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix: auth bug", "repo_name": "backend"},
"review_result": {"summary": "All good"},
}
with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_merge(state, config)
call_arg = mock_interrupt.call_args[0][0]
assert "42" in call_arg or "Fix: auth bug" in call_arg or "backend" in call_arg
# ---------------------------------------------------------------------------
# merge_pr_node
# ---------------------------------------------------------------------------
class TestMergePrNode:
async def test_calls_azdo_merge_pr(self) -> None:
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42"},
"last_merge_source_commit": "abc123",
}
result = await merge_pr_node(state, config)
clients.azdo.merge_pr.assert_called_once_with(
pr_id=42, last_merge_source_commit="abc123"
)
async def test_returns_message_on_success(self) -> None:
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42"},
"last_merge_source_commit": "abc123",
}
result = await merge_pr_node(state, config)
assert "messages" in result
async def test_re_raises_on_service_error(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=409, detail="Conflict"
))
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42"},
"last_merge_source_commit": "abc123",
}
with pytest.raises(ServiceError):
await merge_pr_node(state, config)
# ---------------------------------------------------------------------------
# move_jira_ready_for_stage
# ---------------------------------------------------------------------------
class TestMoveJiraReadyForStage:
async def test_transitions_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_ready_for_stage(state, config)
clients.jira.transition_issue.assert_called_once_with(
"ALLPOST-100", "Ready for stage (2)"
)
async def test_skips_when_no_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock()
config = build_config(clients)
state = {"has_ticket": False}
await move_jira_ready_for_stage(state, config)
clients.jira.transition_issue.assert_not_called()
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_ready_for_stage(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# add_jira_pr_link
# ---------------------------------------------------------------------------
class TestAddJiraPrLink:
async def test_calls_add_remote_link(self) -> None:
clients = build_mock_clients()
clients.jira.add_remote_link = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"ticket_id": "ALLPOST-100",
"has_ticket": True,
"pr_info": {
"pr_id": "42",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
"pr_title": "Fix: bug",
},
}
result = await add_jira_pr_link(state, config)
clients.jira.add_remote_link.assert_called_once()
async def test_skips_when_no_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.add_remote_link = AsyncMock()
config = build_config(clients)
state = {"has_ticket": False}
await add_jira_pr_link(state, config)
clients.jira.add_remote_link.assert_not_called()
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.add_remote_link = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Error"
))
config = build_config(clients)
state = {
"ticket_id": "ALLPOST-100",
"has_ticket": True,
"pr_info": {
"pr_id": "42",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
"pr_title": "Fix",
},
}
result = await add_jira_pr_link(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# calculate_version
# ---------------------------------------------------------------------------
class TestCalculateVersion:
async def test_returns_v1_0_0_for_empty_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await calculate_version(state, config)
assert result["version"] == "v1.0.0"
async def test_increments_patch_version(self, tmp_path) -> None:
from release_agent.models.release import StagingRelease
staging_store = JsonFileStagingStore(directory=tmp_path)
# Pre-populate with an existing version
staging = StagingRelease(
version="v1.0.5",
repo="my-repo",
started_at=date(2025, 1, 1),
tickets=[],
)
await staging_store.archive(staging, date(2025, 1, 10))
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await calculate_version(state, config)
assert result["version"] == "v1.0.6"
async def test_sets_version_in_state(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "new-repo"}
result = await calculate_version(state, config)
assert "version" in result
assert result["version"].startswith("v")
# ---------------------------------------------------------------------------
# update_staging
# ---------------------------------------------------------------------------
class TestUpdateStaging:
async def test_creates_new_staging_when_none_exists(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
# Jira get_issue returns a summary
from release_agent.models.jira import JiraIssue
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="ALLPOST-100", summary="Fix auth bug", status="Ready for stage (2)"
))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-100",
"has_ticket": True,
"pr_info": {
"pr_id": "42",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
"pr_title": "Fix: auth bug",
"branch": "feature/ALLPOST-100-fix",
},
}
result = await update_staging(state, config)
loaded = await staging_store.load("my-repo")
assert loaded is not None
assert loaded.has_ticket("ALLPOST-100")
async def test_appends_ticket_to_existing_staging(self, tmp_path) -> None:
from datetime import date
from release_agent.models.release import StagingRelease
from release_agent.models.jira import JiraIssue
staging_store = JsonFileStagingStore(directory=tmp_path)
existing = StagingRelease(
version="v1.0.0", repo="my-repo",
started_at=date(2025, 1, 1), tickets=[]
)
await staging_store.save(existing)
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="BILL-99", summary="New feature", status="Ready"
))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "BILL-99",
"has_ticket": True,
"pr_info": {
"pr_id": "55",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/55",
"pr_title": "Feat: new feature",
"branch": "feature/BILL-99-feat",
},
}
await update_staging(state, config)
loaded = await staging_store.load("my-repo")
assert loaded is not None
assert loaded.has_ticket("BILL-99")
async def test_skips_ticket_add_when_no_ticket(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"has_ticket": False,
}
await update_staging(state, config)
# No staging file should be created for ticket-less PR if no existing staging
# (or staging exists without new ticket added)
clients.jira.get_issue.assert_not_called()
async def test_returns_empty_dict_when_no_staging_store(self) -> None:
from release_agent.models.jira import JiraIssue
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="ALLPOST-1", summary="Fix", status="Ready"
))
config = build_config(clients, staging_store=None)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-1",
"has_ticket": True,
"pr_info": {
"pr_id": "1",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1",
"pr_title": "Fix",
"branch": "feature/ALLPOST-1",
},
}
result = await update_staging(state, config)
assert result == {}
async def test_uses_ticket_id_as_summary_on_jira_failure(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(side_effect=Exception("Jira unavailable"))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-99",
"has_ticket": True,
"pr_info": {
"pr_id": "5",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/5",
"pr_title": "Fix something",
"branch": "feature/ALLPOST-99_fix",
},
}
result = await update_staging(state, config)
loaded = await staging_store.load("my-repo")
assert loaded is not None
assert loaded.tickets[0].id == "ALLPOST-99"
assert loaded.tickets[0].summary == "ALLPOST-99"
async def test_sets_staging_dict_in_result(self, tmp_path) -> None:
from release_agent.models.jira import JiraIssue
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="ALLPOST-1", summary="S", status="Ready"
))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-1",
"has_ticket": True,
"pr_info": {
"pr_id": "1",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1",
"pr_title": "Fix",
"branch": "feature/ALLPOST-1",
},
}
result = await update_staging(state, config)
assert "staging" in result
assert isinstance(result["staging"], dict)
# ---------------------------------------------------------------------------
# notify_request_changes
# ---------------------------------------------------------------------------
class TestNotifyRequestChanges:
async def test_calls_slack_send_approval_request(self) -> None:
clients = build_mock_clients()
clients.slack.send_approval_request = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"},
"review_result": {
"verdict": "request_changes",
"summary": "Too many issues",
"issues": [{"severity": "blocker", "description": "No tests"}],
},
}
result = await notify_request_changes(state, config)
clients.slack.send_approval_request.assert_called_once()
async def test_appends_error_on_slack_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.slack.send_approval_request = AsyncMock(side_effect=ServiceError(
service="slack", status_code=500, detail="Webhook error"
))
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "repo"},
"review_result": {"summary": "Issues found", "issues": []},
}
result = await notify_request_changes(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# build_pr_completed_graph
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# auto_create_ticket node
# ---------------------------------------------------------------------------
class TestAutoCreateTicket:
"""Tests for the auto_create_ticket node."""
def _make_config_with_jira_project(
self, jira_project: str = "ALLPOST"
):
clients = build_mock_clients()
clients.jira.create_issue = AsyncMock(return_value="ALLPOST-99")
clients.reviewer.generate_ticket_content = AsyncMock(
return_value=("My summary", "My description")
)
config = build_config(clients)
config["configurable"]["default_jira_project"] = jira_project
return config, clients
async def test_creates_jira_issue_and_returns_ticket_id(self) -> None:
config, clients = self._make_config_with_jira_project()
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"},
}
result = await auto_create_ticket(state, config)
assert result.get("ticket_id") == "ALLPOST-99"
async def test_sets_has_ticket_true(self) -> None:
config, clients = self._make_config_with_jira_project()
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"},
}
result = await auto_create_ticket(state, config)
assert result.get("has_ticket") is True
async def test_calls_generate_ticket_content(self) -> None:
config, clients = self._make_config_with_jira_project()
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix login", "repo_name": "auth-service"},
}
await auto_create_ticket(state, config)
clients.reviewer.generate_ticket_content.assert_awaited_once()
async def test_calls_create_issue_with_project_key(self) -> None:
config, clients = self._make_config_with_jira_project(jira_project="MYPROJ")
clients.jira.create_issue = AsyncMock(return_value="MYPROJ-5")
config["configurable"]["default_jira_project"] = "MYPROJ"
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
await auto_create_ticket(state, config)
call_kwargs = clients.jira.create_issue.call_args.kwargs
assert call_kwargs["project"] == "MYPROJ"
async def test_appends_message_on_success(self) -> None:
config, _ = self._make_config_with_jira_project()
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert "messages" in result
assert len(result["messages"]) > 0
async def test_appends_error_on_create_issue_failure(self) -> None:
config, clients = self._make_config_with_jira_project()
clients.jira.create_issue = AsyncMock(side_effect=Exception("Jira down"))
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
async def test_appends_error_on_generate_content_failure(self) -> None:
config, clients = self._make_config_with_jira_project()
clients.reviewer.generate_ticket_content = AsyncMock(side_effect=RuntimeError("CLI fail"))
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert "errors" in result
async def test_uses_default_project_from_config(self) -> None:
config, clients = self._make_config_with_jira_project(jira_project="TEAM")
clients.jira.create_issue = AsyncMock(return_value="TEAM-1")
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert result["ticket_id"] == "TEAM-1"
# ---------------------------------------------------------------------------
# build_pr_completed_graph
# ---------------------------------------------------------------------------
class TestBuildPrCompletedGraph:
def test_returns_compiled_graph(self) -> None:
graph = build_pr_completed_graph()
assert graph is not None
def test_graph_has_nodes(self) -> None:
graph = build_pr_completed_graph()
# The compiled graph object should be truthy
assert graph is not None
def test_graph_includes_trigger_ci_build_node(self) -> None:
graph = build_pr_completed_graph()
# Graph nodes should include CI pipeline nodes
graph_nodes = graph.get_graph().nodes
assert "trigger_ci_build" in graph_nodes
def test_graph_includes_poll_ci_build_node(self) -> None:
graph = build_pr_completed_graph()
graph_nodes = graph.get_graph().nodes
assert "poll_ci_build" in graph_nodes
def test_graph_includes_notify_ci_result_node(self) -> None:
graph = build_pr_completed_graph()
graph_nodes = graph.get_graph().nodes
assert "notify_ci_result" in graph_nodes
def test_graph_includes_auto_create_ticket_node(self) -> None:
graph = build_pr_completed_graph()
graph_nodes = graph.get_graph().nodes
assert "auto_create_ticket" in graph_nodes

866
tests/graph/test_release.py Normal file
View File

@@ -0,0 +1,866 @@
"""Tests for graph/release.py node functions. Written FIRST (TDD RED phase).
Each node is an async function (state, config) -> dict.
Tests call nodes directly — no graph compilation required.
"""
from datetime import date
from unittest.mock import AsyncMock, patch
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore
from release_agent.graph.release import (
approve_stage,
archive_release,
check_release_approvals,
create_release_pr,
interrupt_confirm_approve,
interrupt_confirm_merge_release,
interrupt_confirm_release,
interrupt_confirm_trigger,
list_pipelines,
load_staging,
merge_release_pr,
move_tickets_to_done,
send_slack_notification,
trigger_pipelines,
build_release_graph,
)
from release_agent.models.pipeline import PipelineInfo, ReleasePipelineStage
from release_agent.models.release import StagingRelease
from release_agent.models.ticket import TicketEntry
from tests.graph.conftest import build_config, build_mock_clients
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(
*,
repo: str = "my-repo",
version: str = "v1.0.0",
tickets: list | None = None,
) -> StagingRelease:
t = tickets if tickets is not None else [_make_ticket()]
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=t,
)
def _staging_dict(staging: StagingRelease) -> dict:
return staging.model_dump(mode="json")
# ---------------------------------------------------------------------------
# load_staging
# ---------------------------------------------------------------------------
class TestLoadStaging:
async def test_loads_staging_from_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await load_staging(state, config)
assert "staging" in result
assert result["staging"]["version"] == "v1.0.0"
async def test_returns_none_when_no_staging(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "nonexistent"}
result = await load_staging(state, config)
assert result.get("staging") is None
async def test_staging_includes_tickets(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(tickets=[_make_ticket("BILL-10"), _make_ticket("BILL-11")])
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await load_staging(state, config)
assert len(result["staging"]["tickets"]) == 2
# ---------------------------------------------------------------------------
# interrupt_confirm_release
# ---------------------------------------------------------------------------
class TestInterruptConfirmRelease:
async def test_calls_interrupt_with_staging_summary(self) -> None:
config = build_config()
staging = _make_staging()
state = {
"repo_name": "my-repo",
"staging": _staging_dict(staging),
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_release(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
async def test_interrupt_contains_version_and_repo(self) -> None:
config = build_config()
staging = _make_staging(version="v2.5.0", repo="backend")
state = {
"repo_name": "backend",
"staging": _staging_dict(staging),
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_release(state, config)
call_arg = mock_interrupt.call_args[0][0]
assert "v2.5.0" in call_arg or "backend" in call_arg
# ---------------------------------------------------------------------------
# create_release_pr
# ---------------------------------------------------------------------------
class TestCreateReleasePr:
async def test_calls_azdo_create_pr(self) -> None:
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(return_value={
"pullRequestId": 99,
"lastMergeSourceCommit": {"commitId": "deadbeef"},
})
config = build_config(clients)
staging = _make_staging(version="v1.2.0")
state = {
"repo_name": "my-repo",
"version": "v1.2.0",
"staging": _staging_dict(staging),
}
result = await create_release_pr(state, config)
clients.azdo.create_pr.assert_called_once()
call_kwargs = clients.azdo.create_pr.call_args.kwargs
assert call_kwargs["repo"] == "my-repo"
async def test_sets_release_pr_id(self) -> None:
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(return_value={
"pullRequestId": 77,
"lastMergeSourceCommit": {"commitId": "cafe1234"},
})
config = build_config(clients)
staging = _make_staging(version="v1.0.3")
state = {
"repo_name": "my-repo",
"version": "v1.0.3",
"staging": _staging_dict(staging),
}
result = await create_release_pr(state, config)
assert result["release_pr_id"] == "77"
async def test_sets_release_pr_commit(self) -> None:
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(return_value={
"pullRequestId": 77,
"lastMergeSourceCommit": {"commitId": "cafe1234"},
})
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
result = await create_release_pr(state, config)
assert result["release_pr_commit"] == "cafe1234"
async def test_re_raises_on_service_error(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=422, detail="Invalid branch"
))
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
with pytest.raises(ServiceError):
await create_release_pr(state, config)
# ---------------------------------------------------------------------------
# interrupt_confirm_merge_release
# ---------------------------------------------------------------------------
class TestInterruptConfirmMergeRelease:
async def test_calls_interrupt_with_pr_info(self) -> None:
config = build_config()
state = {
"release_pr_id": "99",
"version": "v1.0.0",
"repo_name": "my-repo",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_merge_release(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
assert len(call_arg) > 0
# ---------------------------------------------------------------------------
# merge_release_pr
# ---------------------------------------------------------------------------
class TestMergeReleasePr:
async def test_calls_azdo_merge_pr(self) -> None:
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"release_pr_id": "99",
"release_pr_commit": "abc123",
}
await merge_release_pr(state, config)
clients.azdo.merge_pr.assert_called_once_with(
pr_id=99, last_merge_source_commit="abc123"
)
async def test_re_raises_on_service_error(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=409, detail="Conflict"
))
config = build_config(clients)
state = {"release_pr_id": "99", "release_pr_commit": "abc"}
with pytest.raises(ServiceError):
await merge_release_pr(state, config)
# ---------------------------------------------------------------------------
# move_tickets_to_done
# ---------------------------------------------------------------------------
class TestMoveTicketsToDone:
async def test_transitions_all_tickets(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
staging = _make_staging(tickets=[_make_ticket("BILL-1"), _make_ticket("BILL-2")])
state = {"staging": _staging_dict(staging)}
await move_tickets_to_done(state, config)
assert clients.jira.transition_issue.call_count == 2
async def test_calls_transition_with_done_name(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
staging = _make_staging(tickets=[_make_ticket("BILL-1")])
state = {"staging": _staging_dict(staging)}
await move_tickets_to_done(state, config)
call_args = clients.jira.transition_issue.call_args_list[0]
ticket_id, transition = call_args[0]
assert ticket_id == "BILL-1"
assert "done" in transition.lower() or "released" in transition.lower()
async def test_appends_error_on_jira_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Error"
))
config = build_config(clients)
staging = _make_staging(tickets=[_make_ticket()])
state = {"staging": _staging_dict(staging)}
result = await move_tickets_to_done(state, config)
assert "errors" in result
async def test_empty_tickets_no_calls(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock()
config = build_config(clients)
staging = _make_staging(tickets=[])
state = {"staging": _staging_dict(staging)}
await move_tickets_to_done(state, config)
clients.jira.transition_issue.assert_not_called()
# ---------------------------------------------------------------------------
# send_slack_notification
# ---------------------------------------------------------------------------
class TestSendSlackNotification:
async def test_calls_slack_send_release_notification(self) -> None:
clients = build_mock_clients()
clients.slack.send_release_notification = AsyncMock(return_value=True)
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
result = await send_slack_notification(state, config)
clients.slack.send_release_notification.assert_called_once()
async def test_appends_error_on_slack_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.slack.send_release_notification = AsyncMock(side_effect=ServiceError(
service="slack", status_code=500, detail="Webhook error"
))
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
result = await send_slack_notification(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# archive_release
# ---------------------------------------------------------------------------
class TestArchiveRelease:
async def test_archives_staging_to_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"staging": _staging_dict(staging),
}
await archive_release(state, config)
# Staging should be gone now
assert await staging_store.load("my-repo") is None
async def test_archive_file_created_in_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(version="v3.0.0")
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"staging": _staging_dict(staging),
}
await archive_release(state, config)
versions = await staging_store.list_versions("my-repo")
assert "v3.0.0" in versions
# ---------------------------------------------------------------------------
# list_pipelines
# ---------------------------------------------------------------------------
class TestListPipelines:
async def test_fetches_pipelines_from_azdo(self) -> None:
clients = build_mock_clients()
pipelines = [PipelineInfo(id=1, name="build", repo="my-repo")]
clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines)
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
clients.azdo.list_build_pipelines.assert_called_once_with(repo="my-repo")
assert "pipelines" in result
assert len(result["pipelines"]) == 1
async def test_stores_pipelines_as_list_of_dicts(self) -> None:
clients = build_mock_clients()
pipelines = [
PipelineInfo(id=1, name="build", repo="my-repo"),
PipelineInfo(id=2, name="deploy", repo="my-repo"),
]
clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines)
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
assert len(result["pipelines"]) == 2
assert result["pipelines"][0]["id"] == 1
async def test_empty_pipelines_stored_as_empty_list(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines = AsyncMock(return_value=[])
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
assert result["pipelines"] == []
async def test_appends_error_on_service_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.list_build_pipelines = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# interrupt_confirm_trigger
# ---------------------------------------------------------------------------
class TestInterruptConfirmTrigger:
async def test_calls_interrupt_with_pipelines_summary(self) -> None:
config = build_config()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}],
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_trigger(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
# ---------------------------------------------------------------------------
# trigger_pipelines
# ---------------------------------------------------------------------------
class TestTriggerPipelines:
async def test_triggers_each_pipeline(self) -> None:
clients = build_mock_clients()
clients.azdo.trigger_pipeline = AsyncMock(return_value={"id": 1001})
config = build_config(clients)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [
{"id": 1, "name": "build", "repo": "my-repo"},
{"id": 2, "name": "deploy", "repo": "my-repo"},
],
}
result = await trigger_pipelines(state, config)
assert clients.azdo.trigger_pipeline.call_count == 2
assert "triggered_builds" in result
assert len(result["triggered_builds"]) == 2
async def test_no_pipelines_no_calls(self) -> None:
clients = build_mock_clients()
clients.azdo.trigger_pipeline = AsyncMock()
config = build_config(clients)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [],
}
result = await trigger_pipelines(state, config)
clients.azdo.trigger_pipeline.assert_not_called()
assert result["triggered_builds"] == []
async def test_appends_error_on_trigger_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.trigger_pipeline = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}],
}
result = await trigger_pipelines(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# check_release_approvals
# ---------------------------------------------------------------------------
class TestCheckReleaseApprovals:
async def test_fetches_pending_approvals_from_builds(self) -> None:
clients = build_mock_clients()
clients.azdo.get_build_status = AsyncMock(return_value="completed")
config = build_config(clients)
state = {
"triggered_builds": [{"id": 1001}],
}
result = await check_release_approvals(state, config)
assert "pending_approvals" in result
async def test_empty_builds_means_no_approvals(self) -> None:
clients = build_mock_clients()
config = build_config(clients)
state = {"triggered_builds": []}
result = await check_release_approvals(state, config)
assert result["pending_approvals"] == []
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.get_build_status = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"triggered_builds": [{"id": 1001}]}
result = await check_release_approvals(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# interrupt_confirm_approve
# ---------------------------------------------------------------------------
class TestInterruptConfirmApprove:
async def test_calls_interrupt_with_approvals_summary(self) -> None:
config = build_config()
state = {
"pending_approvals": [{"approval_id": "aaa", "stage_name": "Production"}],
"version": "v1.0.0",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_approve(state, config)
mock_interrupt.assert_called_once()
# ---------------------------------------------------------------------------
# approve_stage
# ---------------------------------------------------------------------------
class TestApproveStage:
async def test_approves_each_pending_approval(self) -> None:
clients = build_mock_clients()
clients.azdo.approve_release = AsyncMock(return_value={"status": "approved"})
config = build_config(clients)
state = {
"pending_approvals": [
{"approval_id": "aaa"},
{"approval_id": "bbb"},
],
}
result = await approve_stage(state, config)
assert clients.azdo.approve_release.call_count == 2
async def test_no_approvals_no_calls(self) -> None:
clients = build_mock_clients()
clients.azdo.approve_release = AsyncMock()
config = build_config(clients)
state = {"pending_approvals": []}
await approve_stage(state, config)
clients.azdo.approve_release.assert_not_called()
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.approve_release = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"pending_approvals": [{"approval_id": "aaa"}]}
result = await approve_stage(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# build_release_graph
# ---------------------------------------------------------------------------
class TestBuildReleaseGraph:
def test_returns_compiled_graph(self) -> None:
graph = build_release_graph()
assert graph is not None
def test_graph_includes_trigger_ci_build_main_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "trigger_ci_build_main" in graph_nodes
def test_graph_includes_poll_ci_build_main_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "poll_ci_build_main" in graph_nodes
def test_graph_includes_wait_for_cd_release_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "wait_for_cd_release" in graph_nodes
def test_graph_includes_poll_release_approvals_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "poll_release_approvals" in graph_nodes
def test_graph_includes_interrupt_sandbox_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "interrupt_sandbox_approval" in graph_nodes
def test_graph_includes_interrupt_prod_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "interrupt_prod_approval" in graph_nodes
def test_graph_includes_execute_sandbox_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "execute_sandbox_approval" in graph_nodes
def test_graph_includes_execute_prod_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "execute_prod_approval" in graph_nodes
def test_graph_includes_notify_ci_failure_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "notify_ci_failure" in graph_nodes
# ---------------------------------------------------------------------------
# New release graph node: wait_for_cd_release
# ---------------------------------------------------------------------------
class TestWaitForCdRelease:
"""Tests for wait_for_cd_release node."""
async def test_sets_release_id_when_found(self) -> None:
from release_agent.graph.release import wait_for_cd_release
clients = build_mock_clients()
clients.azdo.get_latest_release.return_value = {"id": 100, "name": "Release-100"}
config = build_config(clients)
state = {"release_definition_id": 5, "repo_name": "my-repo"}
result = await wait_for_cd_release(state, config)
assert "release_id" in result
assert result["release_id"] == 100
async def test_appends_error_when_no_release(self) -> None:
from release_agent.graph.release import wait_for_cd_release
clients = build_mock_clients()
clients.azdo.get_latest_release.return_value = {}
config = build_config(clients)
state = {"release_definition_id": 5, "repo_name": "my-repo"}
result = await wait_for_cd_release(state, config)
assert "errors" in result
async def test_works_without_release_definition_id(self) -> None:
from release_agent.graph.release import wait_for_cd_release
clients = build_mock_clients()
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await wait_for_cd_release(state, config)
assert isinstance(result, dict)
# ---------------------------------------------------------------------------
# New release graph node: poll_release_approvals
# ---------------------------------------------------------------------------
class TestPollReleaseApprovals:
"""Tests for poll_release_approvals node."""
async def test_sets_pending_approvals_from_azdo(self) -> None:
from release_agent.graph.release import poll_release_approvals
from release_agent.models.build import ApprovalRecord
clients = build_mock_clients()
clients.azdo.get_release_approvals.return_value = [
ApprovalRecord(approval_id="a1", stage_name="Sandbox", status="pending", release_id=10),
]
config = build_config(clients)
state = {"release_id": 10}
result = await poll_release_approvals(state, config)
assert "pending_approvals" in result
assert len(result["pending_approvals"]) == 1
async def test_returns_empty_list_when_no_approvals(self) -> None:
from release_agent.graph.release import poll_release_approvals
clients = build_mock_clients()
clients.azdo.get_release_approvals.return_value = []
config = build_config(clients)
state = {"release_id": 10}
result = await poll_release_approvals(state, config)
assert result.get("pending_approvals") == []
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
from release_agent.graph.release import poll_release_approvals
clients = build_mock_clients()
clients.azdo.get_release_approvals.side_effect = ServiceError(
service="azdo", status_code=500, detail="error"
)
config = build_config(clients)
state = {"release_id": 10}
result = await poll_release_approvals(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# New release graph node: interrupt_sandbox_approval
# ---------------------------------------------------------------------------
class TestInterruptSandboxApproval:
async def test_calls_interrupt(self) -> None:
from release_agent.graph.release import interrupt_sandbox_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
"version": "v1.0.0",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_sandbox_approval(state, config)
mock_interrupt.assert_called_once()
async def test_sets_current_stage_to_sandbox_pending(self) -> None:
from release_agent.graph.release import interrupt_sandbox_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
}
with patch("release_agent.graph.release.interrupt", return_value="yes"):
result = await interrupt_sandbox_approval(state, config)
assert result.get("current_stage") == "sandbox_pending"
# ---------------------------------------------------------------------------
# New release graph node: interrupt_prod_approval
# ---------------------------------------------------------------------------
class TestInterruptProdApproval:
async def test_calls_interrupt(self) -> None:
from release_agent.graph.release import interrupt_prod_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
"version": "v1.0.0",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_prod_approval(state, config)
mock_interrupt.assert_called_once()
async def test_sets_current_stage_to_prod_pending(self) -> None:
from release_agent.graph.release import interrupt_prod_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
}
with patch("release_agent.graph.release.interrupt", return_value="yes"):
result = await interrupt_prod_approval(state, config)
assert result.get("current_stage") == "prod_pending"
# ---------------------------------------------------------------------------
# New release graph node: execute_sandbox_approval
# ---------------------------------------------------------------------------
class TestExecuteSandboxApproval:
async def test_approves_sandbox_approvals(self) -> None:
from release_agent.graph.release import execute_sandbox_approval
clients = build_mock_clients()
clients.azdo.approve_release.return_value = {"status": "approved"}
config = build_config(clients)
state = {
"pending_approvals": [{"approval_id": "sb1", "stage_name": "Sandbox"}],
}
result = await execute_sandbox_approval(state, config)
clients.azdo.approve_release.assert_called()
async def test_returns_empty_dict_on_success(self) -> None:
from release_agent.graph.release import execute_sandbox_approval
clients = build_mock_clients()
clients.azdo.approve_release.return_value = {"status": "approved"}
config = build_config(clients)
state = {"pending_approvals": [{"approval_id": "sb1"}]}
result = await execute_sandbox_approval(state, config)
assert "errors" not in result or result["errors"] == []
# ---------------------------------------------------------------------------
# New release graph node: execute_prod_approval
# ---------------------------------------------------------------------------
class TestExecuteProdApproval:
async def test_approves_prod_approvals(self) -> None:
from release_agent.graph.release import execute_prod_approval
clients = build_mock_clients()
clients.azdo.approve_release.return_value = {"status": "approved"}
config = build_config(clients)
state = {
"pending_approvals": [{"approval_id": "pd1", "stage_name": "Production"}],
}
result = await execute_prod_approval(state, config)
clients.azdo.approve_release.assert_called()
# ---------------------------------------------------------------------------
# New release graph node: notify_ci_failure
# ---------------------------------------------------------------------------
class TestNotifyCiFailure:
async def test_sends_slack_notification(self) -> None:
from release_agent.graph.release import notify_ci_failure
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_result": "failed",
"ci_build_url": "https://build/1",
}
result = await notify_ci_failure(state, config)
clients.slack.send_notification.assert_called_once()
async def test_appends_message_on_success(self) -> None:
from release_agent.graph.release import notify_ci_failure
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {"repo_name": "my-repo", "ci_build_result": "failed"}
result = await notify_ci_failure(state, config)
assert "messages" in result or isinstance(result, dict)

302
tests/graph/test_routing.py Normal file
View File

@@ -0,0 +1,302 @@
"""Tests for graph/routing.py. Written FIRST (TDD RED phase).
All routing functions are pure — they take a state dict and return a string.
Every branch is tested, including missing state fields (defaults to falsy).
"""
import pytest
from release_agent.graph.routing import (
has_pending_approvals,
has_pipelines,
has_ticket,
is_pr_already_merged,
is_review_approved,
route_after_fetch,
route_approval_stage,
route_ci_result,
should_continue_to_release,
)
# ---------------------------------------------------------------------------
# is_pr_already_merged
# ---------------------------------------------------------------------------
class TestIsPrAlreadyMerged:
def test_returns_merged_when_true(self) -> None:
state = {"pr_already_merged": True}
assert is_pr_already_merged(state) == "merged"
def test_returns_active_when_false(self) -> None:
state = {"pr_already_merged": False}
assert is_pr_already_merged(state) == "active"
def test_returns_active_when_field_missing(self) -> None:
state = {}
assert is_pr_already_merged(state) == "active"
def test_returns_active_when_none(self) -> None:
state = {"pr_already_merged": None}
assert is_pr_already_merged(state) == "active"
# ---------------------------------------------------------------------------
# is_review_approved
# ---------------------------------------------------------------------------
class TestIsReviewApproved:
def test_returns_approve_when_true(self) -> None:
state = {"review_approved": True}
assert is_review_approved(state) == "approve"
def test_returns_request_changes_when_false(self) -> None:
state = {"review_approved": False}
assert is_review_approved(state) == "request_changes"
def test_returns_request_changes_when_field_missing(self) -> None:
state = {}
assert is_review_approved(state) == "request_changes"
def test_returns_request_changes_when_none(self) -> None:
state = {"review_approved": None}
assert is_review_approved(state) == "request_changes"
# ---------------------------------------------------------------------------
# has_ticket
# ---------------------------------------------------------------------------
class TestHasTicket:
def test_returns_yes_when_true(self) -> None:
state = {"has_ticket": True}
assert has_ticket(state) == "yes"
def test_returns_no_when_false(self) -> None:
state = {"has_ticket": False}
assert has_ticket(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert has_ticket(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"has_ticket": None}
assert has_ticket(state) == "no"
# ---------------------------------------------------------------------------
# should_continue_to_release
# ---------------------------------------------------------------------------
class TestShouldContinueToRelease:
def test_returns_yes_when_true(self) -> None:
state = {"continue_to_release": True}
assert should_continue_to_release(state) == "yes"
def test_returns_no_when_false(self) -> None:
state = {"continue_to_release": False}
assert should_continue_to_release(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert should_continue_to_release(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"continue_to_release": None}
assert should_continue_to_release(state) == "no"
# ---------------------------------------------------------------------------
# has_pipelines
# ---------------------------------------------------------------------------
class TestHasPipelines:
def test_returns_yes_when_non_empty_list(self) -> None:
state = {"pipelines": [{"id": 1}]}
assert has_pipelines(state) == "yes"
def test_returns_no_when_empty_list(self) -> None:
state = {"pipelines": []}
assert has_pipelines(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert has_pipelines(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"pipelines": None}
assert has_pipelines(state) == "no"
def test_returns_yes_with_multiple_pipelines(self) -> None:
state = {"pipelines": [{"id": 1}, {"id": 2}]}
assert has_pipelines(state) == "yes"
# ---------------------------------------------------------------------------
# has_pending_approvals
# ---------------------------------------------------------------------------
class TestHasPendingApprovals:
def test_returns_yes_when_non_empty_list(self) -> None:
state = {"pending_approvals": [{"approval_id": "abc"}]}
assert has_pending_approvals(state) == "yes"
def test_returns_no_when_empty_list(self) -> None:
state = {"pending_approvals": []}
assert has_pending_approvals(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert has_pending_approvals(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"pending_approvals": None}
assert has_pending_approvals(state) == "no"
def test_returns_yes_with_multiple_approvals(self) -> None:
state = {"pending_approvals": [{"approval_id": "a"}, {"approval_id": "b"}]}
assert has_pending_approvals(state) == "yes"
# ---------------------------------------------------------------------------
# route_ci_result
# ---------------------------------------------------------------------------
class TestRouteCiResult:
"""Tests for route_ci_result routing function."""
def test_returns_ci_passed_when_succeeded(self) -> None:
state = {"ci_build_result": "succeeded"}
assert route_ci_result(state) == "ci_passed"
def test_returns_ci_failed_when_failed(self) -> None:
state = {"ci_build_result": "failed"}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_canceled(self) -> None:
state = {"ci_build_result": "canceled"}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_partially_succeeded(self) -> None:
state = {"ci_build_result": "partiallySucceeded"}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_field_missing(self) -> None:
state = {}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_none(self) -> None:
state = {"ci_build_result": None}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_empty_string(self) -> None:
state = {"ci_build_result": ""}
assert route_ci_result(state) == "ci_failed"
def test_case_sensitive_succeeded(self) -> None:
# AzDo returns "succeeded" (lowercase)
state = {"ci_build_result": "succeeded"}
assert route_ci_result(state) == "ci_passed"
# ---------------------------------------------------------------------------
# route_approval_stage
# ---------------------------------------------------------------------------
class TestRouteApprovalStage:
"""Tests for route_approval_stage routing function."""
def test_returns_all_deployed_when_no_pending_approvals(self) -> None:
state = {"pending_approvals": []}
assert route_approval_stage(state) == "all_deployed"
def test_returns_all_deployed_when_field_missing(self) -> None:
state = {}
assert route_approval_stage(state) == "all_deployed"
def test_returns_all_deployed_when_none(self) -> None:
state = {"pending_approvals": None}
assert route_approval_stage(state) == "all_deployed"
def test_returns_sandbox_pending_when_sandbox_approval_exists(self) -> None:
state = {
"current_stage": "sandbox_pending",
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
}
assert route_approval_stage(state) == "sandbox_pending"
def test_returns_prod_pending_when_prod_approval_exists(self) -> None:
state = {
"current_stage": "prod_pending",
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
}
assert route_approval_stage(state) == "prod_pending"
def test_uses_current_stage_field_when_present(self) -> None:
state = {
"current_stage": "sandbox_pending",
"pending_approvals": [{"approval_id": "z"}],
}
assert route_approval_stage(state) == "sandbox_pending"
def test_returns_all_deployed_when_no_current_stage_and_has_approvals(self) -> None:
# When current_stage is missing but approvals exist, stage is unknown
# so we treat as sandbox by default (first stage)
state = {
"pending_approvals": [{"approval_id": "a"}],
}
# Must return either sandbox_pending or prod_pending (not all_deployed)
result = route_approval_stage(state)
assert result in ("sandbox_pending", "prod_pending")
def test_sandbox_pending_from_current_stage(self) -> None:
state = {"current_stage": "sandbox_pending", "pending_approvals": [{"approval_id": "x"}]}
assert route_approval_stage(state) == "sandbox_pending"
def test_prod_pending_from_current_stage(self) -> None:
state = {"current_stage": "prod_pending", "pending_approvals": [{"approval_id": "x"}]}
assert route_approval_stage(state) == "prod_pending"
# ---------------------------------------------------------------------------
# route_after_fetch
# ---------------------------------------------------------------------------
class TestRouteAfterFetch:
"""Tests for route_after_fetch — 3-way routing replacing is_pr_already_merged."""
def test_returns_merged_when_pr_already_merged(self) -> None:
state = {"pr_already_merged": True}
assert route_after_fetch(state) == "merged"
def test_returns_active_with_ticket_when_active_and_has_ticket(self) -> None:
state = {"pr_already_merged": False, "has_ticket": True}
assert route_after_fetch(state) == "active_with_ticket"
def test_returns_active_no_ticket_when_active_and_no_ticket(self) -> None:
state = {"pr_already_merged": False, "has_ticket": False}
assert route_after_fetch(state) == "active_no_ticket"
def test_returns_active_no_ticket_when_has_ticket_missing(self) -> None:
state = {"pr_already_merged": False}
assert route_after_fetch(state) == "active_no_ticket"
def test_returns_active_no_ticket_when_has_ticket_none(self) -> None:
state = {"pr_already_merged": False, "has_ticket": None}
assert route_after_fetch(state) == "active_no_ticket"
def test_returns_active_no_ticket_when_all_fields_missing(self) -> None:
state = {}
assert route_after_fetch(state) == "active_no_ticket"
def test_merged_takes_precedence_over_has_ticket(self) -> None:
# Even if has_ticket is True, merged PR should route to "merged"
state = {"pr_already_merged": True, "has_ticket": True}
assert route_after_fetch(state) == "merged"
def test_returns_active_with_ticket_ignores_merged_false(self) -> None:
state = {"pr_already_merged": False, "has_ticket": True}
result = route_after_fetch(state)
assert result != "merged"
assert result == "active_with_ticket"