feat: initial commit — Billo Release Agent (LangGraph)

LangGraph-based release automation agent with:
- PR discovery (webhook + polling)
- AI code review via Claude Code CLI (subscription-based)
- Auto-create Jira tickets for PRs without ticket ID
- Jira ticket lifecycle management (code review -> staging -> done)
- CI/CD pipeline trigger, polling, and approval gates
- Slack interactive messages with approval buttons
- Per-repo semantic versioning
- PostgreSQL persistence (threads, staging, releases)
- FastAPI API (webhooks, approvals, status, manual triggers)
- Docker Compose deployment

1069 tests, 95%+ coverage.
This commit is contained in:
Yaojia Wang
2026-03-24 17:38:23 +01:00
commit f5c2733cfb
104 changed files with 19721 additions and 0 deletions

0
tests/__init__.py Normal file
View File

0
tests/api/__init__.py Normal file
View File

259
tests/api/test_approvals.py Normal file
View File

@@ -0,0 +1,259 @@
"""Tests for approvals endpoint. Written FIRST (TDD RED phase)."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.approvals import router as approvals_router
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_app(
*,
interrupted_threads: list[dict] | None = None,
graph_resume_result: dict | None = None,
) -> FastAPI:
"""Return a FastAPI app with mocked state for approvals tests."""
app = FastAPI()
app.include_router(approvals_router)
if interrupted_threads is None:
interrupted_threads = []
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = ""
mock_graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
mock_clients = MagicMock()
# Mock pool that returns interrupted threads from DB
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
rows = [
(
t["thread_id"],
t.get("graph_name", "pr_completed"),
t.get("interrupt_value", "Confirm?"),
t.get("created_at", datetime.now(tz=timezone.utc)),
t.get("repo_name"),
t.get("pr_id"),
t.get("version"),
)
for t in interrupted_threads
]
mock_cursor.fetchall = AsyncMock(return_value=rows)
mock_cursor.fetchone = AsyncMock(return_value=("pr_completed",))
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
app.state.settings = mock_settings
app.state.graphs = mock_graphs
app.state.tool_clients = mock_clients
app.state.db_pool = mock_pool
app.state.background_tasks = set()
return app
# ---------------------------------------------------------------------------
# POST /approvals/{thread_id}
# ---------------------------------------------------------------------------
class TestPostApproval:
def test_valid_merge_decision_returns_200(self) -> None:
app = _make_test_app()
mock_graph = MagicMock()
mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]})
app.state.graphs["pr_completed"] = mock_graph
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {"messages": ["resumed"]}
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
)
assert response.status_code == 200
data = response.json()
assert data["thread_id"] == "thread-123"
assert "status" in data
assert "message" in data
def test_valid_cancel_decision_returns_200(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {"messages": ["cancelled"]}
with TestClient(app) as client:
response = client.post(
"/approvals/thread-456",
json={"decision": "cancel"},
)
assert response.status_code == 200
def test_invalid_decision_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "invalid_decision"},
)
assert response.status_code == 422
def test_missing_decision_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={},
)
assert response.status_code == 422
def test_response_contains_thread_id(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/my-thread-id",
json={"decision": "approve"},
)
assert response.json()["thread_id"] == "my-thread-id"
def test_approve_decision_returns_200(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/t1",
json={"decision": "approve"},
)
assert response.status_code == 200
def test_skip_decision_returns_200(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/t1",
json={"decision": "skip"},
)
assert response.status_code == 200
def test_trigger_decision_returns_200(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/t1",
json={"decision": "trigger"},
)
assert response.status_code == 200
# ---------------------------------------------------------------------------
# GET /approvals/pending
# ---------------------------------------------------------------------------
class TestGetPendingApprovals:
def test_empty_pending_returns_200(self) -> None:
app = _make_test_app(interrupted_threads=[])
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 200
data = response.json()
assert data["count"] == 0
assert data["items"] == []
def test_pending_approvals_list_structure(self) -> None:
now = datetime.now(tz=timezone.utc)
threads = [
{
"thread_id": "t1",
"graph_name": "pr_completed",
"interrupt_value": "Confirm merge?",
"created_at": now,
"repo_name": "my-repo",
"pr_id": "42",
"version": "v1.0.0",
}
]
app = _make_test_app(interrupted_threads=threads)
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 200
data = response.json()
assert data["count"] == 1
assert data["items"][0]["thread_id"] == "t1"
assert data["items"][0]["graph_name"] == "pr_completed"
def test_multiple_pending_approvals(self) -> None:
now = datetime.now(tz=timezone.utc)
threads = [
{
"thread_id": f"t{i}",
"graph_name": "pr_completed",
"interrupt_value": "Confirm?",
"created_at": now,
"repo_name": None,
"pr_id": None,
"version": None,
}
for i in range(3)
]
app = _make_test_app(interrupted_threads=threads)
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 200
data = response.json()
assert data["count"] == 3
assert len(data["items"]) == 3
def test_pending_approval_optional_fields_nullable(self) -> None:
now = datetime.now(tz=timezone.utc)
threads = [
{
"thread_id": "t1",
"graph_name": "release",
"interrupt_value": "Run release?",
"created_at": now,
"repo_name": None,
"pr_id": None,
"version": None,
}
]
app = _make_test_app(interrupted_threads=threads)
with TestClient(app) as client:
response = client.get("/approvals/pending")
item = response.json()["items"][0]
assert item["repo_name"] is None
assert item["pr_id"] is None
assert item["version"] is None
# ---------------------------------------------------------------------------
# _resume_graph helper function tests
# ---------------------------------------------------------------------------
class TestResumeGraph:
def test_resume_graph_callable(self) -> None:
from release_agent.api.approvals import _resume_graph
import inspect
assert inspect.iscoroutinefunction(_resume_graph)

View File

@@ -0,0 +1,139 @@
"""Tests for approvals endpoints with operator token authentication.
Phase 5 - Step 3: Verifies that POST /approvals/{thread_id} and
GET /approvals/pending require operator token when configured.
Written FIRST (TDD RED phase).
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.approvals import router as approvals_router
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_app(operator_token: str = "") -> FastAPI:
"""Return a FastAPI app with approvals router and configurable operator token."""
app = FastAPI()
app.include_router(approvals_router)
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = operator_token
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchall = AsyncMock(return_value=[])
mock_cursor.fetchone = AsyncMock(return_value=("pr_completed",))
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
app.state.settings = mock_settings
app.state.graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
app.state.tool_clients = MagicMock()
app.state.db_pool = mock_pool
app.state.background_tasks = set()
return app
# ---------------------------------------------------------------------------
# POST /approvals/{thread_id} with auth
# ---------------------------------------------------------------------------
class TestPostApprovalWithAuth:
def test_valid_token_allows_post_approval(self) -> None:
app = _make_test_app(operator_token="secret-token")
with patch(
"release_agent.api.approvals._resume_graph", new_callable=AsyncMock
) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
headers={"X-Operator-Token": "secret-token"},
)
assert response.status_code == 200
def test_missing_token_rejects_post_approval(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
)
assert response.status_code == 401
def test_wrong_token_rejects_post_approval(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
headers={"X-Operator-Token": "wrong-token"},
)
assert response.status_code == 401
def test_no_auth_required_when_token_not_configured(self) -> None:
app = _make_test_app(operator_token="")
with patch(
"release_agent.api.approvals._resume_graph", new_callable=AsyncMock
) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
)
assert response.status_code == 200
# ---------------------------------------------------------------------------
# GET /approvals/pending with auth
# ---------------------------------------------------------------------------
class TestGetPendingApprovalsWithAuth:
def test_valid_token_allows_get_pending(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.get(
"/approvals/pending",
headers={"X-Operator-Token": "secret-token"},
)
assert response.status_code == 200
def test_missing_token_rejects_get_pending(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 401
def test_wrong_token_rejects_get_pending(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.get(
"/approvals/pending",
headers={"X-Operator-Token": "wrong"},
)
assert response.status_code == 401
def test_no_auth_required_when_token_not_configured(self) -> None:
app = _make_test_app(operator_token="")
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 200

View File

@@ -0,0 +1,149 @@
"""Tests for API FastAPI dependencies. Written FIRST (TDD RED phase)."""
from unittest.mock import MagicMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.dependencies import (
get_db_pool,
get_graphs,
get_settings,
get_staging_store,
get_tool_clients,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_app_with_state(**state_kwargs) -> FastAPI:
"""Return a minimal FastAPI app with app.state attributes set."""
app = FastAPI()
for key, value in state_kwargs.items():
setattr(app.state, key, value)
return app
# ---------------------------------------------------------------------------
# get_settings
# ---------------------------------------------------------------------------
class TestGetSettings:
def test_returns_settings_from_state(self) -> None:
mock_settings = MagicMock()
app = _make_app_with_state(settings=mock_settings)
with TestClient(app) as client:
# We test the dependency directly by simulating a request
request = MagicMock()
request.app = app
result = get_settings(request)
assert result is mock_settings
def test_raises_when_settings_missing(self) -> None:
app = FastAPI() # no state.settings
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_settings(request)
# ---------------------------------------------------------------------------
# get_graphs
# ---------------------------------------------------------------------------
class TestGetGraphs:
def test_returns_graphs_from_state(self) -> None:
mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()}
app = _make_app_with_state(graphs=mock_graphs)
request = MagicMock()
request.app = app
result = get_graphs(request)
assert result is mock_graphs
def test_raises_when_graphs_missing(self) -> None:
app = FastAPI()
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_graphs(request)
# ---------------------------------------------------------------------------
# get_tool_clients
# ---------------------------------------------------------------------------
class TestGetToolClients:
def test_returns_tool_clients_from_state(self) -> None:
mock_clients = MagicMock()
app = _make_app_with_state(tool_clients=mock_clients)
request = MagicMock()
request.app = app
result = get_tool_clients(request)
assert result is mock_clients
def test_raises_when_tool_clients_missing(self) -> None:
app = FastAPI()
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_tool_clients(request)
# ---------------------------------------------------------------------------
# get_staging_store
# ---------------------------------------------------------------------------
class TestGetStagingStore:
def test_returns_staging_store_from_state(self) -> None:
mock_store = MagicMock()
app = _make_app_with_state(staging_store=mock_store)
request = MagicMock()
request.app = app
result = get_staging_store(request)
assert result is mock_store
def test_raises_when_staging_store_missing(self) -> None:
app = FastAPI()
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_staging_store(request)
# ---------------------------------------------------------------------------
# get_db_pool
# ---------------------------------------------------------------------------
class TestGetDbPool:
def test_returns_db_pool_from_state(self) -> None:
mock_pool = MagicMock()
app = _make_app_with_state(db_pool=mock_pool)
request = MagicMock()
request.app = app
result = get_db_pool(request)
assert result is mock_pool
def test_raises_when_db_pool_missing(self) -> None:
app = FastAPI()
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_db_pool(request)

446
tests/api/test_internals.py Normal file
View File

@@ -0,0 +1,446 @@
"""Tests for internal async helper functions.
Tests _run_graph, _upsert_thread, _resume_graph, and exception handlers.
Written FIRST then verified (TDD GREEN phase for internal helpers).
"""
import json
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
# ---------------------------------------------------------------------------
# _upsert_thread tests
# ---------------------------------------------------------------------------
class TestUpsertThread:
@pytest.mark.asyncio
async def test_upsert_thread_executes_sql(self) -> None:
from release_agent.api.webhooks import _upsert_thread
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _upsert_thread(
mock_pool,
thread_id="t1",
thread_status="running",
state={"repo_name": "my-repo"},
)
mock_cursor.execute.assert_called_once()
args = mock_cursor.execute.call_args[0]
assert "agent_threads" in args[0]
assert args[1][0] == "t1"
assert args[1][4] == "running"
# state is JSON-encoded
state_json = json.loads(args[1][5])
assert state_json["repo_name"] == "my-repo"
@pytest.mark.asyncio
async def test_upsert_thread_completed_status(self) -> None:
from release_agent.api.webhooks import _upsert_thread
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _upsert_thread(
mock_pool,
thread_id="t2",
thread_status="completed",
state={},
)
args = mock_cursor.execute.call_args[0]
assert args[1][4] == "completed"
@pytest.mark.asyncio
async def test_upsert_thread_failed_status(self) -> None:
from release_agent.api.webhooks import _upsert_thread
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _upsert_thread(
mock_pool,
thread_id="t3",
thread_status="failed",
state={"errors": ["something went wrong"]},
)
args = mock_cursor.execute.call_args[0]
assert args[1][4] == "failed"
state_json = json.loads(args[1][5])
assert state_json["errors"] == ["something went wrong"]
# ---------------------------------------------------------------------------
# _run_graph tests
# ---------------------------------------------------------------------------
class TestRunGraph:
@pytest.mark.asyncio
async def test_run_graph_success_upserts_completed(self) -> None:
from release_agent.api.webhooks import _run_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]})
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"repo_name": "test"},
thread_id="t1",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
# Should have been called with "running" then "completed"
calls = mock_cursor.execute.call_args_list
assert len(calls) == 2
# First call: "running", second call: "completed"
assert calls[0][0][1][4] == "running"
assert calls[1][0][1][4] == "completed"
@pytest.mark.asyncio
async def test_run_graph_failure_upserts_failed(self) -> None:
from release_agent.api.webhooks import _run_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("graph crashed"))
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"repo_name": "test"},
thread_id="t-fail",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
calls = mock_cursor.execute.call_args_list
# First call: "running", second call: "failed"
assert calls[0][0][1][4] == "running"
assert calls[1][0][1][4] == "failed"
# State should contain errors
failed_state = json.loads(calls[1][0][1][5])
assert "errors" in failed_state
@pytest.mark.asyncio
async def test_run_graph_invokes_with_correct_config(self) -> None:
from release_agent.api.webhooks import _run_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={})
mock_clients = MagicMock()
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"repo_name": "test"},
thread_id="t-config",
tool_clients=mock_clients,
db_pool=mock_pool,
)
call_args = mock_graph.ainvoke.call_args
config = call_args[1]["config"]
assert config["configurable"]["thread_id"] == "t-config"
assert config["configurable"]["clients"] is mock_clients
# ---------------------------------------------------------------------------
# _resume_graph tests
# ---------------------------------------------------------------------------
class TestResumeGraphInternal:
@pytest.mark.asyncio
async def test_resume_graph_success(self) -> None:
from release_agent.api.approvals import _resume_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={"result": "ok"})
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
result = await _resume_graph(
graph=mock_graph,
thread_id="t1",
decision="merge",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
assert result == {"result": "ok"}
mock_graph.ainvoke.assert_called_once()
# Verify the decision was passed
call_args = mock_graph.ainvoke.call_args
assert call_args[0][0]["decision"] == "merge"
@pytest.mark.asyncio
async def test_resume_graph_failure_re_raises(self) -> None:
from release_agent.api.approvals import _resume_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("resume failed"))
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
with pytest.raises(RuntimeError, match="resume failed"):
await _resume_graph(
graph=mock_graph,
thread_id="t1",
decision="cancel",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
@pytest.mark.asyncio
async def test_resume_graph_upserts_completed_on_success(self) -> None:
from release_agent.api.approvals import _resume_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]})
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _resume_graph(
graph=mock_graph,
thread_id="t-success",
decision="approve",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
# The last execute call should be "completed"
last_call = mock_cursor.execute.call_args_list[-1]
assert last_call[0][1][4] == "completed"
@pytest.mark.asyncio
async def test_resume_graph_upserts_failed_on_exception(self) -> None:
from release_agent.api.approvals import _resume_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=ValueError("bad"))
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
with pytest.raises(ValueError):
await _resume_graph(
graph=mock_graph,
thread_id="t-fail",
decision="skip",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
last_call = mock_cursor.execute.call_args_list[-1]
assert last_call[0][1][4] == "failed"
# ---------------------------------------------------------------------------
# run_graph_in_background tests (main.py)
# ---------------------------------------------------------------------------
class TestRunGraphInBackground:
@pytest.mark.asyncio
async def test_success_with_db_pool(self) -> None:
from release_agent.main import run_graph_in_background
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={"done": True})
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await run_graph_in_background(
graph=mock_graph,
initial_state={"repo_name": "test"},
thread_id="t-bg",
db_pool=mock_pool,
)
calls = mock_cursor.execute.call_args_list
assert calls[0][0][1][4] == "running"
assert calls[1][0][1][4] == "completed"
@pytest.mark.asyncio
async def test_failure_with_db_pool(self) -> None:
from release_agent.main import run_graph_in_background
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("bg failed"))
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await run_graph_in_background(
graph=mock_graph,
initial_state={},
thread_id="t-bg-fail",
db_pool=mock_pool,
)
last_call = mock_cursor.execute.call_args_list[-1]
assert last_call[0][1][4] == "failed"
@pytest.mark.asyncio
async def test_success_without_db_pool(self) -> None:
"""run_graph_in_background works even without a db_pool."""
from release_agent.main import run_graph_in_background
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={})
# Should not raise even with no db_pool
await run_graph_in_background(
graph=mock_graph,
initial_state={},
thread_id="t-no-pool",
db_pool=None,
)
mock_graph.ainvoke.assert_called_once()
# ---------------------------------------------------------------------------
# Exception handler tests (main.py)
# ---------------------------------------------------------------------------
class TestExceptionHandlerFunctions:
@pytest.mark.asyncio
async def test_release_agent_error_handler_returns_500(self) -> None:
from release_agent.main import _release_agent_error_handler
from release_agent.exceptions import ServiceError
request = MagicMock()
exc = ServiceError(service="azdo", status_code=503, detail="unavailable")
response = await _release_agent_error_handler(request, exc)
assert response.status_code == 500
body = json.loads(response.body)
assert body["error"] == "ServiceError"
assert "unavailable" in body["detail"]
@pytest.mark.asyncio
async def test_generic_error_handler_returns_500(self) -> None:
from release_agent.main import _generic_error_handler
request = MagicMock()
exc = ValueError("something generic")
response = await _generic_error_handler(request, exc)
assert response.status_code == 500
body = json.loads(response.body)
assert body["error"] == "InternalServerError"
assert "An unexpected error occurred" in body["detail"]

294
tests/api/test_models.py Normal file
View File

@@ -0,0 +1,294 @@
"""Tests for API request/response models. Written FIRST (TDD RED phase)."""
from datetime import datetime, timezone
import pytest
from pydantic import ValidationError
from release_agent.api.models import (
ApprovalDecision,
ApprovalResponse,
ErrorResponse,
HealthResponse,
ManualReleaseRequest,
ManualTriggerResponse,
PendingApproval,
PendingApprovalsResponse,
ReleaseVersionListResponse,
StagingResponse,
WebhookResponse,
)
# ---------------------------------------------------------------------------
# WebhookResponse
# ---------------------------------------------------------------------------
class TestWebhookResponse:
def test_valid_construction(self) -> None:
resp = WebhookResponse(thread_id="thread-123", message="scheduled")
assert resp.thread_id == "thread-123"
assert resp.message == "scheduled"
def test_frozen_immutable(self) -> None:
resp = WebhookResponse(thread_id="t1", message="ok")
with pytest.raises((TypeError, ValidationError)):
resp.thread_id = "other" # type: ignore[misc]
def test_missing_thread_id_raises(self) -> None:
with pytest.raises(ValidationError):
WebhookResponse(message="ok") # type: ignore[call-arg]
def test_missing_message_raises(self) -> None:
with pytest.raises(ValidationError):
WebhookResponse(thread_id="t1") # type: ignore[call-arg]
# ---------------------------------------------------------------------------
# ApprovalDecision
# ---------------------------------------------------------------------------
class TestApprovalDecision:
def test_merge_decision(self) -> None:
d = ApprovalDecision(decision="merge")
assert d.decision == "merge"
def test_cancel_decision(self) -> None:
d = ApprovalDecision(decision="cancel")
assert d.decision == "cancel"
def test_approve_decision(self) -> None:
d = ApprovalDecision(decision="approve")
assert d.decision == "approve"
def test_skip_decision(self) -> None:
d = ApprovalDecision(decision="skip")
assert d.decision == "skip"
def test_trigger_decision(self) -> None:
d = ApprovalDecision(decision="trigger")
assert d.decision == "trigger"
def test_invalid_decision_raises(self) -> None:
with pytest.raises(ValidationError):
ApprovalDecision(decision="invalid") # type: ignore[arg-type]
def test_frozen_immutable(self) -> None:
d = ApprovalDecision(decision="merge")
with pytest.raises((TypeError, ValidationError)):
d.decision = "cancel" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ApprovalResponse
# ---------------------------------------------------------------------------
class TestApprovalResponse:
def test_valid_construction(self) -> None:
resp = ApprovalResponse(
thread_id="t1", status="resumed", message="Graph resumed"
)
assert resp.thread_id == "t1"
assert resp.status == "resumed"
assert resp.message == "Graph resumed"
def test_frozen_immutable(self) -> None:
resp = ApprovalResponse(thread_id="t1", status="ok", message="m")
with pytest.raises((TypeError, ValidationError)):
resp.status = "bad" # type: ignore[misc]
# ---------------------------------------------------------------------------
# PendingApproval
# ---------------------------------------------------------------------------
class TestPendingApproval:
def test_full_construction(self) -> None:
now = datetime.now(tz=timezone.utc)
pa = PendingApproval(
thread_id="t1",
graph_name="pr_completed",
interrupt_value="Confirm merge?",
created_at=now,
repo_name="my-repo",
pr_id="42",
version="v1.2.3",
)
assert pa.thread_id == "t1"
assert pa.graph_name == "pr_completed"
assert pa.repo_name == "my-repo"
assert pa.pr_id == "42"
assert pa.version == "v1.2.3"
def test_optional_fields_none(self) -> None:
now = datetime.now(tz=timezone.utc)
pa = PendingApproval(
thread_id="t1",
graph_name="release",
interrupt_value="Confirm?",
created_at=now,
)
assert pa.repo_name is None
assert pa.pr_id is None
assert pa.version is None
def test_frozen_immutable(self) -> None:
now = datetime.now(tz=timezone.utc)
pa = PendingApproval(
thread_id="t1",
graph_name="g",
interrupt_value="v",
created_at=now,
)
with pytest.raises((TypeError, ValidationError)):
pa.thread_id = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# PendingApprovalsResponse
# ---------------------------------------------------------------------------
class TestPendingApprovalsResponse:
def test_empty_list(self) -> None:
resp = PendingApprovalsResponse(items=[], count=0)
assert resp.items == []
assert resp.count == 0
def test_with_items(self) -> None:
now = datetime.now(tz=timezone.utc)
item = PendingApproval(
thread_id="t1",
graph_name="g",
interrupt_value="v",
created_at=now,
)
resp = PendingApprovalsResponse(items=[item], count=1)
assert resp.count == 1
assert len(resp.items) == 1
def test_frozen_immutable(self) -> None:
resp = PendingApprovalsResponse(items=[], count=0)
with pytest.raises((TypeError, ValidationError)):
resp.count = 5 # type: ignore[misc]
# ---------------------------------------------------------------------------
# HealthResponse
# ---------------------------------------------------------------------------
class TestHealthResponse:
def test_ok_status(self) -> None:
resp = HealthResponse(status="ok", version="0.1.0", uptime_seconds=123.4)
assert resp.status == "ok"
assert resp.version == "0.1.0"
assert resp.uptime_seconds == pytest.approx(123.4)
def test_degraded_status(self) -> None:
resp = HealthResponse(status="degraded", version="0.1.0", uptime_seconds=0.0)
assert resp.status == "degraded"
def test_invalid_status_raises(self) -> None:
with pytest.raises(ValidationError):
HealthResponse(status="unknown", version="0.1.0", uptime_seconds=0.0) # type: ignore[arg-type]
def test_frozen_immutable(self) -> None:
resp = HealthResponse(status="ok", version="0.1.0", uptime_seconds=1.0)
with pytest.raises((TypeError, ValidationError)):
resp.status = "degraded" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ReleaseVersionListResponse
# ---------------------------------------------------------------------------
class TestReleaseVersionListResponse:
def test_valid_construction(self) -> None:
resp = ReleaseVersionListResponse(repo="my-repo", versions=["v1.0.0", "v1.1.0"])
assert resp.repo == "my-repo"
assert resp.versions == ["v1.0.0", "v1.1.0"]
def test_empty_versions(self) -> None:
resp = ReleaseVersionListResponse(repo="r", versions=[])
assert resp.versions == []
def test_frozen_immutable(self) -> None:
resp = ReleaseVersionListResponse(repo="r", versions=[])
with pytest.raises((TypeError, ValidationError)):
resp.repo = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# StagingResponse
# ---------------------------------------------------------------------------
class TestStagingResponse:
def test_with_staging(self) -> None:
staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []}
resp = StagingResponse(repo="my-repo", staging=staging_data)
assert resp.repo == "my-repo"
assert resp.staging is not None
assert resp.staging["version"] == "v1.0.0"
def test_without_staging(self) -> None:
resp = StagingResponse(repo="my-repo", staging=None)
assert resp.staging is None
def test_frozen_immutable(self) -> None:
resp = StagingResponse(repo="r", staging=None)
with pytest.raises((TypeError, ValidationError)):
resp.repo = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ManualTriggerResponse
# ---------------------------------------------------------------------------
class TestManualTriggerResponse:
def test_valid_construction(self) -> None:
resp = ManualTriggerResponse(thread_id="t1", message="triggered")
assert resp.thread_id == "t1"
assert resp.message == "triggered"
def test_frozen_immutable(self) -> None:
resp = ManualTriggerResponse(thread_id="t1", message="m")
with pytest.raises((TypeError, ValidationError)):
resp.thread_id = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ManualReleaseRequest
# ---------------------------------------------------------------------------
class TestManualReleaseRequest:
def test_valid_construction(self) -> None:
req = ManualReleaseRequest(repo="my-repo")
assert req.repo == "my-repo"
def test_missing_repo_raises(self) -> None:
with pytest.raises(ValidationError):
ManualReleaseRequest() # type: ignore[call-arg]
def test_frozen_immutable(self) -> None:
req = ManualReleaseRequest(repo="r")
with pytest.raises((TypeError, ValidationError)):
req.repo = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ErrorResponse
# ---------------------------------------------------------------------------
class TestErrorResponse:
def test_error_only(self) -> None:
resp = ErrorResponse(error="Something went wrong")
assert resp.error == "Something went wrong"
assert resp.detail is None
def test_error_with_detail(self) -> None:
resp = ErrorResponse(error="Not found", detail="Thread t1 not found")
assert resp.detail == "Thread t1 not found"
def test_frozen_immutable(self) -> None:
resp = ErrorResponse(error="e")
with pytest.raises((TypeError, ValidationError)):
resp.error = "other" # type: ignore[misc]

View File

@@ -0,0 +1,111 @@
"""Tests for operator token authentication dependency.
Phase 5 - Step 3: require_operator_token FastAPI dependency.
Written FIRST (TDD RED phase).
"""
from unittest.mock import MagicMock
import pytest
from fastapi import FastAPI, Depends, HTTPException
from fastapi.testclient import TestClient
from release_agent.api.dependencies import require_operator_token
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_app_with_token(operator_token: str = "") -> FastAPI:
"""Return a minimal app with a protected route and the given token config."""
app = FastAPI()
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = operator_token
app.state.settings = mock_settings
@app.get("/protected")
async def protected_route(_: None = Depends(require_operator_token)):
return {"ok": True}
return app
# ---------------------------------------------------------------------------
# require_operator_token tests
# ---------------------------------------------------------------------------
class TestRequireOperatorToken:
def test_valid_token_allows_access(self) -> None:
app = _make_app_with_token("super-secret-token")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": "super-secret-token"},
)
assert response.status_code == 200
def test_missing_token_header_returns_401_when_token_configured(self) -> None:
app = _make_app_with_token("super-secret-token")
with TestClient(app) as client:
response = client.get("/protected")
assert response.status_code == 401
def test_wrong_token_returns_401(self) -> None:
app = _make_app_with_token("super-secret-token")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": "wrong-token"},
)
assert response.status_code == 401
def test_empty_operator_token_config_skips_auth(self) -> None:
"""When operator_token is not configured (empty), all requests pass."""
app = _make_app_with_token("")
with TestClient(app) as client:
response = client.get("/protected")
assert response.status_code == 200
def test_empty_operator_token_config_passes_even_without_header(self) -> None:
app = _make_app_with_token("")
with TestClient(app) as client:
response = client.get("/protected", headers={})
assert response.status_code == 200
def test_token_comparison_is_constant_time(self) -> None:
"""Verify hmac.compare_digest is used (not == operator) — tested by checking
that the function still works correctly, not timing (which we can't test here)."""
app = _make_app_with_token("my-token")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": "my-token"},
)
assert response.status_code == 200
def test_empty_string_token_header_rejected_when_token_configured(self) -> None:
app = _make_app_with_token("configured-token")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": ""},
)
assert response.status_code == 401
def test_401_response_has_detail_field(self) -> None:
app = _make_app_with_token("secret")
with TestClient(app) as client:
response = client.get("/protected")
data = response.json()
assert "detail" in data
def test_valid_token_returns_correct_response_body(self) -> None:
app = _make_app_with_token("token123")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": "token123"},
)
assert response.json() == {"ok": True}

View File

@@ -0,0 +1,473 @@
"""Tests for api/slack_interactions.py endpoint.
Written FIRST (TDD RED phase).
Tests cover:
- Signature verification (HMAC-SHA256)
- Payload parsing
- Button routing
- _resume_graph invocation
- Error handling
"""
import hashlib
import hmac
import json
import time
import urllib.parse
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.slack_interactions import router as slack_interactions_router
from release_agent.api.slack_interactions import _verify_slack_signature
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_TEST_SIGNING_SECRET = "test-signing-secret-abc"
def _make_slack_signature(*, signing_secret: str, timestamp: str, body: str) -> str:
"""Compute a valid Slack signing signature."""
base_string = f"v0:{timestamp}:{body}"
sig = hmac.new(
signing_secret.encode(),
base_string.encode(),
hashlib.sha256,
).hexdigest()
return f"v0={sig}"
def _make_test_app(
*,
signing_secret: str = _TEST_SIGNING_SECRET,
thread_graph_name: str | None = "release",
graph_result: dict | None = None,
) -> FastAPI:
"""Return a FastAPI test app with mocked state for slack interactions."""
app = FastAPI()
app.include_router(slack_interactions_router)
mock_settings = MagicMock()
mock_settings.slack_signing_secret.get_secret_value.return_value = signing_secret
mock_settings.operator_token.get_secret_value.return_value = ""
mock_graph = MagicMock()
mock_graph.ainvoke = AsyncMock(return_value=graph_result or {"messages": ["done"]})
mock_graphs = {
"release": mock_graph,
"pr_completed": MagicMock(),
}
mock_clients = MagicMock()
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(
return_value=(thread_graph_name,) if thread_graph_name else None
)
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
app.state.settings = mock_settings
app.state.graphs = mock_graphs
app.state.tool_clients = mock_clients
app.state.db_pool = mock_pool
app.state.background_tasks = set()
return app
def _make_button_payload(
*,
thread_id: str = "test-thread-123",
value: str = "approve",
user_id: str = "U12345",
user_name: str = "alice",
) -> str:
"""Build a URL-encoded Slack button action payload."""
payload = {
"type": "block_actions",
"user": {"id": user_id, "name": user_name},
"actions": [
{
"type": "button",
"value": f"{thread_id}:{value}",
"action_id": f"approval_{value}_{thread_id}",
}
],
}
return urllib.parse.urlencode({"payload": json.dumps(payload)})
# ---------------------------------------------------------------------------
# _verify_slack_signature pure function tests
# ---------------------------------------------------------------------------
class TestVerifySlackSignature:
"""Tests for the _verify_slack_signature pure function."""
def test_returns_true_for_valid_signature(self) -> None:
timestamp = str(int(time.time()))
body = "test=body&data=here"
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
signature=sig,
) is True
def test_returns_false_for_wrong_secret(self) -> None:
timestamp = str(int(time.time()))
body = "test=body"
sig = _make_slack_signature(
signing_secret="wrong-secret",
timestamp=timestamp,
body=body,
)
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
signature=sig,
) is False
def test_returns_false_for_tampered_body(self) -> None:
timestamp = str(int(time.time()))
original_body = "original=body"
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=original_body,
)
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body="tampered=body",
signature=sig,
) is False
def test_returns_false_for_wrong_timestamp(self) -> None:
body = "test=body"
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp="1000000",
body=body,
)
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp="9999999",
body=body,
signature=sig,
) is False
def test_returns_false_for_malformed_signature(self) -> None:
timestamp = str(int(time.time()))
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body="body",
signature="not-a-valid-sig",
) is False
def test_returns_false_for_empty_signature(self) -> None:
timestamp = str(int(time.time()))
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body="body",
signature="",
) is False
def test_uses_hmac_sha256(self) -> None:
timestamp = "1234567890"
body = "payload=data"
base = f"v0:{timestamp}:{body}"
expected_hash = hmac.new(
_TEST_SIGNING_SECRET.encode(),
base.encode(),
hashlib.sha256,
).hexdigest()
sig = f"v0={expected_hash}"
# Inject current_time matching timestamp to bypass replay prevention
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
signature=sig,
current_time=1234567890.0,
) is True
def test_rejects_stale_timestamp(self) -> None:
old_timestamp = "1000000000" # year 2001
body = "payload=data"
base = f"v0:{old_timestamp}:{body}"
expected_hash = hmac.new(
_TEST_SIGNING_SECRET.encode(),
base.encode(),
hashlib.sha256,
).hexdigest()
sig = f"v0={expected_hash}"
# Valid signature but timestamp too old
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=old_timestamp,
body=body,
signature=sig,
) is False
def test_rejects_non_integer_timestamp(self) -> None:
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp="not-a-number",
body="body",
signature="v0=abc",
) is False
def test_signature_prefix_must_be_v0(self) -> None:
timestamp = "1234567890"
body = "payload=data"
base = f"v0:{timestamp}:{body}"
hash_val = hmac.new(
_TEST_SIGNING_SECRET.encode(),
base.encode(),
hashlib.sha256,
).hexdigest()
wrong_prefix_sig = f"v1={hash_val}"
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
signature=wrong_prefix_sig,
) is False
# ---------------------------------------------------------------------------
# POST /slack/interactions endpoint tests
# ---------------------------------------------------------------------------
class TestSlackInteractionsEndpoint:
"""Tests for POST /slack/interactions."""
def test_returns_200_for_valid_request(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="t-abc", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
def test_returns_403_for_invalid_signature(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload()
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": "v0=invalid_signature",
},
)
assert response.status_code == 403
def test_returns_400_when_missing_timestamp_header(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
body = _make_button_payload()
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Signature": "v0=something",
},
)
assert response.status_code in (400, 403, 422)
def test_rejects_when_signing_secret_not_configured(self) -> None:
app = _make_test_app(signing_secret="")
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="t-abc", value="approve")
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": "v0=any_sig",
},
)
assert response.status_code == 503
def test_returns_200_with_approve_action(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="thread-1", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
def test_returns_200_with_reject_action(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="thread-2", value="reject")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
def test_schedules_graph_resume_in_background(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="t-bg", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
def test_returns_404_for_unknown_thread(self) -> None:
app = _make_test_app(thread_graph_name=None)
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="unknown-thread", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
# Should return 200 immediately (Slack requires immediate 200)
# but the background task may log an error
assert response.status_code == 200
def test_response_body_is_empty_or_ok(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="t-ok", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
# Body may be empty or a simple JSON with ok=True
if response.content:
data = response.json()
assert data.get("ok") is True or "ok" not in data

270
tests/api/test_status.py Normal file
View File

@@ -0,0 +1,270 @@
"""Tests for status, releases, staging, and manual trigger endpoints.
Written FIRST (TDD RED phase).
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.status import router as status_router
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_app(
*,
versions: list[str] | None = None,
staging_data: dict | None = None,
) -> FastAPI:
"""Return a FastAPI app with mocked state for status tests."""
app = FastAPI()
app.include_router(status_router)
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = ""
mock_graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
mock_clients = MagicMock()
mock_staging_store = MagicMock()
mock_staging_store.list_versions = AsyncMock(return_value=versions or [])
# staging store returns StagingRelease-like or None
if staging_data is not None:
mock_staging_obj = MagicMock()
mock_staging_obj.model_dump = MagicMock(return_value=staging_data)
mock_staging_store.load = AsyncMock(return_value=mock_staging_obj)
else:
mock_staging_store.load = AsyncMock(return_value=None)
mock_pool = MagicMock()
app.state.settings = mock_settings
app.state.graphs = mock_graphs
app.state.tool_clients = mock_clients
app.state.staging_store = mock_staging_store
app.state.db_pool = mock_pool
app.state.background_tasks = set()
app.state.started_at = datetime.now(tz=timezone.utc)
return app
# ---------------------------------------------------------------------------
# GET /status
# ---------------------------------------------------------------------------
class TestGetStatus:
def test_returns_200(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
assert response.status_code == 200
def test_response_has_status_field(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
data = response.json()
assert "status" in data
assert data["status"] in ("ok", "degraded")
def test_response_has_version_field(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
data = response.json()
assert "version" in data
assert isinstance(data["version"], str)
def test_response_has_uptime_seconds(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
data = response.json()
assert "uptime_seconds" in data
assert data["uptime_seconds"] >= 0.0
def test_status_is_ok_when_healthy(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
assert response.json()["status"] == "ok"
# ---------------------------------------------------------------------------
# GET /releases/{repo}
# ---------------------------------------------------------------------------
class TestGetReleaseVersions:
def test_returns_200(self) -> None:
app = _make_test_app(versions=["v1.0.0", "v1.1.0"])
with TestClient(app) as client:
response = client.get("/releases/my-repo")
assert response.status_code == 200
def test_response_has_repo_and_versions(self) -> None:
app = _make_test_app(versions=["v1.0.0", "v1.1.0"])
with TestClient(app) as client:
response = client.get("/releases/my-repo")
data = response.json()
assert data["repo"] == "my-repo"
assert data["versions"] == ["v1.0.0", "v1.1.0"]
def test_empty_versions_list(self) -> None:
app = _make_test_app(versions=[])
with TestClient(app) as client:
response = client.get("/releases/unknown-repo")
data = response.json()
assert data["versions"] == []
def test_repo_name_in_path_used(self) -> None:
mock_staging_store = MagicMock()
mock_staging_store.list_versions = AsyncMock(return_value=[])
app = _make_test_app()
app.state.staging_store = mock_staging_store
with TestClient(app) as client:
client.get("/releases/specific-repo")
mock_staging_store.list_versions.assert_called_once_with("specific-repo")
# ---------------------------------------------------------------------------
# GET /staging
# ---------------------------------------------------------------------------
class TestGetStaging:
def test_returns_200_with_staging(self) -> None:
staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []}
app = _make_test_app(staging_data=staging_data)
with TestClient(app) as client:
response = client.get("/staging?repo=my-repo")
assert response.status_code == 200
def test_response_has_repo_and_staging(self) -> None:
staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []}
app = _make_test_app(staging_data=staging_data)
with TestClient(app) as client:
response = client.get("/staging?repo=my-repo")
data = response.json()
assert data["repo"] == "my-repo"
assert data["staging"] is not None
assert data["staging"]["version"] == "v1.0.0"
def test_returns_null_staging_when_not_found(self) -> None:
app = _make_test_app(staging_data=None)
with TestClient(app) as client:
response = client.get("/staging?repo=no-staging-repo")
assert response.status_code == 200
data = response.json()
assert data["staging"] is None
def test_missing_repo_query_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/staging")
assert response.status_code == 422
# ---------------------------------------------------------------------------
# POST /manual/pr/{pr_id}
# ---------------------------------------------------------------------------
class TestManualPrTrigger:
def test_returns_202(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post("/manual/pr/42")
assert response.status_code == 202
def test_response_has_thread_id(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post("/manual/pr/42")
data = response.json()
assert "thread_id" in data
assert isinstance(data["thread_id"], str)
def test_response_has_message(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post("/manual/pr/42")
assert "message" in response.json()
def test_schedules_background_task(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
) as mock_create:
with TestClient(app) as client:
client.post("/manual/pr/99")
mock_create.assert_called_once()
# ---------------------------------------------------------------------------
# POST /manual/release
# ---------------------------------------------------------------------------
class TestManualReleaseTrigger:
def test_returns_202(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
)
assert response.status_code == 202
def test_response_has_thread_id(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
)
data = response.json()
assert "thread_id" in data
def test_missing_repo_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={},
)
assert response.status_code == 422
def test_schedules_background_task(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
) as mock_create:
with TestClient(app) as client:
client.post(
"/manual/release",
json={"repo": "my-repo"},
)
mock_create.assert_called_once()

View File

@@ -0,0 +1,166 @@
"""Tests for status/manual endpoints with operator token authentication.
Phase 5 - Step 3: Verifies that POST /manual/* require operator token
when configured. GET endpoints are not protected.
Written FIRST (TDD RED phase).
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.status import router as status_router
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_app(operator_token: str = "") -> FastAPI:
"""Return a FastAPI app with status router and configurable operator token."""
app = FastAPI()
app.include_router(status_router)
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = operator_token
mock_staging_store = MagicMock()
mock_staging_store.list_versions = AsyncMock(return_value=[])
mock_staging_store.load = AsyncMock(return_value=None)
mock_pool = MagicMock()
app.state.settings = mock_settings
app.state.graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
app.state.tool_clients = MagicMock()
app.state.staging_store = mock_staging_store
app.state.db_pool = mock_pool
app.state.background_tasks = set()
app.state.started_at = datetime.now(tz=timezone.utc)
return app
# ---------------------------------------------------------------------------
# POST /manual/pr/{pr_id} with auth
# ---------------------------------------------------------------------------
class TestManualPrTriggerWithAuth:
def test_valid_token_allows_manual_pr(self) -> None:
app = _make_test_app(operator_token="secure-token")
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/pr/42",
headers={"X-Operator-Token": "secure-token"},
)
assert response.status_code == 202
def test_missing_token_rejects_manual_pr(self) -> None:
app = _make_test_app(operator_token="secure-token")
with TestClient(app) as client:
response = client.post("/manual/pr/42")
assert response.status_code == 401
def test_wrong_token_rejects_manual_pr(self) -> None:
app = _make_test_app(operator_token="secure-token")
with TestClient(app) as client:
response = client.post(
"/manual/pr/42",
headers={"X-Operator-Token": "bad-token"},
)
assert response.status_code == 401
def test_no_auth_required_when_token_not_configured(self) -> None:
app = _make_test_app(operator_token="")
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post("/manual/pr/42")
assert response.status_code == 202
# ---------------------------------------------------------------------------
# POST /manual/release with auth
# ---------------------------------------------------------------------------
class TestManualReleaseTriggerWithAuth:
def test_valid_token_allows_manual_release(self) -> None:
app = _make_test_app(operator_token="secure-token")
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
headers={"X-Operator-Token": "secure-token"},
)
assert response.status_code == 202
def test_missing_token_rejects_manual_release(self) -> None:
app = _make_test_app(operator_token="secure-token")
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
)
assert response.status_code == 401
def test_wrong_token_rejects_manual_release(self) -> None:
app = _make_test_app(operator_token="secure-token")
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
headers={"X-Operator-Token": "wrong"},
)
assert response.status_code == 401
def test_no_auth_required_when_token_not_configured(self) -> None:
app = _make_test_app(operator_token="")
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
)
assert response.status_code == 202
# ---------------------------------------------------------------------------
# GET /status, /releases, /staging do NOT require auth
# ---------------------------------------------------------------------------
class TestReadEndpointsNoAuth:
def test_get_status_no_token_needed(self) -> None:
"""GET /status should never require auth."""
app = _make_test_app(operator_token="super-secret")
with TestClient(app) as client:
response = client.get("/status")
assert response.status_code == 200
def test_get_releases_no_token_needed(self) -> None:
app = _make_test_app(operator_token="super-secret")
with TestClient(app) as client:
response = client.get("/releases/my-repo")
assert response.status_code == 200
def test_get_staging_no_token_needed(self) -> None:
app = _make_test_app(operator_token="super-secret")
with TestClient(app) as client:
response = client.get("/staging?repo=my-repo")
assert response.status_code == 200

218
tests/api/test_webhooks.py Normal file
View File

@@ -0,0 +1,218 @@
"""Tests for webhook endpoint. Written FIRST (TDD RED phase)."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.webhooks import (
_validate_webhook_secret,
router as webhook_router,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
VALID_SECRET = "super-secret-webhook-key"
_COMPLETED_PR_PAYLOAD = {
"subscription_id": "sub-1",
"event_type": "git.pullrequest.updated",
"resource": {
"repository": {
"id": "repo-1",
"name": "my-repo",
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
},
"pull_request_id": 42,
"title": "feat: add feature",
"source_ref_name": "refs/heads/feature/BILL-123-add-feature",
"target_ref_name": "refs/heads/main",
"status": "completed",
"closed_date": "2024-01-15T10:00:00Z",
},
}
_ACTIVE_PR_PAYLOAD = {
"subscription_id": "sub-2",
"event_type": "git.pullrequest.updated",
"resource": {
"repository": {
"id": "repo-1",
"name": "my-repo",
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
},
"pull_request_id": 43,
"title": "WIP: work in progress",
"source_ref_name": "refs/heads/feature/BILL-456",
"target_ref_name": "refs/heads/main",
"status": "active",
"closed_date": None,
},
}
def _make_test_app(webhook_secret: str = VALID_SECRET) -> FastAPI:
"""Return a FastAPI app with mocked state for webhook tests."""
app = FastAPI()
app.include_router(webhook_router)
mock_settings = MagicMock()
mock_settings.webhook_secret.get_secret_value.return_value = webhook_secret
mock_graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
mock_clients = MagicMock()
mock_pool = MagicMock()
# background_tasks set tracked on state
app.state.settings = mock_settings
app.state.graphs = mock_graphs
app.state.tool_clients = mock_clients
app.state.db_pool = mock_pool
app.state.background_tasks = set()
return app
# ---------------------------------------------------------------------------
# _validate_webhook_secret (unit tests, pure function)
# ---------------------------------------------------------------------------
class TestValidateWebhookSecret:
def test_valid_secret_returns_true(self) -> None:
assert _validate_webhook_secret("mysecret", "mysecret") is True
def test_wrong_secret_returns_false(self) -> None:
assert _validate_webhook_secret("wrong", "mysecret") is False
def test_empty_header_returns_false(self) -> None:
assert _validate_webhook_secret("", "mysecret") is False
def test_none_header_returns_false(self) -> None:
assert _validate_webhook_secret(None, "mysecret") is False # type: ignore[arg-type]
def test_uses_constant_time_comparison(self) -> None:
# Should not raise even for very different lengths
assert _validate_webhook_secret("a", "very-long-secret-value") is False
def test_empty_expected_rejects_all(self) -> None:
# Empty expected secret = auth misconfigured, reject everything
assert _validate_webhook_secret("", "") is False
assert _validate_webhook_secret("any-value", "") is False
assert _validate_webhook_secret(None, "") is False
# ---------------------------------------------------------------------------
# POST /webhooks/azdo — integration tests via TestClient
# ---------------------------------------------------------------------------
class TestWebhookEndpoint:
def test_valid_completed_pr_returns_202(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.webhooks.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app, raise_server_exceptions=True) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
headers={"X-Webhook-Secret": VALID_SECRET},
)
assert response.status_code == 202
data = response.json()
assert "thread_id" in data
assert "message" in data
def test_missing_secret_header_returns_401(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
)
assert response.status_code == 401
def test_wrong_secret_header_returns_401(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
headers={"X-Webhook-Secret": "wrong-secret"},
)
assert response.status_code == 401
def test_invalid_payload_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json={"invalid": "payload"},
headers={"X-Webhook-Secret": VALID_SECRET},
)
assert response.status_code == 422
def test_active_pr_event_returns_200_ignored(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_ACTIVE_PR_PAYLOAD,
headers={"X-Webhook-Secret": VALID_SECRET},
)
assert response.status_code == 200
data = response.json()
assert "ignored" in data.get("message", "").lower() or "ignored" in str(data).lower()
def test_completed_pr_thread_id_is_string(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.webhooks.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
headers={"X-Webhook-Secret": VALID_SECRET},
)
assert response.status_code == 202
assert isinstance(response.json()["thread_id"], str)
def test_completed_pr_schedules_background_task(self) -> None:
app = _make_test_app()
task_mock = MagicMock()
with patch(
"release_agent.api.webhooks.asyncio.create_task", return_value=task_mock
) as mock_create:
with TestClient(app) as client:
client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
headers={"X-Webhook-Secret": VALID_SECRET},
)
mock_create.assert_called_once()
# ---------------------------------------------------------------------------
# Error response shape
# ---------------------------------------------------------------------------
class TestWebhookErrorShape:
def test_401_has_detail_field(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
)
assert response.status_code == 401
# FastAPI HTTPException returns {"detail": ...}
assert "detail" in response.json()

0
tests/graph/__init__.py Normal file
View File

44
tests/graph/conftest.py Normal file
View File

@@ -0,0 +1,44 @@
"""Shared fixtures for graph tests.
Provides build_mock_clients() to create ToolClients with AsyncMock fields
so individual node functions can be tested without compiling the full graph.
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from release_agent.graph.dependencies import ToolClients
def build_mock_clients() -> ToolClients:
"""Return a ToolClients instance whose fields are all AsyncMock/MagicMock."""
azdo = AsyncMock()
jira = AsyncMock()
slack = AsyncMock()
reviewer = AsyncMock()
return ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
def build_config(clients: ToolClients | None = None, staging_store=None) -> dict:
"""Return a LangGraph-style config dict with clients and staging_store."""
if clients is None:
clients = build_mock_clients()
return {
"configurable": {
"clients": clients,
"staging_store": staging_store,
}
}
@pytest.fixture()
def mock_clients() -> ToolClients:
"""Pytest fixture returning fresh mock ToolClients."""
return build_mock_clients()
@pytest.fixture()
def config(mock_clients: ToolClients):
"""Pytest fixture returning a config dict with mock clients."""
return build_config(mock_clients)

View File

@@ -0,0 +1,294 @@
"""Tests for graph/ci_nodes.py.
Written FIRST (TDD RED phase).
All external calls (azdo, slack, poll_until) are mocked.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from release_agent.graph.ci_nodes import notify_ci_result, poll_ci_build, trigger_ci_build
from release_agent.models.build import BuildStatus
from release_agent.models.pipeline import PipelineInfo
from tests.graph.conftest import build_config, build_mock_clients
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_pipeline(pipeline_id: int = 10, name: str = "CI-build") -> dict:
return {"id": pipeline_id, "name": name, "repo": "my-repo"}
# ---------------------------------------------------------------------------
# trigger_ci_build
# ---------------------------------------------------------------------------
class TestTriggerCiBuild:
"""Tests for trigger_ci_build node."""
async def test_triggers_pipeline_on_branch(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 555, "state": "inProgress"}
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
clients.azdo.trigger_pipeline.assert_called_once()
assert "ci_build_id" in result
assert result["ci_build_id"] == 555
async def test_returns_ci_build_id(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=20, name="build-and-test", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 999}
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v2.0.0"}
result = await trigger_ci_build(state, config)
assert result["ci_build_id"] == 999
async def test_appends_error_when_no_pipelines_found(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = []
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
assert "errors" in result
assert len(result["errors"]) >= 1
async def test_appends_error_on_trigger_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.side_effect = ServiceError(
service="azdo", status_code=500, detail="Internal error"
)
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
assert "errors" in result
async def test_uses_main_branch_when_no_version(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 1}
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await trigger_ci_build(state, config)
call_kwargs = clients.azdo.trigger_pipeline.call_args[1]
branch = call_kwargs.get("branch", "")
assert "main" in branch or "refs/heads" in branch
async def test_appends_message_on_success(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 123}
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
assert "messages" in result
assert len(result["messages"]) >= 1
# ---------------------------------------------------------------------------
# poll_ci_build
# ---------------------------------------------------------------------------
class TestPollCiBuild:
"""Tests for poll_ci_build node."""
async def test_returns_ci_build_status_and_result_on_completion(self) -> None:
clients = build_mock_clients()
completed_status = BuildStatus(status="completed", result="succeeded", build_url="https://build/1")
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(completed_status, True),
):
result = await poll_ci_build(state, config)
assert result["ci_build_status"] == "completed"
assert result["ci_build_result"] == "succeeded"
async def test_returns_build_url(self) -> None:
clients = build_mock_clients()
completed_status = BuildStatus(
status="completed",
result="succeeded",
build_url="https://dev.azure.com/build/42",
)
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(completed_status, True),
):
result = await poll_ci_build(state, config)
assert result.get("ci_build_url") == "https://dev.azure.com/build/42"
async def test_appends_error_on_timeout(self) -> None:
clients = build_mock_clients()
running_status = BuildStatus(status="inProgress", result=None, build_url=None)
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(running_status, False),
):
result = await poll_ci_build(state, config)
assert "errors" in result
async def test_appends_error_when_build_id_missing(self) -> None:
clients = build_mock_clients()
config = build_config(clients)
state = {"repo_name": "my-repo"} # no ci_build_id
result = await poll_ci_build(state, config)
assert "errors" in result
async def test_passes_correct_build_id_to_poll_fn(self) -> None:
clients = build_mock_clients()
clients.azdo.get_build_status.return_value = BuildStatus(
status="completed", result="succeeded", build_url=None
)
config = build_config(clients)
state = {"ci_build_id": 77, "repo_name": "my-repo"}
async def fake_poll_until(*, poll_fn, is_done, interval_seconds, max_wait_seconds, sleep_fn=None):
result = await poll_fn()
return result, True
with patch("release_agent.graph.ci_nodes.poll_until", side_effect=fake_poll_until):
await poll_ci_build(state, config)
clients.azdo.get_build_status.assert_called_once_with(build_id=77)
async def test_result_none_when_poll_returns_none(self) -> None:
clients = build_mock_clients()
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(None, False),
):
result = await poll_ci_build(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# notify_ci_result
# ---------------------------------------------------------------------------
class TestNotifyCiResult:
"""Tests for notify_ci_result node."""
async def test_sends_notification_on_success(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_status": "completed",
"ci_build_result": "succeeded",
"ci_build_url": "https://build/99",
}
result = await notify_ci_result(state, config)
clients.slack.send_notification.assert_called_once()
assert "messages" in result
async def test_sends_notification_on_failure(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_status": "completed",
"ci_build_result": "failed",
"ci_build_url": None,
}
result = await notify_ci_result(state, config)
clients.slack.send_notification.assert_called_once()
async def test_handles_slack_error_gracefully(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.slack.send_notification.side_effect = ServiceError(
service="slack", status_code=500, detail="Slack error"
)
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_result": "succeeded",
"ci_build_url": None,
}
result = await notify_ci_result(state, config)
# Should not re-raise; should append error
assert "errors" in result
async def test_includes_repo_name_in_message(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "super-service",
"ci_build_result": "succeeded",
"ci_build_url": None,
}
await notify_ci_result(state, config)
call_kwargs = clients.slack.send_notification.call_args[1]
text_or_blocks = str(call_kwargs)
assert "super-service" in text_or_blocks
async def test_returns_empty_dict_when_state_has_no_data(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {}
result = await notify_ci_result(state, config)
# Should not crash; may return messages or empty dict
assert isinstance(result, dict)

View File

@@ -0,0 +1,283 @@
"""Tests for graph/dependencies.py. Written FIRST (TDD RED phase).
Covers:
- ToolClients frozen dataclass
- StagingStore Protocol (structural check)
- JsonFileStagingStore file I/O operations
"""
import json
from datetime import date
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients
from release_agent.models.release import ArchivedRelease, StagingRelease
from release_agent.models.ticket import TicketEntry
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease:
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=[],
)
# ---------------------------------------------------------------------------
# ToolClients tests
# ---------------------------------------------------------------------------
class TestToolClients:
"""Tests for the ToolClients frozen dataclass."""
def test_can_be_constructed_with_all_fields(self) -> None:
azdo = AsyncMock()
jira = AsyncMock()
slack = AsyncMock()
reviewer = AsyncMock()
clients = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
assert clients.azdo is azdo
assert clients.jira is jira
assert clients.slack is slack
assert clients.reviewer is reviewer
def test_is_frozen_cannot_reassign_field(self) -> None:
clients = ToolClients(
azdo=AsyncMock(), jira=AsyncMock(), slack=AsyncMock(), reviewer=AsyncMock()
)
with pytest.raises((AttributeError, TypeError)):
clients.azdo = AsyncMock() # type: ignore[misc]
def test_fields_are_accessible_by_name(self) -> None:
azdo = object()
clients = ToolClients(
azdo=azdo, jira=object(), slack=object(), reviewer=object()
)
assert clients.azdo is azdo
def test_equality_for_same_instances(self) -> None:
azdo = AsyncMock()
jira = AsyncMock()
slack = AsyncMock()
reviewer = AsyncMock()
c1 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
c2 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
assert c1 == c2
# ---------------------------------------------------------------------------
# StagingStore Protocol structural tests
# ---------------------------------------------------------------------------
class TestStagingStoreProtocol:
"""Verify that the Protocol is structurally correct."""
def test_json_file_store_satisfies_protocol(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
# runtime_checkable would need @runtime_checkable; check duck-typing instead
assert hasattr(store, "load")
assert hasattr(store, "save")
assert hasattr(store, "archive")
assert hasattr(store, "list_versions")
def test_protocol_is_importable(self) -> None:
# Just import-level check
assert StagingStore is not None
# ---------------------------------------------------------------------------
# JsonFileStagingStore tests
# ---------------------------------------------------------------------------
class TestJsonFileStagingStore:
"""Tests for JsonFileStagingStore using tmp_path for file I/O."""
# ------------------------------------------------------------------
# load
# ------------------------------------------------------------------
async def test_load_returns_none_when_file_missing(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
result = await store.load("nonexistent-repo")
assert result is None
async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.0"
assert loaded.repo == "my-repo"
async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging().add_ticket(_make_ticket("BILL-10"))
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert len(loaded.tickets) == 1
assert loaded.tickets[0].id == "BILL-10"
async def test_load_is_read_only_does_not_mutate_stored(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded1 = await store.load("my-repo")
loaded2 = await store.load("my-repo")
assert loaded1 is not loaded2 # fresh objects each time
# ------------------------------------------------------------------
# save
# ------------------------------------------------------------------
async def test_save_creates_file_in_directory(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="api-service")
await store.save(staging)
expected_path = tmp_path / "api-service.json"
assert expected_path.exists()
async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging_v1 = _make_staging(version="v1.0.0")
staging_v2 = _make_staging(version="v1.0.1")
await store.save(staging_v1)
await store.save(staging_v2)
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.1"
async def test_save_writes_valid_json(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
raw = (tmp_path / "my-repo.json").read_text()
data = json.loads(raw)
assert data["version"] == "v1.0.0"
assert data["repo"] == "my-repo"
async def test_save_does_not_mutate_staging_release(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
original_tickets = list(staging.tickets)
await store.save(staging)
assert list(staging.tickets) == original_tickets
# ------------------------------------------------------------------
# archive
# ------------------------------------------------------------------
async def test_archive_removes_staging_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
assert await store.load("my-repo") is None
async def test_archive_creates_archive_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="my-repo", version="v1.0.0")
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
assert archive_path.exists()
async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
release_date = date(2025, 6, 1)
await store.archive(staging, release_date)
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
data = json.loads(archive_path.read_text())
assert data["released_at"] == "2025-06-01"
async def test_archive_without_prior_save_creates_archive(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
assert archive_path.exists()
# ------------------------------------------------------------------
# list_versions
# ------------------------------------------------------------------
async def test_list_versions_empty_directory(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
versions = await store.list_versions("my-repo")
assert versions == []
async def test_list_versions_returns_version_from_staging_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v2.1.0"))
versions = await store.list_versions("my-repo")
assert "v2.1.0" in versions
async def test_list_versions_includes_archived_versions(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(version="v1.5.0")
await store.save(staging)
await store.archive(staging, date(2025, 3, 1))
# Now save a new staging for the same repo
await store.save(_make_staging(version="v1.6.0"))
versions = await store.list_versions("my-repo")
assert "v1.5.0" in versions
assert "v1.6.0" in versions
async def test_list_versions_only_returns_versions_for_given_repo(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(repo="repo-a", version="v1.0.0"))
await store.save(_make_staging(repo="repo-b", version="v2.0.0"))
versions_a = await store.list_versions("repo-a")
assert "v1.0.0" in versions_a
# repo-b version should not appear in repo-a's list
assert "v2.0.0" not in versions_a
async def test_list_versions_no_duplicates(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v1.0.0"))
versions = await store.list_versions("my-repo")
assert len(versions) == len(set(versions))
async def test_list_versions_multiple_archives(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
for i in range(3):
staging = _make_staging(version=f"v1.0.{i}")
await store.archive(staging, date(2025, 1, i + 1))
versions = await store.list_versions("my-repo")
assert len(versions) == 3
assert "v1.0.0" in versions
assert "v1.0.1" in versions
assert "v1.0.2" in versions
# ------------------------------------------------------------------
# directory creation
# ------------------------------------------------------------------
def test_store_creates_directory_if_not_exists(self, tmp_path: Path) -> None:
new_dir = tmp_path / "staging_data"
assert not new_dir.exists()
JsonFileStagingStore(directory=new_dir)
assert new_dir.exists()

View File

@@ -0,0 +1,177 @@
"""Tests for async StagingStore protocol and async JsonFileStagingStore.
Phase 5 - Step 1: All StagingStore methods become async def.
Written FIRST (TDD RED phase).
"""
import json
from datetime import date
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients
from release_agent.models.release import StagingRelease
from release_agent.models.ticket import TicketEntry
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease:
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=[],
)
# ---------------------------------------------------------------------------
# Protocol: all methods must be async
# ---------------------------------------------------------------------------
class TestStagingStoreProtocolIsAsync:
"""Verify that StagingStore protocol methods are async-compatible."""
def test_protocol_has_load_method(self) -> None:
assert hasattr(StagingStore, "load")
def test_protocol_has_save_method(self) -> None:
assert hasattr(StagingStore, "save")
def test_protocol_has_archive_method(self) -> None:
assert hasattr(StagingStore, "archive")
def test_protocol_has_list_versions_method(self) -> None:
assert hasattr(StagingStore, "list_versions")
# ---------------------------------------------------------------------------
# JsonFileStagingStore async interface
# ---------------------------------------------------------------------------
class TestJsonFileStagingStoreAsync:
"""Verify that JsonFileStagingStore methods are awaitable (async def)."""
async def test_load_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
result = await store.load("nonexistent-repo")
assert result is None
async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.0"
assert loaded.repo == "my-repo"
async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging().add_ticket(_make_ticket("BILL-10"))
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert len(loaded.tickets) == 1
assert loaded.tickets[0].id == "BILL-10"
async def test_load_returns_fresh_objects(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded1 = await store.load("my-repo")
loaded2 = await store.load("my-repo")
assert loaded1 is not loaded2
async def test_save_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="api-service")
await store.save(staging)
expected_path = tmp_path / "api-service.json"
assert expected_path.exists()
async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v1.0.0"))
await store.save(_make_staging(version="v1.0.1"))
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.1"
async def test_save_writes_valid_json(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
raw = (tmp_path / "my-repo.json").read_text()
data = json.loads(raw)
assert data["version"] == "v1.0.0"
assert data["repo"] == "my-repo"
async def test_archive_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
assert await store.load("my-repo") is None
async def test_archive_creates_archive_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="my-repo", version="v1.0.0")
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
assert archive_path.exists()
async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
data = json.loads(archive_path.read_text())
assert data["released_at"] == "2025-06-01"
async def test_list_versions_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
versions = await store.list_versions("my-repo")
assert versions == []
async def test_list_versions_returns_staging_version(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v2.1.0"))
versions = await store.list_versions("my-repo")
assert "v2.1.0" in versions
async def test_list_versions_includes_archived(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(version="v1.5.0")
await store.save(staging)
await store.archive(staging, date(2025, 3, 1))
await store.save(_make_staging(version="v1.6.0"))
versions = await store.list_versions("my-repo")
assert "v1.5.0" in versions
assert "v1.6.0" in versions
async def test_list_versions_only_for_given_repo(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(repo="repo-a", version="v1.0.0"))
await store.save(_make_staging(repo="repo-b", version="v2.0.0"))
versions_a = await store.list_versions("repo-a")
assert "v1.0.0" in versions_a
assert "v2.0.0" not in versions_a

View File

@@ -0,0 +1,53 @@
"""Tests for graph/full_cycle.py.
Tests that the full cycle graph composes pr_completed and release subgraphs
correctly, and that the routing conditional edge works as expected.
"""
from release_agent.graph.full_cycle import build_full_cycle_graph
from release_agent.graph.routing import should_continue_to_release
class TestBuildFullCycleGraph:
def test_returns_compiled_graph(self) -> None:
graph = build_full_cycle_graph()
assert graph is not None
def test_graph_can_be_built_multiple_times(self) -> None:
graph1 = build_full_cycle_graph()
graph2 = build_full_cycle_graph()
assert graph1 is not None
assert graph2 is not None
def test_graph_has_get_graph_method(self) -> None:
graph = build_full_cycle_graph()
assert hasattr(graph, "get_graph") or hasattr(graph, "nodes")
class TestFullCycleRouting:
"""Test that the routing function used by full_cycle correctly
determines whether to continue to the release subgraph."""
def test_continue_when_flag_true_and_no_errors(self) -> None:
state = {"continue_to_release": True, "errors": []}
assert should_continue_to_release(state) == "yes"
def test_stop_when_flag_false(self) -> None:
state = {"continue_to_release": False}
assert should_continue_to_release(state) == "no"
def test_stop_when_flag_missing(self) -> None:
state = {}
assert should_continue_to_release(state) == "no"
def test_stop_when_errors_present(self) -> None:
state = {"continue_to_release": True, "errors": ["some error"]}
assert should_continue_to_release(state) == "no"
def test_stop_when_flag_true_but_errors_present(self) -> None:
state = {"continue_to_release": True, "errors": ["critical failure"]}
assert should_continue_to_release(state) == "no"
def test_continue_when_errors_empty_list(self) -> None:
state = {"continue_to_release": True, "errors": []}
assert should_continue_to_release(state) == "yes"

356
tests/graph/test_polling.py Normal file
View File

@@ -0,0 +1,356 @@
"""Tests for graph/polling.py — poll_until async utility.
Written FIRST (TDD RED phase).
All tests inject a fake_sleep_fn that returns immediately to avoid real waits.
"""
import asyncio
from unittest.mock import AsyncMock, call
import pytest
from release_agent.graph.polling import poll_until
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _immediate_sleep(seconds: float) -> None:
"""Drop-in replacement for asyncio.sleep that returns immediately."""
return
# ---------------------------------------------------------------------------
# Success path tests
# ---------------------------------------------------------------------------
class TestPollUntilSuccess:
"""Tests for the happy path where poll_fn succeeds before timeout."""
async def test_returns_tuple_of_result_and_completed_true(self) -> None:
calls = iter(["running", "running", "completed"])
async def poll_fn():
return next(calls)
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "completed",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "completed"
assert completed is True
async def test_returns_immediately_when_already_done(self) -> None:
async def poll_fn():
return "completed"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "completed",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "completed"
assert completed is True
async def test_polls_multiple_times_before_done(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
return "done" if call_count >= 3 else "pending"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "done"
assert completed is True
assert call_count == 3
async def test_sleep_called_between_polls(self) -> None:
call_count = 0
sleep_calls: list[float] = []
async def poll_fn():
nonlocal call_count
call_count += 1
return "done" if call_count >= 2 else "pending"
async def tracking_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=15,
max_wait_seconds=60,
sleep_fn=tracking_sleep,
)
assert len(sleep_calls) >= 1
assert all(s == 15 for s in sleep_calls)
async def test_no_sleep_on_first_successful_poll(self) -> None:
sleep_calls: list[float] = []
async def poll_fn():
return "done"
async def tracking_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=10,
max_wait_seconds=60,
sleep_fn=tracking_sleep,
)
assert sleep_calls == []
async def test_works_with_dict_results(self) -> None:
responses = iter([
{"status": "inProgress"},
{"status": "completed", "result": "succeeded"},
])
async def poll_fn():
return next(responses)
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r["status"] == "completed",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result["result"] == "succeeded"
assert completed is True
# ---------------------------------------------------------------------------
# Timeout tests
# ---------------------------------------------------------------------------
class TestPollUntilTimeout:
"""Tests for timeout behavior."""
async def test_returns_last_result_and_completed_false_on_timeout(self) -> None:
async def poll_fn():
return "still_running"
# With interval=10, max_wait=5, it should time out after one poll
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=10,
max_wait_seconds=5,
sleep_fn=_immediate_sleep,
)
assert result == "still_running"
assert completed is False
async def test_at_least_one_poll_happens_before_timeout(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
return "running"
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=100,
max_wait_seconds=1,
sleep_fn=_immediate_sleep,
)
assert call_count >= 1
async def test_max_polls_bounded_by_max_wait_over_interval(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
return "running"
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: False,
interval_seconds=10,
max_wait_seconds=30,
sleep_fn=_immediate_sleep,
)
# With interval=10, max_wait=30: should poll at most ceil(30/10)+1 = 4 times
assert call_count <= 5
# ---------------------------------------------------------------------------
# Error handling tests
# ---------------------------------------------------------------------------
class TestPollUntilErrorHandling:
"""Tests for error/exception handling in poll_until."""
async def test_continues_after_transient_exception(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("Transient error")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "done"
assert completed is True
async def test_aborts_after_three_consecutive_failures(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
raise RuntimeError("Persistent error")
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: True,
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
# Should abort after 3 consecutive failures
assert call_count == 3
assert completed is False
assert result is None
async def test_resets_consecutive_failure_count_on_success(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
# Fail twice, succeed once, fail twice, succeed (done)
if call_count in (1, 2):
raise RuntimeError("fail")
if call_count == 3:
return "running"
if call_count in (4, 5):
raise RuntimeError("fail again")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=120,
sleep_fn=_immediate_sleep,
)
assert completed is True
assert result == "done"
async def test_single_exception_does_not_abort(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValueError("one error")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert completed is True
assert result == "done"
async def test_two_consecutive_failures_do_not_abort(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
if call_count <= 2:
raise ConnectionError("two errors")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert completed is True
assert result == "done"
# ---------------------------------------------------------------------------
# Default parameter tests
# ---------------------------------------------------------------------------
class TestPollUntilDefaults:
"""Tests that default parameters match the spec."""
async def test_default_interval_is_30_seconds(self) -> None:
sleep_calls: list[float] = []
async def poll_fn():
return "done" if len(sleep_calls) >= 1 else "running"
async def tracking_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
sleep_fn=tracking_sleep,
)
if sleep_calls:
assert sleep_calls[0] == 30
async def test_poll_fn_and_is_done_are_keyword_only(self) -> None:
"""poll_fn and is_done must be passed as keyword arguments."""
async def poll_fn():
return "done"
with pytest.raises(TypeError):
await poll_until(poll_fn, lambda r: r == "done") # type: ignore[call-arg]

View File

@@ -0,0 +1,414 @@
"""Tests for PostgresStagingStore.
Phase 5 - Step 2: PostgreSQL-backed StagingStore using async pool.
Written FIRST (TDD RED phase).
All tests use FakeAsyncPool — no real PostgreSQL required.
"""
import json
from datetime import date
from unittest.mock import AsyncMock, MagicMock
import pytest
from release_agent.graph.postgres_staging_store import PostgresStagingStore
from release_agent.models.release import ArchivedRelease, StagingRelease
from release_agent.models.ticket import TicketEntry
# ---------------------------------------------------------------------------
# Fake pool infrastructure
# ---------------------------------------------------------------------------
class FakeAsyncCursor:
"""Records SQL calls and returns configured results."""
def __init__(self) -> None:
self.executed: list[tuple[str, tuple]] = []
self._fetchone_result: tuple | None = None
self._fetchall_result: list[tuple] = []
def set_fetchone(self, row: tuple | None) -> None:
self._fetchone_result = row
def set_fetchall(self, rows: list[tuple]) -> None:
self._fetchall_result = rows
async def execute(self, sql: str, params: tuple = ()) -> None:
self.executed.append((sql, params))
async def fetchone(self) -> tuple | None:
return self._fetchone_result
async def fetchall(self) -> list[tuple]:
return self._fetchall_result
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeAsyncTransaction:
"""Fake async transaction context manager (no-op)."""
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeAsyncConnection:
"""Async context manager returning a FakeAsyncCursor."""
def __init__(self, cursor: FakeAsyncCursor) -> None:
self._cursor = cursor
def cursor(self):
return self._cursor
def transaction(self):
return FakeAsyncTransaction()
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeAsyncPool:
"""Records all SQL executed through it."""
def __init__(self, cursor: FakeAsyncCursor) -> None:
self._cursor = cursor
self._conn = FakeAsyncConnection(cursor)
def connection(self):
return self._conn
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(
repo: str = "my-repo",
version: str = "v1.0.0",
tickets: list | None = None,
) -> StagingRelease:
t = tickets if tickets is not None else []
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=t,
)
def _staging_row(staging: StagingRelease) -> tuple:
"""Return (repo, version, started_at, tickets_json) as DB would store it."""
return (
staging.repo,
staging.version,
staging.started_at.isoformat(),
json.dumps([t.model_dump(mode="json") for t in staging.tickets]),
)
# ---------------------------------------------------------------------------
# load()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreLoad:
async def test_load_returns_none_when_no_row(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
result = await store.load("nonexistent-repo")
assert result is None
async def test_load_returns_staging_release_when_row_exists(self) -> None:
staging = _make_staging(repo="api-service", version="v2.0.0")
cursor = FakeAsyncCursor()
cursor.set_fetchone((
staging.repo,
staging.version,
staging.started_at.isoformat(),
json.dumps([]),
))
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
result = await store.load("api-service")
assert result is not None
assert isinstance(result, StagingRelease)
assert result.repo == "api-service"
assert result.version == "v2.0.0"
async def test_load_returns_staging_with_tickets(self) -> None:
ticket = _make_ticket("BILL-42")
staging = _make_staging(tickets=[ticket])
cursor = FakeAsyncCursor()
cursor.set_fetchone((
staging.repo,
staging.version,
staging.started_at.isoformat(),
json.dumps([ticket.model_dump(mode="json")]),
))
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
result = await store.load("my-repo")
assert result is not None
assert len(result.tickets) == 1
assert result.tickets[0].id == "BILL-42"
async def test_load_executes_select_with_correct_repo(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
await store.load("target-repo")
assert len(cursor.executed) >= 1
sql, params = cursor.executed[-1]
assert "SELECT" in sql.upper()
assert "target-repo" in params
async def test_load_queries_staging_releases_table(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
await store.load("my-repo")
sql, _ = cursor.executed[-1]
assert "staging_releases" in sql
# ---------------------------------------------------------------------------
# save()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreSave:
async def test_save_executes_upsert(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
await store.save(staging)
assert len(cursor.executed) >= 1
sql, _ = cursor.executed[-1]
# Should be an INSERT ... ON CONFLICT ... or UPSERT
assert "INSERT" in sql.upper() or "UPSERT" in sql.upper()
async def test_save_passes_repo_to_upsert(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(repo="payment-service")
await store.save(staging)
_, params = cursor.executed[-1]
assert "payment-service" in params
async def test_save_passes_version_to_upsert(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(version="v3.1.0")
await store.save(staging)
_, params = cursor.executed[-1]
assert "v3.1.0" in params
async def test_save_targets_staging_releases_table(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
await store.save(staging)
sql, _ = cursor.executed[-1]
assert "staging_releases" in sql
async def test_save_serializes_tickets_as_json(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(tickets=[_make_ticket("ALLPOST-99")])
await store.save(staging)
_, params = cursor.executed[-1]
# tickets param should be a JSON string containing the ticket id
tickets_json = next(p for p in params if isinstance(p, str) and "ALLPOST-99" in p)
parsed = json.loads(tickets_json)
assert parsed[0]["id"] == "ALLPOST-99"
# ---------------------------------------------------------------------------
# archive()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreArchive:
async def test_archive_inserts_into_archived_releases(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
await store.archive(staging, date(2025, 6, 1))
sql_statements = [sql for sql, _ in cursor.executed]
assert any("archived_releases" in sql for sql in sql_statements)
async def test_archive_deletes_from_staging_releases(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(repo="my-repo")
await store.archive(staging, date(2025, 6, 1))
sql_statements = [sql for sql, _ in cursor.executed]
assert any("DELETE" in sql.upper() and "staging_releases" in sql for sql in sql_statements)
async def test_archive_passes_released_at_date(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
release_date = date(2025, 12, 31)
await store.archive(staging, release_date)
all_params = [params for _, params in cursor.executed]
all_values = [v for params in all_params for v in params]
assert "2025-12-31" in all_values or release_date.isoformat() in all_values
async def test_archive_passes_repo_to_delete(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(repo="payment-service")
await store.archive(staging, date(2025, 6, 1))
delete_calls = [(sql, params) for sql, params in cursor.executed if "DELETE" in sql.upper()]
assert len(delete_calls) >= 1
_, params = delete_calls[0]
assert "payment-service" in params
# ---------------------------------------------------------------------------
# list_versions()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreListVersions:
async def test_list_versions_returns_empty_when_no_data(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
cursor.set_fetchall([])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert versions == []
async def test_list_versions_includes_staging_version(self) -> None:
cursor = FakeAsyncCursor()
# fetchone returns staging row
cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]"))
# fetchall returns archived rows
cursor.set_fetchall([])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert "v1.0.0" in versions
async def test_list_versions_includes_archived_versions(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
cursor.set_fetchall([
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-06-01"),
("my-repo", "v1.1.0", "2025-02-01", "[]", "2025-07-01"),
])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert "v1.0.0" in versions
assert "v1.1.0" in versions
async def test_list_versions_combines_staging_and_archived(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(("my-repo", "v2.0.0", "2025-03-01", "[]"))
cursor.set_fetchall([
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"),
])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert "v2.0.0" in versions
assert "v1.0.0" in versions
async def test_list_versions_no_duplicates(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]"))
cursor.set_fetchall([
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"),
])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert len(versions) == len(set(versions))
async def test_list_versions_executes_queries_for_correct_repo(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
cursor.set_fetchall([])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
await store.list_versions("target-repo")
all_params = [params for _, params in cursor.executed]
all_values = [v for params in all_params for v in params]
assert "target-repo" in all_values

View File

@@ -0,0 +1,956 @@
"""Tests for graph/pr_completed.py node functions. Written FIRST (TDD RED phase).
Each node is an async function (state, config) -> dict.
Tests call nodes directly with a state dict and config dict — no graph compilation.
"""
from datetime import date, datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients
from release_agent.graph.pr_completed import (
_post_review_to_pr,
add_jira_pr_link,
auto_create_ticket,
calculate_version,
evaluate_review,
fetch_pr_details,
interrupt_confirm_merge,
merge_pr_node,
move_jira_code_review,
move_jira_ready_for_stage,
notify_request_changes,
parse_webhook,
run_code_review,
update_staging,
build_pr_completed_graph,
)
from release_agent.models.review import ReviewIssue
from release_agent.models.jira import JiraIssue
from release_agent.models.pr import PRInfo
from release_agent.models.review import ReviewResult
from tests.graph.conftest import build_config, build_mock_clients
# ---------------------------------------------------------------------------
# Webhook payload fixtures
# ---------------------------------------------------------------------------
def _make_webhook_payload(
*,
repo_name: str = "my-repo",
pr_id: int = 42,
source_ref: str = "refs/heads/feature/ALLPOST-100_fix-bug",
target_ref: str = "refs/heads/main",
status: str = "completed",
title: str = "Fix: bug",
closed_date: str | None = "2025-01-15T10:00:00Z",
) -> dict:
# Uses snake_case keys to match WebhookPayload Pydantic model field names
return {
"subscription_id": "sub-1",
"event_type": "git.pullrequest.merged",
"resource": {
"repository": {
"id": "repo-id-1",
"name": repo_name,
"web_url": "https://dev.azure.com/org/proj/_git/my-repo",
},
"pull_request_id": pr_id,
"title": title,
"source_ref_name": source_ref,
"target_ref_name": target_ref,
"status": status,
"closed_date": closed_date,
},
}
def _make_pr_info(
*,
pr_id: str = "42",
repo_name: str = "my-repo",
branch: str = "refs/heads/feature/ALLPOST-100-fix-bug",
status: str = "completed",
) -> PRInfo:
return PRInfo(
pr_id=pr_id,
pr_url="https://dev.azure.com/org/proj/_git/my-repo/pullrequest/42",
repo_name=repo_name,
branch=branch,
pr_title="Fix: bug",
pr_status=status,
)
def _make_approve_review() -> dict:
return {
"verdict": "approve",
"summary": "Looks good",
"issues": [],
"has_blockers": False,
}
def _make_request_changes_review() -> dict:
return {
"verdict": "request_changes",
"summary": "Needs work",
"issues": [{"severity": "blocker", "description": "Missing tests"}],
"has_blockers": True,
}
# ---------------------------------------------------------------------------
# parse_webhook
# ---------------------------------------------------------------------------
class TestParseWebhook:
async def test_extracts_pr_info_from_payload(self) -> None:
state = {"webhook_payload": _make_webhook_payload()}
config = build_config()
result = await parse_webhook(state, config)
assert "pr_info" in result
pr = result["pr_info"]
assert pr["pr_id"] == "42"
assert pr["repo_name"] == "my-repo"
async def test_extracts_ticket_from_branch(self) -> None:
state = {"webhook_payload": _make_webhook_payload(
source_ref="refs/heads/feature/ALLPOST-100_fix-bug"
)}
config = build_config()
result = await parse_webhook(state, config)
assert result["ticket_id"] == "ALLPOST-100"
assert result["has_ticket"] is True
async def test_no_ticket_when_branch_has_none(self) -> None:
state = {"webhook_payload": _make_webhook_payload(
source_ref="refs/heads/bugfix/generic_fix"
)}
config = build_config()
result = await parse_webhook(state, config)
assert result["has_ticket"] is False
assert result["ticket_id"] is None
async def test_sets_repo_name(self) -> None:
state = {"webhook_payload": _make_webhook_payload(repo_name="backend-api")}
config = build_config()
result = await parse_webhook(state, config)
assert result["repo_name"] == "backend-api"
async def test_sets_pr_id_as_string(self) -> None:
state = {"webhook_payload": _make_webhook_payload(pr_id=99)}
config = build_config()
result = await parse_webhook(state, config)
assert result["pr_info"]["pr_id"] == "99"
async def test_invalid_payload_adds_error(self) -> None:
state = {"webhook_payload": {"bad": "data"}}
config = build_config()
result = await parse_webhook(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
# ---------------------------------------------------------------------------
# fetch_pr_details
# ---------------------------------------------------------------------------
class TestFetchPrDetails:
async def test_fetches_pr_and_sets_pr_already_merged_false(self) -> None:
clients = build_mock_clients()
pr = _make_pr_info(status="active")
clients.azdo.get_pr = AsyncMock(return_value=pr)
clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py")
config = build_config(clients)
state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "active"}}
result = await fetch_pr_details(state, config)
assert result["pr_already_merged"] is False
assert result["pr_diff"] == "edit: main.py"
async def test_sets_pr_already_merged_true_when_completed(self) -> None:
clients = build_mock_clients()
pr = _make_pr_info(status="completed")
clients.azdo.get_pr = AsyncMock(return_value=pr)
clients.azdo.get_pr_diff = AsyncMock(return_value="")
config = build_config(clients)
state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "completed"}}
result = await fetch_pr_details(state, config)
assert result["pr_already_merged"] is True
async def test_stores_last_merge_source_commit(self) -> None:
clients = build_mock_clients()
pr = _make_pr_info(status="active")
clients.azdo.get_pr = AsyncMock(return_value=pr)
clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py")
config = build_config(clients)
state = {"pr_id": "42", "pr_info": {"pr_id": "42"}}
result = await fetch_pr_details(state, config)
# last_merge_source_commit may be None if pr doesn't have it, but key must be present
assert "last_merge_source_commit" in result
async def test_adds_error_on_service_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.get_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Server error"
))
config = build_config(clients)
state = {"pr_id": "42"}
result = await fetch_pr_details(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
# ---------------------------------------------------------------------------
# move_jira_code_review
# ---------------------------------------------------------------------------
class TestMoveJiraCodeReview:
async def test_transitions_ticket_when_has_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_code_review(state, config)
clients.jira.transition_issue.assert_called_once_with("ALLPOST-100", "code review")
assert result == {} or "messages" in result
async def test_skips_when_no_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
state = {"has_ticket": False, "ticket_id": None}
result = await move_jira_code_review(state, config)
clients.jira.transition_issue.assert_not_called()
async def test_appends_error_on_jira_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Jira down"
))
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_code_review(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
# ---------------------------------------------------------------------------
# run_code_review
# ---------------------------------------------------------------------------
class TestRunCodeReview:
async def test_calls_reviewer_with_diff(self) -> None:
clients = build_mock_clients()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
clients.reviewer.review_pr = AsyncMock(return_value=review)
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix: bug", "repo_name": "my-repo"},
}
result = await run_code_review(state, config)
clients.reviewer.review_pr.assert_called_once()
assert "review_result" in result
async def test_stores_review_result_as_dict(self) -> None:
clients = build_mock_clients()
review = ReviewResult(verdict="approve", summary="Clean code", issues=())
clients.reviewer.review_pr = AsyncMock(return_value=review)
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix", "repo_name": "repo"},
}
result = await run_code_review(state, config)
assert result["review_result"]["verdict"] == "approve"
async def test_adds_error_on_reviewer_failure(self) -> None:
clients = build_mock_clients()
clients.reviewer.review_pr = AsyncMock(side_effect=Exception("API error"))
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix", "repo_name": "repo"},
}
result = await run_code_review(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# _post_review_to_pr
# ---------------------------------------------------------------------------
class TestPostReviewToPr:
async def test_posts_summary_comment(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_comment.assert_called_once()
call_kwargs = clients.azdo.add_pr_comment.call_args
assert "APPROVE" in call_kwargs.kwargs["content"]
async def test_posts_inline_comment_for_issue_with_file_and_line(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issue = ReviewIssue(
severity="error", description="Null check missing",
file_path="src/Foo.cs", line_start=42, suggestion="Add null guard",
)
review = ReviewResult(verdict="request_changes", summary="Issues", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_inline_comment.assert_called_once()
call_kwargs = clients.azdo.add_pr_inline_comment.call_args.kwargs
assert call_kwargs["file_path"] == "src/Foo.cs"
assert call_kwargs["line_start"] == 42
assert "Null check missing" in call_kwargs["content"]
assert "Add null guard" in call_kwargs["content"]
async def test_skips_inline_for_issue_without_line(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issue = ReviewIssue(severity="warning", description="Style issue", file_path="src/Foo.cs")
review = ReviewResult(verdict="approve", summary="OK", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_inline_comment.assert_not_called()
async def test_skips_inline_for_issue_without_file(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issue = ReviewIssue(severity="info", description="General note", line_start=10)
review = ReviewResult(verdict="approve", summary="OK", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_inline_comment.assert_not_called()
async def test_inline_failure_does_not_prevent_summary(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock(side_effect=Exception("API error"))
issue = ReviewIssue(
severity="blocker", description="Critical", file_path="a.cs", line_start=1
)
review = ReviewResult(verdict="request_changes", summary="Bad", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
# Summary should still be posted even though inline failed
clients.azdo.add_pr_comment.assert_called_once()
async def test_summary_failure_does_not_raise(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock(side_effect=Exception("Network error"))
clients.azdo.add_pr_inline_comment = AsyncMock()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
# Should not raise
await _post_review_to_pr(clients, "my-repo", 42, review)
async def test_summary_contains_issue_count(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issues = (
ReviewIssue(severity="warning", description="Issue 1"),
ReviewIssue(severity="error", description="Issue 2"),
)
review = ReviewResult(verdict="request_changes", summary="Problems", issues=issues)
await _post_review_to_pr(clients, "my-repo", 42, review)
content = clients.azdo.add_pr_comment.call_args.kwargs["content"]
assert "2 issue(s)" in content
async def test_run_code_review_calls_post_review(self) -> None:
"""Integration: run_code_review posts comments when pr_id and repo_name present."""
clients = build_mock_clients()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
clients.reviewer.review_pr = AsyncMock(return_value=review)
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "my-repo"},
}
await run_code_review(state, config)
clients.azdo.add_pr_comment.assert_called_once()
# ---------------------------------------------------------------------------
# evaluate_review
# ---------------------------------------------------------------------------
class TestEvaluateReview:
async def test_sets_review_approved_true_for_approve_verdict(self) -> None:
config = build_config()
state = {"review_result": _make_approve_review()}
result = await evaluate_review(state, config)
assert result["review_approved"] is True
async def test_sets_review_approved_false_for_request_changes(self) -> None:
config = build_config()
state = {"review_result": _make_request_changes_review()}
result = await evaluate_review(state, config)
assert result["review_approved"] is False
async def test_sets_false_when_review_result_missing(self) -> None:
config = build_config()
state = {}
result = await evaluate_review(state, config)
assert result["review_approved"] is False
async def test_sets_false_when_has_blockers(self) -> None:
config = build_config()
state = {
"review_result": {
"verdict": "approve",
"summary": "Approve with blocker?",
"issues": [{"severity": "blocker", "description": "Problem"}],
"has_blockers": True,
}
}
result = await evaluate_review(state, config)
assert result["review_approved"] is False
# ---------------------------------------------------------------------------
# interrupt_confirm_merge
# ---------------------------------------------------------------------------
class TestInterruptConfirmMerge:
async def test_calls_interrupt_with_summary_string(self) -> None:
config = build_config()
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"},
"review_result": {"summary": "LGTM"},
}
with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_merge(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
assert len(call_arg) > 0
async def test_interrupt_value_contains_pr_info(self) -> None:
config = build_config()
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix: auth bug", "repo_name": "backend"},
"review_result": {"summary": "All good"},
}
with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_merge(state, config)
call_arg = mock_interrupt.call_args[0][0]
assert "42" in call_arg or "Fix: auth bug" in call_arg or "backend" in call_arg
# ---------------------------------------------------------------------------
# merge_pr_node
# ---------------------------------------------------------------------------
class TestMergePrNode:
async def test_calls_azdo_merge_pr(self) -> None:
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42"},
"last_merge_source_commit": "abc123",
}
result = await merge_pr_node(state, config)
clients.azdo.merge_pr.assert_called_once_with(
pr_id=42, last_merge_source_commit="abc123"
)
async def test_returns_message_on_success(self) -> None:
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42"},
"last_merge_source_commit": "abc123",
}
result = await merge_pr_node(state, config)
assert "messages" in result
async def test_re_raises_on_service_error(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=409, detail="Conflict"
))
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42"},
"last_merge_source_commit": "abc123",
}
with pytest.raises(ServiceError):
await merge_pr_node(state, config)
# ---------------------------------------------------------------------------
# move_jira_ready_for_stage
# ---------------------------------------------------------------------------
class TestMoveJiraReadyForStage:
async def test_transitions_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_ready_for_stage(state, config)
clients.jira.transition_issue.assert_called_once_with(
"ALLPOST-100", "Ready for stage (2)"
)
async def test_skips_when_no_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock()
config = build_config(clients)
state = {"has_ticket": False}
await move_jira_ready_for_stage(state, config)
clients.jira.transition_issue.assert_not_called()
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_ready_for_stage(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# add_jira_pr_link
# ---------------------------------------------------------------------------
class TestAddJiraPrLink:
async def test_calls_add_remote_link(self) -> None:
clients = build_mock_clients()
clients.jira.add_remote_link = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"ticket_id": "ALLPOST-100",
"has_ticket": True,
"pr_info": {
"pr_id": "42",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
"pr_title": "Fix: bug",
},
}
result = await add_jira_pr_link(state, config)
clients.jira.add_remote_link.assert_called_once()
async def test_skips_when_no_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.add_remote_link = AsyncMock()
config = build_config(clients)
state = {"has_ticket": False}
await add_jira_pr_link(state, config)
clients.jira.add_remote_link.assert_not_called()
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.add_remote_link = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Error"
))
config = build_config(clients)
state = {
"ticket_id": "ALLPOST-100",
"has_ticket": True,
"pr_info": {
"pr_id": "42",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
"pr_title": "Fix",
},
}
result = await add_jira_pr_link(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# calculate_version
# ---------------------------------------------------------------------------
class TestCalculateVersion:
async def test_returns_v1_0_0_for_empty_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await calculate_version(state, config)
assert result["version"] == "v1.0.0"
async def test_increments_patch_version(self, tmp_path) -> None:
from release_agent.models.release import StagingRelease
staging_store = JsonFileStagingStore(directory=tmp_path)
# Pre-populate with an existing version
staging = StagingRelease(
version="v1.0.5",
repo="my-repo",
started_at=date(2025, 1, 1),
tickets=[],
)
await staging_store.archive(staging, date(2025, 1, 10))
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await calculate_version(state, config)
assert result["version"] == "v1.0.6"
async def test_sets_version_in_state(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "new-repo"}
result = await calculate_version(state, config)
assert "version" in result
assert result["version"].startswith("v")
# ---------------------------------------------------------------------------
# update_staging
# ---------------------------------------------------------------------------
class TestUpdateStaging:
async def test_creates_new_staging_when_none_exists(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
# Jira get_issue returns a summary
from release_agent.models.jira import JiraIssue
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="ALLPOST-100", summary="Fix auth bug", status="Ready for stage (2)"
))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-100",
"has_ticket": True,
"pr_info": {
"pr_id": "42",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
"pr_title": "Fix: auth bug",
"branch": "feature/ALLPOST-100-fix",
},
}
result = await update_staging(state, config)
loaded = await staging_store.load("my-repo")
assert loaded is not None
assert loaded.has_ticket("ALLPOST-100")
async def test_appends_ticket_to_existing_staging(self, tmp_path) -> None:
from datetime import date
from release_agent.models.release import StagingRelease
from release_agent.models.jira import JiraIssue
staging_store = JsonFileStagingStore(directory=tmp_path)
existing = StagingRelease(
version="v1.0.0", repo="my-repo",
started_at=date(2025, 1, 1), tickets=[]
)
await staging_store.save(existing)
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="BILL-99", summary="New feature", status="Ready"
))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "BILL-99",
"has_ticket": True,
"pr_info": {
"pr_id": "55",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/55",
"pr_title": "Feat: new feature",
"branch": "feature/BILL-99-feat",
},
}
await update_staging(state, config)
loaded = await staging_store.load("my-repo")
assert loaded is not None
assert loaded.has_ticket("BILL-99")
async def test_skips_ticket_add_when_no_ticket(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"has_ticket": False,
}
await update_staging(state, config)
# No staging file should be created for ticket-less PR if no existing staging
# (or staging exists without new ticket added)
clients.jira.get_issue.assert_not_called()
async def test_returns_empty_dict_when_no_staging_store(self) -> None:
from release_agent.models.jira import JiraIssue
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="ALLPOST-1", summary="Fix", status="Ready"
))
config = build_config(clients, staging_store=None)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-1",
"has_ticket": True,
"pr_info": {
"pr_id": "1",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1",
"pr_title": "Fix",
"branch": "feature/ALLPOST-1",
},
}
result = await update_staging(state, config)
assert result == {}
async def test_uses_ticket_id_as_summary_on_jira_failure(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(side_effect=Exception("Jira unavailable"))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-99",
"has_ticket": True,
"pr_info": {
"pr_id": "5",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/5",
"pr_title": "Fix something",
"branch": "feature/ALLPOST-99_fix",
},
}
result = await update_staging(state, config)
loaded = await staging_store.load("my-repo")
assert loaded is not None
assert loaded.tickets[0].id == "ALLPOST-99"
assert loaded.tickets[0].summary == "ALLPOST-99"
async def test_sets_staging_dict_in_result(self, tmp_path) -> None:
from release_agent.models.jira import JiraIssue
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="ALLPOST-1", summary="S", status="Ready"
))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-1",
"has_ticket": True,
"pr_info": {
"pr_id": "1",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1",
"pr_title": "Fix",
"branch": "feature/ALLPOST-1",
},
}
result = await update_staging(state, config)
assert "staging" in result
assert isinstance(result["staging"], dict)
# ---------------------------------------------------------------------------
# notify_request_changes
# ---------------------------------------------------------------------------
class TestNotifyRequestChanges:
async def test_calls_slack_send_approval_request(self) -> None:
clients = build_mock_clients()
clients.slack.send_approval_request = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"},
"review_result": {
"verdict": "request_changes",
"summary": "Too many issues",
"issues": [{"severity": "blocker", "description": "No tests"}],
},
}
result = await notify_request_changes(state, config)
clients.slack.send_approval_request.assert_called_once()
async def test_appends_error_on_slack_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.slack.send_approval_request = AsyncMock(side_effect=ServiceError(
service="slack", status_code=500, detail="Webhook error"
))
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "repo"},
"review_result": {"summary": "Issues found", "issues": []},
}
result = await notify_request_changes(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# build_pr_completed_graph
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# auto_create_ticket node
# ---------------------------------------------------------------------------
class TestAutoCreateTicket:
"""Tests for the auto_create_ticket node."""
def _make_config_with_jira_project(
self, jira_project: str = "ALLPOST"
):
clients = build_mock_clients()
clients.jira.create_issue = AsyncMock(return_value="ALLPOST-99")
clients.reviewer.generate_ticket_content = AsyncMock(
return_value=("My summary", "My description")
)
config = build_config(clients)
config["configurable"]["default_jira_project"] = jira_project
return config, clients
async def test_creates_jira_issue_and_returns_ticket_id(self) -> None:
config, clients = self._make_config_with_jira_project()
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"},
}
result = await auto_create_ticket(state, config)
assert result.get("ticket_id") == "ALLPOST-99"
async def test_sets_has_ticket_true(self) -> None:
config, clients = self._make_config_with_jira_project()
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"},
}
result = await auto_create_ticket(state, config)
assert result.get("has_ticket") is True
async def test_calls_generate_ticket_content(self) -> None:
config, clients = self._make_config_with_jira_project()
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix login", "repo_name": "auth-service"},
}
await auto_create_ticket(state, config)
clients.reviewer.generate_ticket_content.assert_awaited_once()
async def test_calls_create_issue_with_project_key(self) -> None:
config, clients = self._make_config_with_jira_project(jira_project="MYPROJ")
clients.jira.create_issue = AsyncMock(return_value="MYPROJ-5")
config["configurable"]["default_jira_project"] = "MYPROJ"
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
await auto_create_ticket(state, config)
call_kwargs = clients.jira.create_issue.call_args.kwargs
assert call_kwargs["project"] == "MYPROJ"
async def test_appends_message_on_success(self) -> None:
config, _ = self._make_config_with_jira_project()
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert "messages" in result
assert len(result["messages"]) > 0
async def test_appends_error_on_create_issue_failure(self) -> None:
config, clients = self._make_config_with_jira_project()
clients.jira.create_issue = AsyncMock(side_effect=Exception("Jira down"))
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
async def test_appends_error_on_generate_content_failure(self) -> None:
config, clients = self._make_config_with_jira_project()
clients.reviewer.generate_ticket_content = AsyncMock(side_effect=RuntimeError("CLI fail"))
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert "errors" in result
async def test_uses_default_project_from_config(self) -> None:
config, clients = self._make_config_with_jira_project(jira_project="TEAM")
clients.jira.create_issue = AsyncMock(return_value="TEAM-1")
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert result["ticket_id"] == "TEAM-1"
# ---------------------------------------------------------------------------
# build_pr_completed_graph
# ---------------------------------------------------------------------------
class TestBuildPrCompletedGraph:
def test_returns_compiled_graph(self) -> None:
graph = build_pr_completed_graph()
assert graph is not None
def test_graph_has_nodes(self) -> None:
graph = build_pr_completed_graph()
# The compiled graph object should be truthy
assert graph is not None
def test_graph_includes_trigger_ci_build_node(self) -> None:
graph = build_pr_completed_graph()
# Graph nodes should include CI pipeline nodes
graph_nodes = graph.get_graph().nodes
assert "trigger_ci_build" in graph_nodes
def test_graph_includes_poll_ci_build_node(self) -> None:
graph = build_pr_completed_graph()
graph_nodes = graph.get_graph().nodes
assert "poll_ci_build" in graph_nodes
def test_graph_includes_notify_ci_result_node(self) -> None:
graph = build_pr_completed_graph()
graph_nodes = graph.get_graph().nodes
assert "notify_ci_result" in graph_nodes
def test_graph_includes_auto_create_ticket_node(self) -> None:
graph = build_pr_completed_graph()
graph_nodes = graph.get_graph().nodes
assert "auto_create_ticket" in graph_nodes

866
tests/graph/test_release.py Normal file
View File

@@ -0,0 +1,866 @@
"""Tests for graph/release.py node functions. Written FIRST (TDD RED phase).
Each node is an async function (state, config) -> dict.
Tests call nodes directly — no graph compilation required.
"""
from datetime import date
from unittest.mock import AsyncMock, patch
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore
from release_agent.graph.release import (
approve_stage,
archive_release,
check_release_approvals,
create_release_pr,
interrupt_confirm_approve,
interrupt_confirm_merge_release,
interrupt_confirm_release,
interrupt_confirm_trigger,
list_pipelines,
load_staging,
merge_release_pr,
move_tickets_to_done,
send_slack_notification,
trigger_pipelines,
build_release_graph,
)
from release_agent.models.pipeline import PipelineInfo, ReleasePipelineStage
from release_agent.models.release import StagingRelease
from release_agent.models.ticket import TicketEntry
from tests.graph.conftest import build_config, build_mock_clients
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(
*,
repo: str = "my-repo",
version: str = "v1.0.0",
tickets: list | None = None,
) -> StagingRelease:
t = tickets if tickets is not None else [_make_ticket()]
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=t,
)
def _staging_dict(staging: StagingRelease) -> dict:
return staging.model_dump(mode="json")
# ---------------------------------------------------------------------------
# load_staging
# ---------------------------------------------------------------------------
class TestLoadStaging:
async def test_loads_staging_from_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await load_staging(state, config)
assert "staging" in result
assert result["staging"]["version"] == "v1.0.0"
async def test_returns_none_when_no_staging(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "nonexistent"}
result = await load_staging(state, config)
assert result.get("staging") is None
async def test_staging_includes_tickets(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(tickets=[_make_ticket("BILL-10"), _make_ticket("BILL-11")])
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await load_staging(state, config)
assert len(result["staging"]["tickets"]) == 2
# ---------------------------------------------------------------------------
# interrupt_confirm_release
# ---------------------------------------------------------------------------
class TestInterruptConfirmRelease:
async def test_calls_interrupt_with_staging_summary(self) -> None:
config = build_config()
staging = _make_staging()
state = {
"repo_name": "my-repo",
"staging": _staging_dict(staging),
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_release(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
async def test_interrupt_contains_version_and_repo(self) -> None:
config = build_config()
staging = _make_staging(version="v2.5.0", repo="backend")
state = {
"repo_name": "backend",
"staging": _staging_dict(staging),
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_release(state, config)
call_arg = mock_interrupt.call_args[0][0]
assert "v2.5.0" in call_arg or "backend" in call_arg
# ---------------------------------------------------------------------------
# create_release_pr
# ---------------------------------------------------------------------------
class TestCreateReleasePr:
async def test_calls_azdo_create_pr(self) -> None:
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(return_value={
"pullRequestId": 99,
"lastMergeSourceCommit": {"commitId": "deadbeef"},
})
config = build_config(clients)
staging = _make_staging(version="v1.2.0")
state = {
"repo_name": "my-repo",
"version": "v1.2.0",
"staging": _staging_dict(staging),
}
result = await create_release_pr(state, config)
clients.azdo.create_pr.assert_called_once()
call_kwargs = clients.azdo.create_pr.call_args.kwargs
assert call_kwargs["repo"] == "my-repo"
async def test_sets_release_pr_id(self) -> None:
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(return_value={
"pullRequestId": 77,
"lastMergeSourceCommit": {"commitId": "cafe1234"},
})
config = build_config(clients)
staging = _make_staging(version="v1.0.3")
state = {
"repo_name": "my-repo",
"version": "v1.0.3",
"staging": _staging_dict(staging),
}
result = await create_release_pr(state, config)
assert result["release_pr_id"] == "77"
async def test_sets_release_pr_commit(self) -> None:
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(return_value={
"pullRequestId": 77,
"lastMergeSourceCommit": {"commitId": "cafe1234"},
})
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
result = await create_release_pr(state, config)
assert result["release_pr_commit"] == "cafe1234"
async def test_re_raises_on_service_error(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=422, detail="Invalid branch"
))
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
with pytest.raises(ServiceError):
await create_release_pr(state, config)
# ---------------------------------------------------------------------------
# interrupt_confirm_merge_release
# ---------------------------------------------------------------------------
class TestInterruptConfirmMergeRelease:
async def test_calls_interrupt_with_pr_info(self) -> None:
config = build_config()
state = {
"release_pr_id": "99",
"version": "v1.0.0",
"repo_name": "my-repo",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_merge_release(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
assert len(call_arg) > 0
# ---------------------------------------------------------------------------
# merge_release_pr
# ---------------------------------------------------------------------------
class TestMergeReleasePr:
async def test_calls_azdo_merge_pr(self) -> None:
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"release_pr_id": "99",
"release_pr_commit": "abc123",
}
await merge_release_pr(state, config)
clients.azdo.merge_pr.assert_called_once_with(
pr_id=99, last_merge_source_commit="abc123"
)
async def test_re_raises_on_service_error(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=409, detail="Conflict"
))
config = build_config(clients)
state = {"release_pr_id": "99", "release_pr_commit": "abc"}
with pytest.raises(ServiceError):
await merge_release_pr(state, config)
# ---------------------------------------------------------------------------
# move_tickets_to_done
# ---------------------------------------------------------------------------
class TestMoveTicketsToDone:
async def test_transitions_all_tickets(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
staging = _make_staging(tickets=[_make_ticket("BILL-1"), _make_ticket("BILL-2")])
state = {"staging": _staging_dict(staging)}
await move_tickets_to_done(state, config)
assert clients.jira.transition_issue.call_count == 2
async def test_calls_transition_with_done_name(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
staging = _make_staging(tickets=[_make_ticket("BILL-1")])
state = {"staging": _staging_dict(staging)}
await move_tickets_to_done(state, config)
call_args = clients.jira.transition_issue.call_args_list[0]
ticket_id, transition = call_args[0]
assert ticket_id == "BILL-1"
assert "done" in transition.lower() or "released" in transition.lower()
async def test_appends_error_on_jira_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Error"
))
config = build_config(clients)
staging = _make_staging(tickets=[_make_ticket()])
state = {"staging": _staging_dict(staging)}
result = await move_tickets_to_done(state, config)
assert "errors" in result
async def test_empty_tickets_no_calls(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock()
config = build_config(clients)
staging = _make_staging(tickets=[])
state = {"staging": _staging_dict(staging)}
await move_tickets_to_done(state, config)
clients.jira.transition_issue.assert_not_called()
# ---------------------------------------------------------------------------
# send_slack_notification
# ---------------------------------------------------------------------------
class TestSendSlackNotification:
async def test_calls_slack_send_release_notification(self) -> None:
clients = build_mock_clients()
clients.slack.send_release_notification = AsyncMock(return_value=True)
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
result = await send_slack_notification(state, config)
clients.slack.send_release_notification.assert_called_once()
async def test_appends_error_on_slack_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.slack.send_release_notification = AsyncMock(side_effect=ServiceError(
service="slack", status_code=500, detail="Webhook error"
))
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
result = await send_slack_notification(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# archive_release
# ---------------------------------------------------------------------------
class TestArchiveRelease:
async def test_archives_staging_to_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"staging": _staging_dict(staging),
}
await archive_release(state, config)
# Staging should be gone now
assert await staging_store.load("my-repo") is None
async def test_archive_file_created_in_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(version="v3.0.0")
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"staging": _staging_dict(staging),
}
await archive_release(state, config)
versions = await staging_store.list_versions("my-repo")
assert "v3.0.0" in versions
# ---------------------------------------------------------------------------
# list_pipelines
# ---------------------------------------------------------------------------
class TestListPipelines:
async def test_fetches_pipelines_from_azdo(self) -> None:
clients = build_mock_clients()
pipelines = [PipelineInfo(id=1, name="build", repo="my-repo")]
clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines)
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
clients.azdo.list_build_pipelines.assert_called_once_with(repo="my-repo")
assert "pipelines" in result
assert len(result["pipelines"]) == 1
async def test_stores_pipelines_as_list_of_dicts(self) -> None:
clients = build_mock_clients()
pipelines = [
PipelineInfo(id=1, name="build", repo="my-repo"),
PipelineInfo(id=2, name="deploy", repo="my-repo"),
]
clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines)
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
assert len(result["pipelines"]) == 2
assert result["pipelines"][0]["id"] == 1
async def test_empty_pipelines_stored_as_empty_list(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines = AsyncMock(return_value=[])
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
assert result["pipelines"] == []
async def test_appends_error_on_service_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.list_build_pipelines = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# interrupt_confirm_trigger
# ---------------------------------------------------------------------------
class TestInterruptConfirmTrigger:
async def test_calls_interrupt_with_pipelines_summary(self) -> None:
config = build_config()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}],
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_trigger(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
# ---------------------------------------------------------------------------
# trigger_pipelines
# ---------------------------------------------------------------------------
class TestTriggerPipelines:
async def test_triggers_each_pipeline(self) -> None:
clients = build_mock_clients()
clients.azdo.trigger_pipeline = AsyncMock(return_value={"id": 1001})
config = build_config(clients)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [
{"id": 1, "name": "build", "repo": "my-repo"},
{"id": 2, "name": "deploy", "repo": "my-repo"},
],
}
result = await trigger_pipelines(state, config)
assert clients.azdo.trigger_pipeline.call_count == 2
assert "triggered_builds" in result
assert len(result["triggered_builds"]) == 2
async def test_no_pipelines_no_calls(self) -> None:
clients = build_mock_clients()
clients.azdo.trigger_pipeline = AsyncMock()
config = build_config(clients)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [],
}
result = await trigger_pipelines(state, config)
clients.azdo.trigger_pipeline.assert_not_called()
assert result["triggered_builds"] == []
async def test_appends_error_on_trigger_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.trigger_pipeline = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}],
}
result = await trigger_pipelines(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# check_release_approvals
# ---------------------------------------------------------------------------
class TestCheckReleaseApprovals:
async def test_fetches_pending_approvals_from_builds(self) -> None:
clients = build_mock_clients()
clients.azdo.get_build_status = AsyncMock(return_value="completed")
config = build_config(clients)
state = {
"triggered_builds": [{"id": 1001}],
}
result = await check_release_approvals(state, config)
assert "pending_approvals" in result
async def test_empty_builds_means_no_approvals(self) -> None:
clients = build_mock_clients()
config = build_config(clients)
state = {"triggered_builds": []}
result = await check_release_approvals(state, config)
assert result["pending_approvals"] == []
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.get_build_status = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"triggered_builds": [{"id": 1001}]}
result = await check_release_approvals(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# interrupt_confirm_approve
# ---------------------------------------------------------------------------
class TestInterruptConfirmApprove:
async def test_calls_interrupt_with_approvals_summary(self) -> None:
config = build_config()
state = {
"pending_approvals": [{"approval_id": "aaa", "stage_name": "Production"}],
"version": "v1.0.0",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_approve(state, config)
mock_interrupt.assert_called_once()
# ---------------------------------------------------------------------------
# approve_stage
# ---------------------------------------------------------------------------
class TestApproveStage:
async def test_approves_each_pending_approval(self) -> None:
clients = build_mock_clients()
clients.azdo.approve_release = AsyncMock(return_value={"status": "approved"})
config = build_config(clients)
state = {
"pending_approvals": [
{"approval_id": "aaa"},
{"approval_id": "bbb"},
],
}
result = await approve_stage(state, config)
assert clients.azdo.approve_release.call_count == 2
async def test_no_approvals_no_calls(self) -> None:
clients = build_mock_clients()
clients.azdo.approve_release = AsyncMock()
config = build_config(clients)
state = {"pending_approvals": []}
await approve_stage(state, config)
clients.azdo.approve_release.assert_not_called()
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.approve_release = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"pending_approvals": [{"approval_id": "aaa"}]}
result = await approve_stage(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# build_release_graph
# ---------------------------------------------------------------------------
class TestBuildReleaseGraph:
def test_returns_compiled_graph(self) -> None:
graph = build_release_graph()
assert graph is not None
def test_graph_includes_trigger_ci_build_main_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "trigger_ci_build_main" in graph_nodes
def test_graph_includes_poll_ci_build_main_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "poll_ci_build_main" in graph_nodes
def test_graph_includes_wait_for_cd_release_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "wait_for_cd_release" in graph_nodes
def test_graph_includes_poll_release_approvals_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "poll_release_approvals" in graph_nodes
def test_graph_includes_interrupt_sandbox_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "interrupt_sandbox_approval" in graph_nodes
def test_graph_includes_interrupt_prod_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "interrupt_prod_approval" in graph_nodes
def test_graph_includes_execute_sandbox_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "execute_sandbox_approval" in graph_nodes
def test_graph_includes_execute_prod_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "execute_prod_approval" in graph_nodes
def test_graph_includes_notify_ci_failure_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "notify_ci_failure" in graph_nodes
# ---------------------------------------------------------------------------
# New release graph node: wait_for_cd_release
# ---------------------------------------------------------------------------
class TestWaitForCdRelease:
"""Tests for wait_for_cd_release node."""
async def test_sets_release_id_when_found(self) -> None:
from release_agent.graph.release import wait_for_cd_release
clients = build_mock_clients()
clients.azdo.get_latest_release.return_value = {"id": 100, "name": "Release-100"}
config = build_config(clients)
state = {"release_definition_id": 5, "repo_name": "my-repo"}
result = await wait_for_cd_release(state, config)
assert "release_id" in result
assert result["release_id"] == 100
async def test_appends_error_when_no_release(self) -> None:
from release_agent.graph.release import wait_for_cd_release
clients = build_mock_clients()
clients.azdo.get_latest_release.return_value = {}
config = build_config(clients)
state = {"release_definition_id": 5, "repo_name": "my-repo"}
result = await wait_for_cd_release(state, config)
assert "errors" in result
async def test_works_without_release_definition_id(self) -> None:
from release_agent.graph.release import wait_for_cd_release
clients = build_mock_clients()
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await wait_for_cd_release(state, config)
assert isinstance(result, dict)
# ---------------------------------------------------------------------------
# New release graph node: poll_release_approvals
# ---------------------------------------------------------------------------
class TestPollReleaseApprovals:
"""Tests for poll_release_approvals node."""
async def test_sets_pending_approvals_from_azdo(self) -> None:
from release_agent.graph.release import poll_release_approvals
from release_agent.models.build import ApprovalRecord
clients = build_mock_clients()
clients.azdo.get_release_approvals.return_value = [
ApprovalRecord(approval_id="a1", stage_name="Sandbox", status="pending", release_id=10),
]
config = build_config(clients)
state = {"release_id": 10}
result = await poll_release_approvals(state, config)
assert "pending_approvals" in result
assert len(result["pending_approvals"]) == 1
async def test_returns_empty_list_when_no_approvals(self) -> None:
from release_agent.graph.release import poll_release_approvals
clients = build_mock_clients()
clients.azdo.get_release_approvals.return_value = []
config = build_config(clients)
state = {"release_id": 10}
result = await poll_release_approvals(state, config)
assert result.get("pending_approvals") == []
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
from release_agent.graph.release import poll_release_approvals
clients = build_mock_clients()
clients.azdo.get_release_approvals.side_effect = ServiceError(
service="azdo", status_code=500, detail="error"
)
config = build_config(clients)
state = {"release_id": 10}
result = await poll_release_approvals(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# New release graph node: interrupt_sandbox_approval
# ---------------------------------------------------------------------------
class TestInterruptSandboxApproval:
async def test_calls_interrupt(self) -> None:
from release_agent.graph.release import interrupt_sandbox_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
"version": "v1.0.0",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_sandbox_approval(state, config)
mock_interrupt.assert_called_once()
async def test_sets_current_stage_to_sandbox_pending(self) -> None:
from release_agent.graph.release import interrupt_sandbox_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
}
with patch("release_agent.graph.release.interrupt", return_value="yes"):
result = await interrupt_sandbox_approval(state, config)
assert result.get("current_stage") == "sandbox_pending"
# ---------------------------------------------------------------------------
# New release graph node: interrupt_prod_approval
# ---------------------------------------------------------------------------
class TestInterruptProdApproval:
async def test_calls_interrupt(self) -> None:
from release_agent.graph.release import interrupt_prod_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
"version": "v1.0.0",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_prod_approval(state, config)
mock_interrupt.assert_called_once()
async def test_sets_current_stage_to_prod_pending(self) -> None:
from release_agent.graph.release import interrupt_prod_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
}
with patch("release_agent.graph.release.interrupt", return_value="yes"):
result = await interrupt_prod_approval(state, config)
assert result.get("current_stage") == "prod_pending"
# ---------------------------------------------------------------------------
# New release graph node: execute_sandbox_approval
# ---------------------------------------------------------------------------
class TestExecuteSandboxApproval:
async def test_approves_sandbox_approvals(self) -> None:
from release_agent.graph.release import execute_sandbox_approval
clients = build_mock_clients()
clients.azdo.approve_release.return_value = {"status": "approved"}
config = build_config(clients)
state = {
"pending_approvals": [{"approval_id": "sb1", "stage_name": "Sandbox"}],
}
result = await execute_sandbox_approval(state, config)
clients.azdo.approve_release.assert_called()
async def test_returns_empty_dict_on_success(self) -> None:
from release_agent.graph.release import execute_sandbox_approval
clients = build_mock_clients()
clients.azdo.approve_release.return_value = {"status": "approved"}
config = build_config(clients)
state = {"pending_approvals": [{"approval_id": "sb1"}]}
result = await execute_sandbox_approval(state, config)
assert "errors" not in result or result["errors"] == []
# ---------------------------------------------------------------------------
# New release graph node: execute_prod_approval
# ---------------------------------------------------------------------------
class TestExecuteProdApproval:
async def test_approves_prod_approvals(self) -> None:
from release_agent.graph.release import execute_prod_approval
clients = build_mock_clients()
clients.azdo.approve_release.return_value = {"status": "approved"}
config = build_config(clients)
state = {
"pending_approvals": [{"approval_id": "pd1", "stage_name": "Production"}],
}
result = await execute_prod_approval(state, config)
clients.azdo.approve_release.assert_called()
# ---------------------------------------------------------------------------
# New release graph node: notify_ci_failure
# ---------------------------------------------------------------------------
class TestNotifyCiFailure:
async def test_sends_slack_notification(self) -> None:
from release_agent.graph.release import notify_ci_failure
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_result": "failed",
"ci_build_url": "https://build/1",
}
result = await notify_ci_failure(state, config)
clients.slack.send_notification.assert_called_once()
async def test_appends_message_on_success(self) -> None:
from release_agent.graph.release import notify_ci_failure
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {"repo_name": "my-repo", "ci_build_result": "failed"}
result = await notify_ci_failure(state, config)
assert "messages" in result or isinstance(result, dict)

302
tests/graph/test_routing.py Normal file
View File

@@ -0,0 +1,302 @@
"""Tests for graph/routing.py. Written FIRST (TDD RED phase).
All routing functions are pure — they take a state dict and return a string.
Every branch is tested, including missing state fields (defaults to falsy).
"""
import pytest
from release_agent.graph.routing import (
has_pending_approvals,
has_pipelines,
has_ticket,
is_pr_already_merged,
is_review_approved,
route_after_fetch,
route_approval_stage,
route_ci_result,
should_continue_to_release,
)
# ---------------------------------------------------------------------------
# is_pr_already_merged
# ---------------------------------------------------------------------------
class TestIsPrAlreadyMerged:
def test_returns_merged_when_true(self) -> None:
state = {"pr_already_merged": True}
assert is_pr_already_merged(state) == "merged"
def test_returns_active_when_false(self) -> None:
state = {"pr_already_merged": False}
assert is_pr_already_merged(state) == "active"
def test_returns_active_when_field_missing(self) -> None:
state = {}
assert is_pr_already_merged(state) == "active"
def test_returns_active_when_none(self) -> None:
state = {"pr_already_merged": None}
assert is_pr_already_merged(state) == "active"
# ---------------------------------------------------------------------------
# is_review_approved
# ---------------------------------------------------------------------------
class TestIsReviewApproved:
def test_returns_approve_when_true(self) -> None:
state = {"review_approved": True}
assert is_review_approved(state) == "approve"
def test_returns_request_changes_when_false(self) -> None:
state = {"review_approved": False}
assert is_review_approved(state) == "request_changes"
def test_returns_request_changes_when_field_missing(self) -> None:
state = {}
assert is_review_approved(state) == "request_changes"
def test_returns_request_changes_when_none(self) -> None:
state = {"review_approved": None}
assert is_review_approved(state) == "request_changes"
# ---------------------------------------------------------------------------
# has_ticket
# ---------------------------------------------------------------------------
class TestHasTicket:
def test_returns_yes_when_true(self) -> None:
state = {"has_ticket": True}
assert has_ticket(state) == "yes"
def test_returns_no_when_false(self) -> None:
state = {"has_ticket": False}
assert has_ticket(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert has_ticket(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"has_ticket": None}
assert has_ticket(state) == "no"
# ---------------------------------------------------------------------------
# should_continue_to_release
# ---------------------------------------------------------------------------
class TestShouldContinueToRelease:
def test_returns_yes_when_true(self) -> None:
state = {"continue_to_release": True}
assert should_continue_to_release(state) == "yes"
def test_returns_no_when_false(self) -> None:
state = {"continue_to_release": False}
assert should_continue_to_release(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert should_continue_to_release(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"continue_to_release": None}
assert should_continue_to_release(state) == "no"
# ---------------------------------------------------------------------------
# has_pipelines
# ---------------------------------------------------------------------------
class TestHasPipelines:
def test_returns_yes_when_non_empty_list(self) -> None:
state = {"pipelines": [{"id": 1}]}
assert has_pipelines(state) == "yes"
def test_returns_no_when_empty_list(self) -> None:
state = {"pipelines": []}
assert has_pipelines(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert has_pipelines(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"pipelines": None}
assert has_pipelines(state) == "no"
def test_returns_yes_with_multiple_pipelines(self) -> None:
state = {"pipelines": [{"id": 1}, {"id": 2}]}
assert has_pipelines(state) == "yes"
# ---------------------------------------------------------------------------
# has_pending_approvals
# ---------------------------------------------------------------------------
class TestHasPendingApprovals:
def test_returns_yes_when_non_empty_list(self) -> None:
state = {"pending_approvals": [{"approval_id": "abc"}]}
assert has_pending_approvals(state) == "yes"
def test_returns_no_when_empty_list(self) -> None:
state = {"pending_approvals": []}
assert has_pending_approvals(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert has_pending_approvals(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"pending_approvals": None}
assert has_pending_approvals(state) == "no"
def test_returns_yes_with_multiple_approvals(self) -> None:
state = {"pending_approvals": [{"approval_id": "a"}, {"approval_id": "b"}]}
assert has_pending_approvals(state) == "yes"
# ---------------------------------------------------------------------------
# route_ci_result
# ---------------------------------------------------------------------------
class TestRouteCiResult:
"""Tests for route_ci_result routing function."""
def test_returns_ci_passed_when_succeeded(self) -> None:
state = {"ci_build_result": "succeeded"}
assert route_ci_result(state) == "ci_passed"
def test_returns_ci_failed_when_failed(self) -> None:
state = {"ci_build_result": "failed"}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_canceled(self) -> None:
state = {"ci_build_result": "canceled"}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_partially_succeeded(self) -> None:
state = {"ci_build_result": "partiallySucceeded"}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_field_missing(self) -> None:
state = {}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_none(self) -> None:
state = {"ci_build_result": None}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_empty_string(self) -> None:
state = {"ci_build_result": ""}
assert route_ci_result(state) == "ci_failed"
def test_case_sensitive_succeeded(self) -> None:
# AzDo returns "succeeded" (lowercase)
state = {"ci_build_result": "succeeded"}
assert route_ci_result(state) == "ci_passed"
# ---------------------------------------------------------------------------
# route_approval_stage
# ---------------------------------------------------------------------------
class TestRouteApprovalStage:
"""Tests for route_approval_stage routing function."""
def test_returns_all_deployed_when_no_pending_approvals(self) -> None:
state = {"pending_approvals": []}
assert route_approval_stage(state) == "all_deployed"
def test_returns_all_deployed_when_field_missing(self) -> None:
state = {}
assert route_approval_stage(state) == "all_deployed"
def test_returns_all_deployed_when_none(self) -> None:
state = {"pending_approvals": None}
assert route_approval_stage(state) == "all_deployed"
def test_returns_sandbox_pending_when_sandbox_approval_exists(self) -> None:
state = {
"current_stage": "sandbox_pending",
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
}
assert route_approval_stage(state) == "sandbox_pending"
def test_returns_prod_pending_when_prod_approval_exists(self) -> None:
state = {
"current_stage": "prod_pending",
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
}
assert route_approval_stage(state) == "prod_pending"
def test_uses_current_stage_field_when_present(self) -> None:
state = {
"current_stage": "sandbox_pending",
"pending_approvals": [{"approval_id": "z"}],
}
assert route_approval_stage(state) == "sandbox_pending"
def test_returns_all_deployed_when_no_current_stage_and_has_approvals(self) -> None:
# When current_stage is missing but approvals exist, stage is unknown
# so we treat as sandbox by default (first stage)
state = {
"pending_approvals": [{"approval_id": "a"}],
}
# Must return either sandbox_pending or prod_pending (not all_deployed)
result = route_approval_stage(state)
assert result in ("sandbox_pending", "prod_pending")
def test_sandbox_pending_from_current_stage(self) -> None:
state = {"current_stage": "sandbox_pending", "pending_approvals": [{"approval_id": "x"}]}
assert route_approval_stage(state) == "sandbox_pending"
def test_prod_pending_from_current_stage(self) -> None:
state = {"current_stage": "prod_pending", "pending_approvals": [{"approval_id": "x"}]}
assert route_approval_stage(state) == "prod_pending"
# ---------------------------------------------------------------------------
# route_after_fetch
# ---------------------------------------------------------------------------
class TestRouteAfterFetch:
"""Tests for route_after_fetch — 3-way routing replacing is_pr_already_merged."""
def test_returns_merged_when_pr_already_merged(self) -> None:
state = {"pr_already_merged": True}
assert route_after_fetch(state) == "merged"
def test_returns_active_with_ticket_when_active_and_has_ticket(self) -> None:
state = {"pr_already_merged": False, "has_ticket": True}
assert route_after_fetch(state) == "active_with_ticket"
def test_returns_active_no_ticket_when_active_and_no_ticket(self) -> None:
state = {"pr_already_merged": False, "has_ticket": False}
assert route_after_fetch(state) == "active_no_ticket"
def test_returns_active_no_ticket_when_has_ticket_missing(self) -> None:
state = {"pr_already_merged": False}
assert route_after_fetch(state) == "active_no_ticket"
def test_returns_active_no_ticket_when_has_ticket_none(self) -> None:
state = {"pr_already_merged": False, "has_ticket": None}
assert route_after_fetch(state) == "active_no_ticket"
def test_returns_active_no_ticket_when_all_fields_missing(self) -> None:
state = {}
assert route_after_fetch(state) == "active_no_ticket"
def test_merged_takes_precedence_over_has_ticket(self) -> None:
# Even if has_ticket is True, merged PR should route to "merged"
state = {"pr_already_merged": True, "has_ticket": True}
assert route_after_fetch(state) == "merged"
def test_returns_active_with_ticket_ignores_merged_false(self) -> None:
state = {"pr_already_merged": False, "has_ticket": True}
result = route_after_fetch(state)
assert result != "merged"
assert result == "active_with_ticket"

View File

View File

@@ -0,0 +1,332 @@
"""Tests for scripts/migrate_json_to_db.py.
Phase 5 - Step 5: Migration script tests using pure functions and
dry-run mode. No real database required.
Written FIRST (TDD RED phase).
"""
import json
from datetime import date
from pathlib import Path
import pytest
from scripts.migrate_json_to_db import (
collect_json_files,
parse_staging_json,
parse_archived_json,
build_staging_insert_sql,
build_archived_insert_sql,
is_archived_filename,
is_staging_filename,
MigrationRecord,
)
# ---------------------------------------------------------------------------
# Fixture data (mirrors real JSON structure from release-workflow/releases/)
# ---------------------------------------------------------------------------
STAGING_JSON = {
"version": "v1.0.0",
"repo": "Billo.Platform.Document",
"started_at": "2026-03-17",
"tickets": [
{
"id": "ALLPOST-4219",
"summary": "Test release bot",
"pr_id": "10460",
"pr_url": "https://dev.azure.com/billodev/Billo%20App%20Platform/_git/Billo.Platform.Document/pullrequest/10460",
"pr_title": "chore: trigger release bot test",
"branch": "feature_ALLPOST-4219_test_release_bot",
"merged_at": "2026-03-17",
}
],
}
PAYMENT_STAGING_JSON = {
"version": "v1.0.1",
"repo": "Billo.Platform.Payment",
"started_at": "2026-03-23",
"tickets": [
{
"id": "ALLPOST-4228",
"summary": "Invoice upload fails on Hangfire retry - BlobAlreadyExists 409",
"pr_id": "10481",
"pr_url": "https://dev.azure.com/billodev/Billo%20App%20Platform/_git/Billo.Platform.Payment/pullrequest/10481",
"pr_title": "Invoice upload fails on Hangfire retry - BlobAlreadyExists 409",
"branch": "bug/ALLPOST-4228_fix-invoice-upload-blob-already-exists",
"merged_at": "2026-03-23",
}
],
}
# Archived JSON has an additional released_at field
ARCHIVED_JSON = {
"version": "v1.0.0",
"repo": "Billo.Platform.Payment",
"started_at": "2026-01-01",
"tickets": [],
"released_at": "2026-01-15",
}
# ---------------------------------------------------------------------------
# is_staging_filename / is_archived_filename
# ---------------------------------------------------------------------------
class TestFileNameClassification:
def test_staging_filename_identified(self) -> None:
assert is_staging_filename("Billo.Platform.Document.json") is True
def test_archived_filename_identified(self) -> None:
assert is_archived_filename("Billo.Platform.Payment_v1.0.1_2026-03-23.json") is True
def test_staging_filename_not_archived(self) -> None:
assert is_archived_filename("Billo.Platform.Document.json") is False
def test_archived_filename_not_staging(self) -> None:
assert is_staging_filename("Billo.Platform.Payment_v1.0.1_2026-03-23.json") is False
def test_non_json_file_is_not_staging(self) -> None:
assert is_staging_filename("README.md") is False
def test_non_json_file_is_not_archived(self) -> None:
assert is_archived_filename("README.md") is False
# ---------------------------------------------------------------------------
# parse_staging_json
# ---------------------------------------------------------------------------
class TestParseStagingJson:
def test_parse_returns_migration_record(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert isinstance(record, MigrationRecord)
def test_parse_extracts_repo(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert record.repo == "Billo.Platform.Document"
def test_parse_extracts_version(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert record.version == "v1.0.0"
def test_parse_extracts_started_at(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert record.started_at == date(2026, 3, 17)
def test_parse_extracts_tickets(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert len(record.tickets) == 1
assert record.tickets[0]["id"] == "ALLPOST-4219"
def test_parse_staging_has_no_released_at(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert record.released_at is None
def test_parse_staging_with_multiple_tickets(self) -> None:
data = {
**STAGING_JSON,
"tickets": [
{**STAGING_JSON["tickets"][0], "id": "ALLPOST-1"},
{**STAGING_JSON["tickets"][0], "id": "ALLPOST-2"},
],
}
record = parse_staging_json(data)
assert len(record.tickets) == 2
def test_parse_staging_with_empty_tickets(self) -> None:
data = {**STAGING_JSON, "tickets": []}
record = parse_staging_json(data)
assert record.tickets == []
# ---------------------------------------------------------------------------
# parse_archived_json
# ---------------------------------------------------------------------------
class TestParseArchivedJson:
def test_parse_returns_migration_record(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
assert isinstance(record, MigrationRecord)
def test_parse_extracts_released_at(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
assert record.released_at == date(2026, 1, 15)
def test_parse_extracts_repo(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
assert record.repo == "Billo.Platform.Payment"
def test_parse_extracts_version(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
assert record.version == "v1.0.0"
# ---------------------------------------------------------------------------
# build_staging_insert_sql
# ---------------------------------------------------------------------------
class TestBuildStagingInsertSql:
def test_returns_tuple_of_sql_and_params(self) -> None:
record = parse_staging_json(STAGING_JSON)
sql, params = build_staging_insert_sql(record)
assert isinstance(sql, str)
assert isinstance(params, tuple)
def test_sql_inserts_into_staging_releases(self) -> None:
record = parse_staging_json(STAGING_JSON)
sql, _ = build_staging_insert_sql(record)
assert "staging_releases" in sql
def test_sql_is_insert_statement(self) -> None:
record = parse_staging_json(STAGING_JSON)
sql, _ = build_staging_insert_sql(record)
assert "INSERT" in sql.upper()
def test_params_include_repo(self) -> None:
record = parse_staging_json(STAGING_JSON)
_, params = build_staging_insert_sql(record)
assert "Billo.Platform.Document" in params
def test_params_include_version(self) -> None:
record = parse_staging_json(STAGING_JSON)
_, params = build_staging_insert_sql(record)
assert "v1.0.0" in params
def test_params_include_started_at(self) -> None:
record = parse_staging_json(STAGING_JSON)
_, params = build_staging_insert_sql(record)
assert "2026-03-17" in params
def test_params_include_tickets_json(self) -> None:
record = parse_staging_json(STAGING_JSON)
_, params = build_staging_insert_sql(record)
# tickets should be serialized as JSON string
tickets_json = next(p for p in params if isinstance(p, str) and "ALLPOST-4219" in p)
parsed = json.loads(tickets_json)
assert parsed[0]["id"] == "ALLPOST-4219"
def test_sql_uses_on_conflict_do_nothing_or_update(self) -> None:
record = parse_staging_json(STAGING_JSON)
sql, _ = build_staging_insert_sql(record)
assert "ON CONFLICT" in sql.upper() or "INSERT" in sql.upper()
# ---------------------------------------------------------------------------
# build_archived_insert_sql
# ---------------------------------------------------------------------------
class TestBuildArchivedInsertSql:
def test_returns_tuple_of_sql_and_params(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
sql, params = build_archived_insert_sql(record)
assert isinstance(sql, str)
assert isinstance(params, tuple)
def test_sql_inserts_into_archived_releases(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
sql, _ = build_archived_insert_sql(record)
assert "archived_releases" in sql
def test_sql_is_insert_statement(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
sql, _ = build_archived_insert_sql(record)
assert "INSERT" in sql.upper()
def test_params_include_released_at(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
_, params = build_archived_insert_sql(record)
assert "2026-01-15" in params
def test_params_include_repo(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
_, params = build_archived_insert_sql(record)
assert "Billo.Platform.Payment" in params
# ---------------------------------------------------------------------------
# collect_json_files
# ---------------------------------------------------------------------------
class TestCollectJsonFiles:
def test_returns_empty_list_for_empty_directory(self, tmp_path: Path) -> None:
result = collect_json_files(tmp_path)
assert result == []
def test_finds_staging_json_files(self, tmp_path: Path) -> None:
(tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert len(result) == 1
def test_finds_archived_json_files(self, tmp_path: Path) -> None:
(tmp_path / "my-repo_v1.0.0_2025-06-01.json").write_text(
json.dumps(ARCHIVED_JSON)
)
result = collect_json_files(tmp_path)
assert len(result) == 1
def test_ignores_non_json_files(self, tmp_path: Path) -> None:
(tmp_path / "README.md").write_text("readme")
(tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert len(result) == 1
def test_collects_from_nested_directories(self, tmp_path: Path) -> None:
repo_dir = tmp_path / "Billo.Platform.Document"
repo_dir.mkdir()
(repo_dir / "v1.0.0.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert len(result) == 1
def test_returns_path_objects(self, tmp_path: Path) -> None:
(tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert all(isinstance(p, Path) for p in result)
def test_collects_multiple_files(self, tmp_path: Path) -> None:
for i in range(3):
(tmp_path / f"repo-{i}.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert len(result) == 3
# ---------------------------------------------------------------------------
# Dry-run mode (integration of pure functions)
# ---------------------------------------------------------------------------
class TestDryRunMode:
def test_dry_run_collects_records_without_db_access(self, tmp_path: Path) -> None:
"""Dry run processes files and returns SQL/params without executing."""
repo_dir = tmp_path / "Billo.Platform.Document"
repo_dir.mkdir()
(repo_dir / "v1.0.0.json").write_text(json.dumps(STAGING_JSON))
files = collect_json_files(tmp_path)
assert len(files) == 1
# Parse and build SQL — no DB connection needed
record = parse_staging_json(json.loads(files[0].read_text()))
sql, params = build_staging_insert_sql(record)
assert "INSERT" in sql.upper()
assert "Billo.Platform.Document" in params
def test_payment_staging_file_parses_correctly(self, tmp_path: Path) -> None:
(tmp_path / "Billo.Platform.Payment.json").write_text(
json.dumps(PAYMENT_STAGING_JSON)
)
files = collect_json_files(tmp_path)
record = parse_staging_json(json.loads(files[0].read_text()))
assert record.repo == "Billo.Platform.Payment"
assert record.version == "v1.0.1"
assert len(record.tickets) == 1
assert record.tickets[0]["id"] == "ALLPOST-4228"
def test_archived_file_parses_correctly(self, tmp_path: Path) -> None:
(tmp_path / "my-repo_v1.0.0_2026-01-15.json").write_text(
json.dumps(ARCHIVED_JSON)
)
files = collect_json_files(tmp_path)
record = parse_archived_json(json.loads(files[0].read_text()))
assert record.released_at == date(2026, 1, 15)

View File

View File

@@ -0,0 +1,141 @@
"""Tests for services/pr_dedup.py. Written FIRST (TDD RED phase).
find_unprocessed_prs queries agent_threads to find which PRs have not yet
been processed (no existing thread for that repo+pr_id combination).
"""
import pytest
from release_agent.models.pr import PRInfo
from release_agent.services.pr_dedup import find_unprocessed_prs
# ---------------------------------------------------------------------------
# Helpers — fake async pool
# ---------------------------------------------------------------------------
def _make_pr(pr_id: str, repo_name: str = "my-repo") -> PRInfo:
return PRInfo(
pr_id=pr_id,
pr_url=f"https://dev.azure.com/org/proj/_git/{repo_name}/pullrequest/{pr_id}",
repo_name=repo_name,
branch="refs/heads/feature/ALLPOST-100-fix",
pr_title=f"PR {pr_id}",
pr_status="active",
)
def _make_pool(existing_rows: list[tuple[str, str]]):
"""Return a fake async connection pool.
existing_rows: list of (pr_id, repo_name) tuples representing already-processed PRs.
"""
class FakeCursor:
def __init__(self, rows):
self._rows = rows
async def execute(self, sql, params=None):
pass
async def fetchall(self):
return self._rows
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeConn:
def __init__(self, rows):
self._rows = rows
def cursor(self):
return FakeCursor(self._rows)
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakePool:
def __init__(self, rows):
self._rows = rows
def connection(self):
return FakeConn(self._rows)
return FakePool(existing_rows)
# ---------------------------------------------------------------------------
# find_unprocessed_prs tests
# ---------------------------------------------------------------------------
class TestFindUnprocessedPrs:
async def test_returns_all_when_none_processed(self) -> None:
prs = [_make_pr("10"), _make_pr("20")]
pool = _make_pool([])
result = await find_unprocessed_prs(pool, prs)
assert len(result) == 2
async def test_returns_empty_when_all_processed(self) -> None:
prs = [_make_pr("10"), _make_pr("20")]
# existing rows: (pr_id, repo_name)
pool = _make_pool([("10", "my-repo"), ("20", "my-repo")])
result = await find_unprocessed_prs(pool, prs)
assert result == []
async def test_returns_only_unprocessed(self) -> None:
prs = [_make_pr("10"), _make_pr("20"), _make_pr("30")]
pool = _make_pool([("10", "my-repo")])
result = await find_unprocessed_prs(pool, prs)
pr_ids = [p.pr_id for p in result]
assert "10" not in pr_ids
assert "20" in pr_ids
assert "30" in pr_ids
async def test_empty_input_returns_empty(self) -> None:
pool = _make_pool([])
result = await find_unprocessed_prs(pool, [])
assert result == []
async def test_different_repos_not_confused(self) -> None:
pr_repo_a = _make_pr("10", repo_name="repo-a")
pr_repo_b = _make_pr("10", repo_name="repo-b")
# Only repo-a/10 is processed
pool = _make_pool([("10", "repo-a")])
result = await find_unprocessed_prs(pool, [pr_repo_a, pr_repo_b])
# repo-b/10 should still be returned (different repo)
assert len(result) == 1
assert result[0].repo_name == "repo-b"
async def test_returns_list_of_pr_info(self) -> None:
prs = [_make_pr("42")]
pool = _make_pool([])
result = await find_unprocessed_prs(pool, prs)
assert all(isinstance(p, PRInfo) for p in result)
async def test_preserves_pr_info_objects(self) -> None:
pr = _make_pr("77")
pool = _make_pool([])
result = await find_unprocessed_prs(pool, [pr])
assert result[0].pr_id == "77"
assert result[0].repo_name == "my-repo"

View File

@@ -0,0 +1,309 @@
"""Tests for services/pr_poller.py. Written FIRST (TDD RED phase).
Tests verify:
- _synthesize_webhook_payload produces a valid payload dict
- run_pr_poll_loop calls list_active_prs, dedup, then schedules graph for each unprocessed PR
- Fake sleep is injected to avoid real waits
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from release_agent.models.pr import PRInfo
from release_agent.services.pr_poller import _synthesize_webhook_payload, run_pr_poll_loop
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_pr(
pr_id: str = "10",
repo_name: str = "my-repo",
branch: str = "refs/heads/feature/ALLPOST-100-fix",
title: str = "Test PR",
status: str = "active",
) -> PRInfo:
return PRInfo(
pr_id=pr_id,
pr_url=f"https://dev.azure.com/org/proj/_git/{repo_name}/pullrequest/{pr_id}",
repo_name=repo_name,
branch=branch,
pr_title=title,
pr_status=status,
)
# ---------------------------------------------------------------------------
# _synthesize_webhook_payload tests
# ---------------------------------------------------------------------------
class TestSynthesizeWebhookPayload:
def test_returns_dict(self) -> None:
pr = _make_pr()
result = _synthesize_webhook_payload(pr)
assert isinstance(result, dict)
def test_has_resource_key(self) -> None:
pr = _make_pr()
result = _synthesize_webhook_payload(pr)
assert "resource" in result
def test_resource_contains_pull_request_id(self) -> None:
pr = _make_pr(pr_id="42")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["pull_request_id"] == 42
def test_resource_contains_repository_name(self) -> None:
pr = _make_pr(repo_name="backend-api")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["repository"]["name"] == "backend-api"
def test_resource_contains_title(self) -> None:
pr = _make_pr(title="My PR Title")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["title"] == "My PR Title"
def test_resource_contains_source_ref_name(self) -> None:
pr = _make_pr(branch="refs/heads/feature/ALLPOST-200-test")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["source_ref_name"] == "refs/heads/feature/ALLPOST-200-test"
def test_resource_status_is_active(self) -> None:
pr = _make_pr(status="active")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["status"] == "active"
def test_event_type_is_pr_updated(self) -> None:
pr = _make_pr()
result = _synthesize_webhook_payload(pr)
assert "event_type" in result
def test_subscription_id_present(self) -> None:
pr = _make_pr()
result = _synthesize_webhook_payload(pr)
assert "subscription_id" in result
def test_different_prs_produce_different_payloads(self) -> None:
pr1 = _make_pr(pr_id="1")
pr2 = _make_pr(pr_id="2")
r1 = _synthesize_webhook_payload(pr1)
r2 = _synthesize_webhook_payload(pr2)
assert r1["resource"]["pull_request_id"] != r2["resource"]["pull_request_id"]
# ---------------------------------------------------------------------------
# run_pr_poll_loop tests
# ---------------------------------------------------------------------------
class TestRunPrPollLoop:
async def test_calls_list_active_prs_for_each_repo(self) -> None:
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[])
sleep_calls: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["repo-a", "repo-b"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
assert azdo.list_active_prs.call_count == 2
async def test_calls_find_unprocessed_prs(self) -> None:
pr = _make_pr(pr_id="10")
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[pr])
find_mock = AsyncMock(return_value=[])
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=find_mock):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["my-repo"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
find_mock.assert_called_once()
async def test_schedules_graph_for_each_unprocessed_pr(self) -> None:
pr1 = _make_pr(pr_id="10")
pr2 = _make_pr(pr_id="20")
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[pr1, pr2])
schedule_mock = MagicMock()
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
with patch(
"release_agent.services.pr_poller.find_unprocessed_prs",
new=AsyncMock(return_value=[pr1, pr2]),
):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["my-repo"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=schedule_mock,
sleep_fn=fake_sleep,
)
assert schedule_mock.call_count == 2
async def test_does_not_schedule_already_processed_prs(self) -> None:
pr = _make_pr(pr_id="10")
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[pr])
schedule_mock = MagicMock()
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
# All PRs already processed
with patch(
"release_agent.services.pr_poller.find_unprocessed_prs",
new=AsyncMock(return_value=[]),
):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["my-repo"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=schedule_mock,
sleep_fn=fake_sleep,
)
schedule_mock.assert_not_called()
async def test_sleeps_for_configured_interval(self) -> None:
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[])
sleep_calls: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["my-repo"],
target_branch="refs/heads/develop",
interval_seconds=123,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
assert sleep_calls[0] == 123
async def test_handles_empty_watched_repos(self) -> None:
azdo = AsyncMock()
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=[],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
azdo.list_active_prs.assert_not_called()
async def test_schedule_fn_receives_synthesized_payload(self) -> None:
pr = _make_pr(pr_id="55", repo_name="test-repo")
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[pr])
schedule_calls: list[dict] = []
def schedule_mock(**kwargs) -> None:
schedule_calls.append(kwargs)
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
with patch(
"release_agent.services.pr_poller.find_unprocessed_prs",
new=AsyncMock(return_value=[pr]),
):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["test-repo"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=schedule_mock,
sleep_fn=fake_sleep,
)
assert len(schedule_calls) == 1
initial_state = schedule_calls[0]["initial_state"]
assert initial_state["webhook_payload"]["resource"]["pull_request_id"] == 55
assert initial_state["pr_id"] == "55"
assert initial_state["repo_name"] == "test-repo"
async def test_continues_after_list_active_prs_error(self) -> None:
azdo = AsyncMock()
# First repo raises, second succeeds
azdo.list_active_prs = AsyncMock(side_effect=[Exception("API error"), []])
sleep_calls: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["repo-a", "repo-b"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
# Should still sleep (loop iteration completed despite error)
assert len(sleep_calls) == 1

122
tests/test_branch_parser.py Normal file
View File

@@ -0,0 +1,122 @@
"""Tests for branch_parser module. Written FIRST (TDD RED phase)."""
from release_agent.branch_parser import parse_branch, strip_refs_prefix
class TestStripRefsPrefix:
"""Tests for strip_refs_prefix function."""
def test_strips_refs_heads_prefix(self) -> None:
assert strip_refs_prefix("refs/heads/fix/BILL-42_something") == "fix/BILL-42_something"
def test_strips_refs_heads_prefix_feature(self) -> None:
assert strip_refs_prefix("refs/heads/feature/ALLPOST-100_add-feature") == "feature/ALLPOST-100_add-feature"
def test_no_refs_prefix_unchanged(self) -> None:
assert strip_refs_prefix("bug/ALLPOST-4229_fix-review") == "bug/ALLPOST-4229_fix-review"
def test_main_unchanged(self) -> None:
assert strip_refs_prefix("main") == "main"
def test_develop_unchanged(self) -> None:
assert strip_refs_prefix("develop") == "develop"
def test_empty_string(self) -> None:
assert strip_refs_prefix("") == ""
def test_only_refs_heads(self) -> None:
assert strip_refs_prefix("refs/heads/") == ""
def test_refs_tags_not_stripped(self) -> None:
assert strip_refs_prefix("refs/tags/v1.0.0") == "refs/tags/v1.0.0"
class TestParseBranch:
"""Tests for parse_branch function."""
def test_bug_branch_with_ticket(self) -> None:
ticket_id, has_ticket = parse_branch("bug/ALLPOST-4229_fix-review")
assert ticket_id == "ALLPOST-4229"
assert has_ticket is True
def test_feature_branch_with_ticket(self) -> None:
ticket_id, has_ticket = parse_branch("feature/ALLPOST-100_add-feature")
assert ticket_id == "ALLPOST-100"
assert has_ticket is True
def test_refs_heads_fix_branch(self) -> None:
ticket_id, has_ticket = parse_branch("refs/heads/fix/BILL-42_something")
assert ticket_id == "BILL-42"
assert has_ticket is True
def test_feat_branch_short(self) -> None:
ticket_id, has_ticket = parse_branch("feat/MY-1_x")
assert ticket_id == "MY-1"
assert has_ticket is True
def test_chore_without_ticket(self) -> None:
ticket_id, has_ticket = parse_branch("chore/update-dependencies")
assert ticket_id is None
assert has_ticket is False
def test_main_branch(self) -> None:
ticket_id, has_ticket = parse_branch("main")
assert ticket_id is None
assert has_ticket is False
def test_develop_branch(self) -> None:
ticket_id, has_ticket = parse_branch("develop")
assert ticket_id is None
assert has_ticket is False
def test_release_branch(self) -> None:
ticket_id, has_ticket = parse_branch("release/v1.0.3")
assert ticket_id is None
assert has_ticket is False
def test_returns_tuple(self) -> None:
result = parse_branch("main")
assert isinstance(result, tuple)
assert len(result) == 2
def test_ticket_id_type_when_present(self) -> None:
ticket_id, has_ticket = parse_branch("bug/ALLPOST-4229_fix-review")
assert isinstance(ticket_id, str)
assert isinstance(has_ticket, bool)
def test_ticket_id_type_when_absent(self) -> None:
ticket_id, has_ticket = parse_branch("main")
assert ticket_id is None
assert isinstance(has_ticket, bool)
def test_fix_prefix(self) -> None:
ticket_id, has_ticket = parse_branch("fix/PROJ-999_some-fix")
assert ticket_id == "PROJ-999"
assert has_ticket is True
def test_refs_heads_feature_branch(self) -> None:
ticket_id, has_ticket = parse_branch("refs/heads/feature/ALLPOST-100_add-feature")
assert ticket_id == "ALLPOST-100"
assert has_ticket is True
def test_ticket_with_multiple_digits(self) -> None:
ticket_id, has_ticket = parse_branch("feature/ABC-12345_some-long-feature")
assert ticket_id == "ABC-12345"
assert has_ticket is True
def test_branch_without_underscore_separator(self) -> None:
# Branch has ticket pattern but no underscore - still detects ticket
ticket_id, has_ticket = parse_branch("feature/PROJ-100")
assert ticket_id == "PROJ-100"
assert has_ticket is True
def test_empty_string(self) -> None:
ticket_id, has_ticket = parse_branch("")
assert ticket_id is None
assert has_ticket is False
def test_ticket_with_numeric_project_prefix(self) -> None:
ticket_id, has_ticket = parse_branch("feature/AB2-100_feature")
assert ticket_id == "AB2-100"
assert has_ticket is True

450
tests/test_config.py Normal file
View File

@@ -0,0 +1,450 @@
"""Tests for config module. Written FIRST (TDD RED phase)."""
import os
from unittest.mock import patch
import pytest
from pydantic import SecretStr, ValidationError
from release_agent.config import Settings
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _base_env() -> dict[str, str]:
"""Return minimal valid environment variables."""
return {
"AZDO_ORGANIZATION": "my-org",
"AZDO_PROJECT": "my-project",
"AZDO_PAT": "super-secret-pat",
"ANTHROPIC_API_KEY": "sk-ant-key",
"POSTGRES_DSN": "postgresql://user:pass@localhost:5432/db",
"JIRA_EMAIL": "user@example.com",
"JIRA_API_TOKEN": "jira-token-abc",
"SLACK_WEBHOOK_URL": "https://hooks.slack.com/services/T000/B000/xxxx",
"WEBHOOK_SECRET": "test-webhook-secret",
}
# ---------------------------------------------------------------------------
# Settings tests
# ---------------------------------------------------------------------------
class TestSettings:
"""Tests for Settings config class."""
def test_loads_from_env_vars(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.azdo_organization == "my-org"
assert settings.azdo_project == "my-project"
def test_pat_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.azdo_pat, SecretStr)
def test_anthropic_key_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.anthropic_api_key, SecretStr)
def test_secret_str_not_leaked_in_repr(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
repr_str = repr(settings)
assert "super-secret-pat" not in repr_str
assert "sk-ant-key" not in repr_str
def test_secret_str_not_leaked_in_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
str_repr = str(settings)
assert "super-secret-pat" not in str_repr
def test_missing_required_azdo_org_raises(self) -> None:
env = _base_env()
del env["AZDO_ORGANIZATION"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_missing_required_azdo_project_raises(self) -> None:
env = _base_env()
del env["AZDO_PROJECT"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_missing_required_pat_raises(self) -> None:
env = _base_env()
del env["AZDO_PAT"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_missing_anthropic_key_is_optional(self) -> None:
env = _base_env()
del env["ANTHROPIC_API_KEY"]
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.anthropic_api_key.get_secret_value() == ""
def test_missing_postgres_dsn_raises(self) -> None:
env = _base_env()
del env["POSTGRES_DSN"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_azdo_base_url_computed(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
expected = "https://dev.azure.com/my-org"
assert settings.azdo_base_url == expected
def test_azdo_api_url_computed(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
expected = "https://dev.azure.com/my-org/my-project/_apis"
assert settings.azdo_api_url == expected
def test_default_port(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.port == 8000
def test_custom_port_from_env(self) -> None:
env = {**_base_env(), "PORT": "9000"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.port == 9000
def test_port_below_minimum_raises(self) -> None:
env = {**_base_env(), "PORT": "0"}
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_port_above_maximum_raises(self) -> None:
env = {**_base_env(), "PORT": "65536"}
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_port_minimum_valid(self) -> None:
env = {**_base_env(), "PORT": "1"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.port == 1
def test_port_maximum_valid(self) -> None:
env = {**_base_env(), "PORT": "65535"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.port == 65535
def test_get_pat_value(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.azdo_pat.get_secret_value() == "super-secret-pat"
def test_get_anthropic_key_value(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.anthropic_api_key.get_secret_value() == "sk-ant-key"
def test_postgres_dsn_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.postgres_dsn, SecretStr)
assert "localhost" in settings.postgres_dsn.get_secret_value()
def test_postgres_dsn_not_leaked_in_repr(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert "user:pass" not in repr(settings)
class TestSettingsPhase2:
"""Tests for Phase 2 settings fields."""
def test_jira_base_url_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.jira_base_url == "https://billolife.atlassian.net"
def test_jira_base_url_custom(self) -> None:
env = {**_base_env(), "JIRA_BASE_URL": "https://custom.atlassian.net"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.jira_base_url == "https://custom.atlassian.net"
def test_jira_email_required(self) -> None:
env = _base_env()
del env["JIRA_EMAIL"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_jira_email_stored(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.jira_email == "user@example.com"
def test_jira_api_token_required(self) -> None:
env = _base_env()
del env["JIRA_API_TOKEN"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_jira_api_token_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.jira_api_token, SecretStr)
def test_jira_api_token_not_leaked_in_repr(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert "jira-token-abc" not in repr(settings)
def test_slack_webhook_url_optional_defaults_empty(self) -> None:
env = _base_env()
del env["SLACK_WEBHOOK_URL"]
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_webhook_url.get_secret_value() == ""
def test_slack_webhook_url_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.slack_webhook_url, SecretStr)
def test_slack_webhook_url_not_leaked_in_repr(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert "xxxx" not in repr(settings)
def test_claude_review_model_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.claude_review_model == "claude-sonnet-4-20250514"
def test_claude_review_model_custom(self) -> None:
env = {**_base_env(), "CLAUDE_REVIEW_MODEL": "claude-opus-4-20250514"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.claude_review_model == "claude-opus-4-20250514"
def test_azdo_vsrm_api_url_computed(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
expected = "https://vsrm.dev.azure.com/my-org/my-project/_apis"
assert settings.azdo_vsrm_api_url == expected
class TestSettingsPhase4:
"""Tests for Phase 4 settings fields (webhook secret)."""
def test_webhook_secret_optional_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
# When not provided, defaults to empty string or None
assert settings.webhook_secret is not None or settings.webhook_secret == ""
def test_webhook_secret_custom_value(self) -> None:
env = {**_base_env(), "WEBHOOK_SECRET": "my-super-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.webhook_secret.get_secret_value() == "my-super-secret"
def test_webhook_secret_is_secret_str(self) -> None:
env = {**_base_env(), "WEBHOOK_SECRET": "secret-value"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert isinstance(settings.webhook_secret, SecretStr)
def test_webhook_secret_not_leaked_in_repr(self) -> None:
env = {**_base_env(), "WEBHOOK_SECRET": "super-private-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert "super-private-secret" not in repr(settings)
class TestSettingsPhase5:
"""Tests for Phase 5 settings fields (Slack Web API + CI polling)."""
def test_slack_webhook_url_optional_when_bot_token_provided(self) -> None:
env = {k: v for k, v in _base_env().items() if k != "SLACK_WEBHOOK_URL"}
env["SLACK_BOT_TOKEN"] = "xoxb-test-token"
env["SLACK_CHANNEL_ID"] = "C12345678"
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_bot_token is not None
assert settings.slack_bot_token.get_secret_value() == "xoxb-test-token"
def test_slack_bot_token_optional_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.slack_bot_token.get_secret_value() == ""
def test_slack_bot_token_is_secret_str(self) -> None:
env = {**_base_env(), "SLACK_BOT_TOKEN": "xoxb-abc-123"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert isinstance(settings.slack_bot_token, SecretStr)
def test_slack_bot_token_not_leaked_in_repr(self) -> None:
env = {**_base_env(), "SLACK_BOT_TOKEN": "xoxb-super-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert "xoxb-super-secret" not in repr(settings)
def test_slack_signing_secret_optional_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.slack_signing_secret.get_secret_value() == ""
def test_slack_signing_secret_custom_value(self) -> None:
env = {**_base_env(), "SLACK_SIGNING_SECRET": "signing-secret-xyz"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_signing_secret.get_secret_value() == "signing-secret-xyz"
def test_slack_signing_secret_is_secret_str(self) -> None:
env = {**_base_env(), "SLACK_SIGNING_SECRET": "some-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert isinstance(settings.slack_signing_secret, SecretStr)
def test_slack_signing_secret_not_leaked_in_repr(self) -> None:
env = {**_base_env(), "SLACK_SIGNING_SECRET": "private-signing-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert "private-signing-secret" not in repr(settings)
def test_slack_channel_id_optional_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.slack_channel_id == ""
def test_slack_channel_id_custom_value(self) -> None:
env = {**_base_env(), "SLACK_CHANNEL_ID": "C0987654321"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_channel_id == "C0987654321"
def test_ci_poll_interval_seconds_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.ci_poll_interval_seconds == 30
def test_ci_poll_interval_seconds_custom(self) -> None:
env = {**_base_env(), "CI_POLL_INTERVAL_SECONDS": "60"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.ci_poll_interval_seconds == 60
def test_ci_poll_max_wait_seconds_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.ci_poll_max_wait_seconds == 1800
def test_ci_poll_max_wait_seconds_custom(self) -> None:
env = {**_base_env(), "CI_POLL_MAX_WAIT_SECONDS": "3600"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.ci_poll_max_wait_seconds == 3600
def test_slack_webhook_url_still_optional(self) -> None:
env = {k: v for k, v in _base_env().items() if k != "SLACK_WEBHOOK_URL"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_webhook_url.get_secret_value() == ""
class TestSettingsPrPolling:
"""Tests for PR polling config fields (Step 1)."""
def test_watched_repos_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.watched_repos == ""
def test_watched_repos_custom_value(self) -> None:
env = {**_base_env(), "WATCHED_REPOS": "repo-a,repo-b"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.watched_repos == "repo-a,repo-b"
def test_watched_repos_list_empty_when_blank(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.watched_repos_list == []
def test_watched_repos_list_splits_comma_separated(self) -> None:
env = {**_base_env(), "WATCHED_REPOS": "repo-a,repo-b,repo-c"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.watched_repos_list == ["repo-a", "repo-b", "repo-c"]
def test_watched_repos_list_strips_whitespace(self) -> None:
env = {**_base_env(), "WATCHED_REPOS": " repo-a , repo-b "}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.watched_repos_list == ["repo-a", "repo-b"]
def test_watched_repos_list_ignores_empty_entries(self) -> None:
env = {**_base_env(), "WATCHED_REPOS": "repo-a,,repo-b"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.watched_repos_list == ["repo-a", "repo-b"]
def test_pr_poll_interval_seconds_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.pr_poll_interval_seconds == 300
def test_pr_poll_interval_seconds_custom(self) -> None:
env = {**_base_env(), "PR_POLL_INTERVAL_SECONDS": "60"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.pr_poll_interval_seconds == 60
def test_pr_poll_target_branch_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.pr_poll_target_branch == "refs/heads/develop"
def test_pr_poll_target_branch_custom(self) -> None:
env = {**_base_env(), "PR_POLL_TARGET_BRANCH": "refs/heads/main"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.pr_poll_target_branch == "refs/heads/main"
def test_pr_poll_enabled_defaults_false(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.pr_poll_enabled is False
def test_pr_poll_enabled_true_from_env(self) -> None:
env = {**_base_env(), "PR_POLL_ENABLED": "true"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.pr_poll_enabled is True
def test_default_jira_project_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.default_jira_project == "ALLPOST"
def test_default_jira_project_custom(self) -> None:
env = {**_base_env(), "DEFAULT_JIRA_PROJECT": "MYPROJ"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.default_jira_project == "MYPROJ"
def test_auto_create_ticket_enabled_defaults_true(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.auto_create_ticket_enabled is True
def test_auto_create_ticket_enabled_false_from_env(self) -> None:
env = {**_base_env(), "AUTO_CREATE_TICKET_ENABLED": "false"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.auto_create_ticket_enabled is False

198
tests/test_exceptions.py Normal file
View File

@@ -0,0 +1,198 @@
"""Tests for custom exception hierarchy. Written FIRST (TDD RED phase)."""
import pytest
from release_agent.exceptions import (
AuthenticationError,
NotFoundError,
RateLimitError,
ReleaseAgentError,
ServiceError,
ServiceUnavailableError,
)
class TestReleaseAgentError:
"""Tests for the base exception class."""
def test_is_exception(self) -> None:
err = ReleaseAgentError("something went wrong")
assert isinstance(err, Exception)
def test_message_stored(self) -> None:
err = ReleaseAgentError("something went wrong")
assert str(err) == "something went wrong"
def test_can_be_raised(self) -> None:
with pytest.raises(ReleaseAgentError):
raise ReleaseAgentError("boom")
class TestServiceError:
"""Tests for ServiceError with service name and status code."""
def test_is_release_agent_error(self) -> None:
err = ServiceError(service="jira", status_code=500, detail="Internal error")
assert isinstance(err, ReleaseAgentError)
def test_stores_service(self) -> None:
err = ServiceError(service="jira", status_code=500, detail="Internal error")
assert err.service == "jira"
def test_stores_status_code(self) -> None:
err = ServiceError(service="azdo", status_code=422, detail="Unprocessable")
assert err.status_code == 422
def test_stores_detail(self) -> None:
err = ServiceError(service="slack", status_code=400, detail="Bad payload")
assert err.detail == "Bad payload"
def test_str_includes_service_and_status(self) -> None:
err = ServiceError(service="jira", status_code=500, detail="Server error")
text = str(err)
assert "jira" in text
assert "500" in text
def test_can_be_raised(self) -> None:
with pytest.raises(ServiceError):
raise ServiceError(service="azdo", status_code=400, detail="bad request")
def test_detail_none_allowed(self) -> None:
err = ServiceError(service="jira", status_code=404, detail=None)
assert err.detail is None
class TestAuthenticationError:
"""Tests for AuthenticationError (401/403)."""
def test_is_service_error(self) -> None:
err = AuthenticationError(service="azdo")
assert isinstance(err, ServiceError)
def test_default_status_code_401(self) -> None:
err = AuthenticationError(service="azdo")
assert err.status_code == 401
def test_service_stored(self) -> None:
err = AuthenticationError(service="jira")
assert err.service == "jira"
def test_custom_status_code(self) -> None:
err = AuthenticationError(service="azdo", status_code=403)
assert err.status_code == 403
def test_str_contains_service(self) -> None:
err = AuthenticationError(service="slack")
assert "slack" in str(err)
def test_can_be_raised(self) -> None:
with pytest.raises(AuthenticationError):
raise AuthenticationError(service="azdo")
def test_caught_as_service_error(self) -> None:
with pytest.raises(ServiceError):
raise AuthenticationError(service="azdo")
class TestNotFoundError:
"""Tests for NotFoundError (404)."""
def test_is_service_error(self) -> None:
err = NotFoundError(service="azdo", detail="PR not found")
assert isinstance(err, ServiceError)
def test_status_code_is_404(self) -> None:
err = NotFoundError(service="jira", detail="Issue not found")
assert err.status_code == 404
def test_detail_stored(self) -> None:
err = NotFoundError(service="azdo", detail="PR 999 not found")
assert "PR 999" in err.detail
def test_can_be_raised(self) -> None:
with pytest.raises(NotFoundError):
raise NotFoundError(service="azdo", detail="not found")
def test_caught_as_release_agent_error(self) -> None:
with pytest.raises(ReleaseAgentError):
raise NotFoundError(service="jira", detail="issue missing")
class TestRateLimitError:
"""Tests for RateLimitError (429) with retry_after."""
def test_is_service_error(self) -> None:
err = RateLimitError(service="jira", retry_after=30)
assert isinstance(err, ServiceError)
def test_status_code_is_429(self) -> None:
err = RateLimitError(service="jira", retry_after=30)
assert err.status_code == 429
def test_stores_retry_after(self) -> None:
err = RateLimitError(service="slack", retry_after=60)
assert err.retry_after == 60
def test_retry_after_none_allowed(self) -> None:
err = RateLimitError(service="azdo", retry_after=None)
assert err.retry_after is None
def test_str_contains_service(self) -> None:
err = RateLimitError(service="jira", retry_after=5)
assert "jira" in str(err)
def test_can_be_raised(self) -> None:
with pytest.raises(RateLimitError):
raise RateLimitError(service="jira", retry_after=30)
class TestServiceUnavailableError:
"""Tests for ServiceUnavailableError (503)."""
def test_is_service_error(self) -> None:
err = ServiceUnavailableError(service="azdo")
assert isinstance(err, ServiceError)
def test_status_code_is_503(self) -> None:
err = ServiceUnavailableError(service="azdo")
assert err.status_code == 503
def test_service_stored(self) -> None:
err = ServiceUnavailableError(service="slack")
assert err.service == "slack"
def test_custom_detail(self) -> None:
err = ServiceUnavailableError(service="azdo", detail="Maintenance window")
assert "Maintenance" in err.detail
def test_can_be_raised(self) -> None:
with pytest.raises(ServiceUnavailableError):
raise ServiceUnavailableError(service="azdo")
def test_caught_as_service_error(self) -> None:
with pytest.raises(ServiceError):
raise ServiceUnavailableError(service="jira")
class TestExceptionHierarchyInheritance:
"""Tests verifying the full exception hierarchy is correct."""
def test_all_are_release_agent_errors(self) -> None:
errors = [
AuthenticationError(service="svc"),
NotFoundError(service="svc", detail="x"),
RateLimitError(service="svc", retry_after=1),
ServiceUnavailableError(service="svc"),
]
for err in errors:
assert isinstance(err, ReleaseAgentError), f"{type(err)} not ReleaseAgentError"
def test_all_are_service_errors(self) -> None:
errors = [
AuthenticationError(service="svc"),
NotFoundError(service="svc", detail="x"),
RateLimitError(service="svc", retry_after=1),
ServiceUnavailableError(service="svc"),
]
for err in errors:
assert isinstance(err, ServiceError), f"{type(err)} not ServiceError"

666
tests/test_main.py Normal file
View File

@@ -0,0 +1,666 @@
"""Tests for main FastAPI application. Written FIRST (TDD RED phase).
Heavy startup (PostgreSQL, httpx clients, graph compilation) is mocked.
Tests verify: routes registered, lifespan hooks, exception handlers.
"""
import asyncio
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
# ---------------------------------------------------------------------------
# Helpers / fixtures
# ---------------------------------------------------------------------------
def _make_mock_settings():
s = MagicMock()
s.webhook_secret.get_secret_value.return_value = "test-secret"
s.postgres_dsn.get_secret_value.return_value = "postgresql://u:p@localhost/db"
s.azdo_pat.get_secret_value.return_value = "pat"
s.anthropic_api_key.get_secret_value.return_value = "key"
s.jira_api_token.get_secret_value.return_value = "jira"
s.slack_webhook_url.get_secret_value.return_value = "https://hooks.slack.com/x"
s.slack_bot_token.get_secret_value.return_value = ""
s.slack_channel_id = ""
s.slack_signing_secret.get_secret_value.return_value = ""
s.port = 8000
s.pr_poll_enabled = False
s.pr_poll_interval_seconds = 300
s.pr_poll_target_branch = "refs/heads/develop"
s.watched_repos_list = []
s.default_jira_project = "ALLPOST"
return s
def _make_patched_app():
"""Return the FastAPI app with all heavy startup mocked."""
mock_settings = _make_mock_settings()
mock_pool = AsyncMock()
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
mock_pool.__aexit__ = AsyncMock(return_value=False)
mock_graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
mock_clients = MagicMock()
mock_staging_store = MagicMock()
patches = [
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]),
patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=mock_clients),
patch("release_agent.main._create_staging_store", return_value=mock_staging_store),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
]
for p in patches:
p.start()
from release_agent.main import create_app
app = create_app()
for p in patches:
p.stop()
return app, mock_settings, mock_pool, mock_graphs
# ---------------------------------------------------------------------------
# Route registration tests
# ---------------------------------------------------------------------------
class TestRouteRegistration:
def test_webhook_route_registered(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/webhooks/azdo" in routes
def test_approvals_routes_registered(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/approvals/pending" in routes
assert "/approvals/{thread_id}" in routes
def test_status_routes_registered(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/status" in routes
assert "/staging" in routes
# ---------------------------------------------------------------------------
# schedule_graph / run_graph_in_background tests
# ---------------------------------------------------------------------------
class TestScheduleGraph:
def test_schedule_graph_returns_thread_id(self) -> None:
from release_agent.main import schedule_graph
mock_app = MagicMock()
mock_app.state.background_tasks = set()
mock_graph = MagicMock()
with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()):
thread_id = schedule_graph(
app=mock_app,
graph=mock_graph,
initial_state={"repo_name": "my-repo"},
thread_id=None,
)
assert isinstance(thread_id, str)
assert len(thread_id) > 0
def test_schedule_graph_uses_provided_thread_id(self) -> None:
from release_agent.main import schedule_graph
mock_app = MagicMock()
mock_app.state.background_tasks = set()
mock_graph = MagicMock()
with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()):
thread_id = schedule_graph(
app=mock_app,
graph=mock_graph,
initial_state={},
thread_id="custom-thread-id",
)
assert thread_id == "custom-thread-id"
def test_schedule_graph_adds_task_to_background_tasks(self) -> None:
from release_agent.main import schedule_graph
mock_app = MagicMock()
mock_app.state.background_tasks = set()
mock_graph = MagicMock()
mock_task = MagicMock()
with patch("release_agent.main.asyncio.create_task", return_value=mock_task):
schedule_graph(
app=mock_app,
graph=mock_graph,
initial_state={},
thread_id=None,
)
assert mock_task in mock_app.state.background_tasks
def test_run_graph_in_background_is_coroutine(self) -> None:
from release_agent.main import run_graph_in_background
import inspect
assert inspect.iscoroutinefunction(run_graph_in_background)
# ---------------------------------------------------------------------------
# _ensure_db_schema tests
# ---------------------------------------------------------------------------
class TestEnsureDbSchema:
@pytest.mark.asyncio
async def test_ensure_db_schema_creates_table(self) -> None:
from release_agent.main import _ensure_db_schema
mock_pool = AsyncMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
# Phase 5: now executes multiple DDL statements (agent_threads +
# staging_releases + archived_releases), so called_once no longer holds.
assert mock_cursor.execute.call_count >= 1
all_sql = " ".join(
call.args[0] for call in mock_cursor.execute.call_args_list
)
assert "agent_threads" in all_sql
# ---------------------------------------------------------------------------
# _create_tool_clients tests
# ---------------------------------------------------------------------------
class TestCreateToolClients:
def test_create_tool_clients_returns_tool_clients_instance(self) -> None:
from release_agent.main import _create_tool_clients
from release_agent.graph.dependencies import ToolClients
mock_settings = _make_mock_settings()
with (
patch("release_agent.main.AzDoClient") as mock_azdo,
patch("release_agent.main.JiraClient") as mock_jira,
patch("release_agent.main.SlackClient") as mock_slack,
patch("release_agent.main.ClaudeReviewer") as mock_reviewer,
patch("release_agent.main.httpx.AsyncClient") as mock_httpx,
):
clients, http_clients = _create_tool_clients(mock_settings)
assert isinstance(clients, ToolClients)
# ---------------------------------------------------------------------------
# _create_staging_store tests
# ---------------------------------------------------------------------------
class TestCreateStagingStore:
def test_create_staging_store_returns_store(self) -> None:
from release_agent.main import _create_staging_store
from release_agent.graph.dependencies import JsonFileStagingStore
result = _create_staging_store()
assert isinstance(result, JsonFileStagingStore)
# ---------------------------------------------------------------------------
# Global exception handler tests
# ---------------------------------------------------------------------------
class TestExceptionHandlers:
def test_app_has_exception_handlers(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
# FastAPI stores exception handlers in exception_handlers attribute
assert hasattr(app, "exception_handlers")
# ---------------------------------------------------------------------------
# Lifespan tests
# ---------------------------------------------------------------------------
class TestGracefulShutdown:
@pytest.mark.asyncio
async def test_lifespan_cancels_timed_out_tasks(self) -> None:
"""Verify the lifespan waits for tasks and cancels timed-out ones."""
from release_agent.main import lifespan
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
from fastapi import FastAPI
app = FastAPI()
app.state.background_tasks = set()
mock_settings = _make_mock_settings()
mock_task = MagicMock()
mock_task.cancel = MagicMock()
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
patch(
"release_agent.main.asyncio.wait",
new_callable=AsyncMock,
return_value=(set(), {mock_task}),
),
):
ctx = lifespan(app)
await ctx.__aenter__()
# Add a fake task to background_tasks after startup
app.state.background_tasks.add(mock_task)
await ctx.__aexit__(None, None, None)
# The pending task should have been cancelled
mock_task.cancel.assert_called_once()
class TestLifespan:
def test_app_state_set_after_lifespan(self) -> None:
"""Verify app.state.graphs and app.state.settings are set during lifespan."""
from release_agent.main import create_app
mock_pool = AsyncMock()
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
mock_pool.__aexit__ = AsyncMock(return_value=False)
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
mock_settings = _make_mock_settings()
mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()}
mock_clients = MagicMock()
mock_staging_store = MagicMock()
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]),
patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(mock_clients, [])),
patch("release_agent.main._create_staging_store", return_value=mock_staging_store),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
with TestClient(app) as client:
# App is started; state should be accessible
response = client.get("/status")
# We just verify no crash
assert response.status_code in (200, 500)
# ---------------------------------------------------------------------------
# Phase 5: Slack interactions route + new config tests
# ---------------------------------------------------------------------------
class TestPhase5Routes:
"""Tests for Phase 5 additions to main.py."""
def _make_patches(self):
mock_settings = _make_mock_settings()
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
mock_pool.__aexit__ = AsyncMock(return_value=False)
return [
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
]
def test_slack_interactions_route_registered(self) -> None:
from release_agent.main import create_app
patches = self._make_patches()
for p in patches:
p.start()
try:
app = create_app()
finally:
for p in patches:
p.stop()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/slack/interactions" in routes
def test_create_tool_clients_uses_bot_token(self) -> None:
from release_agent.main import _create_tool_clients
mock_settings = _make_mock_settings()
mock_settings.slack_bot_token.get_secret_value.return_value = "xoxb-test"
mock_settings.slack_channel_id = "C12345"
mock_settings.slack_webhook_url.get_secret_value.return_value = ""
mock_settings.azdo_api_url = "https://dev.azure.com/org/proj/_apis"
mock_settings.azdo_vsrm_api_url = "https://vsrm.dev.azure.com/org/proj/_apis"
mock_settings.jira_base_url = "https://example.atlassian.net"
mock_settings.jira_email = "test@example.com"
# Should not raise
clients, http_clients = _create_tool_clients(mock_settings)
assert clients is not None
# Clean up
for hc in http_clients:
asyncio.get_event_loop().run_until_complete(hc.aclose())
class TestPhase5DbSchema:
"""Tests that _ensure_db_schema adds the slack_message_ts column."""
async def test_ensure_db_schema_executes_sql_statements(self) -> None:
from release_agent.main import _ensure_db_schema
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
executed_sql: list[str] = []
async def capture_execute(sql, *args):
executed_sql.append(sql.strip())
mock_cursor.execute = capture_execute
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
# Should have executed CREATE TABLE statements
assert len(executed_sql) >= 3
combined = " ".join(executed_sql)
assert "agent_threads" in combined
assert "staging_releases" in combined
async def test_ensure_db_schema_includes_slack_message_ts_column(self) -> None:
from release_agent.main import _ensure_db_schema
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
executed_sql: list[str] = []
async def capture_execute(sql, *args):
executed_sql.append(sql.strip())
mock_cursor.execute = capture_execute
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
combined = " ".join(executed_sql)
assert "slack_message_ts" in combined
# ---------------------------------------------------------------------------
# PR polling lifespan integration tests
# ---------------------------------------------------------------------------
class TestPrPollingLifespan:
"""Tests for PR polling startup in the lifespan handler."""
def _make_polling_settings(self, *, pr_poll_enabled: bool = True) -> MagicMock:
s = _make_mock_settings()
s.pr_poll_enabled = pr_poll_enabled
s.pr_poll_interval_seconds = 30
s.pr_poll_target_branch = "refs/heads/develop"
s.watched_repos_list = ["repo-a"]
s.default_jira_project = "ALLPOST"
return s
async def test_poll_loop_started_when_pr_poll_enabled(self) -> None:
"""When pr_poll_enabled=True, a background task for polling is created."""
from release_agent.main import create_app
mock_settings = self._make_polling_settings(pr_poll_enabled=True)
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_pool.connection = MagicMock()
poll_loop_started = []
async def fake_run_poll_loop(**kwargs):
poll_loop_started.append(True)
# Simulate an immediate cancellation to avoid infinite loop
raise asyncio.CancelledError
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop),
):
app = create_app()
async with app.router.lifespan_context(app):
# Give the event loop a chance to start background tasks
await asyncio.sleep(0)
assert len(poll_loop_started) > 0
async def test_poll_loop_not_started_when_pr_poll_disabled(self) -> None:
"""When pr_poll_enabled=False, no polling background task is created."""
from release_agent.main import create_app
mock_settings = self._make_polling_settings(pr_poll_enabled=False)
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_pool.connection = MagicMock()
poll_loop_started = []
async def fake_run_poll_loop(**kwargs):
poll_loop_started.append(True)
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop),
):
app = create_app()
async with app.router.lifespan_context(app):
await asyncio.sleep(0)
assert len(poll_loop_started) == 0
# ---------------------------------------------------------------------------
# _run_graph default_jira_project injection tests
# ---------------------------------------------------------------------------
class TestRunGraphJiraProjectInjection:
"""Tests that _run_graph passes default_jira_project into the graph config."""
async def test_default_jira_project_passed_to_graph_config(self) -> None:
from release_agent.api.webhooks import _run_graph
captured_configs: list[dict] = []
mock_graph = MagicMock()
async def fake_ainvoke(state, config=None):
captured_configs.append(config or {})
return {}
mock_graph.ainvoke = fake_ainvoke
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"pr_id": "1", "repo_name": "r"},
thread_id="tid-1",
tool_clients=MagicMock(),
db_pool=mock_pool,
repos_base_dir="",
graph_name="pr_completed",
default_jira_project="MYPROJ",
)
assert len(captured_configs) == 1
configurable = captured_configs[0].get("configurable", {})
assert configurable.get("default_jira_project") == "MYPROJ"
async def test_default_jira_project_defaults_to_allpost(self) -> None:
from release_agent.api.webhooks import _run_graph
captured_configs: list[dict] = []
mock_graph = MagicMock()
async def fake_ainvoke(state, config=None):
captured_configs.append(config or {})
return {}
mock_graph.ainvoke = fake_ainvoke
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"pr_id": "1", "repo_name": "r"},
thread_id="tid-2",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
configurable = captured_configs[0].get("configurable", {})
assert configurable.get("default_jira_project") == "ALLPOST"

147
tests/test_main_phase5.py Normal file
View File

@@ -0,0 +1,147 @@
"""Tests for main.py Phase 5 changes.
Phase 5 - Step 4: _ensure_db_schema creates staging/archived tables,
and lifespan uses PostgresStagingStore.
Written FIRST (TDD RED phase).
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# _ensure_db_schema includes staging DDL
# ---------------------------------------------------------------------------
class TestEnsureDbSchemaPhase5:
async def test_schema_creates_staging_releases_table(self) -> None:
from release_agent.main import _ensure_db_schema
executed_sqls: list[str] = []
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
async def capture_execute(sql: str, *args) -> None:
executed_sqls.append(sql)
mock_cursor.execute = capture_execute
mock_conn = AsyncMock()
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool = MagicMock()
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
all_sql = " ".join(executed_sqls)
assert "staging_releases" in all_sql
async def test_schema_creates_archived_releases_table(self) -> None:
from release_agent.main import _ensure_db_schema
executed_sqls: list[str] = []
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
async def capture_execute(sql: str, *args) -> None:
executed_sqls.append(sql)
mock_cursor.execute = capture_execute
mock_conn = AsyncMock()
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool = MagicMock()
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
all_sql = " ".join(executed_sqls)
assert "archived_releases" in all_sql
async def test_schema_still_creates_agent_threads_table(self) -> None:
from release_agent.main import _ensure_db_schema
executed_sqls: list[str] = []
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
async def capture_execute(sql: str, *args) -> None:
executed_sqls.append(sql)
mock_cursor.execute = capture_execute
mock_conn = AsyncMock()
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool = MagicMock()
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
all_sql = " ".join(executed_sqls)
assert "agent_threads" in all_sql
async def test_schema_uses_if_not_exists(self) -> None:
from release_agent.main import _ensure_db_schema
executed_sqls: list[str] = []
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
async def capture_execute(sql: str, *args) -> None:
executed_sqls.append(sql)
mock_cursor.execute = capture_execute
mock_conn = AsyncMock()
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool = MagicMock()
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
all_sql = " ".join(executed_sqls)
assert "IF NOT EXISTS" in all_sql.upper()
# ---------------------------------------------------------------------------
# Lifespan: PostgresStagingStore wired in
# ---------------------------------------------------------------------------
class TestLifespanUsesPostgresStagingStore:
def test_lifespan_creates_postgres_staging_store(self) -> None:
"""When PostgresStagingStore is imported in main, it is used in lifespan."""
from release_agent.main import _create_staging_store
from release_agent.graph.postgres_staging_store import PostgresStagingStore
mock_pool = MagicMock()
result = _create_staging_store(pool=mock_pool)
assert isinstance(result, PostgresStagingStore)
def test_create_staging_store_without_pool_falls_back_to_json(self) -> None:
"""Without a pool, falls back to JsonFileStagingStore for local dev."""
from release_agent.main import _create_staging_store
from release_agent.graph.dependencies import JsonFileStagingStore
result = _create_staging_store(pool=None)
assert isinstance(result, JsonFileStagingStore)

635
tests/test_models.py Normal file
View File

@@ -0,0 +1,635 @@
"""Tests for Pydantic models. Written FIRST (TDD RED phase)."""
from datetime import date, datetime
import pytest
from pydantic import ValidationError
from release_agent.models.jira import JiraIssue, JiraTransition
from release_agent.models.pipeline import PipelineInfo, ReleasePipelineStage
from release_agent.models.pr import PRInfo
from release_agent.models.release import ArchivedRelease, StagingRelease
from release_agent.models.review import ReviewIssue, ReviewResult
from release_agent.models.ticket import TicketEntry
from release_agent.models.webhook import WebhookPayload, WebhookRepository, WebhookResource
# ---------------------------------------------------------------------------
# PRInfo tests
# ---------------------------------------------------------------------------
class TestPRInfo:
"""Tests for PRInfo model."""
def _make_pr(self, **kwargs) -> PRInfo:
defaults = {
"pr_id": "PR-1",
"pr_url": "https://dev.azure.com/org/project/_git/repo/pullrequest/1",
"repo_name": "my-repo",
"branch": "feature/ALLPOST-100_add-feature",
"pr_title": "Add new feature",
"pr_status": "active",
}
defaults.update(kwargs)
return PRInfo(**defaults)
def test_ticket_id_extracted_from_branch(self) -> None:
pr = self._make_pr(branch="feature/ALLPOST-100_add-feature")
assert pr.ticket_id == "ALLPOST-100"
assert pr.has_ticket is True
def test_branch_without_ticket(self) -> None:
pr = self._make_pr(branch="chore/update-dependencies")
assert pr.ticket_id is None
assert pr.has_ticket is False
def test_main_branch_no_ticket(self) -> None:
pr = self._make_pr(branch="main")
assert pr.ticket_id is None
assert pr.has_ticket is False
def test_refs_heads_branch_parsed(self) -> None:
pr = self._make_pr(branch="refs/heads/fix/BILL-42_fix-bug")
assert pr.ticket_id == "BILL-42"
assert pr.has_ticket is True
def test_pr_status_active(self) -> None:
pr = self._make_pr(pr_status="active")
assert pr.pr_status == "active"
def test_pr_status_completed(self) -> None:
pr = self._make_pr(pr_status="completed")
assert pr.pr_status == "completed"
def test_pr_status_abandoned(self) -> None:
pr = self._make_pr(pr_status="abandoned")
assert pr.pr_status == "abandoned"
def test_invalid_pr_status_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_pr(pr_status="unknown")
def test_model_is_frozen(self) -> None:
pr = self._make_pr()
with pytest.raises(ValidationError):
pr.pr_id = "modified" # type: ignore[misc]
def test_pr_url_is_valid_url(self) -> None:
pr = self._make_pr()
# HttpUrl should have been validated
assert "dev.azure.com" in str(pr.pr_url)
def test_invalid_url_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_pr(pr_url="not-a-url")
# ---------------------------------------------------------------------------
# TicketEntry tests
# ---------------------------------------------------------------------------
class TestTicketEntry:
"""Tests for TicketEntry model."""
def _make_ticket(self, **kwargs) -> TicketEntry:
defaults = {
"id": "ALLPOST-4229",
"summary": "Fix review bug",
"pr_id": "PR-42",
"pr_url": "https://dev.azure.com/org/project/_git/repo/pullrequest/42",
"pr_title": "Fix review",
"branch": "bug/ALLPOST-4229_fix-review",
"merged_at": date(2024, 1, 15),
}
defaults.update(kwargs)
return TicketEntry(**defaults)
def test_valid_ticket_entry(self) -> None:
ticket = self._make_ticket()
assert ticket.id == "ALLPOST-4229"
assert ticket.summary == "Fix review bug"
def test_valid_jira_id_format(self) -> None:
ticket = self._make_ticket(id="BILL-42")
assert ticket.id == "BILL-42"
def test_invalid_id_lowercase_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_ticket(id="allpost-4229")
def test_invalid_id_no_number_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_ticket(id="ALLPOST-")
def test_invalid_id_no_dash_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_ticket(id="ALLPOST4229")
def test_invalid_id_starts_with_number_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_ticket(id="4ALLPOST-4229")
def test_merged_at_is_date(self) -> None:
ticket = self._make_ticket()
assert isinstance(ticket.merged_at, date)
def test_model_is_frozen(self) -> None:
ticket = self._make_ticket()
with pytest.raises(ValidationError):
ticket.id = "OTHER-1" # type: ignore[misc]
def test_minimum_valid_id(self) -> None:
# Single uppercase letter prefix followed by dash and digits
ticket = self._make_ticket(id="A-1")
assert ticket.id == "A-1"
def test_numeric_in_project_key(self) -> None:
ticket = self._make_ticket(id="AB2-100")
assert ticket.id == "AB2-100"
# ---------------------------------------------------------------------------
# StagingRelease tests
# ---------------------------------------------------------------------------
class TestStagingRelease:
"""Tests for StagingRelease model."""
def _make_ticket(self, ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Some ticket",
pr_id="PR-1",
pr_url="https://dev.azure.com/org/project/_git/repo/pullrequest/1",
pr_title="Some PR",
branch=f"feature/{ticket_id}_some-feature",
merged_at=date(2024, 1, 15),
)
def _make_release(self, **kwargs) -> StagingRelease:
defaults = {
"version": "v1.0.0",
"repo": "my-repo",
"started_at": date(2024, 1, 1),
"tickets": [],
}
defaults.update(kwargs)
return StagingRelease(**defaults)
def test_valid_release(self) -> None:
release = self._make_release()
assert release.version == "v1.0.0"
def test_version_must_match_pattern(self) -> None:
with pytest.raises(ValidationError):
self._make_release(version="1.0.0")
def test_version_missing_patch_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_release(version="v1.0")
def test_version_extra_segments_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_release(version="v1.0.0.1")
def test_version_letters_in_numbers_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_release(version="v1.a.0")
def test_add_ticket_returns_new_instance(self) -> None:
release = self._make_release()
ticket = self._make_ticket("ALLPOST-1")
new_release = release.add_ticket(ticket)
assert new_release is not release
def test_add_ticket_immutability(self) -> None:
release = self._make_release()
ticket = self._make_ticket("ALLPOST-1")
new_release = release.add_ticket(ticket)
assert len(release.tickets) == 0
assert len(new_release.tickets) == 1
def test_add_ticket_contains_ticket(self) -> None:
release = self._make_release()
ticket = self._make_ticket("ALLPOST-1")
new_release = release.add_ticket(ticket)
assert ticket in new_release.tickets
def test_has_ticket_true(self) -> None:
ticket = self._make_ticket("ALLPOST-1")
release = self._make_release(tickets=[ticket])
assert release.has_ticket("ALLPOST-1") is True
def test_has_ticket_false(self) -> None:
release = self._make_release()
assert release.has_ticket("ALLPOST-99") is False
def test_has_ticket_after_add(self) -> None:
release = self._make_release()
ticket = self._make_ticket("ALLPOST-5")
new_release = release.add_ticket(ticket)
assert new_release.has_ticket("ALLPOST-5") is True
def test_model_is_frozen(self) -> None:
release = self._make_release()
with pytest.raises(ValidationError):
release.version = "v2.0.0" # type: ignore[misc]
def test_multiple_tickets(self) -> None:
t1 = self._make_ticket("ALLPOST-1")
t2 = self._make_ticket("ALLPOST-2")
release = self._make_release(tickets=[t1, t2])
assert len(release.tickets) == 2
# ---------------------------------------------------------------------------
# ArchivedRelease tests
# ---------------------------------------------------------------------------
class TestArchivedRelease:
"""Tests for ArchivedRelease model."""
def _make_archived(self, **kwargs) -> ArchivedRelease:
defaults = {
"version": "v1.0.0",
"repo": "my-repo",
"started_at": date(2024, 1, 1),
"tickets": [],
"released_at": date(2024, 1, 10),
}
defaults.update(kwargs)
return ArchivedRelease(**defaults)
def test_valid_archived_release(self) -> None:
release = self._make_archived()
assert release.released_at == date(2024, 1, 10)
def test_released_at_same_as_started_at_is_valid(self) -> None:
release = self._make_archived(started_at=date(2024, 1, 1), released_at=date(2024, 1, 1))
assert release.released_at == release.started_at
def test_released_at_before_started_at_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_archived(
started_at=date(2024, 1, 10),
released_at=date(2024, 1, 1),
)
def test_model_is_frozen(self) -> None:
release = self._make_archived()
with pytest.raises(ValidationError):
release.released_at = date(2024, 12, 31) # type: ignore[misc]
def test_inherits_version_validation(self) -> None:
with pytest.raises(ValidationError):
self._make_archived(version="1.0.0")
# ---------------------------------------------------------------------------
# PipelineInfo tests
# ---------------------------------------------------------------------------
class TestPipelineInfo:
"""Tests for PipelineInfo model."""
def test_valid_pipeline_info(self) -> None:
pipeline = PipelineInfo(id=42, name="Release Pipeline", repo="my-repo")
assert pipeline.id == 42
assert pipeline.name == "Release Pipeline"
assert pipeline.repo == "my-repo"
def test_model_is_frozen(self) -> None:
pipeline = PipelineInfo(id=1, name="Test", repo="repo")
with pytest.raises(ValidationError):
pipeline.id = 2 # type: ignore[misc]
# ---------------------------------------------------------------------------
# ReleasePipelineStage tests
# ---------------------------------------------------------------------------
class TestReleasePipelineStage:
"""Tests for ReleasePipelineStage model."""
def test_valid_stage_without_approval(self) -> None:
stage = ReleasePipelineStage(
name="Build", rank=0, requires_approval=False, approval_id=None
)
assert stage.name == "Build"
assert stage.rank == 0
def test_valid_stage_with_approval(self) -> None:
stage = ReleasePipelineStage(
name="Production", rank=2, requires_approval=True, approval_id="approval-uuid-123"
)
assert stage.requires_approval is True
assert stage.approval_id == "approval-uuid-123"
def test_negative_rank_raises(self) -> None:
with pytest.raises(ValidationError):
ReleasePipelineStage(
name="Bad", rank=-1, requires_approval=False, approval_id=None
)
def test_requires_approval_false_with_approval_id_raises(self) -> None:
with pytest.raises(ValidationError):
ReleasePipelineStage(
name="Bad", rank=0, requires_approval=False, approval_id="some-id"
)
def test_requires_approval_true_without_approval_id_is_valid(self) -> None:
stage = ReleasePipelineStage(
name="Production", rank=2, requires_approval=True, approval_id=None
)
assert stage.requires_approval is True
assert stage.approval_id is None
def test_model_is_frozen(self) -> None:
stage = ReleasePipelineStage(name="Build", rank=0, requires_approval=False, approval_id=None)
with pytest.raises(ValidationError):
stage.name = "Changed" # type: ignore[misc]
# ---------------------------------------------------------------------------
# WebhookPayload tests
# ---------------------------------------------------------------------------
class TestWebhookPayload:
"""Tests for WebhookPayload and nested models."""
def _make_payload(self, **kwargs) -> WebhookPayload:
defaults = {
"subscription_id": "sub-123",
"event_type": "git.pullrequest.merged",
"resource": {
"repository": {
"id": "repo-uuid-456",
"name": "my-repo",
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
},
"pull_request_id": 42,
"title": "Fix the bug",
"source_ref_name": "refs/heads/bug/ALLPOST-4229_fix-review",
"target_ref_name": "refs/heads/main",
"status": "completed",
"closed_date": None,
},
}
defaults.update(kwargs)
return WebhookPayload(**defaults)
def test_valid_payload(self) -> None:
payload = self._make_payload()
assert payload.subscription_id == "sub-123"
assert payload.event_type == "git.pullrequest.merged"
def test_resource_parsed(self) -> None:
payload = self._make_payload()
assert isinstance(payload.resource, WebhookResource)
assert payload.resource.pull_request_id == 42
assert payload.resource.title == "Fix the bug"
def test_repository_parsed(self) -> None:
payload = self._make_payload()
repo = payload.resource.repository
assert isinstance(repo, WebhookRepository)
assert repo.name == "my-repo"
def test_repository_web_url(self) -> None:
payload = self._make_payload()
assert "dev.azure.com" in str(payload.resource.repository.web_url)
def test_closed_date_none(self) -> None:
payload = self._make_payload()
assert payload.resource.closed_date is None
def test_closed_date_populated(self) -> None:
payload_data = {
"subscription_id": "sub-123",
"event_type": "git.pullrequest.merged",
"resource": {
"repository": {
"id": "repo-uuid-456",
"name": "my-repo",
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
},
"pull_request_id": 42,
"title": "Fix the bug",
"source_ref_name": "refs/heads/bug/ALLPOST-4229_fix-review",
"target_ref_name": "refs/heads/main",
"status": "completed",
"closed_date": "2024-01-15T10:30:00Z",
},
}
payload = WebhookPayload(**payload_data)
assert payload.resource.closed_date is not None
assert isinstance(payload.resource.closed_date, datetime)
def test_model_is_frozen(self) -> None:
payload = self._make_payload()
with pytest.raises(ValidationError):
payload.subscription_id = "changed" # type: ignore[misc]
def test_source_ref_name_preserved(self) -> None:
payload = self._make_payload()
assert payload.resource.source_ref_name == "refs/heads/bug/ALLPOST-4229_fix-review"
# ---------------------------------------------------------------------------
# ReviewIssue tests
# ---------------------------------------------------------------------------
class TestReviewIssue:
"""Tests for ReviewIssue model."""
def _make_issue(self, **kwargs) -> ReviewIssue:
defaults = {
"severity": "warning",
"description": "Variable name is not descriptive",
}
defaults.update(kwargs)
return ReviewIssue(**defaults)
def test_valid_warning_issue(self) -> None:
issue = self._make_issue(severity="warning", description="Unclear variable")
assert issue.severity == "warning"
assert issue.description == "Unclear variable"
def test_valid_error_issue(self) -> None:
issue = self._make_issue(severity="error", description="Null pointer risk")
assert issue.severity == "error"
def test_valid_info_issue(self) -> None:
issue = self._make_issue(severity="info", description="Minor style note")
assert issue.severity == "info"
def test_valid_blocker_issue(self) -> None:
issue = self._make_issue(severity="blocker", description="Security vulnerability")
assert issue.severity == "blocker"
def test_invalid_severity_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_issue(severity="critical")
def test_file_path_optional_none_by_default(self) -> None:
issue = self._make_issue()
assert issue.file_path is None
def test_file_path_can_be_set(self) -> None:
issue = self._make_issue(file_path="src/foo.py")
assert issue.file_path == "src/foo.py"
def test_suggestion_optional_none_by_default(self) -> None:
issue = self._make_issue()
assert issue.suggestion is None
def test_suggestion_can_be_set(self) -> None:
issue = self._make_issue(suggestion="Rename to `user_count`")
assert issue.suggestion == "Rename to `user_count`"
def test_model_is_frozen(self) -> None:
issue = self._make_issue()
with pytest.raises(ValidationError):
issue.severity = "error" # type: ignore[misc]
def test_description_required(self) -> None:
with pytest.raises(ValidationError):
ReviewIssue(severity="warning") # type: ignore[call-arg]
# ---------------------------------------------------------------------------
# ReviewResult tests
# ---------------------------------------------------------------------------
class TestReviewResult:
"""Tests for ReviewResult model."""
def _make_blocker_issue(self) -> ReviewIssue:
return ReviewIssue(severity="blocker", description="Must fix this")
def _make_warning_issue(self) -> ReviewIssue:
return ReviewIssue(severity="warning", description="Minor issue")
def _make_result(self, **kwargs) -> ReviewResult:
defaults = {
"verdict": "approve",
"summary": "Looks good overall",
"issues": [],
}
defaults.update(kwargs)
return ReviewResult(**defaults)
def test_valid_approve_verdict(self) -> None:
result = self._make_result(verdict="approve")
assert result.verdict == "approve"
def test_valid_request_changes_verdict(self) -> None:
result = self._make_result(verdict="request_changes")
assert result.verdict == "request_changes"
def test_invalid_verdict_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_result(verdict="reject")
def test_summary_stored(self) -> None:
result = self._make_result(summary="Great PR")
assert result.summary == "Great PR"
def test_issues_empty_by_default(self) -> None:
result = self._make_result()
assert len(result.issues) == 0
def test_has_blockers_false_with_no_issues(self) -> None:
result = self._make_result(issues=[])
assert result.has_blockers is False
def test_has_blockers_false_with_only_warnings(self) -> None:
result = self._make_result(issues=[self._make_warning_issue()])
assert result.has_blockers is False
def test_has_blockers_true_with_blocker_issue(self) -> None:
result = self._make_result(issues=[self._make_blocker_issue()])
assert result.has_blockers is True
def test_has_blockers_true_mixed_issues(self) -> None:
result = self._make_result(
issues=[self._make_warning_issue(), self._make_blocker_issue()]
)
assert result.has_blockers is True
def test_model_is_frozen(self) -> None:
result = self._make_result()
with pytest.raises(ValidationError):
result.verdict = "request_changes" # type: ignore[misc]
def test_multiple_issues_stored(self) -> None:
issues = [self._make_warning_issue(), self._make_blocker_issue()]
result = self._make_result(issues=issues)
assert len(result.issues) == 2
def test_has_blockers_is_computed(self) -> None:
# Verify has_blockers cannot be set directly (it's computed)
result = self._make_result(issues=[self._make_blocker_issue()])
assert result.has_blockers is True
# ---------------------------------------------------------------------------
# JiraTransition tests
# ---------------------------------------------------------------------------
class TestJiraTransition:
"""Tests for JiraTransition model."""
def test_valid_transition(self) -> None:
transition = JiraTransition(id="11", name="To Do")
assert transition.id == "11"
assert transition.name == "To Do"
def test_model_is_frozen(self) -> None:
transition = JiraTransition(id="11", name="To Do")
with pytest.raises(ValidationError):
transition.id = "22" # type: ignore[misc]
def test_id_required(self) -> None:
with pytest.raises(ValidationError):
JiraTransition(name="To Do") # type: ignore[call-arg]
def test_name_required(self) -> None:
with pytest.raises(ValidationError):
JiraTransition(id="11") # type: ignore[call-arg]
# ---------------------------------------------------------------------------
# JiraIssue tests
# ---------------------------------------------------------------------------
class TestJiraIssue:
"""Tests for JiraIssue model."""
def test_valid_issue(self) -> None:
issue = JiraIssue(key="ALLPOST-100", summary="Fix the bug", status="In Progress")
assert issue.key == "ALLPOST-100"
assert issue.summary == "Fix the bug"
assert issue.status == "In Progress"
def test_model_is_frozen(self) -> None:
issue = JiraIssue(key="ALLPOST-100", summary="Fix the bug", status="In Progress")
with pytest.raises(ValidationError):
issue.key = "ALLPOST-200" # type: ignore[misc]
def test_key_required(self) -> None:
with pytest.raises(ValidationError):
JiraIssue(summary="Fix the bug", status="In Progress") # type: ignore[call-arg]
def test_summary_required(self) -> None:
with pytest.raises(ValidationError):
JiraIssue(key="ALLPOST-100", status="In Progress") # type: ignore[call-arg]
def test_status_required(self) -> None:
with pytest.raises(ValidationError):
JiraIssue(key="ALLPOST-100", summary="Fix the bug") # type: ignore[call-arg]
def test_various_statuses(self) -> None:
statuses = ["To Do", "In Progress", "Done", "Released"]
for status in statuses:
issue = JiraIssue(key="ALLPOST-1", summary="Test", status=status)
assert issue.status == status

148
tests/test_models_build.py Normal file
View File

@@ -0,0 +1,148 @@
"""Tests for models/build.py — BuildStatus and ApprovalRecord.
Written FIRST (TDD RED phase).
"""
import pytest
from dataclasses import FrozenInstanceError
from release_agent.models.build import ApprovalRecord, BuildStatus
# ---------------------------------------------------------------------------
# BuildStatus tests
# ---------------------------------------------------------------------------
class TestBuildStatus:
"""Tests for BuildStatus frozen dataclass."""
def test_can_be_created_with_all_fields(self) -> None:
bs = BuildStatus(
status="completed",
result="succeeded",
build_url="https://dev.azure.com/org/proj/_build/results?buildId=42",
)
assert bs.status == "completed"
assert bs.result == "succeeded"
assert bs.build_url == "https://dev.azure.com/org/proj/_build/results?buildId=42"
def test_result_can_be_none(self) -> None:
bs = BuildStatus(
status="inProgress",
result=None,
build_url="https://dev.azure.com/org/proj/_build/results?buildId=99",
)
assert bs.result is None
def test_build_url_can_be_none(self) -> None:
bs = BuildStatus(status="notStarted", result=None, build_url=None)
assert bs.build_url is None
def test_is_frozen_status(self) -> None:
bs = BuildStatus(status="completed", result="succeeded", build_url=None)
with pytest.raises((FrozenInstanceError, AttributeError)):
bs.status = "inProgress" # type: ignore[misc]
def test_is_frozen_result(self) -> None:
bs = BuildStatus(status="completed", result="succeeded", build_url=None)
with pytest.raises((FrozenInstanceError, AttributeError)):
bs.result = "failed" # type: ignore[misc]
def test_equality(self) -> None:
a = BuildStatus(status="completed", result="succeeded", build_url="http://x")
b = BuildStatus(status="completed", result="succeeded", build_url="http://x")
assert a == b
def test_inequality_on_status(self) -> None:
a = BuildStatus(status="completed", result="succeeded", build_url=None)
b = BuildStatus(status="inProgress", result="succeeded", build_url=None)
assert a != b
def test_inequality_on_result(self) -> None:
a = BuildStatus(status="completed", result="succeeded", build_url=None)
b = BuildStatus(status="completed", result="failed", build_url=None)
assert a != b
def test_repr_contains_status(self) -> None:
bs = BuildStatus(status="completed", result="succeeded", build_url=None)
assert "completed" in repr(bs)
def test_status_values_typical(self) -> None:
for s in ("notStarted", "inProgress", "completed", "cancelling"):
bs = BuildStatus(status=s, result=None, build_url=None)
assert bs.status == s
def test_result_values_typical(self) -> None:
for r in ("succeeded", "failed", "canceled", "partiallySucceeded"):
bs = BuildStatus(status="completed", result=r, build_url=None)
assert bs.result == r
# ---------------------------------------------------------------------------
# ApprovalRecord tests
# ---------------------------------------------------------------------------
class TestApprovalRecord:
"""Tests for ApprovalRecord frozen dataclass."""
def test_can_be_created_with_all_fields(self) -> None:
ar = ApprovalRecord(
approval_id="approval-abc-123",
stage_name="Sandbox",
status="pending",
release_id=42,
)
assert ar.approval_id == "approval-abc-123"
assert ar.stage_name == "Sandbox"
assert ar.status == "pending"
assert ar.release_id == 42
def test_is_frozen_approval_id(self) -> None:
ar = ApprovalRecord(
approval_id="abc",
stage_name="Sandbox",
status="pending",
release_id=1,
)
with pytest.raises((FrozenInstanceError, AttributeError)):
ar.approval_id = "xyz" # type: ignore[misc]
def test_is_frozen_stage_name(self) -> None:
ar = ApprovalRecord(
approval_id="abc",
stage_name="Sandbox",
status="pending",
release_id=1,
)
with pytest.raises((FrozenInstanceError, AttributeError)):
ar.stage_name = "Production" # type: ignore[misc]
def test_equality(self) -> None:
a = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1)
b = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1)
assert a == b
def test_inequality_on_approval_id(self) -> None:
a = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1)
b = ApprovalRecord(approval_id="y", stage_name="S", status="pending", release_id=1)
assert a != b
def test_status_pending(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="pending", release_id=10)
assert ar.status == "pending"
def test_status_approved(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="approved", release_id=10)
assert ar.status == "approved"
def test_status_rejected(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="rejected", release_id=10)
assert ar.status == "rejected"
def test_repr_contains_stage_name(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="Production", status="pending", release_id=5)
assert "Production" in repr(ar)
def test_release_id_is_int(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="S", status="pending", release_id=999)
assert isinstance(ar.release_id, int)

241
tests/test_state.py Normal file
View File

@@ -0,0 +1,241 @@
"""Tests for LangGraph state module. Written FIRST (TDD RED phase)."""
import json
from release_agent.state import ReleaseState, add_errors, add_messages
# ---------------------------------------------------------------------------
# ReleaseState tests
# ---------------------------------------------------------------------------
class TestReleaseState:
"""Tests for ReleaseState TypedDict."""
def test_empty_state_is_valid(self) -> None:
# total=False means all fields are optional
state: ReleaseState = {}
assert state == {}
def test_partial_state_with_repo(self) -> None:
state: ReleaseState = {"repo_name": "my-repo"}
assert state["repo_name"] == "my-repo"
def test_partial_state_with_messages(self) -> None:
state: ReleaseState = {"messages": ["Hello"]}
assert state["messages"] == ["Hello"]
def test_partial_state_with_errors(self) -> None:
state: ReleaseState = {"errors": ["Something went wrong"]}
assert state["errors"] == ["Something went wrong"]
def test_state_with_pr_id(self) -> None:
state: ReleaseState = {"pr_id": "PR-42"}
assert state["pr_id"] == "PR-42"
def test_state_with_ticket_id(self) -> None:
state: ReleaseState = {"ticket_id": "ALLPOST-100"}
assert state["ticket_id"] == "ALLPOST-100"
def test_state_with_version(self) -> None:
state: ReleaseState = {"version": "v1.0.1"}
assert state["version"] == "v1.0.1"
# Phase 3 new fields
def test_state_with_webhook_payload(self) -> None:
state: ReleaseState = {"webhook_payload": {"event_type": "git.pullrequest.merged"}}
assert state["webhook_payload"]["event_type"] == "git.pullrequest.merged"
def test_state_with_pr_info(self) -> None:
state: ReleaseState = {"pr_info": {"pr_id": "42", "repo_name": "my-repo"}}
assert state["pr_info"]["repo_name"] == "my-repo"
def test_state_with_pr_diff(self) -> None:
state: ReleaseState = {"pr_diff": "edit: src/main.py"}
assert state["pr_diff"] == "edit: src/main.py"
def test_state_with_last_merge_source_commit(self) -> None:
state: ReleaseState = {"last_merge_source_commit": "abc123"}
assert state["last_merge_source_commit"] == "abc123"
def test_state_with_ticket_summary(self) -> None:
state: ReleaseState = {"ticket_summary": "Fix login bug"}
assert state["ticket_summary"] == "Fix login bug"
def test_state_with_has_ticket(self) -> None:
state: ReleaseState = {"has_ticket": True}
assert state["has_ticket"] is True
def test_state_with_review_result(self) -> None:
state: ReleaseState = {"review_result": {"verdict": "approve", "summary": "LGTM"}}
assert state["review_result"]["verdict"] == "approve"
def test_state_with_review_approved(self) -> None:
state: ReleaseState = {"review_approved": True}
assert state["review_approved"] is True
def test_state_with_staging(self) -> None:
state: ReleaseState = {"staging": {"version": "v1.0.0", "tickets": []}}
assert state["staging"]["version"] == "v1.0.0"
def test_state_with_pr_already_merged(self) -> None:
state: ReleaseState = {"pr_already_merged": False}
assert state["pr_already_merged"] is False
def test_state_with_release_pr_id(self) -> None:
state: ReleaseState = {"release_pr_id": "123"}
assert state["release_pr_id"] == "123"
def test_state_with_release_pr_commit(self) -> None:
state: ReleaseState = {"release_pr_commit": "deadbeef"}
assert state["release_pr_commit"] == "deadbeef"
def test_state_with_pipelines(self) -> None:
state: ReleaseState = {"pipelines": [{"id": 1, "name": "build"}]}
assert len(state["pipelines"]) == 1
def test_state_with_triggered_builds(self) -> None:
state: ReleaseState = {"triggered_builds": [{"id": 99}]}
assert state["triggered_builds"][0]["id"] == 99
def test_state_with_pending_approvals(self) -> None:
state: ReleaseState = {"pending_approvals": [{"approval_id": "aaa"}]}
assert state["pending_approvals"][0]["approval_id"] == "aaa"
def test_state_with_continue_to_release(self) -> None:
state: ReleaseState = {"continue_to_release": True}
assert state["continue_to_release"] is True
# Phase 5: CI/CD and approval fields
def test_state_with_ci_build_id(self) -> None:
state: ReleaseState = {"ci_build_id": 12345}
assert state["ci_build_id"] == 12345
def test_state_with_ci_build_status(self) -> None:
state: ReleaseState = {"ci_build_status": "inProgress"}
assert state["ci_build_status"] == "inProgress"
def test_state_with_ci_build_result(self) -> None:
state: ReleaseState = {"ci_build_result": "succeeded"}
assert state["ci_build_result"] == "succeeded"
def test_state_with_ci_build_url(self) -> None:
state: ReleaseState = {"ci_build_url": "https://dev.azure.com/org/proj/_build/results?buildId=99"}
assert "buildId=99" in state["ci_build_url"]
def test_state_with_release_definition_id(self) -> None:
state: ReleaseState = {"release_definition_id": 7}
assert state["release_definition_id"] == 7
def test_state_with_release_id(self) -> None:
state: ReleaseState = {"release_id": 456}
assert state["release_id"] == 456
def test_state_with_current_stage(self) -> None:
state: ReleaseState = {"current_stage": "sandbox_pending"}
assert state["current_stage"] == "sandbox_pending"
def test_state_with_approval_message_ts(self) -> None:
state: ReleaseState = {"approval_message_ts": "1234567890.123456"}
assert state["approval_message_ts"] == "1234567890.123456"
def test_state_with_slack_message_ts(self) -> None:
state: ReleaseState = {"slack_message_ts": "9876543210.000001"}
assert state["slack_message_ts"] == "9876543210.000001"
def test_state_json_serializable_empty(self) -> None:
state: ReleaseState = {}
serialized = json.dumps(state)
assert json.loads(serialized) == {}
def test_state_json_serializable_with_strings(self) -> None:
state: ReleaseState = {
"repo_name": "my-repo",
"pr_id": "PR-1",
"ticket_id": "ALLPOST-1",
"version": "v1.0.0",
}
serialized = json.dumps(state)
loaded = json.loads(serialized)
assert loaded["repo_name"] == "my-repo"
assert loaded["pr_id"] == "PR-1"
def test_state_json_serializable_with_lists(self) -> None:
state: ReleaseState = {
"messages": ["msg1", "msg2"],
"errors": ["err1"],
}
serialized = json.dumps(state)
loaded = json.loads(serialized)
assert loaded["messages"] == ["msg1", "msg2"]
assert loaded["errors"] == ["err1"]
# ---------------------------------------------------------------------------
# Reducer tests
# ---------------------------------------------------------------------------
class TestAddMessages:
"""Tests for add_messages reducer."""
def test_accumulates_to_empty(self) -> None:
result = add_messages([], ["Hello"])
assert result == ["Hello"]
def test_accumulates_to_existing(self) -> None:
result = add_messages(["Hello"], ["World"])
assert result == ["Hello", "World"]
def test_accumulates_multiple(self) -> None:
result = add_messages(["A", "B"], ["C", "D"])
assert result == ["A", "B", "C", "D"]
def test_existing_unchanged(self) -> None:
existing = ["Hello"]
add_messages(existing, ["World"])
# Original should not be mutated
assert existing == ["Hello"]
def test_empty_new_messages(self) -> None:
result = add_messages(["Hello"], [])
assert result == ["Hello"]
def test_both_empty(self) -> None:
result = add_messages([], [])
assert result == []
def test_returns_new_list(self) -> None:
existing = ["Hello"]
new_msgs = ["World"]
result = add_messages(existing, new_msgs)
assert result is not existing
assert result is not new_msgs
class TestAddErrors:
"""Tests for add_errors reducer."""
def test_accumulates_to_empty(self) -> None:
result = add_errors([], ["Error occurred"])
assert result == ["Error occurred"]
def test_accumulates_to_existing(self) -> None:
result = add_errors(["First error"], ["Second error"])
assert result == ["First error", "Second error"]
def test_existing_unchanged(self) -> None:
existing = ["First error"]
add_errors(existing, ["Second error"])
assert existing == ["First error"]
def test_empty_new_errors(self) -> None:
result = add_errors(["Existing"], [])
assert result == ["Existing"]
def test_both_empty(self) -> None:
result = add_errors([], [])
assert result == []
def test_returns_new_list(self) -> None:
existing = ["Error"]
result = add_errors(existing, ["New error"])
assert result is not existing

124
tests/test_versioning.py Normal file
View File

@@ -0,0 +1,124 @@
"""Tests for versioning module. Written FIRST (TDD RED phase)."""
import pytest
from release_agent.versioning import (
calculate_next_version,
format_version,
parse_version,
)
class TestParseVersion:
"""Tests for parse_version function."""
def test_parse_with_v_prefix(self) -> None:
assert parse_version("v1.2.3") == (1, 2, 3)
def test_parse_without_v_prefix(self) -> None:
assert parse_version("1.2.3") == (1, 2, 3)
def test_parse_zeros(self) -> None:
assert parse_version("v0.0.0") == (0, 0, 0)
def test_parse_large_numbers(self) -> None:
assert parse_version("v10.20.300") == (10, 20, 300)
def test_parse_returns_tuple_of_ints(self) -> None:
result = parse_version("v1.2.3")
assert isinstance(result, tuple)
assert len(result) == 3
assert all(isinstance(x, int) for x in result)
def test_parse_invalid_raises_value_error(self) -> None:
with pytest.raises(ValueError):
parse_version("invalid")
def test_parse_partial_version_raises_value_error(self) -> None:
with pytest.raises(ValueError):
parse_version("v1.2")
def test_parse_non_numeric_raises_value_error(self) -> None:
with pytest.raises(ValueError):
parse_version("va.b.c")
class TestFormatVersion:
"""Tests for format_version function."""
def test_format_basic(self) -> None:
assert format_version(1, 0, 3) == "v1.0.3"
def test_format_zeros(self) -> None:
assert format_version(0, 0, 0) == "v0.0.0"
def test_format_large_numbers(self) -> None:
assert format_version(10, 20, 300) == "v10.20.300"
def test_format_returns_string(self) -> None:
result = format_version(1, 2, 3)
assert isinstance(result, str)
def test_format_starts_with_v(self) -> None:
result = format_version(1, 2, 3)
assert result.startswith("v")
class TestCalculateNextVersion:
"""Tests for calculate_next_version function."""
def test_empty_list_returns_v1_0_0(self) -> None:
assert calculate_next_version("my-repo", []) == "v1.0.0"
def test_single_version_increments_patch(self) -> None:
assert calculate_next_version("my-repo", ["v1.0.0"]) == "v1.0.1"
def test_multiple_versions_uses_highest(self) -> None:
assert calculate_next_version("my-repo", ["v1.0.3", "v1.0.1"]) == "v1.0.4"
def test_different_major_versions(self) -> None:
assert calculate_next_version("my-repo", ["v2.1.0", "v1.9.9"]) == "v2.1.1"
def test_skips_malformed_versions(self) -> None:
assert calculate_next_version("my-repo", ["invalid", "v1.0.0"]) == "v1.0.1"
def test_all_malformed_versions_returns_v1_0_0(self) -> None:
assert calculate_next_version("my-repo", ["invalid", "bad", "nope"]) == "v1.0.0"
def test_repo_name_does_not_affect_result(self) -> None:
result_a = calculate_next_version("repo-a", ["v1.0.0"])
result_b = calculate_next_version("repo-b", ["v1.0.0"])
assert result_a == result_b
def test_versions_out_of_order(self) -> None:
assert calculate_next_version("my-repo", ["v1.0.1", "v1.0.3", "v1.0.2"]) == "v1.0.4"
def test_patch_overflow_does_not_occur(self) -> None:
# Just increments patch - no overflow logic required
result = calculate_next_version("my-repo", ["v1.0.99"])
assert result == "v1.0.100"
def test_versions_without_v_prefix_skipped(self) -> None:
# Versions without 'v' prefix are treated as malformed per spec
result = calculate_next_version("my-repo", ["1.0.0", "v2.0.0"])
assert result == "v2.0.1"
def test_result_format_starts_with_v(self) -> None:
result = calculate_next_version("my-repo", ["v1.0.0"])
assert result.startswith("v")
def test_result_has_three_parts(self) -> None:
result = calculate_next_version("my-repo", ["v1.0.0"])
parts = result[1:].split(".")
assert len(parts) == 3
assert all(p.isdigit() for p in parts)
def test_v_prefix_with_nonnumeric_parts_skipped(self) -> None:
# Starts with 'v' but is malformed - should be skipped gracefully
result = calculate_next_version("my-repo", ["va.b.c", "v1.0.0"])
assert result == "v1.0.1"
def test_v_prefix_partial_version_skipped(self) -> None:
# Starts with 'v' but only has two parts - should be skipped
result = calculate_next_version("my-repo", ["v1.0", "v2.0.0"])
assert result == "v2.0.1"

0
tests/tools/__init__.py Normal file
View File

View File

@@ -0,0 +1,9 @@
{
"id": "approval-uuid-123",
"status": "approved",
"approver": {
"id": "user-uuid-456",
"displayName": "Release Bot"
},
"comments": "Approved via release agent"
}

View File

@@ -0,0 +1,9 @@
{
"id": 1001,
"buildNumber": "20240115.1",
"status": "completed",
"result": "succeeded",
"queueTime": "2024-01-15T10:00:00Z",
"startTime": "2024-01-15T10:01:00Z",
"finishTime": "2024-01-15T10:10:00Z"
}

View File

@@ -0,0 +1,8 @@
{
"pullRequestId": 99,
"title": "Release v1.2.0",
"status": "active",
"sourceRefName": "refs/heads/release/v1.2.0",
"targetRefName": "refs/heads/main",
"url": "https://dev.azure.com/my-org/my-project/_apis/git/repositories/my-repo/pullRequests/99"
}

View File

@@ -0,0 +1,8 @@
{
"pullRequestId": 42,
"status": "completed",
"title": "Fix the auth bug",
"completionOptions": {
"mergeStrategy": "squash"
}
}

View File

@@ -0,0 +1,15 @@
{
"value": [
{
"id": 10,
"name": "Release Pipeline",
"folder": "\\"
},
{
"id": 20,
"name": "Build Pipeline",
"folder": "\\"
}
],
"count": 2
}

View File

@@ -0,0 +1,16 @@
{
"pullRequestId": 42,
"title": "Fix the auth bug",
"status": "active",
"sourceRefName": "refs/heads/bug/ALLPOST-999_fix-auth",
"targetRefName": "refs/heads/main",
"url": "https://dev.azure.com/my-org/my-project/_apis/git/repositories/my-repo/pullRequests/42",
"repository": {
"id": "repo-uuid-123",
"name": "my-repo",
"remoteUrl": "https://dev.azure.com/my-org/my-project/_git/my-repo"
},
"lastMergeSourceCommit": {
"commitId": "abc123def456"
}
}

View File

@@ -0,0 +1,11 @@
diff --git a/src/auth.py b/src/auth.py
index 1234567..abcdefg 100644
--- a/src/auth.py
+++ b/src/auth.py
@@ -10,6 +10,10 @@ class AuthService:
def authenticate(self, token: str) -> bool:
- return token == "hardcoded"
+ return self._validate_token(token)
+
+ def _validate_token(self, token: str) -> bool:
+ return len(token) > 0 and token.startswith("Bearer ")

View File

@@ -0,0 +1,11 @@
{
"id": 1001,
"buildNumber": "20240115.1",
"status": "notStarted",
"queueTime": "2024-01-15T10:00:00Z",
"definition": {
"id": 10,
"name": "Release Pipeline"
},
"sourceBranch": "refs/heads/main"
}

View File

@@ -0,0 +1,10 @@
{
"id": "12345",
"key": "ALLPOST-100",
"fields": {
"summary": "Fix the authentication bug",
"status": {
"name": "In Progress"
}
}
}

View File

@@ -0,0 +1,20 @@
{
"transitions": [
{
"id": "11",
"name": "To Do"
},
{
"id": "21",
"name": "In Progress"
},
{
"id": "31",
"name": "Done"
},
{
"id": "41",
"name": "Released"
}
]
}

819
tests/tools/test_azdo.py Normal file
View File

@@ -0,0 +1,819 @@
"""Tests for AzDoClient. Written FIRST (TDD RED phase)."""
import json
from pathlib import Path
import httpx
import pytest
from release_agent.exceptions import AuthenticationError, NotFoundError, ServiceError
from release_agent.models.build import ApprovalRecord, BuildStatus
from release_agent.models.pipeline import PipelineInfo
from release_agent.models.pr import PRInfo
from release_agent.tools.azdo import AzDoClient
# ---------------------------------------------------------------------------
# Fixture helpers
# ---------------------------------------------------------------------------
FIXTURES = Path(__file__).parent / "fixtures"
def _load_json(name: str) -> dict:
return json.loads((FIXTURES / name).read_text())
def _load_text(name: str) -> str:
return (FIXTURES / name).read_text()
def _make_transport(routes: dict[tuple[str, str], tuple[int, bytes | str]]) -> httpx.MockTransport:
"""Build a MockTransport that dispatches based on (method, url_substring)."""
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
for (m, url_fragment), (status, body) in routes.items():
if m == method and url_fragment in url:
content = body if isinstance(body, bytes) else body.encode()
return httpx.Response(status_code=status, content=content)
return httpx.Response(status_code=404, content=b'{"message": "Not found"}')
return httpx.MockTransport(handler)
def _make_client(routes: dict) -> AzDoClient:
"""Create an AzDoClient with mocked HTTP transport."""
transport = _make_transport(routes)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
return AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
# ---------------------------------------------------------------------------
# AzDoClient construction tests
# ---------------------------------------------------------------------------
class TestAzDoClientConstruction:
"""Tests for AzDoClient initialization."""
def test_can_be_instantiated_with_injected_clients(self) -> None:
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/org/proj/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis",
pat="my-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
assert client is not None
async def test_context_manager_closes_clients(self) -> None:
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
async with AzDoClient(
base_url="https://dev.azure.com/org/proj/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis",
pat="my-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
) as client:
assert client is not None
# After context manager exits, clients should be closed
assert http_client.is_closed
assert vsrm_client.is_closed
# ---------------------------------------------------------------------------
# get_pr tests
# ---------------------------------------------------------------------------
class TestGetPr:
"""Tests for AzDoClient.get_pr."""
async def test_returns_pr_info(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
client = _make_client(routes)
result = await client.get_pr(42)
assert isinstance(result, PRInfo)
assert result.pr_id == "42"
async def test_pr_title_extracted(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
client = _make_client(routes)
result = await client.get_pr(42)
assert result.pr_title == "Fix the auth bug"
async def test_pr_branch_extracted(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
client = _make_client(routes)
result = await client.get_pr(42)
assert "ALLPOST-999" in result.branch or "bug" in result.branch
async def test_pr_status_extracted(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
client = _make_client(routes)
result = await client.get_pr(42)
assert result.pr_status == "active"
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "pullRequests/999"): (404, b'{"message": "PR not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_pr(999)
async def test_401_raises_authentication_error(self) -> None:
routes = {("GET", "pullRequests/42"): (401, b'{"message": "Unauthorized"}')}
client = _make_client(routes)
with pytest.raises(AuthenticationError):
await client.get_pr(42)
async def test_500_raises_service_error(self) -> None:
routes = {("GET", "pullRequests/42"): (500, b'{"message": "Internal error"}')}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.get_pr(42)
# ---------------------------------------------------------------------------
# get_pr_diff tests
# ---------------------------------------------------------------------------
class TestGetPrDiff:
"""Tests for AzDoClient.get_pr_diff."""
async def test_returns_diff_string(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {
("GET", "pullRequests/42"): (200, json.dumps(pr_data)),
("GET", "diffs"): (200, json.dumps({
"changes": [
{
"item": {"path": "/src/auth.py"},
"changeType": "edit",
}
]
})),
}
client = _make_client(routes)
result = await client.get_pr_diff(42)
assert isinstance(result, str)
async def test_diff_includes_file_paths(self) -> None:
pr_data = _load_json("azdo_pr.json")
diff_data = {
"changes": [
{"item": {"path": "/src/auth.py"}, "changeType": "edit"},
{"item": {"path": "/src/util.py"}, "changeType": "add"},
]
}
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if "diffs" in url:
return httpx.Response(200, content=json.dumps(diff_data).encode())
if "pullRequests/42" in url:
return httpx.Response(200, content=json.dumps(pr_data).encode())
return httpx.Response(404, content=b"{}")
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
result = await client.get_pr_diff(42)
assert "/src/auth.py" in result
assert "/src/util.py" in result
async def test_empty_changes_returns_empty_string(self) -> None:
pr_data = _load_json("azdo_pr.json")
diff_data: dict = {"changes": []}
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if "diffs" in url:
return httpx.Response(200, content=json.dumps(diff_data).encode())
if "pullRequests/42" in url:
return httpx.Response(200, content=json.dumps(pr_data).encode())
return httpx.Response(404, content=b"{}")
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
result = await client.get_pr_diff(42)
assert result == ""
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "pullRequests/999"): (404, b'{"message": "PR not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_pr_diff(999)
async def test_pr_without_url_field_uses_remote_url(self) -> None:
"""When the API response omits the 'url' field, fallback URL is built."""
pr_data = {
"pullRequestId": 42,
"title": "Fix bug",
"status": "active",
"sourceRefName": "refs/heads/fix/ALLPOST-1_fix",
"repository": {
"id": "repo-uuid",
"name": "my-repo",
"remoteUrl": "https://dev.azure.com/org/proj/_git/my-repo",
},
# NOTE: 'url' field is intentionally omitted
}
routes = {
("GET", "pullRequests/42"): (200, json.dumps(pr_data)),
("GET", "diffs"): (200, json.dumps({"changes": []})),
}
client = _make_client(routes)
result = await client.get_pr(42)
assert "42" in str(result.pr_url)
# ---------------------------------------------------------------------------
# merge_pr tests
# ---------------------------------------------------------------------------
class TestMergePr:
"""Tests for AzDoClient.merge_pr."""
async def test_returns_true_on_success(self) -> None:
merge_data = _load_json("azdo_merge_pr.json")
routes = {("PATCH", "pullRequests/42"): (200, json.dumps(merge_data))}
client = _make_client(routes)
result = await client.merge_pr(pr_id=42, last_merge_source_commit="abc123def456")
assert result is True
async def test_404_raises_not_found(self) -> None:
routes = {("PATCH", "pullRequests/999"): (404, b'{"message": "Not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.merge_pr(pr_id=999, last_merge_source_commit="abc123")
async def test_409_raises_service_error(self) -> None:
routes = {("PATCH", "pullRequests/42"): (409, b'{"message": "Conflict"}')}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.merge_pr(pr_id=42, last_merge_source_commit="abc123")
# ---------------------------------------------------------------------------
# create_pr tests
# ---------------------------------------------------------------------------
class TestCreatePr:
"""Tests for AzDoClient.create_pr."""
async def test_returns_dict_with_pr_id(self) -> None:
create_data = _load_json("azdo_create_pr.json")
routes = {("POST", "pullRequests"): (201, json.dumps(create_data))}
client = _make_client(routes)
result = await client.create_pr(
repo="my-repo",
source="refs/heads/release/v1.2.0",
target="refs/heads/main",
title="Release v1.2.0",
description="Release notes",
)
assert isinstance(result, dict)
assert result["pullRequestId"] == 99
async def test_400_raises_service_error(self) -> None:
routes = {("POST", "pullRequests"): (400, b'{"message": "Bad request"}')}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.create_pr(
repo="my-repo",
source="refs/heads/release/v1.2.0",
target="refs/heads/main",
title="Release",
description="",
)
# ---------------------------------------------------------------------------
# list_build_pipelines tests
# ---------------------------------------------------------------------------
class TestListBuildPipelines:
"""Tests for AzDoClient.list_build_pipelines."""
async def test_returns_list_of_pipeline_info(self) -> None:
pipeline_data = _load_json("azdo_pipelines.json")
routes = {("GET", "pipelines"): (200, json.dumps(pipeline_data))}
client = _make_client(routes)
result = await client.list_build_pipelines(repo="my-repo")
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(p, PipelineInfo) for p in result)
async def test_pipeline_ids_extracted(self) -> None:
pipeline_data = _load_json("azdo_pipelines.json")
routes = {("GET", "pipelines"): (200, json.dumps(pipeline_data))}
client = _make_client(routes)
result = await client.list_build_pipelines(repo="my-repo")
ids = [p.id for p in result]
assert 10 in ids
assert 20 in ids
async def test_empty_list_on_no_pipelines(self) -> None:
routes = {("GET", "pipelines"): (200, json.dumps({"value": [], "count": 0}))}
client = _make_client(routes)
result = await client.list_build_pipelines(repo="my-repo")
assert result == []
# ---------------------------------------------------------------------------
# trigger_pipeline tests
# ---------------------------------------------------------------------------
class TestTriggerPipeline:
"""Tests for AzDoClient.trigger_pipeline."""
async def test_returns_dict_with_build_id(self) -> None:
trigger_data = _load_json("azdo_trigger_pipeline.json")
routes = {("POST", "pipelines/10/runs"): (200, json.dumps(trigger_data))}
client = _make_client(routes)
result = await client.trigger_pipeline(pipeline_id=10, branch="refs/heads/main")
assert isinstance(result, dict)
assert result["id"] == 1001
async def test_404_raises_not_found(self) -> None:
routes = {("POST", "pipelines/999/runs"): (404, b'{"message": "Pipeline not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.trigger_pipeline(pipeline_id=999, branch="refs/heads/main")
# ---------------------------------------------------------------------------
# get_build_status tests
# ---------------------------------------------------------------------------
class TestGetBuildStatus:
"""Tests for AzDoClient.get_build_status."""
async def test_returns_build_status_object(self) -> None:
build_data = _load_json("azdo_build_status.json")
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=1001)
assert isinstance(result, BuildStatus)
async def test_status_field_populated(self) -> None:
build_data = _load_json("azdo_build_status.json")
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=1001)
assert result.status == "completed"
async def test_result_field_populated(self) -> None:
build_data = _load_json("azdo_build_status.json")
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=1001)
assert result.result == "succeeded"
async def test_build_url_present(self) -> None:
build_data = _load_json("azdo_build_status.json")
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=1001)
# build_url may be None if not in fixture, but field must exist
assert hasattr(result, "build_url")
async def test_result_none_when_not_completed(self) -> None:
build_data = {"id": 99, "status": "inProgress", "buildNumber": "20240101.1"}
routes = {("GET", "build/builds/99"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=99)
assert result.status == "inProgress"
assert result.result is None
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "build/builds/9999"): (404, b'{"message": "Build not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_build_status(build_id=9999)
# ---------------------------------------------------------------------------
# get_release_approvals tests
# ---------------------------------------------------------------------------
class TestGetReleaseApprovals:
"""Tests for AzDoClient.get_release_approvals."""
async def test_returns_list_of_approval_records(self) -> None:
approvals_data = {
"value": [
{
"id": 101,
"status": "pending",
"releaseEnvironment": {"name": "Sandbox", "release": {"id": 55}},
},
{
"id": 102,
"status": "approved",
"releaseEnvironment": {"name": "Production", "release": {"id": 55}},
},
],
"count": 2,
}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=55)
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(a, ApprovalRecord) for a in result)
async def test_approval_id_populated(self) -> None:
approvals_data = {
"value": [
{
"id": 201,
"status": "pending",
"releaseEnvironment": {"name": "Sandbox", "release": {"id": 10}},
}
],
"count": 1,
}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=10)
assert result[0].approval_id == "201"
async def test_stage_name_populated(self) -> None:
approvals_data = {
"value": [
{
"id": 300,
"status": "pending",
"releaseEnvironment": {"name": "Production", "release": {"id": 20}},
}
],
"count": 1,
}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=20)
assert result[0].stage_name == "Production"
async def test_release_id_populated(self) -> None:
approvals_data = {
"value": [
{
"id": 400,
"status": "pending",
"releaseEnvironment": {"name": "Stage", "release": {"id": 99}},
}
],
"count": 1,
}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=99)
assert result[0].release_id == 99
async def test_empty_list_when_no_approvals(self) -> None:
approvals_data = {"value": [], "count": 0}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=77)
assert result == []
async def test_filters_by_release_id_in_query(self) -> None:
captured_urls: list[str] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_urls.append(str(request.url))
return httpx.Response(200, content=b'{"value": [], "count": 0}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
await client.get_release_approvals(release_id=42)
assert any("approvals" in url for url in captured_urls)
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "release/approvals"): (404, b'{"message": "Not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_release_approvals(release_id=999)
# ---------------------------------------------------------------------------
# get_latest_release tests
# ---------------------------------------------------------------------------
class TestGetLatestRelease:
"""Tests for AzDoClient.get_latest_release."""
async def test_returns_dict(self) -> None:
release_data = {
"value": [{"id": 55, "name": "Release-55", "status": "active"}],
"count": 1,
}
routes = {("GET", "release/releases"): (200, json.dumps(release_data))}
client = _make_client(routes)
result = await client.get_latest_release(definition_id=7)
assert isinstance(result, dict)
assert result["id"] == 55
async def test_returns_empty_dict_when_no_releases(self) -> None:
release_data = {"value": [], "count": 0}
routes = {("GET", "release/releases"): (200, json.dumps(release_data))}
client = _make_client(routes)
result = await client.get_latest_release(definition_id=99)
assert result == {}
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "release/releases"): (404, b'{"message": "Not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_latest_release(definition_id=999)
async def test_passes_definition_id_as_filter(self) -> None:
captured_urls: list[str] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_urls.append(str(request.url))
return httpx.Response(200, content=b'{"value": [], "count": 0}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
await client.get_latest_release(definition_id=13)
assert any("releases" in url for url in captured_urls)
# ---------------------------------------------------------------------------
# approve_release tests
# ---------------------------------------------------------------------------
class TestApproveRelease:
"""Tests for AzDoClient.approve_release."""
async def test_returns_dict_with_status(self) -> None:
approve_data = _load_json("azdo_approve_release.json")
routes = {("PATCH", "release/approvals"): (200, json.dumps(approve_data))}
client = _make_client(routes)
result = await client.approve_release(
approval_id="approval-uuid-123", comment="Approved"
)
assert isinstance(result, dict)
assert result["status"] == "approved"
async def test_404_raises_not_found(self) -> None:
routes = {("PATCH", "release/approvals"): (404, b'{"message": "Approval not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.approve_release(approval_id="bad-id", comment="Approve")
# ---------------------------------------------------------------------------
# close() lifecycle tests
# ---------------------------------------------------------------------------
class TestAzDoClientLifecycle:
"""Tests for AzDoClient close() method."""
async def test_close_closes_both_clients(self) -> None:
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/org/proj/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis",
pat="my-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
await client.close()
assert http_client.is_closed
assert vsrm_client.is_closed
# ---------------------------------------------------------------------------
# list_active_prs tests
# ---------------------------------------------------------------------------
def _make_pr_list_response(prs: list[dict]) -> str:
return json.dumps({"value": prs, "count": len(prs)})
def _make_active_pr_item(
pr_id: int = 10,
title: str = "Test PR",
branch: str = "refs/heads/feature/ALLPOST-100_fix",
status: str = "active",
repo_name: str = "my-repo",
) -> dict:
return {
"pullRequestId": pr_id,
"title": title,
"status": status,
"sourceRefName": branch,
"targetRefName": "refs/heads/develop",
"url": f"https://dev.azure.com/org/proj/_apis/git/repositories/{repo_name}/pullRequests/{pr_id}",
"repository": {
"id": "repo-uuid",
"name": repo_name,
"remoteUrl": f"https://dev.azure.com/org/proj/_git/{repo_name}",
},
}
class TestListActivePrs:
"""Tests for AzDoClient.list_active_prs."""
async def test_returns_list_of_pr_info(self) -> None:
pr_item = _make_active_pr_item()
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response([pr_item]),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], PRInfo)
async def test_pr_id_extracted(self) -> None:
pr_item = _make_active_pr_item(pr_id=55)
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response([pr_item]),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert result[0].pr_id == "55"
async def test_pr_title_extracted(self) -> None:
pr_item = _make_active_pr_item(title="My Feature")
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response([pr_item]),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert result[0].pr_title == "My Feature"
async def test_empty_list_when_no_prs(self) -> None:
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response([]),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert result == []
async def test_multiple_prs_returned(self) -> None:
prs = [
_make_active_pr_item(pr_id=10, title="PR 10"),
_make_active_pr_item(pr_id=20, title="PR 20"),
]
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response(prs),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert len(result) == 2
assert {r.pr_id for r in result} == {"10", "20"}
async def test_404_raises_not_found(self) -> None:
routes = {
("GET", "git/repositories/missing-repo/pullRequests"): (
404,
b'{"message": "Repo not found"}',
)
}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.list_active_prs("missing-repo", "refs/heads/develop")
async def test_401_raises_authentication_error(self) -> None:
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
401,
b'{"message": "Unauthorized"}',
)
}
client = _make_client(routes)
with pytest.raises(AuthenticationError):
await client.list_active_prs("my-repo", "refs/heads/develop")
async def test_500_raises_service_error(self) -> None:
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
500,
b'{"message": "Internal error"}',
)
}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.list_active_prs("my-repo", "refs/heads/develop")

View File

@@ -0,0 +1,454 @@
"""Tests for ClaudeReviewer using Claude Code CLI subprocess."""
import json
import pytest
from release_agent.models.review import ReviewResult
from release_agent.tools.claude_review import (
ClaudeReviewer,
_build_prompt,
_parse_cli_output,
_truncate_diff,
)
MAX_DIFF_CHARS = 100_000
# ---------------------------------------------------------------------------
# Helpers — fake subprocess runner
# ---------------------------------------------------------------------------
def _make_cli_output(
verdict: str = "approve",
summary: str = "LGTM",
issues: list | None = None,
) -> str:
"""Build a JSON string mimicking Claude Code CLI --output-format json."""
structured = {
"verdict": verdict,
"summary": summary,
"issues": issues or [],
}
return json.dumps({"result": "", "structured_output": structured})
def _make_subprocess_runner(
stdout: str = "",
stderr: str = "",
returncode: int = 0,
):
"""Return a fake run_subprocess callable that records calls."""
calls: list[dict] = []
async def fake_run(*, cmd, cwd, timeout):
calls.append({"cmd": cmd, "cwd": cwd, "timeout": timeout})
return (stdout, stderr, returncode)
return fake_run, calls
# ---------------------------------------------------------------------------
# _truncate_diff tests
# ---------------------------------------------------------------------------
class TestTruncateDiff:
def test_short_diff_not_truncated(self) -> None:
diff = "short diff"
assert _truncate_diff(diff) == diff
def test_exact_limit_not_truncated(self) -> None:
diff = "x" * MAX_DIFF_CHARS
assert _truncate_diff(diff) == diff
def test_over_limit_truncated(self) -> None:
diff = "x" * (MAX_DIFF_CHARS + 1000)
result = _truncate_diff(diff)
assert len(result) < len(diff)
assert "TRUNCATED" in result
# ---------------------------------------------------------------------------
# _build_prompt tests
# ---------------------------------------------------------------------------
class TestBuildPrompt:
def test_contains_pr_title(self) -> None:
prompt = _build_prompt(diff="d", pr_title="My Title", repo_name="repo")
assert "My Title" in prompt
def test_contains_repo_name(self) -> None:
prompt = _build_prompt(diff="d", pr_title="t", repo_name="my-repo")
assert "my-repo" in prompt
def test_contains_diff(self) -> None:
prompt = _build_prompt(diff="UNIQUE_DIFF", pr_title="t", repo_name="r")
assert "UNIQUE_DIFF" in prompt
# ---------------------------------------------------------------------------
# _parse_cli_output tests
# ---------------------------------------------------------------------------
class TestParseCliOutput:
def test_parses_structured_output(self) -> None:
stdout = _make_cli_output(verdict="approve", summary="Good")
result = _parse_cli_output(stdout)
assert isinstance(result, ReviewResult)
assert result.verdict == "approve"
assert result.summary == "Good"
def test_parses_request_changes(self) -> None:
stdout = _make_cli_output(
verdict="request_changes",
summary="Has issues",
issues=[{"severity": "blocker", "description": "SQL injection"}],
)
result = _parse_cli_output(stdout)
assert result.verdict == "request_changes"
assert len(result.issues) == 1
assert result.has_blockers is True
def test_parses_issues_with_optional_fields(self) -> None:
stdout = _make_cli_output(
verdict="request_changes",
summary="Issues found",
issues=[{
"severity": "warning",
"description": "Style issue",
"file_path": "src/foo.py",
"suggestion": "Fix it",
}],
)
result = _parse_cli_output(stdout)
assert result.issues[0].file_path == "src/foo.py"
assert result.issues[0].suggestion == "Fix it"
def test_empty_issues_no_blockers(self) -> None:
stdout = _make_cli_output(verdict="approve", summary="Clean", issues=[])
result = _parse_cli_output(stdout)
assert result.has_blockers is False
assert len(result.issues) == 0
def test_result_field_as_json_string(self) -> None:
"""When structured_output is absent, falls back to parsing result as JSON."""
inner = {"verdict": "approve", "summary": "OK", "issues": []}
stdout = json.dumps({"result": json.dumps(inner)})
result = _parse_cli_output(stdout)
assert result.verdict == "approve"
def test_invalid_json_raises(self) -> None:
with pytest.raises(ValueError, match="Failed to parse"):
_parse_cli_output("not json at all")
def test_missing_structured_output_and_result_raises(self) -> None:
with pytest.raises(ValueError, match="No structured_output"):
_parse_cli_output(json.dumps({"other": "data"}))
def test_non_dict_structured_output_raises(self) -> None:
stdout = json.dumps({"structured_output": ["not", "a", "dict"]})
with pytest.raises(ValueError, match="Expected dict"):
_parse_cli_output(stdout)
def test_result_is_non_json_string_raises(self) -> None:
stdout = json.dumps({"result": "just plain text, not json"})
with pytest.raises(ValueError, match="not valid JSON"):
_parse_cli_output(stdout)
# ---------------------------------------------------------------------------
# ClaudeReviewer construction tests
# ---------------------------------------------------------------------------
class TestClaudeReviewerConstruction:
def test_can_be_instantiated(self) -> None:
reviewer = ClaudeReviewer()
assert reviewer is not None
def test_custom_claude_cmd(self) -> None:
reviewer = ClaudeReviewer(claude_cmd="/usr/local/bin/claude")
assert reviewer._claude_cmd == "/usr/local/bin/claude"
def test_custom_timeout(self) -> None:
reviewer = ClaudeReviewer(timeout=60)
assert reviewer._timeout == 60
# ---------------------------------------------------------------------------
# review_pr tests
# ---------------------------------------------------------------------------
class TestReviewPr:
async def test_returns_review_result(self) -> None:
stdout = _make_cli_output(verdict="approve", summary="Looks good")
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
result = await reviewer.review_pr(
diff="diff --git a/foo.py ...",
pr_title="Fix bug",
repo_name="my-repo",
)
assert isinstance(result, ReviewResult)
assert result.verdict == "approve"
async def test_passes_cwd_to_subprocess(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(
diff="diff",
pr_title="PR",
repo_name="repo",
cwd="/path/to/worktree",
)
assert calls[0]["cwd"] == "/path/to/worktree"
async def test_cmd_includes_claude_p(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
assert cmd[0] == "claude"
assert "-p" in cmd
async def test_cmd_includes_output_format_json(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
idx = cmd.index("--output-format")
assert cmd[idx + 1] == "json"
async def test_cmd_includes_json_schema(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
assert "--json-schema" in cmd
async def test_cmd_includes_allowed_tools(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
idx = cmd.index("--allowedTools")
assert "Read" in cmd[idx + 1]
async def test_cmd_includes_system_prompt(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
assert "--system-prompt" in cmd
async def test_nonzero_exit_raises(self) -> None:
runner, _ = _make_subprocess_runner(
stdout="", stderr="error occurred", returncode=1
)
reviewer = ClaudeReviewer(run_subprocess=runner)
with pytest.raises(RuntimeError, match="exited with code 1"):
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
async def test_timeout_passed_to_subprocess(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner, timeout=120)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
assert calls[0]["timeout"] == 120
async def test_pr_title_in_prompt(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(
diff="d", pr_title="Specific Title", repo_name="r"
)
cmd = calls[0]["cmd"]
prompt = cmd[cmd.index("-p") + 1]
assert "Specific Title" in prompt
async def test_repo_name_in_prompt(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(
diff="d", pr_title="t", repo_name="special-repo"
)
cmd = calls[0]["cmd"]
prompt = cmd[cmd.index("-p") + 1]
assert "special-repo" in prompt
async def test_cwd_none_when_not_provided(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
assert calls[0]["cwd"] is None
async def test_request_changes_with_issues(self) -> None:
stdout = _make_cli_output(
verdict="request_changes",
summary="Problems found",
issues=[
{"severity": "blocker", "description": "Security flaw"},
{"severity": "warning", "description": "Missing docs"},
],
)
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
result = await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
assert result.verdict == "request_changes"
assert len(result.issues) == 2
assert result.has_blockers is True
# ---------------------------------------------------------------------------
# ClaudeReviewer.generate_ticket_content tests
# ---------------------------------------------------------------------------
def _make_ticket_cli_output(summary: str = "My summary", description: str = "My desc") -> str:
"""Build a JSON string mimicking Claude Code CLI output for ticket generation."""
structured = {"summary": summary, "description": description}
return json.dumps({"result": "", "structured_output": structured})
class TestGenerateTicketContent:
"""Tests for ClaudeReviewer.generate_ticket_content."""
async def test_returns_tuple_of_summary_and_description(self) -> None:
stdout = _make_ticket_cli_output(summary="Fix login bug", description="Detailed desc")
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
result = await reviewer.generate_ticket_content(
diff="edit: main.py", pr_title="Fix login", repo_name="backend"
)
assert isinstance(result, tuple)
assert len(result) == 2
async def test_returns_correct_summary(self) -> None:
stdout = _make_ticket_cli_output(summary="Implement OAuth2 login")
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
summary, _ = await reviewer.generate_ticket_content(
diff="d", pr_title="Add OAuth", repo_name="auth-service"
)
assert summary == "Implement OAuth2 login"
async def test_returns_correct_description(self) -> None:
stdout = _make_ticket_cli_output(description="This adds OAuth2 support for the login flow")
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
_, description = await reviewer.generate_ticket_content(
diff="d", pr_title="Add OAuth", repo_name="auth-service"
)
assert description == "This adds OAuth2 support for the login flow"
async def test_uses_json_schema_with_summary_and_description_fields(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
# Verify --json-schema flag was used
assert "--json-schema" in cmd
schema_idx = cmd.index("--json-schema")
schema_json = cmd[schema_idx + 1]
schema = json.loads(schema_json)
assert "summary" in schema["properties"]
assert "description" in schema["properties"]
async def test_passes_pr_title_in_prompt(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(
diff="d", pr_title="My Unique PR Title", repo_name="r"
)
cmd_str = " ".join(calls[0]["cmd"])
assert "My Unique PR Title" in cmd_str
async def test_passes_repo_name_in_prompt(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(
diff="d", pr_title="t", repo_name="my-special-repo"
)
cmd_str = " ".join(calls[0]["cmd"])
assert "my-special-repo" in cmd_str
async def test_passes_cwd_to_subprocess(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(
diff="d", pr_title="t", repo_name="r", cwd="/some/path"
)
assert calls[0]["cwd"] == "/some/path"
async def test_cwd_none_by_default(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
assert calls[0]["cwd"] is None
async def test_raises_on_nonzero_exit_code(self) -> None:
runner, _ = _make_subprocess_runner(stdout="", stderr="Error", returncode=1)
reviewer = ClaudeReviewer(run_subprocess=runner)
with pytest.raises(RuntimeError, match="Claude CLI"):
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
async def test_raises_on_invalid_json_output(self) -> None:
runner, _ = _make_subprocess_runner(stdout="not json at all")
reviewer = ClaudeReviewer(run_subprocess=runner)
with pytest.raises((ValueError, Exception)):
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")

205
tests/tools/test_http.py Normal file
View File

@@ -0,0 +1,205 @@
"""Tests for shared HTTP helpers. Written FIRST (TDD RED phase)."""
import base64
import httpx
import pytest
from release_agent.exceptions import (
AuthenticationError,
NotFoundError,
RateLimitError,
ServiceError,
ServiceUnavailableError,
)
from release_agent.tools._http import build_auth_header, raise_for_status
# ---------------------------------------------------------------------------
# raise_for_status tests
# ---------------------------------------------------------------------------
def _make_response(status_code: int, headers: dict | None = None) -> httpx.Response:
"""Build a minimal httpx.Response with the given status code."""
return httpx.Response(
status_code=status_code,
headers=headers or {},
content=b"{}",
request=httpx.Request("GET", "https://example.com"),
)
class TestRaiseForStatus:
"""Tests for raise_for_status helper."""
def test_2xx_does_not_raise(self) -> None:
for code in [200, 201, 204]:
response = _make_response(code)
# Should not raise anything
raise_for_status(response, service="test")
def test_401_raises_authentication_error(self) -> None:
response = _make_response(401)
with pytest.raises(AuthenticationError) as exc_info:
raise_for_status(response, service="azdo")
assert exc_info.value.service == "azdo"
assert exc_info.value.status_code == 401
def test_403_raises_authentication_error(self) -> None:
response = _make_response(403)
with pytest.raises(AuthenticationError) as exc_info:
raise_for_status(response, service="jira")
assert exc_info.value.status_code == 403
def test_404_raises_not_found_error(self) -> None:
response = _make_response(404)
with pytest.raises(NotFoundError) as exc_info:
raise_for_status(response, service="azdo")
assert exc_info.value.service == "azdo"
assert exc_info.value.status_code == 404
def test_429_raises_rate_limit_error(self) -> None:
response = _make_response(429)
with pytest.raises(RateLimitError) as exc_info:
raise_for_status(response, service="jira")
assert exc_info.value.status_code == 429
def test_429_with_retry_after_header_populates_retry_after(self) -> None:
response = _make_response(429, headers={"Retry-After": "60"})
with pytest.raises(RateLimitError) as exc_info:
raise_for_status(response, service="jira")
assert exc_info.value.retry_after == 60
def test_429_without_retry_after_header_retry_after_is_none(self) -> None:
response = _make_response(429)
with pytest.raises(RateLimitError) as exc_info:
raise_for_status(response, service="jira")
assert exc_info.value.retry_after is None
def test_503_raises_service_unavailable(self) -> None:
response = _make_response(503)
with pytest.raises(ServiceUnavailableError) as exc_info:
raise_for_status(response, service="slack")
assert exc_info.value.status_code == 503
def test_500_raises_service_error(self) -> None:
response = _make_response(500)
with pytest.raises(ServiceError) as exc_info:
raise_for_status(response, service="azdo")
assert exc_info.value.status_code == 500
assert exc_info.value.service == "azdo"
def test_400_raises_service_error(self) -> None:
response = _make_response(400)
with pytest.raises(ServiceError) as exc_info:
raise_for_status(response, service="jira")
assert exc_info.value.status_code == 400
def test_422_raises_service_error(self) -> None:
response = _make_response(422)
with pytest.raises(ServiceError):
raise_for_status(response, service="azdo")
def test_service_name_propagated_in_all_errors(self) -> None:
"""Each error type must carry the service name."""
cases = [
(401, AuthenticationError),
(404, NotFoundError),
(429, RateLimitError),
(503, ServiceUnavailableError),
(500, ServiceError),
]
for code, exc_type in cases:
response = _make_response(code)
with pytest.raises(exc_type) as exc_info:
raise_for_status(response, service="my-service")
assert exc_info.value.service == "my-service"
def test_3xx_does_not_raise(self) -> None:
"""Redirects are not errors (httpx follows them)."""
response = _make_response(301)
raise_for_status(response, service="test")
# ---------------------------------------------------------------------------
# build_auth_header tests
# ---------------------------------------------------------------------------
class TestBuildAuthHeader:
"""Tests for build_auth_header helper."""
def test_returns_authorization_key(self) -> None:
header = build_auth_header("user", "pass")
assert "Authorization" in header
def test_returns_basic_scheme(self) -> None:
header = build_auth_header("user", "pass")
assert header["Authorization"].startswith("Basic ")
def test_value_is_base64_encoded(self) -> None:
header = build_auth_header("user", "pass")
encoded_part = header["Authorization"].removeprefix("Basic ")
decoded = base64.b64decode(encoded_part).decode()
assert decoded == "user:pass"
def test_empty_username(self) -> None:
# PAT auth uses empty username with token as password
header = build_auth_header("", "my-token")
encoded_part = header["Authorization"].removeprefix("Basic ")
decoded = base64.b64decode(encoded_part).decode()
assert decoded == ":my-token"
def test_special_characters_in_password(self) -> None:
header = build_auth_header("user", "p@ss!#$%")
encoded_part = header["Authorization"].removeprefix("Basic ")
decoded = base64.b64decode(encoded_part).decode()
assert decoded == "user:p@ss!#$%"
def test_returns_dict(self) -> None:
result = build_auth_header("u", "p")
assert isinstance(result, dict)
def test_result_is_immutable_dict(self) -> None:
result = build_auth_header("u", "p")
# Ensure only the Authorization key is present
assert list(result.keys()) == ["Authorization"]
# ---------------------------------------------------------------------------
# Edge case coverage for _extract_detail and _parse_retry_after
# ---------------------------------------------------------------------------
class TestExtractDetailEdgeCases:
"""Tests for the private _extract_detail helper via raise_for_status."""
def test_non_dict_body_still_raises_service_error(self) -> None:
"""A JSON array body (non-dict) should still raise ServiceError."""
response = httpx.Response(
status_code=500,
content=b'["error", "list"]',
request=httpx.Request("GET", "https://example.com"),
)
with pytest.raises(ServiceError):
raise_for_status(response, service="test")
def test_invalid_json_body_still_raises(self) -> None:
"""A non-JSON response body should still raise ServiceError."""
response = httpx.Response(
status_code=500,
content=b"Internal Server Error (plain text)",
request=httpx.Request("GET", "https://example.com"),
)
with pytest.raises(ServiceError):
raise_for_status(response, service="test")
def test_429_with_non_integer_retry_after_retry_after_is_none(self) -> None:
"""A non-integer Retry-After value should result in retry_after=None."""
response = httpx.Response(
status_code=429,
headers={"Retry-After": "Wed, 21 Oct 2015 07:28:00 GMT"},
content=b"{}",
request=httpx.Request("GET", "https://example.com"),
)
with pytest.raises(RateLimitError) as exc_info:
raise_for_status(response, service="test")
assert exc_info.value.retry_after is None

572
tests/tools/test_jira.py Normal file
View File

@@ -0,0 +1,572 @@
"""Tests for JiraClient. Written FIRST (TDD RED phase)."""
import json
from pathlib import Path
import httpx
import pytest
from release_agent.exceptions import AuthenticationError, NotFoundError, ServiceError
from release_agent.models.jira import JiraIssue, JiraTransition
from release_agent.tools.jira import JiraClient
# ---------------------------------------------------------------------------
# Fixture helpers
# ---------------------------------------------------------------------------
FIXTURES = Path(__file__).parent / "fixtures"
def _load_json(name: str) -> dict:
return json.loads((FIXTURES / name).read_text())
def _make_transport(routes: dict[tuple[str, str], tuple[int, bytes | str]]) -> httpx.MockTransport:
"""Build a MockTransport dispatching by (method, url_substring)."""
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
for (m, url_fragment), (status, body) in routes.items():
if m == method and url_fragment in url:
content = body if isinstance(body, bytes) else body.encode()
return httpx.Response(status_code=status, content=content)
return httpx.Response(status_code=404, content=b'{"errorMessages": ["Not found"]}')
return httpx.MockTransport(handler)
def _make_client(routes: dict) -> JiraClient:
transport = _make_transport(routes)
http_client = httpx.AsyncClient(transport=transport)
return JiraClient(
base_url="https://billolife.atlassian.net",
email="user@example.com",
api_token="test-token",
http_client=http_client,
)
# ---------------------------------------------------------------------------
# Construction tests
# ---------------------------------------------------------------------------
class TestJiraClientConstruction:
"""Tests for JiraClient initialization."""
def test_can_be_instantiated(self) -> None:
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
http_client = httpx.AsyncClient(transport=transport)
client = JiraClient(
base_url="https://billolife.atlassian.net",
email="u@example.com",
api_token="token",
http_client=http_client,
)
assert client is not None
async def test_context_manager_closes_client(self) -> None:
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
http_client = httpx.AsyncClient(transport=transport)
async with JiraClient(
base_url="https://billolife.atlassian.net",
email="u@example.com",
api_token="token",
http_client=http_client,
) as client:
assert client is not None
assert http_client.is_closed
# ---------------------------------------------------------------------------
# get_issue tests
# ---------------------------------------------------------------------------
class TestGetIssue:
"""Tests for JiraClient.get_issue."""
async def test_returns_jira_issue(self) -> None:
issue_data = _load_json("jira_issue.json")
routes = {("GET", "ALLPOST-100"): (200, json.dumps(issue_data))}
client = _make_client(routes)
result = await client.get_issue("ALLPOST-100")
assert isinstance(result, JiraIssue)
async def test_key_extracted(self) -> None:
issue_data = _load_json("jira_issue.json")
routes = {("GET", "ALLPOST-100"): (200, json.dumps(issue_data))}
client = _make_client(routes)
result = await client.get_issue("ALLPOST-100")
assert result.key == "ALLPOST-100"
async def test_summary_extracted(self) -> None:
issue_data = _load_json("jira_issue.json")
routes = {("GET", "ALLPOST-100"): (200, json.dumps(issue_data))}
client = _make_client(routes)
result = await client.get_issue("ALLPOST-100")
assert result.summary == "Fix the authentication bug"
async def test_status_extracted(self) -> None:
issue_data = _load_json("jira_issue.json")
routes = {("GET", "ALLPOST-100"): (200, json.dumps(issue_data))}
client = _make_client(routes)
result = await client.get_issue("ALLPOST-100")
assert result.status == "In Progress"
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "ALLPOST-999"): (404, b'{"errorMessages": ["Issue not found"]}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_issue("ALLPOST-999")
async def test_401_raises_authentication_error(self) -> None:
routes = {("GET", "ALLPOST-100"): (401, b'{"errorMessages": ["Unauthorized"]}')}
client = _make_client(routes)
with pytest.raises(AuthenticationError):
await client.get_issue("ALLPOST-100")
# ---------------------------------------------------------------------------
# get_transitions tests
# ---------------------------------------------------------------------------
class TestGetTransitions:
"""Tests for JiraClient.get_transitions."""
async def test_returns_list_of_transitions(self) -> None:
transition_data = _load_json("jira_transitions.json")
routes = {("GET", "transitions"): (200, json.dumps(transition_data))}
client = _make_client(routes)
result = await client.get_transitions("ALLPOST-100")
assert isinstance(result, list)
assert all(isinstance(t, JiraTransition) for t in result)
async def test_transition_names_extracted(self) -> None:
transition_data = _load_json("jira_transitions.json")
routes = {("GET", "transitions"): (200, json.dumps(transition_data))}
client = _make_client(routes)
result = await client.get_transitions("ALLPOST-100")
names = [t.name for t in result]
assert "Released" in names
assert "In Progress" in names
async def test_transition_ids_extracted(self) -> None:
transition_data = _load_json("jira_transitions.json")
routes = {("GET", "transitions"): (200, json.dumps(transition_data))}
client = _make_client(routes)
result = await client.get_transitions("ALLPOST-100")
ids = [t.id for t in result]
assert "11" in ids
async def test_empty_transitions_returned(self) -> None:
routes = {("GET", "transitions"): (200, json.dumps({"transitions": []}))}
client = _make_client(routes)
result = await client.get_transitions("ALLPOST-100")
assert result == []
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "transitions"): (404, b'{"errorMessages": ["Not found"]}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_transitions("ALLPOST-999")
# ---------------------------------------------------------------------------
# transition_issue tests
# ---------------------------------------------------------------------------
class TestTransitionIssue:
"""Tests for JiraClient.transition_issue."""
async def test_returns_true_on_success(self) -> None:
transition_data = _load_json("jira_transitions.json")
routes = {
("GET", "transitions"): (200, json.dumps(transition_data)),
("POST", "transitions"): (204, b""),
}
client = _make_client(routes)
result = await client.transition_issue("ALLPOST-100", "Released")
assert result is True
async def test_returns_false_when_transition_not_found(self) -> None:
transition_data = _load_json("jira_transitions.json")
routes = {
("GET", "transitions"): (200, json.dumps(transition_data)),
("POST", "transitions"): (204, b""),
}
client = _make_client(routes)
# "QA Review" is not in the fixture transitions
result = await client.transition_issue("ALLPOST-100", "QA Review")
assert result is False
async def test_fallback_to_dev_in_progress_then_retries(self) -> None:
"""Two-step fallback: if target unavailable, try 'Dev in Progress' first."""
get_call_count = {"n": 0}
# First GET: only "Dev in Progress" available (no "Released" yet)
transition_data_first = {
"transitions": [{"id": "21", "name": "Dev in Progress"}]
}
# Second GET (after Dev in Progress transition): "Released" now available
transition_data_second = {
"transitions": [
{"id": "21", "name": "Dev in Progress"},
{"id": "41", "name": "Released"},
]
}
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
if method == "GET" and "transitions" in url:
get_call_count["n"] += 1
if get_call_count["n"] <= 1:
return httpx.Response(200, content=json.dumps(transition_data_first).encode())
return httpx.Response(200, content=json.dumps(transition_data_second).encode())
if method == "POST" and "transitions" in url:
return httpx.Response(204, content=b"")
return httpx.Response(404, content=b'{"errorMessages": ["Not found"]}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = JiraClient(
base_url="https://billolife.atlassian.net",
email="user@example.com",
api_token="test-token",
http_client=http_client,
)
result = await client.transition_issue("ALLPOST-100", "Released")
assert result is True
async def test_returns_false_when_still_unavailable_after_fallback(self) -> None:
"""Return False when target transition is unavailable even after fallback."""
get_call_count = {"n": 0}
# First GET: only "Dev in Progress" available
transition_data_first = {
"transitions": [{"id": "21", "name": "Dev in Progress"}]
}
# Second GET (after fallback): target STILL not available
transition_data_second = {
"transitions": [{"id": "21", "name": "Dev in Progress"}]
}
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
if method == "GET" and "transitions" in url:
get_call_count["n"] += 1
if get_call_count["n"] <= 1:
return httpx.Response(200, content=json.dumps(transition_data_first).encode())
return httpx.Response(200, content=json.dumps(transition_data_second).encode())
if method == "POST" and "transitions" in url:
return httpx.Response(204, content=b"")
return httpx.Response(404, content=b'{"errorMessages": ["Not found"]}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = JiraClient(
base_url="https://billolife.atlassian.net",
email="user@example.com",
api_token="test-token",
http_client=http_client,
)
result = await client.transition_issue("ALLPOST-100", "Released")
assert result is False
async def test_404_on_transition_raises_not_found(self) -> None:
transition_data = _load_json("jira_transitions.json")
routes = {
("GET", "transitions"): (200, json.dumps(transition_data)),
("POST", "transitions"): (404, b'{"errorMessages": ["Not found"]}'),
}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.transition_issue("ALLPOST-100", "Released")
# ---------------------------------------------------------------------------
# add_remote_link tests
# ---------------------------------------------------------------------------
class TestAddRemoteLink:
"""Tests for JiraClient.add_remote_link."""
async def test_returns_true_on_success(self) -> None:
routes = {
("POST", "remotelink"): (200, json.dumps({"id": 1, "self": "..."})),
}
client = _make_client(routes)
result = await client.add_remote_link(
ticket_id="ALLPOST-100",
url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
title="PR #42: Fix auth",
)
assert result is True
async def test_404_raises_not_found(self) -> None:
routes = {("POST", "remotelink"): (404, b'{"errorMessages": ["Not found"]}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.add_remote_link(
ticket_id="ALLPOST-999",
url="https://example.com",
title="Some PR",
)
async def test_400_raises_service_error(self) -> None:
routes = {("POST", "remotelink"): (400, b'{"errorMessages": ["Bad request"]}')}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.add_remote_link(
ticket_id="ALLPOST-100",
url="not-a-url",
title="Bad link",
)
# ---------------------------------------------------------------------------
# Lifecycle tests
# ---------------------------------------------------------------------------
class TestJiraClientLifecycle:
"""Tests for JiraClient close() method."""
async def test_close_closes_http_client(self) -> None:
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
http_client = httpx.AsyncClient(transport=transport)
client = JiraClient(
base_url="https://billolife.atlassian.net",
email="u@example.com",
api_token="token",
http_client=http_client,
)
await client.close()
assert http_client.is_closed
# ---------------------------------------------------------------------------
# _text_to_adf tests
# ---------------------------------------------------------------------------
class TestTextToAdf:
"""Tests for the _text_to_adf helper."""
def test_returns_dict(self) -> None:
from release_agent.tools.jira import _text_to_adf
result = _text_to_adf("Hello world")
assert isinstance(result, dict)
def test_version_is_1(self) -> None:
from release_agent.tools.jira import _text_to_adf
result = _text_to_adf("Hello world")
assert result["version"] == 1
def test_type_is_doc(self) -> None:
from release_agent.tools.jira import _text_to_adf
result = _text_to_adf("Hello world")
assert result["type"] == "doc"
def test_content_is_list(self) -> None:
from release_agent.tools.jira import _text_to_adf
result = _text_to_adf("Hello world")
assert isinstance(result["content"], list)
def test_single_line_produces_one_paragraph(self) -> None:
from release_agent.tools.jira import _text_to_adf
result = _text_to_adf("Hello world")
assert len(result["content"]) == 1
assert result["content"][0]["type"] == "paragraph"
def test_multiline_produces_multiple_paragraphs(self) -> None:
from release_agent.tools.jira import _text_to_adf
result = _text_to_adf("Line one\n\nLine two")
# Each non-empty line becomes a paragraph
paragraphs = [c for c in result["content"] if c["type"] == "paragraph"]
assert len(paragraphs) == 2
def test_paragraph_contains_text_node(self) -> None:
from release_agent.tools.jira import _text_to_adf
result = _text_to_adf("Hello")
paragraph = result["content"][0]
assert "content" in paragraph
text_node = paragraph["content"][0]
assert text_node["type"] == "text"
assert text_node["text"] == "Hello"
def test_empty_string_produces_empty_doc(self) -> None:
from release_agent.tools.jira import _text_to_adf
result = _text_to_adf("")
assert result["content"] == []
# ---------------------------------------------------------------------------
# JiraClient.create_issue tests
# ---------------------------------------------------------------------------
class TestCreateIssue:
"""Tests for JiraClient.create_issue."""
async def test_returns_ticket_key(self) -> None:
response_body = json.dumps({"id": "10001", "key": "ALLPOST-42", "self": "https://..."})
routes = {("POST", "/rest/api/3/issue"): (201, response_body)}
client = _make_client(routes)
result = await client.create_issue(
project="ALLPOST",
summary="New feature",
description="Some description",
)
assert result == "ALLPOST-42"
async def test_default_issue_type_is_story(self) -> None:
captured_bodies: list[dict] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_bodies.append(json.loads(request.content))
return httpx.Response(
201,
content=json.dumps({"key": "ALLPOST-1"}).encode(),
)
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = JiraClient(
base_url="https://billolife.atlassian.net",
email="u@example.com",
api_token="token",
http_client=http_client,
)
await client.create_issue(project="ALLPOST", summary="S", description="D")
assert captured_bodies[0]["fields"]["issuetype"]["name"] == "Story"
async def test_custom_issue_type_sent(self) -> None:
captured_bodies: list[dict] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_bodies.append(json.loads(request.content))
return httpx.Response(
201,
content=json.dumps({"key": "ALLPOST-2"}).encode(),
)
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = JiraClient(
base_url="https://billolife.atlassian.net",
email="u@example.com",
api_token="token",
http_client=http_client,
)
await client.create_issue(
project="ALLPOST", summary="S", description="D", issue_type="Bug"
)
assert captured_bodies[0]["fields"]["issuetype"]["name"] == "Bug"
async def test_project_key_sent_in_body(self) -> None:
captured_bodies: list[dict] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_bodies.append(json.loads(request.content))
return httpx.Response(
201,
content=json.dumps({"key": "MYPROJ-3"}).encode(),
)
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = JiraClient(
base_url="https://billolife.atlassian.net",
email="u@example.com",
api_token="token",
http_client=http_client,
)
await client.create_issue(project="MYPROJ", summary="S", description="D")
assert captured_bodies[0]["fields"]["project"]["key"] == "MYPROJ"
async def test_summary_sent_in_body(self) -> None:
captured_bodies: list[dict] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_bodies.append(json.loads(request.content))
return httpx.Response(
201,
content=json.dumps({"key": "ALLPOST-5"}).encode(),
)
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = JiraClient(
base_url="https://billolife.atlassian.net",
email="u@example.com",
api_token="token",
http_client=http_client,
)
await client.create_issue(project="ALLPOST", summary="My Summary", description="D")
assert captured_bodies[0]["fields"]["summary"] == "My Summary"
async def test_description_is_adf_format(self) -> None:
captured_bodies: list[dict] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_bodies.append(json.loads(request.content))
return httpx.Response(
201,
content=json.dumps({"key": "ALLPOST-6"}).encode(),
)
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = JiraClient(
base_url="https://billolife.atlassian.net",
email="u@example.com",
api_token="token",
http_client=http_client,
)
await client.create_issue(project="ALLPOST", summary="S", description="My desc")
desc = captured_bodies[0]["fields"]["description"]
assert desc["type"] == "doc"
assert desc["version"] == 1
async def test_401_raises_authentication_error(self) -> None:
routes = {("POST", "/rest/api/3/issue"): (401, b'{"errorMessages": ["Unauthorized"]}')}
client = _make_client(routes)
with pytest.raises(AuthenticationError):
await client.create_issue(project="ALLPOST", summary="S", description="D")
async def test_400_raises_service_error(self) -> None:
routes = {
("POST", "/rest/api/3/issue"): (
400,
json.dumps({"errorMessages": ["Bad request"], "errors": {}}).encode(),
)
}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.create_issue(project="ALLPOST", summary="S", description="D")

198
tests/tools/test_retry.py Normal file
View File

@@ -0,0 +1,198 @@
"""Tests for async retry decorator. Written FIRST (TDD RED phase)."""
import asyncio
import pytest
from release_agent.exceptions import (
NotFoundError,
RateLimitError,
ServiceError,
ServiceUnavailableError,
)
from release_agent.tools._retry import with_retry
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_failing_then_succeeding(failures: int, exc_factory, result="ok"):
"""Return an async callable that fails `failures` times then returns `result`."""
call_count = {"n": 0}
async def fn():
call_count["n"] += 1
if call_count["n"] <= failures:
raise exc_factory()
return result
return fn
# ---------------------------------------------------------------------------
# with_retry tests
# ---------------------------------------------------------------------------
class TestWithRetry:
"""Tests for the with_retry decorator."""
async def test_success_on_first_attempt(self) -> None:
call_count = {"n": 0}
@with_retry(max_attempts=3)
async def fn():
call_count["n"] += 1
return "done"
result = await fn()
assert result == "done"
assert call_count["n"] == 1
async def test_retries_on_rate_limit_error(self) -> None:
call_count = {"n": 0}
@with_retry(max_attempts=3, base_delay=0.0)
async def fn():
call_count["n"] += 1
if call_count["n"] < 3:
raise RateLimitError(service="jira", retry_after=None)
return "ok"
result = await fn()
assert result == "ok"
assert call_count["n"] == 3
async def test_retries_on_service_unavailable_error(self) -> None:
call_count = {"n": 0}
@with_retry(max_attempts=3, base_delay=0.0)
async def fn():
call_count["n"] += 1
if call_count["n"] < 2:
raise ServiceUnavailableError(service="azdo")
return "ok"
result = await fn()
assert result == "ok"
assert call_count["n"] == 2
async def test_does_not_retry_on_not_found_error(self) -> None:
call_count = {"n": 0}
@with_retry(max_attempts=3, base_delay=0.0)
async def fn():
call_count["n"] += 1
raise NotFoundError(service="azdo", detail="not found")
with pytest.raises(NotFoundError):
await fn()
assert call_count["n"] == 1
async def test_does_not_retry_on_generic_service_error(self) -> None:
call_count = {"n": 0}
@with_retry(max_attempts=3, base_delay=0.0)
async def fn():
call_count["n"] += 1
raise ServiceError(service="azdo", status_code=400, detail="bad request")
with pytest.raises(ServiceError):
await fn()
assert call_count["n"] == 1
async def test_raises_after_max_attempts_exceeded(self) -> None:
call_count = {"n": 0}
@with_retry(max_attempts=3, base_delay=0.0)
async def fn():
call_count["n"] += 1
raise RateLimitError(service="jira", retry_after=None)
with pytest.raises(RateLimitError):
await fn()
assert call_count["n"] == 3
async def test_max_attempts_one_means_no_retry(self) -> None:
call_count = {"n": 0}
@with_retry(max_attempts=1, base_delay=0.0)
async def fn():
call_count["n"] += 1
raise RateLimitError(service="jira", retry_after=None)
with pytest.raises(RateLimitError):
await fn()
assert call_count["n"] == 1
async def test_does_not_retry_on_non_release_agent_error(self) -> None:
call_count = {"n": 0}
@with_retry(max_attempts=3, base_delay=0.0)
async def fn():
call_count["n"] += 1
raise ValueError("unexpected")
with pytest.raises(ValueError):
await fn()
assert call_count["n"] == 1
async def test_respects_retry_after_from_rate_limit_error(self) -> None:
"""When retry_after is set, the decorator must wait at least that long."""
delays: list[float] = []
async def fake_sleep(seconds: float) -> None:
delays.append(seconds)
call_count = {"n": 0}
@with_retry(max_attempts=2, base_delay=0.0, sleep_fn=fake_sleep)
async def fn():
call_count["n"] += 1
if call_count["n"] < 2:
raise RateLimitError(service="jira", retry_after=5)
return "ok"
result = await fn()
assert result == "ok"
assert len(delays) == 1
assert delays[0] >= 5.0
async def test_exponential_backoff_grows(self) -> None:
"""Verify delays grow between retries (exponential)."""
delays: list[float] = []
async def fake_sleep(seconds: float) -> None:
delays.append(seconds)
call_count = {"n": 0}
@with_retry(max_attempts=4, base_delay=1.0, sleep_fn=fake_sleep)
async def fn():
call_count["n"] += 1
if call_count["n"] < 4:
raise ServiceUnavailableError(service="azdo")
return "ok"
await fn()
assert len(delays) == 3
# Each subsequent delay must not be less than the previous
assert delays[1] >= delays[0]
assert delays[2] >= delays[1]
async def test_preserves_return_value(self) -> None:
@with_retry(max_attempts=2, base_delay=0.0)
async def fn():
return {"key": "value"}
result = await fn()
assert result == {"key": "value"}
async def test_works_without_decorator_args_defaults(self) -> None:
"""Decorator used with defaults should still work."""
@with_retry()
async def fn():
return 42
result = await fn()
assert result == 42

755
tests/tools/test_slack.py Normal file
View File

@@ -0,0 +1,755 @@
"""Tests for SlackClient and Block Kit builders. Written FIRST (TDD RED phase)."""
import json
from datetime import date
import httpx
import pytest
from release_agent.exceptions import ServiceError
from release_agent.models.ticket import TicketEntry
from release_agent.tools.slack import (
SlackClient,
_build_approval_blocks,
_build_ci_status_blocks,
_build_interactive_approval_blocks,
_build_release_blocks,
_build_resolved_approval_blocks,
)
# ---------------------------------------------------------------------------
# Fixture helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-100", summary: str = "Fix bug") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary=summary,
pr_id="PR-42",
pr_url="https://dev.azure.com/org/project/_git/repo/pullrequest/42",
pr_title="Fix bug PR",
branch=f"bug/{ticket_id}_fix-bug",
merged_at=date(2024, 1, 15),
)
def _make_transport(status: int = 200, body: bytes = b'{"ok": true}') -> httpx.MockTransport:
return httpx.MockTransport(lambda r: httpx.Response(status_code=status, content=body))
def _make_client(status: int = 200) -> SlackClient:
transport = _make_transport(status)
http_client = httpx.AsyncClient(transport=transport)
return SlackClient(
webhook_url="https://hooks.slack.com/services/T000/B000/xxxx",
http_client=http_client,
)
def _make_web_api_client(
status: int = 200,
body: bytes = b'{"ok": true, "ts": "1234567890.123456"}',
) -> SlackClient:
transport = httpx.MockTransport(lambda r: httpx.Response(status_code=status, content=body))
http_client = httpx.AsyncClient(transport=transport)
return SlackClient(
bot_token="xoxb-test-token",
channel_id="C12345678",
http_client=http_client,
)
# ---------------------------------------------------------------------------
# _build_release_blocks tests (pure function)
# ---------------------------------------------------------------------------
class TestBuildReleaseBlocks:
"""Tests for the _build_release_blocks pure function."""
def test_returns_list(self) -> None:
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
assert isinstance(blocks, list)
def test_has_at_least_one_block(self) -> None:
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
assert len(blocks) >= 1
def test_repo_name_present_in_blocks(self) -> None:
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
text = json.dumps(blocks)
assert "my-repo" in text
def test_version_present_in_blocks(self) -> None:
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
text = json.dumps(blocks)
assert "v1.2.0" in text
def test_release_date_present_in_blocks(self) -> None:
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
text = json.dumps(blocks)
assert "2024" in text
def test_ticket_ids_present_in_blocks(self) -> None:
tickets = [_make_ticket("ALLPOST-100"), _make_ticket("ALLPOST-200")]
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=tickets,
)
text = json.dumps(blocks)
assert "ALLPOST-100" in text
assert "ALLPOST-200" in text
def test_empty_tickets_still_valid(self) -> None:
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
assert len(blocks) >= 1
def test_blocks_are_dicts(self) -> None:
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[_make_ticket()],
)
assert all(isinstance(b, dict) for b in blocks)
def test_each_block_has_type_key(self) -> None:
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
for block in blocks:
assert "type" in block
def test_ticket_summaries_included(self) -> None:
tickets = [_make_ticket("ALLPOST-100", "Fix the auth bug")]
blocks = _build_release_blocks(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=tickets,
)
text = json.dumps(blocks)
assert "Fix the auth bug" in text
# ---------------------------------------------------------------------------
# _build_approval_blocks tests (pure function)
# ---------------------------------------------------------------------------
class TestBuildApprovalBlocks:
"""Tests for the _build_approval_blocks pure function."""
def test_returns_list(self) -> None:
blocks = _build_approval_blocks(
action="Deploy to Production",
details="v1.2.0 for my-repo",
approval_url="https://dev.azure.com/approve/123",
)
assert isinstance(blocks, list)
def test_has_at_least_one_block(self) -> None:
blocks = _build_approval_blocks(
action="Deploy",
details="v1.0.0",
approval_url="https://example.com",
)
assert len(blocks) >= 1
def test_action_present_in_blocks(self) -> None:
blocks = _build_approval_blocks(
action="Deploy to Production",
details="v1.2.0",
approval_url="https://example.com",
)
text = json.dumps(blocks)
assert "Deploy to Production" in text
def test_details_present_in_blocks(self) -> None:
blocks = _build_approval_blocks(
action="Deploy",
details="version v1.2.0 of my-repo",
approval_url="https://example.com",
)
text = json.dumps(blocks)
assert "version v1.2.0 of my-repo" in text
def test_approval_url_present_in_blocks(self) -> None:
blocks = _build_approval_blocks(
action="Deploy",
details="details",
approval_url="https://dev.azure.com/approve/abc",
)
text = json.dumps(blocks)
assert "https://dev.azure.com/approve/abc" in text
def test_blocks_are_dicts(self) -> None:
blocks = _build_approval_blocks(
action="Deploy",
details="details",
approval_url="https://example.com",
)
assert all(isinstance(b, dict) for b in blocks)
def test_each_block_has_type_key(self) -> None:
blocks = _build_approval_blocks(
action="Deploy",
details="details",
approval_url="https://example.com",
)
for block in blocks:
assert "type" in block
# ---------------------------------------------------------------------------
# SlackClient.send_release_notification tests
# ---------------------------------------------------------------------------
class TestSendReleaseNotification:
"""Tests for SlackClient.send_release_notification."""
async def test_returns_true_on_success(self) -> None:
client = _make_client(status=200)
result = await client.send_release_notification(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[_make_ticket()],
)
assert result is True
async def test_returns_true_with_empty_tickets(self) -> None:
client = _make_client(status=200)
result = await client.send_release_notification(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
assert result is True
async def test_500_raises_service_error(self) -> None:
client = _make_client(status=500)
with pytest.raises(ServiceError):
await client.send_release_notification(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
async def test_sends_post_request(self) -> None:
requests_captured: list[httpx.Request] = []
def handler(request: httpx.Request) -> httpx.Response:
requests_captured.append(request)
return httpx.Response(200, content=b'{"ok": true}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = SlackClient(
webhook_url="https://hooks.slack.com/services/T000/B000/xxxx",
http_client=http_client,
)
await client.send_release_notification(
repo="my-repo",
version="v1.2.0",
release_date=date(2024, 1, 15),
tickets=[],
)
assert len(requests_captured) == 1
assert requests_captured[0].method == "POST"
# ---------------------------------------------------------------------------
# SlackClient.send_approval_request tests
# ---------------------------------------------------------------------------
class TestSendApprovalRequest:
"""Tests for SlackClient.send_approval_request."""
async def test_returns_true_on_success(self) -> None:
client = _make_client(status=200)
result = await client.send_approval_request(
action="Deploy to Production",
details="v1.2.0 for my-repo",
approval_url="https://dev.azure.com/approve/123",
)
assert result is True
async def test_500_raises_service_error(self) -> None:
client = _make_client(status=500)
with pytest.raises(ServiceError):
await client.send_approval_request(
action="Deploy",
details="v1.0.0",
approval_url="https://example.com",
)
# ---------------------------------------------------------------------------
# SlackClient lifecycle tests
# ---------------------------------------------------------------------------
class TestSlackClientLifecycle:
"""Tests for SlackClient close() and context manager."""
async def test_close_closes_http_client(self) -> None:
transport = _make_transport()
http_client = httpx.AsyncClient(transport=transport)
client = SlackClient(
webhook_url="https://hooks.slack.com/services/T000/B000/xxxx",
http_client=http_client,
)
await client.close()
assert http_client.is_closed
async def test_context_manager_closes_client(self) -> None:
transport = _make_transport()
http_client = httpx.AsyncClient(transport=transport)
async with SlackClient(
webhook_url="https://hooks.slack.com/services/T000/B000/xxxx",
http_client=http_client,
) as client:
assert client is not None
assert http_client.is_closed
# ---------------------------------------------------------------------------
# SlackClient dual-mode construction tests
# ---------------------------------------------------------------------------
class TestSlackClientDualMode:
"""Tests for dual-mode SlackClient (webhook vs Web API)."""
def test_can_be_created_with_webhook_only(self) -> None:
transport = _make_transport()
http_client = httpx.AsyncClient(transport=transport)
client = SlackClient(
webhook_url="https://hooks.slack.com/services/T000/B000/xxxx",
http_client=http_client,
)
assert client is not None
def test_can_be_created_with_bot_token_and_channel(self) -> None:
transport = _make_transport()
http_client = httpx.AsyncClient(transport=transport)
client = SlackClient(
bot_token="xoxb-test",
channel_id="C12345",
http_client=http_client,
)
assert client is not None
def test_can_be_created_with_all_params(self) -> None:
transport = _make_transport()
http_client = httpx.AsyncClient(transport=transport)
client = SlackClient(
webhook_url="https://hooks.slack.com/services/T000/B000/xxxx",
bot_token="xoxb-test",
channel_id="C12345",
http_client=http_client,
)
assert client is not None
def test_can_be_created_with_no_url_params(self) -> None:
transport = _make_transport()
http_client = httpx.AsyncClient(transport=transport)
client = SlackClient(http_client=http_client)
assert client is not None
# ---------------------------------------------------------------------------
# SlackClient.send_interactive_approval tests
# ---------------------------------------------------------------------------
class TestSendInteractiveApproval:
"""Tests for SlackClient.send_interactive_approval."""
async def test_returns_message_ts_on_success(self) -> None:
client = _make_web_api_client()
result = await client.send_interactive_approval(
thread_id="thread-abc",
action="Deploy to Sandbox",
details="Release v1.0.0 of my-repo",
buttons=[{"text": "Approve", "value": "approve"}, {"text": "Reject", "value": "reject"}],
)
assert isinstance(result, str)
assert result == "1234567890.123456"
async def test_returns_empty_string_on_api_error(self) -> None:
client = _make_web_api_client(status=200, body=b'{"ok": false, "error": "channel_not_found"}')
result = await client.send_interactive_approval(
thread_id="thread-abc",
action="Deploy",
details="v1.0.0",
buttons=[],
)
assert result == ""
async def test_posts_to_chat_postmessage(self) -> None:
captured_urls: list[str] = []
captured_bodies: list[dict] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_urls.append(str(request.url))
captured_bodies.append(json.loads(request.content))
return httpx.Response(200, content=b'{"ok": true, "ts": "111.222"}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = SlackClient(
bot_token="xoxb-test",
channel_id="C99999",
http_client=http_client,
)
await client.send_interactive_approval(
thread_id="thread-xyz",
action="Deploy",
details="details",
buttons=[],
)
assert any("chat.postMessage" in url for url in captured_urls)
assert captured_bodies[0]["channel"] == "C99999"
async def test_includes_thread_id_in_blocks(self) -> None:
captured_bodies: list[dict] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_bodies.append(json.loads(request.content))
return httpx.Response(200, content=b'{"ok": true, "ts": "111.222"}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = SlackClient(bot_token="xoxb-test", channel_id="C1", http_client=http_client)
await client.send_interactive_approval(
thread_id="my-thread-id",
action="Deploy",
details="v1.0.0",
buttons=[{"text": "Approve", "value": "approve"}],
)
body_str = json.dumps(captured_bodies)
assert "my-thread-id" in body_str
# ---------------------------------------------------------------------------
# SlackClient.update_message tests
# ---------------------------------------------------------------------------
class TestUpdateMessage:
"""Tests for SlackClient.update_message."""
async def test_returns_true_on_success(self) -> None:
client = _make_web_api_client(body=b'{"ok": true}')
result = await client.update_message(
message_ts="1234567890.123456",
text="Updated message",
blocks=[],
)
assert result is True
async def test_returns_false_on_api_error(self) -> None:
client = _make_web_api_client(body=b'{"ok": false, "error": "message_not_found"}')
result = await client.update_message(
message_ts="bad-ts",
text="Update",
blocks=[],
)
assert result is False
async def test_posts_to_chat_update(self) -> None:
captured_urls: list[str] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_urls.append(str(request.url))
return httpx.Response(200, content=b'{"ok": true}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
client = SlackClient(bot_token="xoxb-test", channel_id="C1", http_client=http_client)
await client.update_message(message_ts="ts-abc", text="Hello", blocks=[])
assert any("chat.update" in url for url in captured_urls)
# ---------------------------------------------------------------------------
# SlackClient.send_notification tests
# ---------------------------------------------------------------------------
class TestSendNotification:
"""Tests for SlackClient.send_notification."""
async def test_returns_true_via_web_api(self) -> None:
client = _make_web_api_client(body=b'{"ok": true, "ts": "111.222"}')
result = await client.send_notification(text="Build passed", blocks=[])
assert result is True
async def test_returns_true_via_webhook(self) -> None:
client = _make_client(status=200)
result = await client.send_notification(text="Build passed", blocks=[])
assert result is True
async def test_returns_false_on_web_api_error(self) -> None:
client = _make_web_api_client(body=b'{"ok": false, "error": "invalid_auth"}')
result = await client.send_notification(text="Build passed", blocks=[])
assert result is False
# ---------------------------------------------------------------------------
# _build_interactive_approval_blocks tests (pure function)
# ---------------------------------------------------------------------------
class TestBuildInteractiveApprovalBlocks:
"""Tests for _build_interactive_approval_blocks pure function."""
def test_returns_list(self) -> None:
blocks = _build_interactive_approval_blocks(
thread_id="t1",
action="Deploy to Sandbox",
details="v1.0.0",
buttons=[{"text": "Approve", "value": "approve"}],
)
assert isinstance(blocks, list)
def test_has_at_least_one_block(self) -> None:
blocks = _build_interactive_approval_blocks(
thread_id="t1",
action="Deploy",
details="details",
buttons=[],
)
assert len(blocks) >= 1
def test_action_present_in_blocks(self) -> None:
blocks = _build_interactive_approval_blocks(
thread_id="t1",
action="Deploy to Production",
details="v1.2.0",
buttons=[],
)
text = json.dumps(blocks)
assert "Deploy to Production" in text
def test_thread_id_in_button_value(self) -> None:
blocks = _build_interactive_approval_blocks(
thread_id="my-unique-thread",
action="Deploy",
details="details",
buttons=[{"text": "Approve", "value": "approve"}],
)
text = json.dumps(blocks)
assert "my-unique-thread" in text
def test_buttons_render_as_actions_block(self) -> None:
blocks = _build_interactive_approval_blocks(
thread_id="t1",
action="Deploy",
details="details",
buttons=[
{"text": "Approve", "value": "approve"},
{"text": "Reject", "value": "reject"},
],
)
block_types = [b["type"] for b in blocks]
assert "actions" in block_types
def test_empty_buttons_still_valid(self) -> None:
blocks = _build_interactive_approval_blocks(
thread_id="t1",
action="Deploy",
details="details",
buttons=[],
)
assert isinstance(blocks, list)
def test_details_present_in_blocks(self) -> None:
blocks = _build_interactive_approval_blocks(
thread_id="t1",
action="Deploy",
details="Release v2.0.0 of my-service",
buttons=[],
)
text = json.dumps(blocks)
assert "Release v2.0.0 of my-service" in text
# ---------------------------------------------------------------------------
# _build_ci_status_blocks tests (pure function)
# ---------------------------------------------------------------------------
class TestBuildCiStatusBlocks:
"""Tests for _build_ci_status_blocks pure function."""
def test_returns_list(self) -> None:
blocks = _build_ci_status_blocks(
repo="my-repo",
branch="main",
status="succeeded",
build_url="https://dev.azure.com/org/proj/_build/results?buildId=42",
)
assert isinstance(blocks, list)
def test_repo_present(self) -> None:
blocks = _build_ci_status_blocks(
repo="my-service",
branch="main",
status="succeeded",
build_url=None,
)
text = json.dumps(blocks)
assert "my-service" in text
def test_branch_present(self) -> None:
blocks = _build_ci_status_blocks(
repo="my-repo",
branch="release/v1.0.0",
status="succeeded",
build_url=None,
)
text = json.dumps(blocks)
assert "release/v1.0.0" in text
def test_status_present(self) -> None:
blocks = _build_ci_status_blocks(
repo="my-repo",
branch="main",
status="failed",
build_url=None,
)
text = json.dumps(blocks)
assert "failed" in text
def test_build_url_present_when_provided(self) -> None:
url = "https://dev.azure.com/org/proj/_build/results?buildId=99"
blocks = _build_ci_status_blocks(
repo="my-repo",
branch="main",
status="succeeded",
build_url=url,
)
text = json.dumps(blocks)
assert url in text
def test_build_url_none_does_not_crash(self) -> None:
blocks = _build_ci_status_blocks(
repo="my-repo",
branch="main",
status="succeeded",
build_url=None,
)
assert isinstance(blocks, list)
def test_all_blocks_are_dicts(self) -> None:
blocks = _build_ci_status_blocks(
repo="my-repo",
branch="main",
status="succeeded",
build_url=None,
)
assert all(isinstance(b, dict) for b in blocks)
# ---------------------------------------------------------------------------
# _build_resolved_approval_blocks tests (pure function)
# ---------------------------------------------------------------------------
class TestBuildResolvedApprovalBlocks:
"""Tests for _build_resolved_approval_blocks pure function."""
def test_returns_list(self) -> None:
blocks = _build_resolved_approval_blocks(
action="Deploy to Sandbox",
outcome="approved",
user="alice",
)
assert isinstance(blocks, list)
def test_action_present(self) -> None:
blocks = _build_resolved_approval_blocks(
action="Deploy to Production",
outcome="approved",
user="alice",
)
text = json.dumps(blocks)
assert "Deploy to Production" in text
def test_outcome_present(self) -> None:
blocks = _build_resolved_approval_blocks(
action="Deploy",
outcome="rejected",
user="bob",
)
text = json.dumps(blocks)
assert "rejected" in text
def test_user_present(self) -> None:
blocks = _build_resolved_approval_blocks(
action="Deploy",
outcome="approved",
user="charlie",
)
text = json.dumps(blocks)
assert "charlie" in text
def test_all_blocks_are_dicts(self) -> None:
blocks = _build_resolved_approval_blocks(
action="Deploy",
outcome="approved",
user="dave",
)
assert all(isinstance(b, dict) for b in blocks)