"""Tests for approvals endpoint. Written FIRST (TDD RED phase).""" from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from release_agent.api.approvals import router as approvals_router # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_test_app( *, interrupted_threads: list[dict] | None = None, graph_resume_result: dict | None = None, ) -> FastAPI: """Return a FastAPI app with mocked state for approvals tests.""" app = FastAPI() app.include_router(approvals_router) if interrupted_threads is None: interrupted_threads = [] mock_settings = MagicMock() mock_settings.operator_token.get_secret_value.return_value = "" mock_graphs = { "pr_completed": MagicMock(), "release": MagicMock(), } mock_clients = MagicMock() # Mock pool that returns interrupted threads from DB mock_pool = MagicMock() mock_conn = AsyncMock() mock_cursor = AsyncMock() rows = [ ( t["thread_id"], t.get("graph_name", "pr_completed"), t.get("interrupt_value", "Confirm?"), t.get("created_at", datetime.now(tz=timezone.utc)), t.get("repo_name"), t.get("pr_id"), t.get("version"), ) for t in interrupted_threads ] mock_cursor.fetchall = AsyncMock(return_value=rows) mock_cursor.fetchone = AsyncMock(return_value=("pr_completed",)) mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) mock_cursor.__aexit__ = AsyncMock(return_value=False) mock_conn.cursor = MagicMock(return_value=mock_cursor) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock(return_value=False) mock_pool.connection = MagicMock(return_value=mock_conn) app.state.settings = mock_settings app.state.graphs = mock_graphs app.state.tool_clients = mock_clients app.state.db_pool = mock_pool app.state.background_tasks = set() return app # --------------------------------------------------------------------------- # POST /approvals/{thread_id} # --------------------------------------------------------------------------- class TestPostApproval: def test_valid_merge_decision_returns_200(self) -> None: app = _make_test_app() mock_graph = MagicMock() mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]}) app.state.graphs["pr_completed"] = mock_graph with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: mock_resume.return_value = {"messages": ["resumed"]} with TestClient(app) as client: response = client.post( "/approvals/thread-123", json={"decision": "merge"}, ) assert response.status_code == 200 data = response.json() assert data["thread_id"] == "thread-123" assert "status" in data assert "message" in data def test_valid_cancel_decision_returns_200(self) -> None: app = _make_test_app() with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: mock_resume.return_value = {"messages": ["cancelled"]} with TestClient(app) as client: response = client.post( "/approvals/thread-456", json={"decision": "cancel"}, ) assert response.status_code == 200 def test_invalid_decision_returns_422(self) -> None: app = _make_test_app() with TestClient(app) as client: response = client.post( "/approvals/thread-123", json={"decision": "invalid_decision"}, ) assert response.status_code == 422 def test_missing_decision_returns_422(self) -> None: app = _make_test_app() with TestClient(app) as client: response = client.post( "/approvals/thread-123", json={}, ) assert response.status_code == 422 def test_response_contains_thread_id(self) -> None: app = _make_test_app() with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: mock_resume.return_value = {} with TestClient(app) as client: response = client.post( "/approvals/my-thread-id", json={"decision": "approve"}, ) assert response.json()["thread_id"] == "my-thread-id" def test_approve_decision_returns_200(self) -> None: app = _make_test_app() with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: mock_resume.return_value = {} with TestClient(app) as client: response = client.post( "/approvals/t1", json={"decision": "approve"}, ) assert response.status_code == 200 def test_skip_decision_returns_200(self) -> None: app = _make_test_app() with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: mock_resume.return_value = {} with TestClient(app) as client: response = client.post( "/approvals/t1", json={"decision": "skip"}, ) assert response.status_code == 200 def test_trigger_decision_returns_200(self) -> None: app = _make_test_app() with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: mock_resume.return_value = {} with TestClient(app) as client: response = client.post( "/approvals/t1", json={"decision": "trigger"}, ) assert response.status_code == 200 # --------------------------------------------------------------------------- # GET /approvals/pending # --------------------------------------------------------------------------- class TestGetPendingApprovals: def test_empty_pending_returns_200(self) -> None: app = _make_test_app(interrupted_threads=[]) with TestClient(app) as client: response = client.get("/approvals/pending") assert response.status_code == 200 data = response.json() assert data["count"] == 0 assert data["items"] == [] def test_pending_approvals_list_structure(self) -> None: now = datetime.now(tz=timezone.utc) threads = [ { "thread_id": "t1", "graph_name": "pr_completed", "interrupt_value": "Confirm merge?", "created_at": now, "repo_name": "my-repo", "pr_id": "42", "version": "v1.0.0", } ] app = _make_test_app(interrupted_threads=threads) with TestClient(app) as client: response = client.get("/approvals/pending") assert response.status_code == 200 data = response.json() assert data["count"] == 1 assert data["items"][0]["thread_id"] == "t1" assert data["items"][0]["graph_name"] == "pr_completed" def test_multiple_pending_approvals(self) -> None: now = datetime.now(tz=timezone.utc) threads = [ { "thread_id": f"t{i}", "graph_name": "pr_completed", "interrupt_value": "Confirm?", "created_at": now, "repo_name": None, "pr_id": None, "version": None, } for i in range(3) ] app = _make_test_app(interrupted_threads=threads) with TestClient(app) as client: response = client.get("/approvals/pending") assert response.status_code == 200 data = response.json() assert data["count"] == 3 assert len(data["items"]) == 3 def test_pending_approval_optional_fields_nullable(self) -> None: now = datetime.now(tz=timezone.utc) threads = [ { "thread_id": "t1", "graph_name": "release", "interrupt_value": "Run release?", "created_at": now, "repo_name": None, "pr_id": None, "version": None, } ] app = _make_test_app(interrupted_threads=threads) with TestClient(app) as client: response = client.get("/approvals/pending") item = response.json()["items"][0] assert item["repo_name"] is None assert item["pr_id"] is None assert item["version"] is None # --------------------------------------------------------------------------- # _resume_graph helper function tests # --------------------------------------------------------------------------- class TestResumeGraph: def test_resume_graph_callable(self) -> None: from release_agent.api.approvals import _resume_graph import inspect assert inspect.iscoroutinefunction(_resume_graph)