"""Tests for main FastAPI application. Written FIRST (TDD RED phase). Heavy startup (PostgreSQL, httpx clients, graph compilation) is mocked. Tests verify: routes registered, lifespan hooks, exception handlers. """ import asyncio from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import FastAPI from fastapi.testclient import TestClient # --------------------------------------------------------------------------- # Helpers / fixtures # --------------------------------------------------------------------------- def _make_mock_settings(): s = MagicMock() s.webhook_secret.get_secret_value.return_value = "test-secret" s.postgres_dsn.get_secret_value.return_value = "postgresql://u:p@localhost/db" s.azdo_pat.get_secret_value.return_value = "pat" s.anthropic_api_key.get_secret_value.return_value = "key" s.jira_api_token.get_secret_value.return_value = "jira" s.slack_webhook_url.get_secret_value.return_value = "https://hooks.slack.com/x" s.slack_bot_token.get_secret_value.return_value = "" s.slack_channel_id = "" s.slack_signing_secret.get_secret_value.return_value = "" s.port = 8000 s.pr_poll_enabled = False s.pr_poll_interval_seconds = 300 s.pr_poll_target_branch = "refs/heads/develop" s.watched_repos_list = [] s.default_jira_project = "ALLPOST" return s def _make_patched_app(): """Return the FastAPI app with all heavy startup mocked.""" mock_settings = _make_mock_settings() mock_pool = AsyncMock() mock_pool.__aenter__ = AsyncMock(return_value=mock_pool) mock_pool.__aexit__ = AsyncMock(return_value=False) mock_graphs = { "pr_completed": MagicMock(), "release": MagicMock(), } mock_clients = MagicMock() mock_staging_store = MagicMock() patches = [ patch("release_agent.main.Settings", return_value=mock_settings), patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]), patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]), patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), patch("release_agent.main._create_tool_clients", return_value=mock_clients), patch("release_agent.main._create_staging_store", return_value=mock_staging_store), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), ] for p in patches: p.start() from release_agent.main import create_app app = create_app() for p in patches: p.stop() return app, mock_settings, mock_pool, mock_graphs # --------------------------------------------------------------------------- # Route registration tests # --------------------------------------------------------------------------- class TestRouteRegistration: def test_webhook_route_registered(self) -> None: from release_agent.main import create_app with ( patch("release_agent.main.Settings", return_value=_make_mock_settings()), patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), patch("release_agent.main.build_release_graph", return_value=MagicMock()), patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()), patch("release_agent.main._create_tool_clients", return_value=MagicMock()), patch("release_agent.main._create_staging_store", return_value=MagicMock()), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), ): app = create_app() routes = {r.path for r in app.routes} # type: ignore[attr-defined] assert "/webhooks/azdo" in routes def test_approvals_routes_registered(self) -> None: from release_agent.main import create_app with ( patch("release_agent.main.Settings", return_value=_make_mock_settings()), patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), patch("release_agent.main.build_release_graph", return_value=MagicMock()), patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()), patch("release_agent.main._create_tool_clients", return_value=MagicMock()), patch("release_agent.main._create_staging_store", return_value=MagicMock()), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), ): app = create_app() routes = {r.path for r in app.routes} # type: ignore[attr-defined] assert "/approvals/pending" in routes assert "/approvals/{thread_id}" in routes def test_status_routes_registered(self) -> None: from release_agent.main import create_app with ( patch("release_agent.main.Settings", return_value=_make_mock_settings()), patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), patch("release_agent.main.build_release_graph", return_value=MagicMock()), patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()), patch("release_agent.main._create_tool_clients", return_value=MagicMock()), patch("release_agent.main._create_staging_store", return_value=MagicMock()), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), ): app = create_app() routes = {r.path for r in app.routes} # type: ignore[attr-defined] assert "/status" in routes assert "/staging" in routes # --------------------------------------------------------------------------- # schedule_graph / run_graph_in_background tests # --------------------------------------------------------------------------- class TestScheduleGraph: def test_schedule_graph_returns_thread_id(self) -> None: from release_agent.main import schedule_graph mock_app = MagicMock() mock_app.state.background_tasks = set() mock_graph = MagicMock() with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()): thread_id = schedule_graph( app=mock_app, graph=mock_graph, initial_state={"repo_name": "my-repo"}, thread_id=None, ) assert isinstance(thread_id, str) assert len(thread_id) > 0 def test_schedule_graph_uses_provided_thread_id(self) -> None: from release_agent.main import schedule_graph mock_app = MagicMock() mock_app.state.background_tasks = set() mock_graph = MagicMock() with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()): thread_id = schedule_graph( app=mock_app, graph=mock_graph, initial_state={}, thread_id="custom-thread-id", ) assert thread_id == "custom-thread-id" def test_schedule_graph_adds_task_to_background_tasks(self) -> None: from release_agent.main import schedule_graph mock_app = MagicMock() mock_app.state.background_tasks = set() mock_graph = MagicMock() mock_task = MagicMock() with patch("release_agent.main.asyncio.create_task", return_value=mock_task): schedule_graph( app=mock_app, graph=mock_graph, initial_state={}, thread_id=None, ) assert mock_task in mock_app.state.background_tasks def test_run_graph_in_background_is_coroutine(self) -> None: from release_agent.main import run_graph_in_background import inspect assert inspect.iscoroutinefunction(run_graph_in_background) # --------------------------------------------------------------------------- # _ensure_db_schema tests # --------------------------------------------------------------------------- class TestEnsureDbSchema: @pytest.mark.asyncio async def test_ensure_db_schema_creates_table(self) -> None: from release_agent.main import _ensure_db_schema mock_pool = AsyncMock() mock_conn = AsyncMock() mock_cursor = AsyncMock() mock_cursor.execute = AsyncMock() mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) mock_cursor.__aexit__ = AsyncMock(return_value=False) mock_conn.cursor = MagicMock(return_value=mock_cursor) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock(return_value=False) mock_pool.connection = MagicMock(return_value=mock_conn) await _ensure_db_schema(mock_pool) # Phase 5: now executes multiple DDL statements (agent_threads + # staging_releases + archived_releases), so called_once no longer holds. assert mock_cursor.execute.call_count >= 1 all_sql = " ".join( call.args[0] for call in mock_cursor.execute.call_args_list ) assert "agent_threads" in all_sql # --------------------------------------------------------------------------- # _create_tool_clients tests # --------------------------------------------------------------------------- class TestCreateToolClients: def test_create_tool_clients_returns_tool_clients_instance(self) -> None: from release_agent.main import _create_tool_clients from release_agent.graph.dependencies import ToolClients mock_settings = _make_mock_settings() with ( patch("release_agent.main.AzDoClient") as mock_azdo, patch("release_agent.main.JiraClient") as mock_jira, patch("release_agent.main.SlackClient") as mock_slack, patch("release_agent.main.ClaudeReviewer") as mock_reviewer, patch("release_agent.main.httpx.AsyncClient") as mock_httpx, ): clients, http_clients = _create_tool_clients(mock_settings) assert isinstance(clients, ToolClients) # --------------------------------------------------------------------------- # _create_staging_store tests # --------------------------------------------------------------------------- class TestCreateStagingStore: def test_create_staging_store_returns_store(self) -> None: from release_agent.main import _create_staging_store from release_agent.graph.dependencies import JsonFileStagingStore result = _create_staging_store() assert isinstance(result, JsonFileStagingStore) # --------------------------------------------------------------------------- # Global exception handler tests # --------------------------------------------------------------------------- class TestExceptionHandlers: def test_app_has_exception_handlers(self) -> None: from release_agent.main import create_app with ( patch("release_agent.main.Settings", return_value=_make_mock_settings()), patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), patch("release_agent.main.build_release_graph", return_value=MagicMock()), patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()), patch("release_agent.main._create_tool_clients", return_value=MagicMock()), patch("release_agent.main._create_staging_store", return_value=MagicMock()), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), ): app = create_app() # FastAPI stores exception handlers in exception_handlers attribute assert hasattr(app, "exception_handlers") # --------------------------------------------------------------------------- # Lifespan tests # --------------------------------------------------------------------------- class TestGracefulShutdown: @pytest.mark.asyncio async def test_lifespan_cancels_timed_out_tasks(self) -> None: """Verify the lifespan waits for tasks and cancels timed-out ones.""" from release_agent.main import lifespan mock_pool = AsyncMock() mock_pool.open = AsyncMock() mock_pool.close = AsyncMock() mock_conn = AsyncMock() mock_cursor = AsyncMock() mock_cursor.execute = AsyncMock() mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) mock_cursor.__aexit__ = AsyncMock(return_value=False) mock_conn.cursor = MagicMock(return_value=mock_cursor) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock(return_value=False) mock_pool.connection = MagicMock(return_value=mock_conn) from fastapi import FastAPI app = FastAPI() app.state.background_tasks = set() mock_settings = _make_mock_settings() mock_task = MagicMock() mock_task.cancel = MagicMock() with ( patch("release_agent.main.Settings", return_value=mock_settings), patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), patch("release_agent.main.build_release_graph", return_value=MagicMock()), patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])), patch("release_agent.main._create_staging_store", return_value=MagicMock()), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), patch( "release_agent.main.asyncio.wait", new_callable=AsyncMock, return_value=(set(), {mock_task}), ), ): ctx = lifespan(app) await ctx.__aenter__() # Add a fake task to background_tasks after startup app.state.background_tasks.add(mock_task) await ctx.__aexit__(None, None, None) # The pending task should have been cancelled mock_task.cancel.assert_called_once() class TestLifespan: def test_app_state_set_after_lifespan(self) -> None: """Verify app.state.graphs and app.state.settings are set during lifespan.""" from release_agent.main import create_app mock_pool = AsyncMock() mock_pool.__aenter__ = AsyncMock(return_value=mock_pool) mock_pool.__aexit__ = AsyncMock(return_value=False) mock_pool.open = AsyncMock() mock_pool.close = AsyncMock() mock_conn = AsyncMock() mock_cursor = AsyncMock() mock_cursor.execute = AsyncMock() mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) mock_cursor.__aexit__ = AsyncMock(return_value=False) mock_conn.cursor = MagicMock(return_value=mock_cursor) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock(return_value=False) mock_pool.connection = MagicMock(return_value=mock_conn) mock_settings = _make_mock_settings() mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()} mock_clients = MagicMock() mock_staging_store = MagicMock() with ( patch("release_agent.main.Settings", return_value=mock_settings), patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]), patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]), patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), patch("release_agent.main._create_tool_clients", return_value=(mock_clients, [])), patch("release_agent.main._create_staging_store", return_value=mock_staging_store), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), ): app = create_app() with TestClient(app) as client: # App is started; state should be accessible response = client.get("/status") # We just verify no crash assert response.status_code in (200, 500) # --------------------------------------------------------------------------- # Phase 5: Slack interactions route + new config tests # --------------------------------------------------------------------------- class TestPhase5Routes: """Tests for Phase 5 additions to main.py.""" def _make_patches(self): mock_settings = _make_mock_settings() mock_pool = AsyncMock() mock_pool.open = AsyncMock() mock_pool.close = AsyncMock() mock_pool.__aenter__ = AsyncMock(return_value=mock_pool) mock_pool.__aexit__ = AsyncMock(return_value=False) return [ patch("release_agent.main.Settings", return_value=mock_settings), patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), patch("release_agent.main.build_release_graph", return_value=MagicMock()), patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), patch("release_agent.main._create_tool_clients", return_value=MagicMock()), patch("release_agent.main._create_staging_store", return_value=MagicMock()), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), ] def test_slack_interactions_route_registered(self) -> None: from release_agent.main import create_app patches = self._make_patches() for p in patches: p.start() try: app = create_app() finally: for p in patches: p.stop() routes = {r.path for r in app.routes} # type: ignore[attr-defined] assert "/slack/interactions" in routes def test_create_tool_clients_uses_bot_token(self) -> None: from release_agent.main import _create_tool_clients mock_settings = _make_mock_settings() mock_settings.slack_bot_token.get_secret_value.return_value = "xoxb-test" mock_settings.slack_channel_id = "C12345" mock_settings.slack_webhook_url.get_secret_value.return_value = "" mock_settings.azdo_api_url = "https://dev.azure.com/org/proj/_apis" mock_settings.azdo_vsrm_api_url = "https://vsrm.dev.azure.com/org/proj/_apis" mock_settings.jira_base_url = "https://example.atlassian.net" mock_settings.jira_email = "test@example.com" # Should not raise clients, http_clients = _create_tool_clients(mock_settings) assert clients is not None # Clean up for hc in http_clients: asyncio.get_event_loop().run_until_complete(hc.aclose()) class TestPhase5DbSchema: """Tests that _ensure_db_schema adds the slack_message_ts column.""" async def test_ensure_db_schema_executes_sql_statements(self) -> None: from release_agent.main import _ensure_db_schema mock_pool = MagicMock() mock_conn = AsyncMock() mock_cursor = AsyncMock() executed_sql: list[str] = [] async def capture_execute(sql, *args): executed_sql.append(sql.strip()) mock_cursor.execute = capture_execute mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) mock_cursor.__aexit__ = AsyncMock(return_value=False) mock_conn.cursor = MagicMock(return_value=mock_cursor) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock(return_value=False) mock_pool.connection = MagicMock(return_value=mock_conn) await _ensure_db_schema(mock_pool) # Should have executed CREATE TABLE statements assert len(executed_sql) >= 3 combined = " ".join(executed_sql) assert "agent_threads" in combined assert "staging_releases" in combined async def test_ensure_db_schema_includes_slack_message_ts_column(self) -> None: from release_agent.main import _ensure_db_schema mock_pool = MagicMock() mock_conn = AsyncMock() mock_cursor = AsyncMock() executed_sql: list[str] = [] async def capture_execute(sql, *args): executed_sql.append(sql.strip()) mock_cursor.execute = capture_execute mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) mock_cursor.__aexit__ = AsyncMock(return_value=False) mock_conn.cursor = MagicMock(return_value=mock_cursor) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock(return_value=False) mock_pool.connection = MagicMock(return_value=mock_conn) await _ensure_db_schema(mock_pool) combined = " ".join(executed_sql) assert "slack_message_ts" in combined # --------------------------------------------------------------------------- # PR polling lifespan integration tests # --------------------------------------------------------------------------- class TestPrPollingLifespan: """Tests for PR polling startup in the lifespan handler.""" def _make_polling_settings(self, *, pr_poll_enabled: bool = True) -> MagicMock: s = _make_mock_settings() s.pr_poll_enabled = pr_poll_enabled s.pr_poll_interval_seconds = 30 s.pr_poll_target_branch = "refs/heads/develop" s.watched_repos_list = ["repo-a"] s.default_jira_project = "ALLPOST" return s async def test_poll_loop_started_when_pr_poll_enabled(self) -> None: """When pr_poll_enabled=True, a background task for polling is created.""" from release_agent.main import create_app mock_settings = self._make_polling_settings(pr_poll_enabled=True) mock_pool = AsyncMock() mock_pool.open = AsyncMock() mock_pool.close = AsyncMock() mock_pool.connection = MagicMock() poll_loop_started = [] async def fake_run_poll_loop(**kwargs): poll_loop_started.append(True) # Simulate an immediate cancellation to avoid infinite loop raise asyncio.CancelledError with ( patch("release_agent.main.Settings", return_value=mock_settings), patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), patch("release_agent.main.build_release_graph", return_value=MagicMock()), patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])), patch("release_agent.main._create_staging_store", return_value=MagicMock()), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop), ): app = create_app() async with app.router.lifespan_context(app): # Give the event loop a chance to start background tasks await asyncio.sleep(0) assert len(poll_loop_started) > 0 async def test_poll_loop_not_started_when_pr_poll_disabled(self) -> None: """When pr_poll_enabled=False, no polling background task is created.""" from release_agent.main import create_app mock_settings = self._make_polling_settings(pr_poll_enabled=False) mock_pool = AsyncMock() mock_pool.open = AsyncMock() mock_pool.close = AsyncMock() mock_pool.connection = MagicMock() poll_loop_started = [] async def fake_run_poll_loop(**kwargs): poll_loop_started.append(True) with ( patch("release_agent.main.Settings", return_value=mock_settings), patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), patch("release_agent.main.build_release_graph", return_value=MagicMock()), patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])), patch("release_agent.main._create_staging_store", return_value=MagicMock()), patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop), ): app = create_app() async with app.router.lifespan_context(app): await asyncio.sleep(0) assert len(poll_loop_started) == 0 # --------------------------------------------------------------------------- # _run_graph default_jira_project injection tests # --------------------------------------------------------------------------- class TestRunGraphJiraProjectInjection: """Tests that _run_graph passes default_jira_project into the graph config.""" async def test_default_jira_project_passed_to_graph_config(self) -> None: from release_agent.api.webhooks import _run_graph captured_configs: list[dict] = [] mock_graph = MagicMock() async def fake_ainvoke(state, config=None): captured_configs.append(config or {}) return {} mock_graph.ainvoke = fake_ainvoke mock_pool = MagicMock() mock_conn = AsyncMock() mock_cursor = AsyncMock() mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) mock_cursor.__aexit__ = AsyncMock(return_value=False) mock_conn.cursor = MagicMock(return_value=mock_cursor) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock(return_value=False) mock_pool.connection = MagicMock(return_value=mock_conn) await _run_graph( graph=mock_graph, initial_state={"pr_id": "1", "repo_name": "r"}, thread_id="tid-1", tool_clients=MagicMock(), db_pool=mock_pool, repos_base_dir="", graph_name="pr_completed", default_jira_project="MYPROJ", ) assert len(captured_configs) == 1 configurable = captured_configs[0].get("configurable", {}) assert configurable.get("default_jira_project") == "MYPROJ" async def test_default_jira_project_defaults_to_allpost(self) -> None: from release_agent.api.webhooks import _run_graph captured_configs: list[dict] = [] mock_graph = MagicMock() async def fake_ainvoke(state, config=None): captured_configs.append(config or {}) return {} mock_graph.ainvoke = fake_ainvoke mock_pool = MagicMock() mock_conn = AsyncMock() mock_cursor = AsyncMock() mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) mock_cursor.__aexit__ = AsyncMock(return_value=False) mock_conn.cursor = MagicMock(return_value=mock_cursor) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock(return_value=False) mock_pool.connection = MagicMock(return_value=mock_conn) await _run_graph( graph=mock_graph, initial_state={"pr_id": "1", "repo_name": "r"}, thread_id="tid-2", tool_clients=MagicMock(), db_pool=mock_pool, ) configurable = captured_configs[0].get("configurable", {}) assert configurable.get("default_jira_project") == "ALLPOST"