Files
billo-release-agent/tests/test_main.py
Yaojia Wang f5c2733cfb feat: initial commit — Billo Release Agent (LangGraph)
LangGraph-based release automation agent with:
- PR discovery (webhook + polling)
- AI code review via Claude Code CLI (subscription-based)
- Auto-create Jira tickets for PRs without ticket ID
- Jira ticket lifecycle management (code review -> staging -> done)
- CI/CD pipeline trigger, polling, and approval gates
- Slack interactive messages with approval buttons
- Per-repo semantic versioning
- PostgreSQL persistence (threads, staging, releases)
- FastAPI API (webhooks, approvals, status, manual triggers)
- Docker Compose deployment

1069 tests, 95%+ coverage.
2026-03-24 17:38:23 +01:00

667 lines
27 KiB
Python

"""Tests for main FastAPI application. Written FIRST (TDD RED phase).
Heavy startup (PostgreSQL, httpx clients, graph compilation) is mocked.
Tests verify: routes registered, lifespan hooks, exception handlers.
"""
import asyncio
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
# ---------------------------------------------------------------------------
# Helpers / fixtures
# ---------------------------------------------------------------------------
def _make_mock_settings():
s = MagicMock()
s.webhook_secret.get_secret_value.return_value = "test-secret"
s.postgres_dsn.get_secret_value.return_value = "postgresql://u:p@localhost/db"
s.azdo_pat.get_secret_value.return_value = "pat"
s.anthropic_api_key.get_secret_value.return_value = "key"
s.jira_api_token.get_secret_value.return_value = "jira"
s.slack_webhook_url.get_secret_value.return_value = "https://hooks.slack.com/x"
s.slack_bot_token.get_secret_value.return_value = ""
s.slack_channel_id = ""
s.slack_signing_secret.get_secret_value.return_value = ""
s.port = 8000
s.pr_poll_enabled = False
s.pr_poll_interval_seconds = 300
s.pr_poll_target_branch = "refs/heads/develop"
s.watched_repos_list = []
s.default_jira_project = "ALLPOST"
return s
def _make_patched_app():
"""Return the FastAPI app with all heavy startup mocked."""
mock_settings = _make_mock_settings()
mock_pool = AsyncMock()
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
mock_pool.__aexit__ = AsyncMock(return_value=False)
mock_graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
mock_clients = MagicMock()
mock_staging_store = MagicMock()
patches = [
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]),
patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=mock_clients),
patch("release_agent.main._create_staging_store", return_value=mock_staging_store),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
]
for p in patches:
p.start()
from release_agent.main import create_app
app = create_app()
for p in patches:
p.stop()
return app, mock_settings, mock_pool, mock_graphs
# ---------------------------------------------------------------------------
# Route registration tests
# ---------------------------------------------------------------------------
class TestRouteRegistration:
def test_webhook_route_registered(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/webhooks/azdo" in routes
def test_approvals_routes_registered(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/approvals/pending" in routes
assert "/approvals/{thread_id}" in routes
def test_status_routes_registered(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/status" in routes
assert "/staging" in routes
# ---------------------------------------------------------------------------
# schedule_graph / run_graph_in_background tests
# ---------------------------------------------------------------------------
class TestScheduleGraph:
def test_schedule_graph_returns_thread_id(self) -> None:
from release_agent.main import schedule_graph
mock_app = MagicMock()
mock_app.state.background_tasks = set()
mock_graph = MagicMock()
with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()):
thread_id = schedule_graph(
app=mock_app,
graph=mock_graph,
initial_state={"repo_name": "my-repo"},
thread_id=None,
)
assert isinstance(thread_id, str)
assert len(thread_id) > 0
def test_schedule_graph_uses_provided_thread_id(self) -> None:
from release_agent.main import schedule_graph
mock_app = MagicMock()
mock_app.state.background_tasks = set()
mock_graph = MagicMock()
with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()):
thread_id = schedule_graph(
app=mock_app,
graph=mock_graph,
initial_state={},
thread_id="custom-thread-id",
)
assert thread_id == "custom-thread-id"
def test_schedule_graph_adds_task_to_background_tasks(self) -> None:
from release_agent.main import schedule_graph
mock_app = MagicMock()
mock_app.state.background_tasks = set()
mock_graph = MagicMock()
mock_task = MagicMock()
with patch("release_agent.main.asyncio.create_task", return_value=mock_task):
schedule_graph(
app=mock_app,
graph=mock_graph,
initial_state={},
thread_id=None,
)
assert mock_task in mock_app.state.background_tasks
def test_run_graph_in_background_is_coroutine(self) -> None:
from release_agent.main import run_graph_in_background
import inspect
assert inspect.iscoroutinefunction(run_graph_in_background)
# ---------------------------------------------------------------------------
# _ensure_db_schema tests
# ---------------------------------------------------------------------------
class TestEnsureDbSchema:
@pytest.mark.asyncio
async def test_ensure_db_schema_creates_table(self) -> None:
from release_agent.main import _ensure_db_schema
mock_pool = AsyncMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
# Phase 5: now executes multiple DDL statements (agent_threads +
# staging_releases + archived_releases), so called_once no longer holds.
assert mock_cursor.execute.call_count >= 1
all_sql = " ".join(
call.args[0] for call in mock_cursor.execute.call_args_list
)
assert "agent_threads" in all_sql
# ---------------------------------------------------------------------------
# _create_tool_clients tests
# ---------------------------------------------------------------------------
class TestCreateToolClients:
def test_create_tool_clients_returns_tool_clients_instance(self) -> None:
from release_agent.main import _create_tool_clients
from release_agent.graph.dependencies import ToolClients
mock_settings = _make_mock_settings()
with (
patch("release_agent.main.AzDoClient") as mock_azdo,
patch("release_agent.main.JiraClient") as mock_jira,
patch("release_agent.main.SlackClient") as mock_slack,
patch("release_agent.main.ClaudeReviewer") as mock_reviewer,
patch("release_agent.main.httpx.AsyncClient") as mock_httpx,
):
clients, http_clients = _create_tool_clients(mock_settings)
assert isinstance(clients, ToolClients)
# ---------------------------------------------------------------------------
# _create_staging_store tests
# ---------------------------------------------------------------------------
class TestCreateStagingStore:
def test_create_staging_store_returns_store(self) -> None:
from release_agent.main import _create_staging_store
from release_agent.graph.dependencies import JsonFileStagingStore
result = _create_staging_store()
assert isinstance(result, JsonFileStagingStore)
# ---------------------------------------------------------------------------
# Global exception handler tests
# ---------------------------------------------------------------------------
class TestExceptionHandlers:
def test_app_has_exception_handlers(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
# FastAPI stores exception handlers in exception_handlers attribute
assert hasattr(app, "exception_handlers")
# ---------------------------------------------------------------------------
# Lifespan tests
# ---------------------------------------------------------------------------
class TestGracefulShutdown:
@pytest.mark.asyncio
async def test_lifespan_cancels_timed_out_tasks(self) -> None:
"""Verify the lifespan waits for tasks and cancels timed-out ones."""
from release_agent.main import lifespan
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
from fastapi import FastAPI
app = FastAPI()
app.state.background_tasks = set()
mock_settings = _make_mock_settings()
mock_task = MagicMock()
mock_task.cancel = MagicMock()
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
patch(
"release_agent.main.asyncio.wait",
new_callable=AsyncMock,
return_value=(set(), {mock_task}),
),
):
ctx = lifespan(app)
await ctx.__aenter__()
# Add a fake task to background_tasks after startup
app.state.background_tasks.add(mock_task)
await ctx.__aexit__(None, None, None)
# The pending task should have been cancelled
mock_task.cancel.assert_called_once()
class TestLifespan:
def test_app_state_set_after_lifespan(self) -> None:
"""Verify app.state.graphs and app.state.settings are set during lifespan."""
from release_agent.main import create_app
mock_pool = AsyncMock()
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
mock_pool.__aexit__ = AsyncMock(return_value=False)
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
mock_settings = _make_mock_settings()
mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()}
mock_clients = MagicMock()
mock_staging_store = MagicMock()
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]),
patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(mock_clients, [])),
patch("release_agent.main._create_staging_store", return_value=mock_staging_store),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
with TestClient(app) as client:
# App is started; state should be accessible
response = client.get("/status")
# We just verify no crash
assert response.status_code in (200, 500)
# ---------------------------------------------------------------------------
# Phase 5: Slack interactions route + new config tests
# ---------------------------------------------------------------------------
class TestPhase5Routes:
"""Tests for Phase 5 additions to main.py."""
def _make_patches(self):
mock_settings = _make_mock_settings()
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
mock_pool.__aexit__ = AsyncMock(return_value=False)
return [
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
]
def test_slack_interactions_route_registered(self) -> None:
from release_agent.main import create_app
patches = self._make_patches()
for p in patches:
p.start()
try:
app = create_app()
finally:
for p in patches:
p.stop()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/slack/interactions" in routes
def test_create_tool_clients_uses_bot_token(self) -> None:
from release_agent.main import _create_tool_clients
mock_settings = _make_mock_settings()
mock_settings.slack_bot_token.get_secret_value.return_value = "xoxb-test"
mock_settings.slack_channel_id = "C12345"
mock_settings.slack_webhook_url.get_secret_value.return_value = ""
mock_settings.azdo_api_url = "https://dev.azure.com/org/proj/_apis"
mock_settings.azdo_vsrm_api_url = "https://vsrm.dev.azure.com/org/proj/_apis"
mock_settings.jira_base_url = "https://example.atlassian.net"
mock_settings.jira_email = "test@example.com"
# Should not raise
clients, http_clients = _create_tool_clients(mock_settings)
assert clients is not None
# Clean up
for hc in http_clients:
asyncio.get_event_loop().run_until_complete(hc.aclose())
class TestPhase5DbSchema:
"""Tests that _ensure_db_schema adds the slack_message_ts column."""
async def test_ensure_db_schema_executes_sql_statements(self) -> None:
from release_agent.main import _ensure_db_schema
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
executed_sql: list[str] = []
async def capture_execute(sql, *args):
executed_sql.append(sql.strip())
mock_cursor.execute = capture_execute
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
# Should have executed CREATE TABLE statements
assert len(executed_sql) >= 3
combined = " ".join(executed_sql)
assert "agent_threads" in combined
assert "staging_releases" in combined
async def test_ensure_db_schema_includes_slack_message_ts_column(self) -> None:
from release_agent.main import _ensure_db_schema
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
executed_sql: list[str] = []
async def capture_execute(sql, *args):
executed_sql.append(sql.strip())
mock_cursor.execute = capture_execute
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
combined = " ".join(executed_sql)
assert "slack_message_ts" in combined
# ---------------------------------------------------------------------------
# PR polling lifespan integration tests
# ---------------------------------------------------------------------------
class TestPrPollingLifespan:
"""Tests for PR polling startup in the lifespan handler."""
def _make_polling_settings(self, *, pr_poll_enabled: bool = True) -> MagicMock:
s = _make_mock_settings()
s.pr_poll_enabled = pr_poll_enabled
s.pr_poll_interval_seconds = 30
s.pr_poll_target_branch = "refs/heads/develop"
s.watched_repos_list = ["repo-a"]
s.default_jira_project = "ALLPOST"
return s
async def test_poll_loop_started_when_pr_poll_enabled(self) -> None:
"""When pr_poll_enabled=True, a background task for polling is created."""
from release_agent.main import create_app
mock_settings = self._make_polling_settings(pr_poll_enabled=True)
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_pool.connection = MagicMock()
poll_loop_started = []
async def fake_run_poll_loop(**kwargs):
poll_loop_started.append(True)
# Simulate an immediate cancellation to avoid infinite loop
raise asyncio.CancelledError
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop),
):
app = create_app()
async with app.router.lifespan_context(app):
# Give the event loop a chance to start background tasks
await asyncio.sleep(0)
assert len(poll_loop_started) > 0
async def test_poll_loop_not_started_when_pr_poll_disabled(self) -> None:
"""When pr_poll_enabled=False, no polling background task is created."""
from release_agent.main import create_app
mock_settings = self._make_polling_settings(pr_poll_enabled=False)
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_pool.connection = MagicMock()
poll_loop_started = []
async def fake_run_poll_loop(**kwargs):
poll_loop_started.append(True)
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop),
):
app = create_app()
async with app.router.lifespan_context(app):
await asyncio.sleep(0)
assert len(poll_loop_started) == 0
# ---------------------------------------------------------------------------
# _run_graph default_jira_project injection tests
# ---------------------------------------------------------------------------
class TestRunGraphJiraProjectInjection:
"""Tests that _run_graph passes default_jira_project into the graph config."""
async def test_default_jira_project_passed_to_graph_config(self) -> None:
from release_agent.api.webhooks import _run_graph
captured_configs: list[dict] = []
mock_graph = MagicMock()
async def fake_ainvoke(state, config=None):
captured_configs.append(config or {})
return {}
mock_graph.ainvoke = fake_ainvoke
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"pr_id": "1", "repo_name": "r"},
thread_id="tid-1",
tool_clients=MagicMock(),
db_pool=mock_pool,
repos_base_dir="",
graph_name="pr_completed",
default_jira_project="MYPROJ",
)
assert len(captured_configs) == 1
configurable = captured_configs[0].get("configurable", {})
assert configurable.get("default_jira_project") == "MYPROJ"
async def test_default_jira_project_defaults_to_allpost(self) -> None:
from release_agent.api.webhooks import _run_graph
captured_configs: list[dict] = []
mock_graph = MagicMock()
async def fake_ainvoke(state, config=None):
captured_configs.append(config or {})
return {}
mock_graph.ainvoke = fake_ainvoke
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"pr_id": "1", "repo_name": "r"},
thread_id="tid-2",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
configurable = captured_configs[0].get("configurable", {})
assert configurable.get("default_jira_project") == "ALLPOST"