From f5c2733cfbe7c5bdc286cce07d0b4262a0d8b9eb Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Tue, 24 Mar 2026 17:38:23 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20initial=20commit=20=E2=80=94=20Billo=20?= =?UTF-8?q?Release=20Agent=20(LangGraph)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LangGraph-based release automation agent with: - PR discovery (webhook + polling) - AI code review via Claude Code CLI (subscription-based) - Auto-create Jira tickets for PRs without ticket ID - Jira ticket lifecycle management (code review -> staging -> done) - CI/CD pipeline trigger, polling, and approval gates - Slack interactive messages with approval buttons - Per-repo semantic versioning - PostgreSQL persistence (threads, staging, releases) - FastAPI API (webhooks, approvals, status, manual triggers) - Docker Compose deployment 1069 tests, 95%+ coverage. --- .env.example | 55 + .gitignore | 36 + Dockerfile | 32 + README.md | 341 +++++++ docker-compose.yml | 48 + pyproject.toml | 58 ++ scripts/__init__.py | 0 scripts/migrate_json_to_db.py | 284 ++++++ src/release_agent/__init__.py | 1 + src/release_agent/api/__init__.py | 0 src/release_agent/api/approvals.py | 166 +++ src/release_agent/api/dependencies.py | 66 ++ src/release_agent/api/models.py | 137 +++ src/release_agent/api/slack_interactions.py | 264 +++++ src/release_agent/api/status.py | 153 +++ src/release_agent/api/webhooks.py | 195 ++++ src/release_agent/branch_parser.py | 46 + src/release_agent/config.py | 110 ++ src/release_agent/exceptions.py | 59 ++ src/release_agent/graph/__init__.py | 0 src/release_agent/graph/ci_nodes.py | 180 ++++ src/release_agent/graph/dependencies.py | 171 ++++ src/release_agent/graph/full_cycle.py | 44 + src/release_agent/graph/polling.py | 91 ++ .../graph/postgres_staging_store.py | 148 +++ src/release_agent/graph/pr_completed.py | 573 +++++++++++ src/release_agent/graph/release.py | 627 ++++++++++++ src/release_agent/graph/routing.py | 131 +++ src/release_agent/main.py | 361 +++++++ src/release_agent/models/__init__.py | 1 + src/release_agent/models/build.py | 43 + src/release_agent/models/jira.py | 33 + src/release_agent/models/pipeline.py | 31 + src/release_agent/models/pr.py | 38 + src/release_agent/models/release.py | 65 ++ src/release_agent/models/review.py | 48 + src/release_agent/models/ticket.py | 33 + src/release_agent/models/webhook.py | 40 + src/release_agent/services/__init__.py | 0 src/release_agent/services/pr_dedup.py | 46 + src/release_agent/services/pr_poller.py | 108 ++ src/release_agent/state.py | 91 ++ src/release_agent/tools/__init__.py | 0 src/release_agent/tools/_http.py | 103 ++ src/release_agent/tools/_retry.py | 74 ++ src/release_agent/tools/azdo.py | 513 ++++++++++ src/release_agent/tools/claude_review.py | 335 ++++++ src/release_agent/tools/jira.py | 269 +++++ src/release_agent/tools/slack.py | 500 +++++++++ src/release_agent/versioning.py | 70 ++ tests/__init__.py | 0 tests/api/__init__.py | 0 tests/api/test_approvals.py | 259 +++++ tests/api/test_approvals_with_auth.py | 139 +++ tests/api/test_dependencies.py | 149 +++ tests/api/test_internals.py | 446 ++++++++ tests/api/test_models.py | 294 ++++++ tests/api/test_operator_auth.py | 111 ++ tests/api/test_slack_interactions.py | 473 +++++++++ tests/api/test_status.py | 270 +++++ tests/api/test_status_with_auth.py | 166 +++ tests/api/test_webhooks.py | 218 ++++ tests/graph/__init__.py | 0 tests/graph/conftest.py | 44 + tests/graph/test_ci_nodes.py | 294 ++++++ tests/graph/test_dependencies.py | 283 ++++++ tests/graph/test_dependencies_async.py | 177 ++++ tests/graph/test_full_cycle.py | 53 + tests/graph/test_polling.py | 356 +++++++ tests/graph/test_postgres_staging_store.py | 414 ++++++++ tests/graph/test_pr_completed.py | 956 ++++++++++++++++++ tests/graph/test_release.py | 866 ++++++++++++++++ tests/graph/test_routing.py | 302 ++++++ tests/scripts/__init__.py | 0 tests/scripts/test_migrate.py | 332 ++++++ tests/services/__init__.py | 0 tests/services/test_pr_dedup.py | 141 +++ tests/services/test_pr_poller.py | 309 ++++++ tests/test_branch_parser.py | 122 +++ tests/test_config.py | 450 +++++++++ tests/test_exceptions.py | 198 ++++ tests/test_main.py | 666 ++++++++++++ tests/test_main_phase5.py | 147 +++ tests/test_models.py | 635 ++++++++++++ tests/test_models_build.py | 148 +++ tests/test_state.py | 241 +++++ tests/test_versioning.py | 124 +++ tests/tools/__init__.py | 0 .../tools/fixtures/azdo_approve_release.json | 9 + tests/tools/fixtures/azdo_build_status.json | 9 + tests/tools/fixtures/azdo_create_pr.json | 8 + tests/tools/fixtures/azdo_merge_pr.json | 8 + tests/tools/fixtures/azdo_pipelines.json | 15 + tests/tools/fixtures/azdo_pr.json | 16 + tests/tools/fixtures/azdo_pr_diff.txt | 11 + .../tools/fixtures/azdo_trigger_pipeline.json | 11 + tests/tools/fixtures/jira_issue.json | 10 + tests/tools/fixtures/jira_transitions.json | 20 + tests/tools/test_azdo.py | 819 +++++++++++++++ tests/tools/test_claude_review.py | 454 +++++++++ tests/tools/test_http.py | 205 ++++ tests/tools/test_jira.py | 572 +++++++++++ tests/tools/test_retry.py | 198 ++++ tests/tools/test_slack.py | 755 ++++++++++++++ 104 files changed, 19721 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 docker-compose.yml create mode 100644 pyproject.toml create mode 100644 scripts/__init__.py create mode 100644 scripts/migrate_json_to_db.py create mode 100644 src/release_agent/__init__.py create mode 100644 src/release_agent/api/__init__.py create mode 100644 src/release_agent/api/approvals.py create mode 100644 src/release_agent/api/dependencies.py create mode 100644 src/release_agent/api/models.py create mode 100644 src/release_agent/api/slack_interactions.py create mode 100644 src/release_agent/api/status.py create mode 100644 src/release_agent/api/webhooks.py create mode 100644 src/release_agent/branch_parser.py create mode 100644 src/release_agent/config.py create mode 100644 src/release_agent/exceptions.py create mode 100644 src/release_agent/graph/__init__.py create mode 100644 src/release_agent/graph/ci_nodes.py create mode 100644 src/release_agent/graph/dependencies.py create mode 100644 src/release_agent/graph/full_cycle.py create mode 100644 src/release_agent/graph/polling.py create mode 100644 src/release_agent/graph/postgres_staging_store.py create mode 100644 src/release_agent/graph/pr_completed.py create mode 100644 src/release_agent/graph/release.py create mode 100644 src/release_agent/graph/routing.py create mode 100644 src/release_agent/main.py create mode 100644 src/release_agent/models/__init__.py create mode 100644 src/release_agent/models/build.py create mode 100644 src/release_agent/models/jira.py create mode 100644 src/release_agent/models/pipeline.py create mode 100644 src/release_agent/models/pr.py create mode 100644 src/release_agent/models/release.py create mode 100644 src/release_agent/models/review.py create mode 100644 src/release_agent/models/ticket.py create mode 100644 src/release_agent/models/webhook.py create mode 100644 src/release_agent/services/__init__.py create mode 100644 src/release_agent/services/pr_dedup.py create mode 100644 src/release_agent/services/pr_poller.py create mode 100644 src/release_agent/state.py create mode 100644 src/release_agent/tools/__init__.py create mode 100644 src/release_agent/tools/_http.py create mode 100644 src/release_agent/tools/_retry.py create mode 100644 src/release_agent/tools/azdo.py create mode 100644 src/release_agent/tools/claude_review.py create mode 100644 src/release_agent/tools/jira.py create mode 100644 src/release_agent/tools/slack.py create mode 100644 src/release_agent/versioning.py create mode 100644 tests/__init__.py create mode 100644 tests/api/__init__.py create mode 100644 tests/api/test_approvals.py create mode 100644 tests/api/test_approvals_with_auth.py create mode 100644 tests/api/test_dependencies.py create mode 100644 tests/api/test_internals.py create mode 100644 tests/api/test_models.py create mode 100644 tests/api/test_operator_auth.py create mode 100644 tests/api/test_slack_interactions.py create mode 100644 tests/api/test_status.py create mode 100644 tests/api/test_status_with_auth.py create mode 100644 tests/api/test_webhooks.py create mode 100644 tests/graph/__init__.py create mode 100644 tests/graph/conftest.py create mode 100644 tests/graph/test_ci_nodes.py create mode 100644 tests/graph/test_dependencies.py create mode 100644 tests/graph/test_dependencies_async.py create mode 100644 tests/graph/test_full_cycle.py create mode 100644 tests/graph/test_polling.py create mode 100644 tests/graph/test_postgres_staging_store.py create mode 100644 tests/graph/test_pr_completed.py create mode 100644 tests/graph/test_release.py create mode 100644 tests/graph/test_routing.py create mode 100644 tests/scripts/__init__.py create mode 100644 tests/scripts/test_migrate.py create mode 100644 tests/services/__init__.py create mode 100644 tests/services/test_pr_dedup.py create mode 100644 tests/services/test_pr_poller.py create mode 100644 tests/test_branch_parser.py create mode 100644 tests/test_config.py create mode 100644 tests/test_exceptions.py create mode 100644 tests/test_main.py create mode 100644 tests/test_main_phase5.py create mode 100644 tests/test_models.py create mode 100644 tests/test_models_build.py create mode 100644 tests/test_state.py create mode 100644 tests/test_versioning.py create mode 100644 tests/tools/__init__.py create mode 100644 tests/tools/fixtures/azdo_approve_release.json create mode 100644 tests/tools/fixtures/azdo_build_status.json create mode 100644 tests/tools/fixtures/azdo_create_pr.json create mode 100644 tests/tools/fixtures/azdo_merge_pr.json create mode 100644 tests/tools/fixtures/azdo_pipelines.json create mode 100644 tests/tools/fixtures/azdo_pr.json create mode 100644 tests/tools/fixtures/azdo_pr_diff.txt create mode 100644 tests/tools/fixtures/azdo_trigger_pipeline.json create mode 100644 tests/tools/fixtures/jira_issue.json create mode 100644 tests/tools/fixtures/jira_transitions.json create mode 100644 tests/tools/test_azdo.py create mode 100644 tests/tools/test_claude_review.py create mode 100644 tests/tools/test_http.py create mode 100644 tests/tools/test_jira.py create mode 100644 tests/tools/test_retry.py create mode 100644 tests/tools/test_slack.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..20afede --- /dev/null +++ b/.env.example @@ -0,0 +1,55 @@ +# Billo Release Agent - Environment Variables +# Copy this file to .env and fill in your values. +# Never commit .env to source control. + +# =========================================================================== +# REQUIRED +# =========================================================================== + +# --- Azure DevOps ---------------------------------------------------------- +AZDO_ORGANIZATION=billodev +AZDO_PROJECT=Billo App Platform +AZDO_PAT= # Personal access token (Code read/write, Build read/execute) + +# --- PostgreSQL ------------------------------------------------------------ +POSTGRES_PASSWORD= # Used by docker-compose db service +POSTGRES_DSN=postgresql://agent:${POSTGRES_PASSWORD}@localhost:5432/agent + +# --- Jira ------------------------------------------------------------------ +JIRA_EMAIL= # Jira account email +JIRA_API_TOKEN= # Jira API token + +# --- Slack ----------------------------------------------------------------- +SLACK_WEBHOOK_URL= # Incoming webhook URL (for release notifications) + +# --- Webhook Security ------------------------------------------------------ +WEBHOOK_SECRET= # HMAC secret for Azure DevOps webhook validation + +# =========================================================================== +# OPTIONAL (have defaults) +# =========================================================================== + +# --- Slack App (for interactive buttons) ----------------------------------- +SLACK_BOT_TOKEN= # xoxb-... Bot token (needed for Slack buttons) +SLACK_SIGNING_SECRET= # Slack app signing secret (needed for /slack/interactions) +SLACK_CHANNEL_ID= # Channel ID to post messages (e.g., C0123456789) + +# --- CI/CD Polling --------------------------------------------------------- +CI_POLL_INTERVAL_SECONDS=30 # Seconds between CI status polls +CI_POLL_MAX_WAIT_SECONDS=1800 # Max wait for CI completion (30 min) + +# --- Claude Code Review ---------------------------------------------------- +# ANTHROPIC_API_KEY= # Not needed — uses Claude Code CLI subscription +# CLAUDE_REVIEW_MODEL=claude-sonnet-4-20250514 + +# --- Local Repo Path ------------------------------------------------------- +REPOS_BASE_DIR= # Base dir containing Billo repos (e.g., /c/Users/yaoji/git/Billo) + +# --- Jira ------------------------------------------------------------------ +JIRA_BASE_URL=https://billolife.atlassian.net + +# --- Operator API ---------------------------------------------------------- +OPERATOR_TOKEN= # When set, protects POST /approvals/* and POST /manual/* + +# --- Server ---------------------------------------------------------------- +PORT=8000 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..670a575 --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +# Environment +.env +.env.local + +# Python +__pycache__/ +*.py[cod] +*.egg-info/ +dist/ +build/ +.eggs/ + +# Virtual environment +.venv/ +venv/ + +# uv +uv.lock + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Testing +.coverage +htmlcov/ +.pytest_cache/ + +# Data +data/staging/ + +# OS +.DS_Store +Thumbs.db diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6568b3e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,32 @@ +FROM python:3.12-slim + +WORKDIR /app + +# Install system dependencies required for psycopg binary +RUN apt-get update && apt-get install -y --no-install-recommends \ + libpq5 \ + && rm -rf /var/lib/apt/lists/* + +# Copy dependency files first for layer caching +COPY pyproject.toml uv.lock ./ + +# Install uv for fast dependency management +RUN pip install --no-cache-dir uv + +# Install runtime dependencies (no dev extras) +RUN uv pip install --system --no-cache-dir -e . + +# Copy application source +COPY src/ ./src/ + +# Create data directory for staging store +RUN mkdir -p data/staging + +# Non-root user for security +RUN adduser --disabled-password --gecos "" appuser && \ + chown -R appuser:appuser /app +USER appuser + +EXPOSE 8000 + +CMD ["uvicorn", "release_agent.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..a77bb7b --- /dev/null +++ b/README.md @@ -0,0 +1,341 @@ +# Billo Release Agent + +A LangGraph-based release automation agent for Billo. Automates the full release +pipeline: PR discovery, code review (via Claude Code CLI), Jira ticket management, +staging release tracking, CI/CD pipeline triggering with approval gates, and Slack +interactive notifications. + +## Architecture + +``` + +--- Azure DevOps Webhook ---+ +--- PR Poller (every 5 min) ---+ + | POST /webhooks/azdo | | Scans WATCHED_REPOS for | + | (push-based) | | active PRs (pull-based) | + +------------+---------------+ +------------+------------------+ + | | + v v + +---------------------------------------------+ + | FastAPI Application | + | /webhooks/azdo /slack/interactions | + | /approvals/* /manual/* /status | + +---------------------+------------------------+ + | + v + +---------------------------------------------+ + | LangGraph Graphs | + | | + | pr_completed: | + | parse -> fetch -> [has ticket?] | + | no -> Claude generates ticket | + | yes -> Claude code review | + | -> merge -> Jira -> staging -> CI build | + | | + | release: | + | create release PR -> merge -> CI build | + | -> CD release -> [Sandbox approve] | + | -> [Production approve] -> Slack notify | + +---------------------+------------------------+ + | + +----------------+----------------+ + | | | + v v v + PostgreSQL Azure DevOps Slack (buttons) + - threads - PRs/Pipelines - Notifications + - staging - Builds/Releases - Approvals + - releases +``` + +## Key Features + +| Feature | Description | +|---------|-------------| +| **PR Discovery** | Webhook-based (push) or polling-based (pull) — or both | +| **Auto-Create Jira Ticket** | When PR branch has no ticket ID, Claude generates summary + description and creates a Jira Story | +| **AI Code Review** | Claude Code CLI reviews PRs with full repo context (Read/Glob/Grep), using your subscription | +| **CI/CD Integration** | Triggers CI builds after merge, polls for completion, handles CD release approval gates | +| **Slack Interactive** | Approval requests with [Approve]/[Cancel] buttons, CI/CD status notifications | +| **Human-in-the-loop** | 5 interrupt points where operator confirmation is required before destructive actions | +| **Per-repo Versioning** | Independent semantic versioning per repository (patch auto-increment) | + +## Prerequisites + +- Python 3.12+ +- PostgreSQL 16+ +- [uv](https://github.com/astral-sh/uv) (recommended) or pip +- Claude Code CLI installed and authenticated (`claude` in PATH) +- Slack App (for interactive buttons) or Slack Incoming Webhook (for notifications only) + +## Quick Start + +### Local Development + +```bash +# Install dependencies +uv sync --all-extras + +# Copy and configure environment +cp .env.example .env +# Edit .env -- fill in all REQUIRED variables + +# Start PostgreSQL +docker compose up -d db + +# Run the server +uv run uvicorn release_agent.main:app --reload --port 8000 + +# Verify +curl http://localhost:8000/status +``` + +### Docker Compose (Production) + +```bash +cp .env.example .env +# Edit .env -- POSTGRES_PASSWORD, WEBHOOK_SECRET, etc. are required + +docker compose up -d +``` + +Tables are created automatically on first startup. + +## Configuration + +All configuration is via environment variables. See `.env.example` for the full list. + +### Required Variables + +| Variable | Description | +|----------|-------------| +| `AZDO_ORGANIZATION` | Azure DevOps organization name | +| `AZDO_PROJECT` | Azure DevOps project name | +| `AZDO_PAT` | Azure DevOps personal access token | +| `POSTGRES_DSN` | PostgreSQL connection string | +| `POSTGRES_PASSWORD` | PostgreSQL password (used by docker-compose) | +| `JIRA_EMAIL` | Jira account email | +| `JIRA_API_TOKEN` | Jira API token | +| `SLACK_WEBHOOK_URL` | Slack incoming webhook URL | +| `WEBHOOK_SECRET` | Shared secret for validating AzDo webhooks (must be non-empty) | + +### Optional Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `REPOS_BASE_DIR` | `""` | Base dir with Billo repos (e.g., `/c/Users/yaoji/git/Billo`) | +| `WATCHED_REPOS` | `""` | Comma-separated repos to poll (e.g., `Billo.Platform.Payment,Billo.Platform.Document.DocumentAnalyser`) | +| `PR_POLL_ENABLED` | `False` | Enable periodic PR polling | +| `PR_POLL_INTERVAL_SECONDS` | `300` | Polling interval (5 min) | +| `PR_POLL_TARGET_BRANCH` | `refs/heads/develop` | Target branch filter | +| `DEFAULT_JIRA_PROJECT` | `ALLPOST` | Jira project key for auto-created tickets | +| `AUTO_CREATE_TICKET_ENABLED` | `True` | Auto-create Jira ticket when branch has no ticket ID | +| `SLACK_BOT_TOKEN` | `""` | Slack App bot token (for interactive buttons) | +| `SLACK_SIGNING_SECRET` | `""` | Slack App signing secret (required for /slack/interactions) | +| `SLACK_CHANNEL_ID` | `""` | Channel for interactive messages | +| `CI_POLL_INTERVAL_SECONDS` | `30` | CI build status poll interval | +| `CI_POLL_MAX_WAIT_SECONDS` | `1800` | Max wait for CI completion (30 min) | +| `OPERATOR_TOKEN` | `""` | Token for operator endpoints (empty = no auth) | +| `JIRA_BASE_URL` | `https://billolife.atlassian.net` | Jira instance URL | +| `PORT` | `8000` | HTTP server port | + +### Security Notes + +- `WEBHOOK_SECRET` must be non-empty; empty secret rejects all webhooks +- `POSTGRES_PASSWORD` has no default in docker-compose; fails if unset +- `SLACK_SIGNING_SECRET` must be set for `/slack/interactions` to accept requests (returns 503 if empty) +- Slack signature verification includes 5-minute replay attack prevention +- All secrets use `SecretStr` and are never logged or included in error responses +- Set `OPERATOR_TOKEN` in production to protect approval and manual trigger endpoints + +## API Endpoints + +### Webhooks + +| Method | Path | Auth | Description | +|--------|------|------|-------------| +| POST | `/webhooks/azdo` | `X-Webhook-Secret` | Receive Azure DevOps PR webhook events | +| POST | `/slack/interactions` | Slack Signing Secret | Receive Slack button click callbacks | + +### Approvals (requires `X-Operator-Token` when configured) + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/approvals/pending` | List threads awaiting operator approval | +| POST | `/approvals/{thread_id}` | Submit approval decision (merge/cancel/approve/skip) | + +### Status and Manual Triggers + +| Method | Path | Auth | Description | +|--------|------|------|-------------| +| GET | `/status` | None | Health check | +| GET | `/releases/{repo}` | None | List versions for a repo | +| GET | `/staging?repo={repo}` | None | Current staging release | +| POST | `/manual/pr/{pr_id}` | `X-Operator-Token` | Manually trigger PR processing | +| POST | `/manual/release` | `X-Operator-Token` | Manually trigger a release | + +## Graph Workflows + +### PR Completed + +``` +parse_webhook -> fetch_pr_details -> route_after_fetch + |-- merged -----------------> calculate_version -> update_staging -> CI build -> END + |-- active_with_ticket -----> move_jira_code_review -+ + |-- active_no_ticket -------> auto_create_ticket ----+ + | + run_code_review -> evaluate_review + |-- approve -> [Slack: Merge?] -> merge_pr + |-- request_changes -> notify -> END + -> Jira transitions -> calculate_version + -> update_staging -> CI build -> notify -> END +``` + +### Release + +``` +load_staging -> [Slack: Create release?] -> create_release_pr + -> [Slack: Merge release?] -> merge_release_pr + -> CI build on main -> poll until complete + |-- ci_failed -> notify failure -> END + |-- ci_passed -> wait for CD release -> approval loop: + |-- [Slack: Approve Sandbox?] -> approve -> poll again + |-- [Slack: Approve Production?] -> approve -> poll again + |-- all_deployed -> move tickets to Done + -> Slack release notification -> archive -> END +``` + +### Interrupt Points (Slack buttons) + +| # | When | Slack Message | Buttons | +|---|------|--------------|---------| +| 1 | After code review approves | PR title + review summary | [Merge] [Cancel] | +| 2 | Before creating release PR | Version + ticket list | [Create] [Cancel] | +| 3 | Before merging release PR | Release PR link | [Merge] [Cancel] | +| 4 | Before triggering pipelines | Pipeline list | [Trigger] [Skip] | +| 5 | Before approving release stage | Stage name + status | [Approve] [Skip] | + +## PR Polling (Alternative to Webhooks) + +When `PR_POLL_ENABLED=True`, the agent periodically scans all `WATCHED_REPOS` for +active PRs targeting the configured branch. New PRs not yet tracked in `agent_threads` +are automatically processed through the `pr_completed` graph. + +This eliminates the need for Azure DevOps webhook configuration and works behind +firewalls without public endpoint exposure. + +## Auto-Create Jira Ticket + +When a PR branch has no ticket ID (e.g., `chore/update-dependencies` instead of +`feature/ALLPOST-4028_login-page`), the agent automatically: + +1. Sends the PR diff to Claude Code CLI +2. Claude generates a concise ticket summary and description +3. Creates a Jira Story in the `DEFAULT_JIRA_PROJECT` +4. Continues the normal workflow with the created ticket + +## Database Schema + +Tables are created automatically on startup: + +```sql +-- Thread tracking for LangGraph interrupts and PR dedup +agent_threads (thread_id, graph_name, repo_name, pr_id, status, state JSONB, + slack_message_ts, created_at, updated_at) + +-- Current in-progress releases (one per repo) +staging_releases (repo, version, started_at, tickets JSONB, updated_at) + +-- Completed releases (immutable history) +archived_releases (repo, version, started_at, tickets JSONB, released_at) +``` + +## Migrating Existing JSON Files + +If you have existing release data from the Claude Code skill: + +```bash +# Dry run +uv run python scripts/migrate_json_to_db.py \ + --source ../release-workflow/releases --dry-run + +# Execute +uv run python scripts/migrate_json_to_db.py \ + --source ../release-workflow/releases \ + --dsn "postgresql://agent:password@localhost/agent" +``` + +## Development + +### Running Tests + +```bash +# Run all tests with coverage (1061 tests, 96%+ coverage) +uv run pytest + +# Run without coverage (faster) +uv run pytest --no-cov + +# Run specific module +uv run pytest tests/graph/test_pr_completed.py -v +``` + +### Project Structure + +``` +src/release_agent/ + main.py # FastAPI app, lifespan, task management + config.py # pydantic-settings (all env vars) + state.py # LangGraph ReleaseState TypedDict + exceptions.py # Exception hierarchy + branch_parser.py # Extract ticket ID from branch name + versioning.py # Per-repo version calculation + api/ + models.py # HTTP request/response Pydantic models + dependencies.py # FastAPI Depends() + operator auth + webhooks.py # POST /webhooks/azdo + approvals.py # Approval endpoints + status.py # Status, releases, manual triggers + slack_interactions.py # POST /slack/interactions (button callbacks) + graph/ + dependencies.py # ToolClients, StagingStore Protocol + postgres_staging_store.py # PostgreSQL-backed store + routing.py # Pure routing functions (route_after_fetch, etc.) + pr_completed.py # PR graph nodes + auto_create_ticket + release.py # Release graph nodes + CI/CD approval loop + full_cycle.py # Subgraph composition + ci_nodes.py # CI trigger, poll, notify nodes + polling.py # Reusable async poll_until utility + models/ + pr.py, ticket.py, release.py, pipeline.py + webhook.py, review.py, jira.py, build.py + tools/ + azdo.py # Azure DevOps REST client + jira.py # Jira REST client (transitions + create_issue) + slack.py # Slack dual-mode (webhook + Web API) + claude_review.py # Claude Code CLI (review + ticket generation) + _http.py, _retry.py # Shared helpers + services/ + pr_poller.py # Background PR polling loop + pr_dedup.py # PR deduplication via agent_threads +scripts/ + migrate_json_to_db.py # One-time JSON -> PostgreSQL migration +tests/ # 1061 tests, 96%+ coverage +``` + +## Docker + +```bash +# Required: set POSTGRES_PASSWORD and WEBHOOK_SECRET in .env +docker compose up -d +``` + +The agent service includes a health check at `/status`. PostgreSQL uses +`pg_isready` with `service_healthy` dependency. + +## Slack App Setup + +To use interactive buttons (optional — REST API approvals still work without it): + +1. Create a Slack App at https://api.slack.com/apps +2. Enable **Interactivity** with Request URL: `https:///slack/interactions` +3. Add Bot Token Scopes: `chat:write`, `chat:update` +4. Install to workspace, get Bot Token (`xoxb-...`) +5. Set `SLACK_BOT_TOKEN`, `SLACK_SIGNING_SECRET`, `SLACK_CHANNEL_ID` in `.env` diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..f217f52 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,48 @@ +services: + agent: + build: . + ports: + - "8000:8000" + environment: + AZDO_ORGANIZATION: ${AZDO_ORGANIZATION} + AZDO_PROJECT: ${AZDO_PROJECT} + AZDO_PAT: ${AZDO_PAT} + ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-unused} + POSTGRES_DSN: postgresql://agent:${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}@db:5432/agent + JIRA_EMAIL: ${JIRA_EMAIL} + JIRA_API_TOKEN: ${JIRA_API_TOKEN} + SLACK_WEBHOOK_URL: ${SLACK_WEBHOOK_URL} + WEBHOOK_SECRET: ${WEBHOOK_SECRET:?WEBHOOK_SECRET must be set} + JIRA_BASE_URL: ${JIRA_BASE_URL:-https://billolife.atlassian.net} + CLAUDE_REVIEW_MODEL: ${CLAUDE_REVIEW_MODEL:-claude-sonnet-4-20250514} + REPOS_BASE_DIR: ${REPOS_BASE_DIR:-} + OPERATOR_TOKEN: ${OPERATOR_TOKEN:-} + PORT: 8000 + depends_on: + db: + condition: service_healthy + restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "curl -sf http://localhost:8000/status || exit 1"] + interval: 10s + timeout: 5s + retries: 3 + start_period: 15s + + db: + image: postgres:16-alpine + environment: + POSTGRES_USER: agent + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set} + POSTGRES_DB: agent + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U agent -d agent"] + interval: 5s + timeout: 5s + retries: 5 + restart: unless-stopped + +volumes: + postgres_data: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..83acb1c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,58 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "billo-release-agent" +version = "0.1.0" +description = "LangGraph-based release automation agent for Billo" +requires-python = ">=3.12" +dependencies = [ + "langgraph>=0.3.30", + "langgraph-checkpoint-postgres>=2.0.15", + "fastapi>=0.115.0", + "uvicorn[standard]>=0.34.0", + "httpx>=0.28.0", + "anthropic>=0.52.0", + "pydantic>=2.11.0", + "pydantic-settings>=2.8.0", + "psycopg[binary]>=3.2.0", + "psycopg-pool>=3.2.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3.0", + "pytest-asyncio>=0.25.0", + "pytest-cov>=6.1.0", + "ruff>=0.11.0", + "psycopg[binary]>=3.2.0", + "psycopg-pool>=3.2.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/release_agent"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" +pythonpath = ["."] +addopts = "--cov=src/release_agent --cov-report=term-missing --cov-fail-under=80" + +[tool.coverage.run] +source = ["src/release_agent"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "raise NotImplementedError", +] + +[tool.ruff] +line-length = 100 +target-version = "py312" + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "UP", "B", "SIM"] +ignore = ["E501"] diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/migrate_json_to_db.py b/scripts/migrate_json_to_db.py new file mode 100644 index 0000000..a57860c --- /dev/null +++ b/scripts/migrate_json_to_db.py @@ -0,0 +1,284 @@ +"""Migration script: JSON files -> PostgreSQL. + +Reads staging and archived release JSON files from a directory tree and +inserts them into the staging_releases and archived_releases tables. + +All business logic is implemented as pure functions so it can be tested +without a real database. The main() entry point wires together the pure +functions with actual I/O. + +Usage: + python scripts/migrate_json_to_db.py --source /path/to/releases \\ + --dsn "postgresql://user:pass@localhost/db" [--dry-run] + +Pure functions (testable without DB): + collect_json_files(directory) -> list[Path] + is_staging_filename(name) -> bool + is_archived_filename(name) -> bool + parse_staging_json(data) -> MigrationRecord + parse_archived_json(data) -> MigrationRecord + build_staging_insert_sql(record) -> tuple[str, tuple] + build_archived_insert_sql(record) -> tuple[str, tuple] +""" + +from __future__ import annotations + +import argparse +import json +import re +import sys +from dataclasses import dataclass, field +from datetime import date +from pathlib import Path + + +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class MigrationRecord: + """Parsed record ready for database insertion. + + released_at is None for staging records, set for archived records. + """ + + repo: str + version: str + started_at: date + tickets: list[dict] + released_at: date | None = None + + +# --------------------------------------------------------------------------- +# File classification +# --------------------------------------------------------------------------- + +# Archived filenames match: __.json +# e.g. Billo.Platform.Payment_v1.0.1_2026-03-23.json +_ARCHIVED_PATTERN = re.compile( + r"^.+_v\d+\.\d+\.\d+_\d{4}-\d{2}-\d{2}\.json$" +) + + +def is_staging_filename(name: str) -> bool: + """Return True if the filename looks like a staging JSON file. + + Staging files end in .json and do not match the archived pattern. + """ + if not name.endswith(".json"): + return False + return not _ARCHIVED_PATTERN.match(name) + + +def is_archived_filename(name: str) -> bool: + """Return True if the filename looks like an archived release JSON file.""" + return bool(_ARCHIVED_PATTERN.match(name)) + + +# --------------------------------------------------------------------------- +# File collection +# --------------------------------------------------------------------------- + +def collect_json_files(directory: Path) -> list[Path]: + """Recursively collect all .json files under directory. + + Returns a sorted list of Path objects. + """ + return sorted(directory.rglob("*.json")) + + +# --------------------------------------------------------------------------- +# Parsing pure functions +# --------------------------------------------------------------------------- + +def parse_staging_json(data: dict) -> MigrationRecord: + """Parse a staging release JSON dict into a MigrationRecord. + + Args: + data: Parsed JSON dict with keys: version, repo, started_at, tickets. + + Returns: + MigrationRecord with released_at=None. + """ + return MigrationRecord( + repo=data["repo"], + version=data["version"], + started_at=date.fromisoformat(data["started_at"]), + tickets=list(data.get("tickets") or []), + released_at=None, + ) + + +def parse_archived_json(data: dict) -> MigrationRecord: + """Parse an archived release JSON dict into a MigrationRecord. + + Args: + data: Parsed JSON dict with keys: version, repo, started_at, tickets, + released_at. + + Returns: + MigrationRecord with released_at set. + """ + return MigrationRecord( + repo=data["repo"], + version=data["version"], + started_at=date.fromisoformat(data["started_at"]), + tickets=list(data.get("tickets") or []), + released_at=date.fromisoformat(data["released_at"]), + ) + + +# --------------------------------------------------------------------------- +# SQL builder pure functions +# --------------------------------------------------------------------------- + +_STAGING_INSERT_SQL = """ + INSERT INTO staging_releases (repo, version, started_at, tickets) + VALUES (%s, %s, %s, %s) + ON CONFLICT (repo) DO NOTHING +""".strip() + +_ARCHIVED_INSERT_SQL = """ + INSERT INTO archived_releases (repo, version, started_at, tickets, released_at) + VALUES (%s, %s, %s, %s, %s) + ON CONFLICT (repo, version) DO NOTHING +""".strip() + + +def build_staging_insert_sql(record: MigrationRecord) -> tuple[str, tuple]: + """Build the INSERT SQL and parameters for a staging release record. + + Returns: + (sql_string, params_tuple) ready for cursor.execute(). + """ + tickets_json = json.dumps(record.tickets) + params = ( + record.repo, + record.version, + record.started_at.isoformat(), + tickets_json, + ) + return _STAGING_INSERT_SQL, params + + +def build_archived_insert_sql(record: MigrationRecord) -> tuple[str, tuple]: + """Build the INSERT SQL and parameters for an archived release record. + + Returns: + (sql_string, params_tuple) ready for cursor.execute(). + """ + tickets_json = json.dumps(record.tickets) + params = ( + record.repo, + record.version, + record.started_at.isoformat(), + tickets_json, + record.released_at.isoformat() if record.released_at else None, + ) + return _ARCHIVED_INSERT_SQL, params + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + +def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Migrate JSON release files to PostgreSQL" + ) + parser.add_argument( + "--source", + type=Path, + required=True, + help="Root directory containing release JSON files", + ) + parser.add_argument( + "--dsn", + type=str, + default="", + help="PostgreSQL DSN (e.g. postgresql://user:pass@localhost/db)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print SQL statements without executing them", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + """Entry point for the migration script. + + Returns: + 0 on success, 1 on error. + """ + args = _parse_args(argv) + + if not args.source.exists(): + print(f"ERROR: source directory does not exist: {args.source}", file=sys.stderr) + return 1 + + files = collect_json_files(args.source) + print(f"Found {len(files)} JSON file(s) under {args.source}") + + statements: list[tuple[str, tuple]] = [] + errors: list[str] = [] + + for path in files: + try: + data = json.loads(path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError) as exc: + errors.append(f"Failed to read {path}: {exc}") + continue + + # Determine file type from filename + name = path.name + try: + if is_archived_filename(name) or "released_at" in data: + record = parse_archived_json(data) + sql, params = build_archived_insert_sql(record) + else: + record = parse_staging_json(data) + sql, params = build_staging_insert_sql(record) + except (KeyError, ValueError) as exc: + errors.append(f"Failed to parse {path}: {exc}") + continue + + statements.append((sql, params)) + + if errors: + for err in errors: + print(f"WARNING: {err}", file=sys.stderr) + + if args.dry_run: + print(f"\nDry run: {len(statements)} statement(s) would be executed:") + for sql, params in statements: + print(f" SQL: {sql!r}") + print(f" Params: {params}") + return 0 + + if not args.dsn: + print("ERROR: --dsn is required when not using --dry-run", file=sys.stderr) + return 1 + + try: + import psycopg # noqa: PLC0415 + except ImportError: + print("ERROR: psycopg not installed. Run: pip install psycopg[binary]", file=sys.stderr) + return 1 + + inserted = 0 + with psycopg.connect(args.dsn) as conn: + with conn.cursor() as cur: + for sql, params in statements: + cur.execute(sql, params) + inserted += 1 + conn.commit() + + print(f"Migration complete: {inserted} record(s) inserted.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/release_agent/__init__.py b/src/release_agent/__init__.py new file mode 100644 index 0000000..593372f --- /dev/null +++ b/src/release_agent/__init__.py @@ -0,0 +1 @@ +"""Billo Release Agent - LangGraph-based release automation.""" diff --git a/src/release_agent/api/__init__.py b/src/release_agent/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/release_agent/api/approvals.py b/src/release_agent/api/approvals.py new file mode 100644 index 0000000..551b3df --- /dev/null +++ b/src/release_agent/api/approvals.py @@ -0,0 +1,166 @@ +"""Approvals endpoint for resuming interrupted graph threads. + +POST /approvals/{thread_id} — resume an interrupted graph with a decision. +GET /approvals/pending — list threads with status="interrupted". +""" + +from datetime import datetime, timezone + +from fastapi import APIRouter, Depends, HTTPException, Request + +from release_agent.api.dependencies import get_db_pool, get_graphs, get_tool_clients, require_operator_token +from release_agent.api.models import ( + ApprovalDecision, + ApprovalResponse, + PendingApproval, + PendingApprovalsResponse, +) + +router = APIRouter() + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +async def _resume_graph( + *, + graph, + thread_id: str, + decision: str, + tool_clients, + db_pool, +) -> dict: + """Resume an interrupted graph thread with the given decision. + + The decision string is passed as the resume value to the graph. + Thread status is updated in the database. + """ + from release_agent.api.webhooks import _upsert_thread + + config = { + "configurable": { + "thread_id": thread_id, + "clients": tool_clients, + } + } + try: + result = await graph.ainvoke({"decision": decision}, config=config) + await _upsert_thread(db_pool, thread_id=thread_id, thread_status="completed", state=result or {}) + return result or {} + except Exception as exc: + await _upsert_thread( + db_pool, + thread_id=thread_id, + thread_status="failed", + state={"errors": [str(exc)]}, + ) + raise + + +async def _get_thread_graph_name(db_pool, thread_id: str) -> str | None: + """Look up the graph_name for a thread_id from the database.""" + sql = "SELECT graph_name FROM agent_threads WHERE thread_id = %s" + async with db_pool.connection() as conn: + async with conn.cursor() as cur: + await cur.execute(sql, (thread_id,)) + row = await cur.fetchone() + return row[0] if row else None + + +async def _fetch_pending_threads(db_pool) -> list[tuple]: + """Query agent_threads for rows with status='interrupted'.""" + sql = """ + SELECT + thread_id, + graph_name, + interrupt_value, + created_at, + repo_name, + pr_id, + version + FROM agent_threads + WHERE status = 'interrupted' + ORDER BY created_at ASC + """ + async with db_pool.connection() as conn: + async with conn.cursor() as cur: + await cur.execute(sql) + return await cur.fetchall() + + +# --------------------------------------------------------------------------- +# GET /approvals/pending +# --------------------------------------------------------------------------- + +@router.get("/approvals/pending") +async def list_pending_approvals( + request: Request, # noqa: ARG001 + db_pool=Depends(get_db_pool), + _auth: None = Depends(require_operator_token), +) -> PendingApprovalsResponse: + """Return all graph threads currently awaiting operator approval.""" + rows = await _fetch_pending_threads(db_pool) + items = [ + PendingApproval( + thread_id=row[0], + graph_name=row[1], + interrupt_value=row[2], + created_at=row[3] if isinstance(row[3], datetime) else datetime.fromisoformat(str(row[3])), + repo_name=row[4], + pr_id=row[5], + version=row[6], + ) + for row in rows + ] + return PendingApprovalsResponse(items=items, count=len(items)) + + +# --------------------------------------------------------------------------- +# POST /approvals/{thread_id} +# --------------------------------------------------------------------------- + +@router.post("/approvals/{thread_id}") +async def submit_approval( + thread_id: str, + body: ApprovalDecision, + request: Request, # noqa: ARG001 + graphs=Depends(get_graphs), + tool_clients=Depends(get_tool_clients), + db_pool=Depends(get_db_pool), + _auth: None = Depends(require_operator_token), +) -> ApprovalResponse: + """Submit an approval decision to resume an interrupted graph thread. + + The decision is forwarded as the interrupt resume value to the graph. + Thread status is updated to 'completed' or 'failed' based on result. + """ + # Look up graph_name from DB for the correct graph + graph_name = await _get_thread_graph_name(db_pool, thread_id) + if graph_name is None: + raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") + + graph = graphs.get(graph_name) + if graph is None: + raise HTTPException(status_code=400, detail=f"Unknown graph: {graph_name}") + + try: + await _resume_graph( + graph=graph, + thread_id=thread_id, + decision=body.decision, + tool_clients=tool_clients, + db_pool=db_pool, + ) + except Exception: + return ApprovalResponse( + thread_id=thread_id, + status="failed", + message=f"Thread {thread_id} failed during resume", + ) + + return ApprovalResponse( + thread_id=thread_id, + status="resumed", + message=f"Thread {thread_id} resumed with decision '{body.decision}'", + ) diff --git a/src/release_agent/api/dependencies.py b/src/release_agent/api/dependencies.py new file mode 100644 index 0000000..6bfd6a1 --- /dev/null +++ b/src/release_agent/api/dependencies.py @@ -0,0 +1,66 @@ +"""FastAPI dependency callables that extract objects from app.state. + +Each function accepts a Request and returns the corresponding object stored +on app.state during lifespan startup. Use with Depends() in route handlers. +""" + +import hmac + +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import APIKeyHeader + +from release_agent.config import Settings +from release_agent.graph.dependencies import StagingStore, ToolClients + +_operator_token_header = APIKeyHeader(name="X-Operator-Token", auto_error=False) + + +def get_settings(request: Request) -> Settings: + """Return the Settings instance stored on app.state.""" + return request.app.state.settings + + +def get_graphs(request: Request) -> dict: + """Return the compiled graphs dict stored on app.state.""" + return request.app.state.graphs + + +def get_tool_clients(request: Request) -> ToolClients: + """Return the ToolClients instance stored on app.state.""" + return request.app.state.tool_clients + + +def get_staging_store(request: Request) -> StagingStore: + """Return the StagingStore instance stored on app.state.""" + return request.app.state.staging_store + + +def get_db_pool(request: Request): + """Return the database connection pool stored on app.state.""" + return request.app.state.db_pool + + +async def require_operator_token( + request: Request, + token: str | None = Depends(_operator_token_header), +) -> None: + """Validate the X-Operator-Token header against the configured operator token. + + If operator_token is not configured (empty string), authentication is skipped + and all requests are allowed. When configured, uses hmac.compare_digest + for constant-time comparison to prevent timing attacks. + + Raises: + HTTPException: 401 if a token is configured but the provided token is + missing or does not match. + """ + configured = request.app.state.settings.operator_token.get_secret_value() + if not configured: + # No token configured — auth disabled + return + + if not token or not hmac.compare_digest(token, configured): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing operator token", + ) diff --git a/src/release_agent/api/models.py b/src/release_agent/api/models.py new file mode 100644 index 0000000..e18d065 --- /dev/null +++ b/src/release_agent/api/models.py @@ -0,0 +1,137 @@ +"""API request/response Pydantic models for the HTTP layer. + +All models are frozen (immutable) following the project immutability policy. +""" + +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, ConfigDict + + +# --------------------------------------------------------------------------- +# Webhook models +# --------------------------------------------------------------------------- + +class WebhookResponse(BaseModel): + """Response returned after a webhook event is accepted.""" + + model_config = ConfigDict(frozen=True) + + thread_id: str + message: str + + +# --------------------------------------------------------------------------- +# Approval models +# --------------------------------------------------------------------------- + +class ApprovalDecision(BaseModel): + """Body for POST /approvals/{thread_id}.""" + + model_config = ConfigDict(frozen=True) + + decision: Literal["merge", "cancel", "approve", "skip", "trigger"] + + +class ApprovalResponse(BaseModel): + """Response returned after an approval decision is processed.""" + + model_config = ConfigDict(frozen=True) + + thread_id: str + status: str + message: str + + +class PendingApproval(BaseModel): + """A single interrupted graph thread awaiting operator input.""" + + model_config = ConfigDict(frozen=True) + + thread_id: str + graph_name: str + interrupt_value: str + created_at: datetime + repo_name: str | None = None + pr_id: str | None = None + version: str | None = None + + +class PendingApprovalsResponse(BaseModel): + """Response for GET /approvals/pending.""" + + model_config = ConfigDict(frozen=True) + + items: list[PendingApproval] + count: int + + +# --------------------------------------------------------------------------- +# Health / status models +# --------------------------------------------------------------------------- + +class HealthResponse(BaseModel): + """Response for GET /status.""" + + model_config = ConfigDict(frozen=True) + + status: Literal["ok", "degraded"] + version: str + uptime_seconds: float + + +# --------------------------------------------------------------------------- +# Release / staging models +# --------------------------------------------------------------------------- + +class ReleaseVersionListResponse(BaseModel): + """Response for GET /releases/{repo}.""" + + model_config = ConfigDict(frozen=True) + + repo: str + versions: list[str] + + +class StagingResponse(BaseModel): + """Response for GET /staging.""" + + model_config = ConfigDict(frozen=True) + + repo: str + staging: dict | None + + +# --------------------------------------------------------------------------- +# Manual trigger models +# --------------------------------------------------------------------------- + +class ManualTriggerResponse(BaseModel): + """Response for manual trigger endpoints.""" + + model_config = ConfigDict(frozen=True) + + thread_id: str + message: str + + +class ManualReleaseRequest(BaseModel): + """Request body for POST /manual/release.""" + + model_config = ConfigDict(frozen=True) + + repo: str + + +# --------------------------------------------------------------------------- +# Error model +# --------------------------------------------------------------------------- + +class ErrorResponse(BaseModel): + """Standard error response envelope.""" + + model_config = ConfigDict(frozen=True) + + error: str + detail: str | None = None diff --git a/src/release_agent/api/slack_interactions.py b/src/release_agent/api/slack_interactions.py new file mode 100644 index 0000000..060f0c5 --- /dev/null +++ b/src/release_agent/api/slack_interactions.py @@ -0,0 +1,264 @@ +"""Slack interactive messages endpoint. + +POST /slack/interactions — receives Slack button click payloads, verifies +the signing secret, extracts thread_id + decision, and resumes the +appropriate graph thread in the background. + +Slack requires a 200 response within 3 seconds. The graph resume is +therefore scheduled as a background task. +""" + +import asyncio +import hashlib +import hmac +import json +import logging +import time +import urllib.parse +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import JSONResponse + +from release_agent.api.dependencies import get_db_pool, get_graphs, get_settings, get_tool_clients + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# --------------------------------------------------------------------------- +# Pure helper: signature verification +# --------------------------------------------------------------------------- + +_MAX_TIMESTAMP_AGE_SECONDS = 300 # 5 minutes — Slack's recommendation +_ALLOWED_DECISIONS = frozenset({"merge", "cancel", "approve", "skip", "trigger", "yes", "no"}) + + +def _verify_slack_signature( + *, + signing_secret: str, + timestamp: str, + body: str, + signature: str, + current_time: float | None = None, +) -> bool: + """Verify a Slack request signature using HMAC-SHA256. + + Also rejects requests with timestamps older than 5 minutes + to prevent replay attacks. + + Args: + signing_secret: The app's Slack signing secret. + timestamp: Value of the X-Slack-Request-Timestamp header. + body: Raw request body as a string. + signature: Value of the X-Slack-Signature header. + current_time: Current unix timestamp (injectable for testing). + + Returns: + True if the signature is valid and fresh, False otherwise. + """ + if not signature or not signature.startswith("v0="): + return False + + # Replay attack prevention + try: + request_ts = int(timestamp) + except ValueError: + return False + + now = current_time if current_time is not None else time.time() + if abs(now - request_ts) > _MAX_TIMESTAMP_AGE_SECONDS: + return False + + base_string = f"v0:{timestamp}:{body}" + computed_hash = hmac.new( + signing_secret.encode(), + base_string.encode(), + hashlib.sha256, + ).hexdigest() + expected_sig = f"v0={computed_hash}" + + return hmac.compare_digest(expected_sig, signature) + + +# --------------------------------------------------------------------------- +# Internal: parse button payload +# --------------------------------------------------------------------------- + +def _parse_button_payload(form_body: str) -> dict[str, Any] | None: + """Parse a URL-encoded Slack button action payload. + + Args: + form_body: Raw URL-encoded form body from Slack. + + Returns: + Dict with "thread_id" and "decision" keys, or None if parsing fails. + """ + try: + parsed = urllib.parse.parse_qs(form_body) + payload_json = parsed.get("payload", [None])[0] + if not payload_json: + return None + + payload = json.loads(payload_json) + actions = payload.get("actions", []) + if not actions: + return None + + action = actions[0] + value = action.get("value", "") + # value format: "{thread_id}:{decision}" + if ":" not in value: + return None + + parts = value.split(":", 1) + thread_id = parts[0] + decision = parts[1] + + user = payload.get("user") or {} + user_name = user.get("name", user.get("id", "unknown")) + + return { + "thread_id": thread_id, + "decision": decision, + "user_name": user_name, + } + except (json.JSONDecodeError, KeyError, IndexError, ValueError) as exc: + logger.warning("Failed to parse Slack button payload: %s", exc) + return None + + +# --------------------------------------------------------------------------- +# Internal: background graph resume +# --------------------------------------------------------------------------- + +async def _background_resume( + *, + thread_id: str, + decision: str, + graphs: dict, + tool_clients: Any, + db_pool: Any, +) -> None: + """Fetch the graph for thread_id from DB and resume it.""" + from release_agent.api.approvals import _get_thread_graph_name, _resume_graph + + graph_name = await _get_thread_graph_name(db_pool, thread_id) + if graph_name is None: + logger.warning("slack_interactions: thread %s not found in DB", thread_id) + return + + graph = graphs.get(graph_name) + if graph is None: + logger.warning( + "slack_interactions: unknown graph '%s' for thread %s", + graph_name, + thread_id, + ) + return + + try: + await _resume_graph( + graph=graph, + thread_id=thread_id, + decision=decision, + tool_clients=tool_clients, + db_pool=db_pool, + ) + except Exception as exc: + logger.error( + "slack_interactions: failed to resume thread %s: %s", + thread_id, + exc, + ) + + +# --------------------------------------------------------------------------- +# Endpoint: POST /slack/interactions +# --------------------------------------------------------------------------- + +@router.post("/slack/interactions") +async def slack_interactions( + request: Request, + graphs: dict = Depends(get_graphs), + tool_clients: Any = Depends(get_tool_clients), + db_pool: Any = Depends(get_db_pool), + settings: Any = Depends(get_settings), +) -> JSONResponse: + """Receive and process Slack interactive button clicks. + + 1. Verify the Slack signing secret. + 2. Parse the payload to extract thread_id and decision. + 3. Schedule the graph resume as a background task. + 4. Return 200 immediately (Slack requires a fast response). + + Returns: + 200 {"ok": True} on success. + 400 if required headers are missing. + 403 if signature verification fails. + """ + raw_body = await request.body() + body_str = raw_body.decode("utf-8") + + timestamp = request.headers.get("X-Slack-Request-Timestamp") + signature = request.headers.get("X-Slack-Signature") + + if not timestamp: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing X-Slack-Request-Timestamp header", + ) + + if not signature: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing X-Slack-Signature header", + ) + + signing_secret = settings.slack_signing_secret.get_secret_value() + if not signing_secret: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Slack signing secret not configured", + ) + + valid = _verify_slack_signature( + signing_secret=signing_secret, + timestamp=timestamp, + body=body_str, + signature=signature, + ) + if not valid: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid Slack signature", + ) + + parsed = _parse_button_payload(body_str) + if not parsed: + logger.warning("slack_interactions: could not parse payload") + return JSONResponse({"ok": True}) + + thread_id = parsed["thread_id"] + decision = parsed["decision"] + + if decision not in _ALLOWED_DECISIONS: + logger.warning("slack_interactions: unexpected decision '%s', ignoring", decision) + return JSONResponse({"ok": True}) + + # Schedule the graph resume as a background task (Slack needs 200 fast) + task = asyncio.create_task( + _background_resume( + thread_id=thread_id, + decision=decision, + graphs=graphs, + tool_clients=tool_clients, + db_pool=db_pool, + ) + ) + if hasattr(request.app.state, "background_tasks"): + request.app.state.background_tasks.add(task) + task.add_done_callback(request.app.state.background_tasks.discard) + + return JSONResponse({"ok": True}) diff --git a/src/release_agent/api/status.py b/src/release_agent/api/status.py new file mode 100644 index 0000000..7cf7958 --- /dev/null +++ b/src/release_agent/api/status.py @@ -0,0 +1,153 @@ +"""Status, releases, staging, and manual trigger endpoints. + +GET /status — health check +GET /releases/{repo} — list release versions for a repo +GET /staging — current staging release for a repo +POST /manual/pr/{pr_id} — manually trigger PR processing +POST /manual/release — manually trigger a release run +""" + +import asyncio +import uuid +from datetime import datetime, timezone + +from fastapi import APIRouter, Depends, Request, status +from fastapi.responses import JSONResponse + +from release_agent.api.dependencies import ( + get_db_pool, + get_graphs, + get_staging_store, + get_tool_clients, + require_operator_token, +) +from release_agent.api.models import ( + HealthResponse, + ManualReleaseRequest, + ManualTriggerResponse, + ReleaseVersionListResponse, + StagingResponse, +) + +router = APIRouter() + +_VERSION = "0.1.0" + + +# --------------------------------------------------------------------------- +# GET /status +# --------------------------------------------------------------------------- + +@router.get("/status") +async def get_status(request: Request) -> HealthResponse: + """Return the health status of the agent service.""" + started_at: datetime = request.app.state.started_at + uptime = (datetime.now(tz=timezone.utc) - started_at).total_seconds() + return HealthResponse( + status="ok", + version=_VERSION, + uptime_seconds=uptime, + ) + + +# --------------------------------------------------------------------------- +# GET /releases/{repo} +# --------------------------------------------------------------------------- + +@router.get("/releases/{repo}") +async def list_release_versions( + repo: str, + staging_store=Depends(get_staging_store), +) -> ReleaseVersionListResponse: + """List all known release versions for the given repository.""" + versions = await staging_store.list_versions(repo) + return ReleaseVersionListResponse(repo=repo, versions=versions) + + +# --------------------------------------------------------------------------- +# GET /staging +# --------------------------------------------------------------------------- + +@router.get("/staging") +async def get_staging( + repo: str, + staging_store=Depends(get_staging_store), +) -> StagingResponse: + """Return the current staging release for the given repository.""" + staging_obj = await staging_store.load(repo) + staging_dict = staging_obj.model_dump(mode="json") if staging_obj is not None else None + return StagingResponse(repo=repo, staging=staging_dict) + + +# --------------------------------------------------------------------------- +# POST /manual/pr/{pr_id} +# --------------------------------------------------------------------------- + +@router.post("/manual/pr/{pr_id}", status_code=status.HTTP_202_ACCEPTED) +async def manual_pr_trigger( + pr_id: str, + request: Request, + graphs=Depends(get_graphs), + tool_clients=Depends(get_tool_clients), + db_pool=Depends(get_db_pool), + _auth: None = Depends(require_operator_token), +) -> ManualTriggerResponse: + """Manually trigger PR processing for the given PR ID.""" + from release_agent.api.webhooks import _run_graph + + thread_id = str(uuid.uuid4()) + initial_state = {"pr_id": pr_id} + + task = asyncio.create_task( + _run_graph( + graph=graphs["pr_completed"], + initial_state=initial_state, + thread_id=thread_id, + tool_clients=tool_clients, + db_pool=db_pool, + ) + ) + request.app.state.background_tasks.add(task) + task.add_done_callback(request.app.state.background_tasks.discard) + + return ManualTriggerResponse( + thread_id=thread_id, + message=f"PR {pr_id} processing scheduled as thread {thread_id}", + ) + + +# --------------------------------------------------------------------------- +# POST /manual/release +# --------------------------------------------------------------------------- + +@router.post("/manual/release", status_code=status.HTTP_202_ACCEPTED) +async def manual_release_trigger( + body: ManualReleaseRequest, + request: Request, + graphs=Depends(get_graphs), + tool_clients=Depends(get_tool_clients), + db_pool=Depends(get_db_pool), + _auth: None = Depends(require_operator_token), +) -> ManualTriggerResponse: + """Manually trigger a release run for the given repository.""" + from release_agent.api.webhooks import _run_graph + + thread_id = str(uuid.uuid4()) + initial_state = {"repo_name": body.repo} + + task = asyncio.create_task( + _run_graph( + graph=graphs["release"], + initial_state=initial_state, + thread_id=thread_id, + tool_clients=tool_clients, + db_pool=db_pool, + ) + ) + request.app.state.background_tasks.add(task) + task.add_done_callback(request.app.state.background_tasks.discard) + + return ManualTriggerResponse( + thread_id=thread_id, + message=f"Release for repo '{body.repo}' scheduled as thread {thread_id}", + ) diff --git a/src/release_agent/api/webhooks.py b/src/release_agent/api/webhooks.py new file mode 100644 index 0000000..838ce14 --- /dev/null +++ b/src/release_agent/api/webhooks.py @@ -0,0 +1,195 @@ +"""Webhook endpoint for Azure DevOps events. + +POST /webhooks/azdo — validates secret, parses payload, filters events, +schedules graph execution as a background task, returns 202. +""" + +import asyncio +import hmac +import logging +import uuid +from typing import Annotated + +logger = logging.getLogger(__name__) + +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status +from fastapi.responses import JSONResponse +from pydantic import ValidationError + +from release_agent.api.dependencies import get_db_pool, get_graphs, get_settings, get_tool_clients +from release_agent.api.models import WebhookResponse +from release_agent.models.webhook import WebhookPayload + +router = APIRouter() + + +# --------------------------------------------------------------------------- +# Secret validation helper (pure function — easily unit-tested) +# --------------------------------------------------------------------------- + +def _validate_webhook_secret(header: str | None, expected: str) -> bool: + """Return True if header matches expected using constant-time comparison. + + Returns False if: + - expected is empty (auth misconfigured, reject all) + - header is None (no header sent) + - header does not match expected + """ + if not expected: + return False + if header is None: + return False + return hmac.compare_digest(header, expected) + + +# --------------------------------------------------------------------------- +# POST /webhooks/azdo +# --------------------------------------------------------------------------- + +@router.post("/webhooks/azdo") +async def receive_azdo_webhook( + request: Request, + x_webhook_secret: Annotated[str | None, Header()] = None, + settings=Depends(get_settings), + graphs=Depends(get_graphs), + tool_clients=Depends(get_tool_clients), + db_pool=Depends(get_db_pool), +): + """Receive an Azure DevOps webhook event. + + Validates the X-Webhook-Secret header, parses the payload, + filters to completed PR events only, then schedules graph execution. + + Returns 202 with thread_id for accepted events. + Returns 200 with "event ignored" for non-completed PR events. + Returns 401 for missing or invalid secret. + Returns 422 for invalid payload structure. + """ + expected_secret = settings.webhook_secret.get_secret_value() + if not _validate_webhook_secret(x_webhook_secret, expected_secret): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing webhook secret", + ) + + body = await request.json() + try: + payload = WebhookPayload.model_validate(body) + except (ValidationError, Exception) as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Invalid payload: {exc}", + ) from exc + + # Only process completed PRs + if payload.resource.status != "completed": + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"message": "event ignored: PR not completed", "status": payload.resource.status}, + ) + + # Generate a unique thread ID for this execution + thread_id = str(uuid.uuid4()) + + # Build initial state from webhook payload + initial_state = { + "webhook_payload": body, + "pr_id": str(payload.resource.pull_request_id), + "repo_name": payload.resource.repository.name, + } + + graph = graphs["pr_completed"] + + # Schedule as a background task (non-blocking) + settings = request.app.state.settings + task = asyncio.create_task( + _run_graph( + graph=graph, + initial_state=initial_state, + thread_id=thread_id, + tool_clients=tool_clients, + db_pool=db_pool, + repos_base_dir=settings.repos_base_dir, + graph_name="pr_completed", + default_jira_project=settings.default_jira_project, + ) + ) + request.app.state.background_tasks.add(task) + task.add_done_callback(request.app.state.background_tasks.discard) + + return JSONResponse( + status_code=status.HTTP_202_ACCEPTED, + content=WebhookResponse( + thread_id=thread_id, + message=f"Webhook accepted, scheduled as thread {thread_id}", + ).model_dump(), + ) + + +# --------------------------------------------------------------------------- +# Internal graph runner +# --------------------------------------------------------------------------- + +async def _run_graph( + *, + graph, + initial_state: dict, + thread_id: str, + tool_clients, + db_pool, + repos_base_dir: str = "", + graph_name: str = "", + default_jira_project: str = "ALLPOST", +) -> None: + """Invoke the graph and persist thread status to the database.""" + repo_name = initial_state.get("repo_name", "") + pr_id = initial_state.get("pr_id", "") + config = { + "configurable": { + "thread_id": thread_id, + "clients": tool_clients, + "repos_base_dir": repos_base_dir, + "default_jira_project": default_jira_project, + } + } + try: + await _upsert_thread( + db_pool, thread_id=thread_id, thread_status="running", + state=initial_state, graph_name=graph_name, + repo_name=repo_name, pr_id=pr_id, + ) + result = await graph.ainvoke(initial_state, config=config) + await _upsert_thread(db_pool, thread_id=thread_id, thread_status="completed", state=result or {}) + except Exception as exc: + logger.exception("Graph execution failed for thread %s", thread_id) + error_state = {**initial_state, "errors": [str(exc)]} + await _upsert_thread(db_pool, thread_id=thread_id, thread_status="failed", state=error_state) + + +async def _upsert_thread( + pool, + *, + thread_id: str, + thread_status: str, + state: dict, + graph_name: str = "", + repo_name: str = "", + pr_id: str = "", +) -> None: + """Insert or update a thread record in agent_threads.""" + import json + + sql = """ + INSERT INTO agent_threads (thread_id, graph_name, repo_name, pr_id, status, state, updated_at) + VALUES (%s, %s, %s, %s, %s, %s, NOW()) + ON CONFLICT (thread_id) DO UPDATE + SET status = EXCLUDED.status, + state = EXCLUDED.state, + updated_at = EXCLUDED.updated_at + """ + async with pool.connection() as conn: + async with conn.cursor() as cur: + await cur.execute(sql, ( + thread_id, graph_name, repo_name, pr_id, + thread_status, json.dumps(state), + )) diff --git a/src/release_agent/branch_parser.py b/src/release_agent/branch_parser.py new file mode 100644 index 0000000..4e06f9b --- /dev/null +++ b/src/release_agent/branch_parser.py @@ -0,0 +1,46 @@ +"""Branch name parsing utilities for extracting Jira ticket IDs. + +Pure functions only - no side effects, no mutation. +""" + +import re + +# Matches Jira-style ticket IDs: one or more uppercase letters/digits (starting +# with a letter) followed by a dash and one or more digits. +# Examples: ALLPOST-4229, BILL-42, MY-1, AB2-100 +_TICKET_PATTERN = re.compile(r"(? str: + """Remove the 'refs/heads/' prefix from a branch name if present. + + Returns the branch name unchanged if the prefix is not present. + Does not strip 'refs/tags/' or any other prefix. + """ + if branch.startswith(_REFS_HEADS_PREFIX): + return branch[len(_REFS_HEADS_PREFIX):] + return branch + + +def parse_branch(branch: str) -> tuple[str | None, bool]: + """Parse a branch name and extract a Jira ticket ID if present. + + Handles the 'refs/heads/' prefix automatically. + + Returns: + A tuple of (ticket_id, has_ticket) where ticket_id is the Jira ID + string (e.g. "ALLPOST-4229") or None, and has_ticket is a bool. + """ + if not branch: + return None, False + + normalized = strip_refs_prefix(branch) + + match = _TICKET_PATTERN.search(normalized) + if match: + return match.group(1), True + + return None, False diff --git a/src/release_agent/config.py b/src/release_agent/config.py new file mode 100644 index 0000000..356e9f6 --- /dev/null +++ b/src/release_agent/config.py @@ -0,0 +1,110 @@ +"""Application configuration using pydantic-settings. + +All secrets are stored as SecretStr to prevent accidental logging. +""" + +from pydantic import Field, SecretStr, computed_field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings loaded from environment variables. + + Required variables: + AZDO_ORGANIZATION - Azure DevOps organization name + AZDO_PROJECT - Azure DevOps project name + AZDO_PAT - Azure DevOps personal access token + ANTHROPIC_API_KEY - Anthropic API key for Claude + POSTGRES_DSN - PostgreSQL connection string + JIRA_EMAIL - Jira account email address + JIRA_API_TOKEN - Jira API token + SLACK_WEBHOOK_URL - Slack incoming webhook URL + + Optional variables: + PORT - HTTP server port (default: 8000, range: 1-65535) + JIRA_BASE_URL - Jira base URL (default: https://billolife.atlassian.net) + CLAUDE_REVIEW_MODEL - Anthropic model for PR review (default: claude-sonnet-4-20250514) + WATCHED_REPOS - Comma-separated repo names to poll for PRs + PR_POLL_INTERVAL_SECONDS - Seconds between PR polling runs (default: 300) + PR_POLL_TARGET_BRANCH - Target branch ref to filter PRs (default: refs/heads/develop) + PR_POLL_ENABLED - Enable PR polling background task (default: False) + DEFAULT_JIRA_PROJECT - Default Jira project key for auto-created tickets (default: ALLPOST) + AUTO_CREATE_TICKET_ENABLED - Auto-create Jira ticket when PR has no ticket (default: True) + """ + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore", + ) + + azdo_organization: str + azdo_project: str + azdo_pat: SecretStr + anthropic_api_key: SecretStr = Field(default=SecretStr("")) # Optional: only needed if using API directly + postgres_dsn: SecretStr + port: int = Field(default=8000, ge=1, le=65535) + + # Jira settings + jira_base_url: str = "https://billolife.atlassian.net" + jira_email: str + jira_api_token: SecretStr + + # Slack settings + slack_webhook_url: SecretStr = Field(default=SecretStr("")) + slack_bot_token: SecretStr = Field(default=SecretStr("")) + slack_signing_secret: SecretStr = Field(default=SecretStr("")) + slack_channel_id: str = "" + + # CI polling settings + ci_poll_interval_seconds: int = 30 + ci_poll_max_wait_seconds: int = 1800 + + # Claude settings + claude_review_model: str = "claude-sonnet-4-20250514" + + # Local repo settings + repos_base_dir: str = "" # Base directory containing Billo repos (e.g., /c/Users/yaoji/git/Billo) + + # Webhook settings + webhook_secret: SecretStr + + # Operator API settings + operator_token: SecretStr = Field(default=SecretStr("")) + + # PR polling settings + watched_repos: str = "" + pr_poll_interval_seconds: int = 300 + pr_poll_target_branch: str = "refs/heads/develop" + pr_poll_enabled: bool = False + + # Auto-create Jira ticket settings + default_jira_project: str = "ALLPOST" + auto_create_ticket_enabled: bool = True + + @computed_field # type: ignore[misc] + @property + def watched_repos_list(self) -> list[str]: + """Return watched_repos as a list, splitting on commas and stripping whitespace.""" + if not self.watched_repos: + return [] + return [r.strip() for r in self.watched_repos.split(",") if r.strip()] + + @computed_field # type: ignore[misc] + @property + def azdo_base_url(self) -> str: + """Base URL for the Azure DevOps organization.""" + return f"https://dev.azure.com/{self.azdo_organization}" + + @computed_field # type: ignore[misc] + @property + def azdo_api_url(self) -> str: + """Base API URL for the Azure DevOps project.""" + return f"https://dev.azure.com/{self.azdo_organization}/{self.azdo_project}/_apis" + + @computed_field # type: ignore[misc] + @property + def azdo_vsrm_api_url(self) -> str: + """Base API URL for the Azure DevOps VSRM release management.""" + return f"https://vsrm.dev.azure.com/{self.azdo_organization}/{self.azdo_project}/_apis" diff --git a/src/release_agent/exceptions.py b/src/release_agent/exceptions.py new file mode 100644 index 0000000..a28dcb3 --- /dev/null +++ b/src/release_agent/exceptions.py @@ -0,0 +1,59 @@ +"""Custom exception hierarchy for the release agent. + +All exceptions inherit from ReleaseAgentError so callers can catch +the entire hierarchy with a single except clause when needed. +""" + + +class ReleaseAgentError(Exception): + """Base exception for all release agent errors.""" + + +class ServiceError(ReleaseAgentError): + """An HTTP error response from an external service. + + Attributes: + service: Name of the service that returned the error (e.g. "jira"). + status_code: HTTP status code from the response. + detail: Optional human-readable detail message. + """ + + def __init__(self, *, service: str, status_code: int, detail: str | None) -> None: + self.service = service + self.status_code = status_code + self.detail = detail + super().__init__(f"[{service}] HTTP {status_code}: {detail}") + + +class AuthenticationError(ServiceError): + """Authentication or authorisation failure (401 / 403).""" + + def __init__(self, *, service: str, status_code: int = 401, detail: str | None = None) -> None: + super().__init__(service=service, status_code=status_code, detail=detail) + + +class NotFoundError(ServiceError): + """Requested resource was not found (404).""" + + def __init__(self, *, service: str, detail: str | None = None) -> None: + super().__init__(service=service, status_code=404, detail=detail) + + +class RateLimitError(ServiceError): + """Rate limit exceeded (429). + + Attributes: + retry_after: Seconds to wait before retrying, or None if not provided. + """ + + def __init__(self, *, service: str, retry_after: int | None) -> None: + self.retry_after = retry_after + detail = f"retry after {retry_after}s" if retry_after is not None else None + super().__init__(service=service, status_code=429, detail=detail) + + +class ServiceUnavailableError(ServiceError): + """Service temporarily unavailable (503).""" + + def __init__(self, *, service: str, detail: str | None = None) -> None: + super().__init__(service=service, status_code=503, detail=detail) diff --git a/src/release_agent/graph/__init__.py b/src/release_agent/graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/release_agent/graph/ci_nodes.py b/src/release_agent/graph/ci_nodes.py new file mode 100644 index 0000000..cc81d12 --- /dev/null +++ b/src/release_agent/graph/ci_nodes.py @@ -0,0 +1,180 @@ +"""CI/CD graph nodes for triggering, polling, and notifying CI builds. + +Each node is an async function (state, config) -> dict. +Nodes never mutate state — they return new dicts. +External clients are accessed via config["configurable"]["clients"]. +""" + +import logging +from typing import Any + +from release_agent.exceptions import ReleaseAgentError +from release_agent.graph.dependencies import ToolClients +from release_agent.graph.polling import poll_until +from release_agent.models.build import BuildStatus +from release_agent.tools.slack import _build_ci_status_blocks + +logger = logging.getLogger(__name__) + +_DEFAULT_BRANCH = "refs/heads/main" +_CI_POLL_INTERVAL = 30 +_CI_POLL_MAX_WAIT = 1800 + + +def _get_clients(config: dict) -> ToolClients: + return config["configurable"]["clients"] + + +def _get_settings(config: dict): + return config["configurable"].get("settings") + + +# --------------------------------------------------------------------------- +# Node: trigger_ci_build +# --------------------------------------------------------------------------- + +async def trigger_ci_build(state: dict[str, Any], config: dict) -> dict: + """Trigger the CI build pipeline for the repository. + + Finds the first available pipeline for the repo and triggers it on + the appropriate branch (main for post-merge, release branch otherwise). + + Args: + state: Graph state dict. + config: LangGraph config dict with "configurable" key. + + Returns: + Dict with ci_build_id on success, or errors on failure. + """ + clients = _get_clients(config) + repo_name = state.get("repo_name", "") + version = state.get("version", "") + + # Determine branch: main if version present (release), develop otherwise (PR merge) + if version: + branch = "refs/heads/main" + else: + branch = "refs/heads/develop" + + try: + pipelines = await clients.azdo.list_build_pipelines(repo=repo_name) + except ReleaseAgentError as exc: + return {"errors": [f"trigger_ci_build: failed to list pipelines: {exc}"]} + + if not pipelines: + return {"errors": [f"trigger_ci_build: no pipelines found for repo '{repo_name}'"]} + + pipeline = pipelines[0] + try: + result = await clients.azdo.trigger_pipeline( + pipeline_id=pipeline.id, + branch=branch, + ) + raw_id = result.get("id") + if not isinstance(raw_id, int): + return {"errors": [f"trigger_ci_build: unexpected build id type: {raw_id!r}"]} + return { + "ci_build_id": raw_id, + "messages": [f"CI build triggered: pipeline {pipeline.id}, build {raw_id}"], + } + except ReleaseAgentError as exc: + return {"errors": [f"trigger_ci_build: failed to trigger pipeline {pipeline.id}: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: poll_ci_build +# --------------------------------------------------------------------------- + +async def poll_ci_build(state: dict[str, Any], config: dict) -> dict: + """Poll the CI build until completion or timeout. + + Uses the polling utility with configurable interval and max wait. + On timeout, appends an error and returns the last known status. + + Args: + state: Graph state dict containing ci_build_id. + config: LangGraph config dict. + + Returns: + Dict with ci_build_status, ci_build_result, and ci_build_url on + completion, or errors on timeout/failure. + """ + clients = _get_clients(config) + build_id = state.get("ci_build_id") + + if not build_id: + return {"errors": ["poll_ci_build: ci_build_id not set in state"]} + + settings = _get_settings(config) + interval = getattr(settings, "ci_poll_interval_seconds", _CI_POLL_INTERVAL) if settings else _CI_POLL_INTERVAL + max_wait = getattr(settings, "ci_poll_max_wait_seconds", _CI_POLL_MAX_WAIT) if settings else _CI_POLL_MAX_WAIT + + async def poll_fn() -> BuildStatus: + return await clients.azdo.get_build_status(build_id=build_id) + + def is_done(bs: BuildStatus) -> bool: + return bs is not None and bs.status == "completed" + + last_status, completed = await poll_until( + poll_fn=poll_fn, + is_done=is_done, + interval_seconds=interval, + max_wait_seconds=max_wait, + ) + + if last_status is None: + return {"errors": [f"poll_ci_build: build {build_id} polling failed (no status returned)"]} + + result: dict = { + "ci_build_status": last_status.status, + "ci_build_result": last_status.result, + "ci_build_url": last_status.build_url, + } + + if not completed: + result["errors"] = [ + f"poll_ci_build: build {build_id} did not complete within timeout " + f"(last status: {last_status.status})" + ] + + return result + + +# --------------------------------------------------------------------------- +# Node: notify_ci_result +# --------------------------------------------------------------------------- + +async def notify_ci_result(state: dict[str, Any], config: dict) -> dict: + """Send a Slack notification with the CI build result. + + Non-critical: errors are appended rather than re-raised. + + Args: + state: Graph state dict. + config: LangGraph config dict. + + Returns: + Dict with messages on success, or errors on failure. + """ + clients = _get_clients(config) + repo_name = state.get("repo_name", "unknown") + build_result = state.get("ci_build_result", "unknown") + build_url = state.get("ci_build_url") + branch = state.get("version", "main") + + blocks = _build_ci_status_blocks( + repo=repo_name, + branch=branch, + status=build_result, + build_url=build_url, + ) + status_label = "passed" if build_result == "succeeded" else "failed" + text = f"CI build {status_label} for {repo_name}" + + try: + await clients.slack.send_notification(text=text, blocks=blocks) + return {"messages": [text]} + except ReleaseAgentError as exc: + return {"errors": [f"notify_ci_result: {exc}"]} + except Exception as exc: + return {"errors": [f"notify_ci_result: unexpected error: {exc}"]} diff --git a/src/release_agent/graph/dependencies.py b/src/release_agent/graph/dependencies.py new file mode 100644 index 0000000..20c373f --- /dev/null +++ b/src/release_agent/graph/dependencies.py @@ -0,0 +1,171 @@ +"""Graph dependency types: ToolClients dataclass and StagingStore protocol. + +ToolClients is a frozen dataclass injected via config["configurable"]["clients"]. +StagingStore is a Protocol with file-based implementation JsonFileStagingStore. +All StagingStore methods are async to support both file-based and DB backends. +""" + +import json +from dataclasses import dataclass +from datetime import date +from pathlib import Path +from typing import Any, Protocol + +from release_agent.models.release import ArchivedRelease, StagingRelease + + +# --------------------------------------------------------------------------- +# ToolClients +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ToolClients: + """Frozen container for all external service clients. + + Injected into graph nodes via config["configurable"]["clients"]. + + Attributes: + azdo: AzDoClient instance. + jira: JiraClient instance. + slack: SlackClient instance. + reviewer: ClaudeReviewer instance. + """ + + azdo: Any + jira: Any + slack: Any + reviewer: Any + + +# --------------------------------------------------------------------------- +# StagingStore Protocol +# --------------------------------------------------------------------------- + +class StagingStore(Protocol): + """Protocol for persistent staging release storage. + + Implementations handle reading and writing StagingRelease and + ArchivedRelease objects. All methods are async to support both + file-based and PostgreSQL backends. + """ + + async def load(self, repo: str) -> StagingRelease | None: + """Load the current staging release for the given repository. + + Returns None if no staging release exists. + """ + ... # pragma: no cover + + async def save(self, release: StagingRelease) -> None: + """Persist a staging release, overwriting any existing entry.""" + ... # pragma: no cover + + async def archive(self, release: StagingRelease, released_at: date) -> None: + """Archive the staging release with the given release date. + + Creates an archive entry and removes the staging entry. + """ + ... # pragma: no cover + + async def list_versions(self, repo: str) -> list[str]: + """Return all version strings known for the given repository. + + Includes both current staging version (if any) and all archived versions. + """ + ... # pragma: no cover + + +# --------------------------------------------------------------------------- +# JsonFileStagingStore +# --------------------------------------------------------------------------- + +class JsonFileStagingStore: + """File-based StagingStore implementation using JSON files. + + Files are stored in a single directory: + - Staging: /.json + - Archive: /__.json + + Args: + directory: Path to the directory for storing JSON files. + Created automatically if it does not exist. + """ + + def __init__(self, *, directory: Path) -> None: + self._dir = directory + self._dir.mkdir(parents=True, exist_ok=True) + + # ------------------------------------------------------------------ + # Public interface (async) + # ------------------------------------------------------------------ + + async def load(self, repo: str) -> StagingRelease | None: + """Load the current staging release for the given repository.""" + path = self._staging_path(repo) + if not path.exists(): + return None + data = json.loads(path.read_text(encoding="utf-8")) + return StagingRelease.model_validate(data) + + async def save(self, release: StagingRelease) -> None: + """Persist a staging release to disk, overwriting any existing file.""" + path = self._staging_path(release.repo) + path.write_text( + release.model_dump_json(), + encoding="utf-8", + ) + + async def archive(self, release: StagingRelease, released_at: date) -> None: + """Archive the staging release and remove the staging file.""" + archived = ArchivedRelease( + version=release.version, + repo=release.repo, + started_at=release.started_at, + tickets=list(release.tickets), + released_at=released_at, + ) + archive_path = self._archive_path(release.repo, release.version, released_at) + archive_path.write_text( + archived.model_dump_json(), + encoding="utf-8", + ) + # Remove staging file if it exists + staging_path = self._staging_path(release.repo) + if staging_path.exists(): + staging_path.unlink() + + async def list_versions(self, repo: str) -> list[str]: + """Return all version strings known for the given repository.""" + versions: list[str] = [] + + # Check current staging file + staging = await self.load(repo) + if staging is not None: + versions.append(staging.version) + + # Check archived files + prefix = f"{repo}_" + for path in self._dir.iterdir(): + if path.name.startswith(prefix) and path.name.endswith(".json"): + try: + data = json.loads(path.read_text(encoding="utf-8")) + v = data.get("version") + if v and v not in versions: + versions.append(v) + except (json.JSONDecodeError, KeyError): + continue + + return versions + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _staging_path(self, repo: str) -> Path: + """Return the path to the staging file for the given repo.""" + return self._dir / f"{repo}.json" + + def _archive_path(self, repo: str, version: str, released_at: date) -> Path: + """Return the path to an archive file.""" + date_str = released_at.isoformat() + return self._dir / f"{repo}_{version}_{date_str}.json" diff --git a/src/release_agent/graph/full_cycle.py b/src/release_agent/graph/full_cycle.py new file mode 100644 index 0000000..ea6eb92 --- /dev/null +++ b/src/release_agent/graph/full_cycle.py @@ -0,0 +1,44 @@ +"""Full Cycle graph: composes pr_completed and release as subgraph nodes. + +The full cycle graph handles the complete lifecycle from a PR merge webhook +through to a released version. After the PR completed subgraph, a conditional +edge routes to the release subgraph if continue_to_release is True. +""" + +from langgraph.graph import END, START, StateGraph + +from release_agent.graph.pr_completed import build_pr_completed_graph +from release_agent.graph.release import build_release_graph +from release_agent.graph.routing import should_continue_to_release +from release_agent.state import ReleaseState + + +def build_full_cycle_graph(): + """Assemble and compile the Full Cycle StateGraph. + + The graph consists of: + - pr_completed subgraph: handles webhook parsing, code review, and PR merge + - A conditional edge routing to the release subgraph or END + - release subgraph: handles staging confirmation, release PR, and pipeline triggers + + Returns the compiled graph ready for invocation. + """ + pr_completed = build_pr_completed_graph() + release = build_release_graph() + + graph = StateGraph(ReleaseState) + + graph.add_node("pr_completed", pr_completed) + graph.add_node("release", release) + + graph.add_edge(START, "pr_completed") + + graph.add_conditional_edges( + "pr_completed", + should_continue_to_release, + {"yes": "release", "no": END}, + ) + + graph.add_edge("release", END) + + return graph.compile() diff --git a/src/release_agent/graph/polling.py b/src/release_agent/graph/polling.py new file mode 100644 index 0000000..8d71c8e --- /dev/null +++ b/src/release_agent/graph/polling.py @@ -0,0 +1,91 @@ +"""Reusable async polling utility for CI/CD status checks. + +poll_until runs a poll_fn repeatedly until is_done returns True, +a timeout is reached, or too many consecutive failures occur. +""" + +import asyncio +import logging +from typing import Any, Callable, TypeVar + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +_MAX_CONSECUTIVE_FAILURES = 3 + + +async def poll_until( + *, + poll_fn: Callable[[], Any], + is_done: Callable[[Any], bool], + interval_seconds: float = 30, + max_wait_seconds: float = 1800, + sleep_fn: Callable[[float], Any] | None = None, +) -> tuple[Any, bool]: + """Poll poll_fn until is_done returns True, timeout, or too many failures. + + Args: + poll_fn: Async callable that fetches the latest status. + is_done: Pure function that returns True when polling should stop. + interval_seconds: Seconds to wait between polls (default: 30). + max_wait_seconds: Maximum total seconds before giving up (default: 1800). + sleep_fn: Async sleep function to inject (default: asyncio.sleep). + Inject a no-op in tests to avoid real waits. + + Returns: + A tuple (last_result, completed_within_timeout) where: + - last_result is the last value returned by poll_fn (or None if every + call raised an exception). + - completed_within_timeout is True if is_done returned True before the + timeout and before consecutive failure abort. + """ + if sleep_fn is None: + sleep_fn = asyncio.sleep + + elapsed: float = 0.0 + consecutive_failures = 0 + last_result: Any = None + + while True: + try: + last_result = await poll_fn() + consecutive_failures = 0 + except Exception as exc: + consecutive_failures += 1 + logger.warning( + "poll_until: poll_fn raised exception (consecutive=%d): %s", + consecutive_failures, + exc, + ) + if consecutive_failures >= _MAX_CONSECUTIVE_FAILURES: + logger.error( + "poll_until: aborting after %d consecutive failures", + _MAX_CONSECUTIVE_FAILURES, + ) + return None, False + # Sleep before retry + await sleep_fn(interval_seconds) + elapsed += interval_seconds + if elapsed >= max_wait_seconds: + return last_result, False + continue + + if is_done(last_result): + return last_result, True + + # Check if we've already used the full budget + if elapsed >= max_wait_seconds: + return last_result, False + + # Sleep and advance elapsed time + remaining = max_wait_seconds - elapsed + sleep_time = min(interval_seconds, remaining) + if sleep_time <= 0: + return last_result, False + + await sleep_fn(sleep_time) + elapsed += sleep_time + + if elapsed >= max_wait_seconds: + return last_result, False diff --git a/src/release_agent/graph/postgres_staging_store.py b/src/release_agent/graph/postgres_staging_store.py new file mode 100644 index 0000000..7275bb4 --- /dev/null +++ b/src/release_agent/graph/postgres_staging_store.py @@ -0,0 +1,148 @@ +"""PostgreSQL-backed StagingStore implementation. + +Uses psycopg async connection pool. Tables: + - staging_releases: current in-progress releases per repo + - archived_releases: completed releases (immutable history) +""" + +import json +from datetime import date + +from release_agent.models.release import ArchivedRelease, StagingRelease +from release_agent.models.ticket import TicketEntry + +_INSERT_STAGING_SQL = """ + INSERT INTO staging_releases (repo, version, started_at, tickets) + VALUES (%s, %s, %s, %s) + ON CONFLICT (repo) DO UPDATE + SET version = EXCLUDED.version, + started_at = EXCLUDED.started_at, + tickets = EXCLUDED.tickets, + updated_at = NOW() +""" + +_SELECT_STAGING_SQL = """ + SELECT repo, version, started_at, tickets + FROM staging_releases + WHERE repo = %s +""" + +_INSERT_ARCHIVED_SQL = """ + INSERT INTO archived_releases (repo, version, started_at, tickets, released_at) + VALUES (%s, %s, %s, %s, %s) + ON CONFLICT (repo, version) DO NOTHING +""" + +_DELETE_STAGING_SQL = """ + DELETE FROM staging_releases WHERE repo = %s +""" + +_SELECT_ARCHIVED_SQL = """ + SELECT repo, version, started_at, tickets, released_at + FROM archived_releases + WHERE repo = %s +""" + + +class PostgresStagingStore: + """Async PostgreSQL-backed StagingStore. + + Args: + pool: A psycopg_pool.AsyncConnectionPool instance (or compatible fake). + """ + + def __init__(self, *, pool) -> None: + self._pool = pool + + # ------------------------------------------------------------------ + # Public async interface + # ------------------------------------------------------------------ + + async def load(self, repo: str) -> StagingRelease | None: + """Load the current staging release for the given repository.""" + async with self._pool.connection() as conn: + async with conn.cursor() as cur: + await cur.execute(_SELECT_STAGING_SQL, (repo,)) + row = await cur.fetchone() + + if row is None: + return None + + return self._row_to_staging(row) + + async def save(self, release: StagingRelease) -> None: + """Upsert the staging release into the database.""" + tickets_json = json.dumps( + [t.model_dump(mode="json") for t in release.tickets] + ) + async with self._pool.connection() as conn: + async with conn.cursor() as cur: + await cur.execute( + _INSERT_STAGING_SQL, + ( + release.repo, + release.version, + release.started_at.isoformat(), + tickets_json, + ), + ) + + async def archive(self, release: StagingRelease, released_at: date) -> None: + """Archive the staging release and delete it from staging_releases.""" + tickets_json = json.dumps( + [t.model_dump(mode="json") for t in release.tickets] + ) + async with self._pool.connection() as conn: + async with conn.transaction(): + async with conn.cursor() as cur: + await cur.execute( + _INSERT_ARCHIVED_SQL, + ( + release.repo, + release.version, + release.started_at.isoformat(), + tickets_json, + released_at.isoformat(), + ), + ) + await cur.execute(_DELETE_STAGING_SQL, (release.repo,)) + + async def list_versions(self, repo: str) -> list[str]: + """Return all version strings for the given repository.""" + versions: list[str] = [] + + # Check current staging + staging = await self.load(repo) + if staging is not None and staging.version not in versions: + versions.append(staging.version) + + # Check archived + async with self._pool.connection() as conn: + async with conn.cursor() as cur: + await cur.execute(_SELECT_ARCHIVED_SQL, (repo,)) + rows = await cur.fetchall() + + for row in rows: + version = row[1] # version is column index 1 + if version and version not in versions: + versions.append(version) + + return versions + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _row_to_staging(self, row: tuple) -> StagingRelease: + """Convert a DB row (repo, version, started_at, tickets) to StagingRelease.""" + repo, version, started_at_str, tickets_raw = row[0], row[1], row[2], row[3] + + tickets_data = json.loads(tickets_raw) if isinstance(tickets_raw, str) else tickets_raw + tickets = [TicketEntry.model_validate(t) for t in (tickets_data or [])] + + return StagingRelease( + repo=repo, + version=version, + started_at=date.fromisoformat(str(started_at_str)), + tickets=tickets, + ) diff --git a/src/release_agent/graph/pr_completed.py b/src/release_agent/graph/pr_completed.py new file mode 100644 index 0000000..af2e88a --- /dev/null +++ b/src/release_agent/graph/pr_completed.py @@ -0,0 +1,573 @@ +"""Node functions for the PR Completed subgraph. + +Each node is an async function (state, config) -> dict. +Nodes never mutate state — they return new dicts with updated fields. +External clients are accessed via config["configurable"]["clients"]. +""" + +from datetime import date +from typing import Any + +from langgraph.graph import END, START, StateGraph +from langgraph.types import interrupt + +from release_agent.branch_parser import parse_branch +from release_agent.exceptions import ReleaseAgentError +from release_agent.graph.ci_nodes import notify_ci_result, poll_ci_build, trigger_ci_build +from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients +from release_agent.graph.routing import has_ticket, is_pr_already_merged, is_review_approved, route_after_fetch +from release_agent.models.release import StagingRelease +from release_agent.models.review import ReviewResult +from release_agent.models.ticket import TicketEntry +from release_agent.models.webhook import WebhookPayload +from release_agent.state import ReleaseState +from release_agent.versioning import calculate_next_version + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _get_clients(config: dict) -> ToolClients: + return config["configurable"]["clients"] + + +def _get_staging_store(config: dict): + return config["configurable"].get("staging_store") + + +# --------------------------------------------------------------------------- +# Node: parse_webhook +# --------------------------------------------------------------------------- + +async def parse_webhook(state: dict[str, Any], config: dict) -> dict: + """Parse the webhook_payload field and extract PR info and ticket ID. + + Returns a dict with pr_info, ticket_id, has_ticket, repo_name, pr_id. + On validation error, appends to errors and returns partial state. + """ + payload_raw = state.get("webhook_payload", {}) + try: + payload = WebhookPayload.model_validate(payload_raw) + resource = payload.resource + pr_id = str(resource.pull_request_id) + repo_name = resource.repository.name + branch = resource.source_ref_name + ticket_id, ticket_found = parse_branch(branch) + + pr_info = { + "pr_id": pr_id, + "repo_name": repo_name, + "branch": branch, + "pr_title": resource.title, + "pr_status": resource.status, + "pr_url": str(resource.repository.web_url) + f"/pullrequest/{pr_id}", + "ticket_id": ticket_id, + "has_ticket": ticket_found, + } + return { + "pr_info": pr_info, + "pr_id": pr_id, + "repo_name": repo_name, + "ticket_id": ticket_id, + "has_ticket": ticket_found, + } + except Exception as exc: + return {"errors": [f"parse_webhook failed: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: fetch_pr_details +# --------------------------------------------------------------------------- + +async def fetch_pr_details(state: dict[str, Any], config: dict) -> dict: + """Fetch full PR details from AzDo and check if already merged. + + Sets pr_already_merged, pr_diff, and last_merge_source_commit. + On ReleaseAgentError, appends to errors (non-critical). + """ + clients = _get_clients(config) + pr_id_raw = state.get("pr_id") or (state.get("pr_info") or {}).get("pr_id", "") + try: + pr_id = int(pr_id_raw) + pr = await clients.azdo.get_pr(pr_id) + already_merged = pr.pr_status == "completed" + diff = "" + if not already_merged: + diff = await clients.azdo.get_pr_diff(pr_id) + + # last_merge_source_commit is not directly on PRInfo; pass None if unavailable + return { + "pr_already_merged": already_merged, + "pr_diff": diff, + "last_merge_source_commit": None, + } + except ReleaseAgentError as exc: + return {"errors": [f"fetch_pr_details failed: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: auto_create_ticket +# --------------------------------------------------------------------------- + +async def auto_create_ticket(state: dict[str, Any], config: dict) -> dict: + """Auto-create a Jira ticket for a PR that has no existing ticket. + + Uses ClaudeReviewer to generate ticket content from the PR diff, + then creates a new Jira Story in the configured default project. + + Sets ticket_id and has_ticket=True on success. + On error, appends to errors (non-critical). + """ + clients = _get_clients(config) + pr_info = state.get("pr_info") or {} + pr_title = pr_info.get("pr_title", "") + repo_name = pr_info.get("repo_name", "") + pr_diff = state.get("pr_diff", "") + default_jira_project = config.get("configurable", {}).get("default_jira_project", "ALLPOST") + + repos_base_dir = config.get("configurable", {}).get("repos_base_dir", "") + cwd = None + if repos_base_dir and repo_name: + from pathlib import Path as _Path + repo_path = _Path(repos_base_dir) / repo_name + if repo_path.is_dir(): + cwd = str(repo_path) + + try: + summary, description = await clients.reviewer.generate_ticket_content( + diff=pr_diff, + pr_title=pr_title, + repo_name=repo_name, + cwd=cwd, + ) + except Exception as exc: + return {"errors": [f"auto_create_ticket generate_ticket_content failed: {exc}"]} + + try: + ticket_id = await clients.jira.create_issue( + project=default_jira_project, + summary=summary, + description=description, + ) + return { + "ticket_id": ticket_id, + "has_ticket": True, + "messages": [f"Auto-created Jira ticket {ticket_id} for PR in {repo_name}"], + } + except Exception as exc: + return {"errors": [f"auto_create_ticket create_issue failed: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: move_jira_code_review +# --------------------------------------------------------------------------- + +async def move_jira_code_review(state: dict[str, Any], config: dict) -> dict: + """Transition the Jira ticket to Code Review status. + + Skipped if has_ticket is False. Non-critical: errors are appended. + """ + if not state.get("has_ticket"): + return {} + clients = _get_clients(config) + ticket_id = state.get("ticket_id", "") + try: + await clients.jira.transition_issue(ticket_id, "code review") + return {"messages": [f"Jira {ticket_id} moved to Code Review"]} + except ReleaseAgentError as exc: + return {"errors": [f"move_jira_code_review failed: {exc}"]} + + +# --------------------------------------------------------------------------- +# Helper: post review comments to Azure DevOps PR +# --------------------------------------------------------------------------- + +async def _post_review_to_pr( + clients: ToolClients, + repo_name: str, + pr_id: int, + review: ReviewResult, +) -> None: + """Post review results as comments on the Azure DevOps PR. + + Posts inline comments for issues with file_path + line_start, + and a summary comment for the overall review. + Non-critical: errors are logged but not raised. + """ + import logging + logger = logging.getLogger(__name__) + + # Post inline comments for issues with file path and line numbers + for issue in review.issues: + if issue.file_path and issue.line_start: + severity_icon = {"blocker": "BLOCKER", "error": "ERROR", "warning": "WARNING", "info": "INFO"} + label = severity_icon.get(issue.severity, issue.severity.upper()) + content = f"**[{label}]** {issue.description}" + if issue.suggestion: + content += f"\n\n**Suggestion:** {issue.suggestion}" + try: + await clients.azdo.add_pr_inline_comment( + repo=repo_name, + pr_id=pr_id, + content=content, + file_path=issue.file_path, + line_start=issue.line_start, + line_end=issue.line_end, + ) + except Exception as exc: + logger.warning("Failed to post inline comment on %s:%d: %s", issue.file_path, issue.line_start, exc) + + # Post summary comment + verdict_text = "APPROVE" if review.verdict == "approve" else "REQUEST CHANGES" + issue_count = len(review.issues) + summary = f"## Claude Code Review: {verdict_text}\n\n{review.summary}\n\n" + if issue_count > 0: + summary += f"**{issue_count} issue(s) found:**\n" + for issue in review.issues: + loc = f" ({issue.file_path}:{issue.line_start})" if issue.file_path and issue.line_start else "" + summary += f"- **{issue.severity.upper()}**{loc}: {issue.description}\n" + else: + summary += "No issues found." + + try: + await clients.azdo.add_pr_comment(repo=repo_name, pr_id=pr_id, content=summary) + except Exception as exc: + logger.warning("Failed to post review summary comment: %s", exc) + + +# --------------------------------------------------------------------------- +# Node: run_code_review +# --------------------------------------------------------------------------- + +async def run_code_review(state: dict[str, Any], config: dict) -> dict: + """Run Claude code review on the PR diff. + + Returns review_result as a serialisable dict. + On error, appends to errors. + """ + clients = _get_clients(config) + diff = state.get("pr_diff", "") + pr_info = state.get("pr_info") or {} + pr_title = pr_info.get("pr_title", "") + repo_name = pr_info.get("repo_name", "") + + # Build cwd from repos_base_dir + repo_name so Claude Code can read the codebase + repos_base_dir = config.get("configurable", {}).get("repos_base_dir", "") + cwd = None + if repos_base_dir and repo_name: + from pathlib import Path + repo_path = Path(repos_base_dir) / repo_name + if repo_path.is_dir(): + cwd = str(repo_path) + + try: + review: ReviewResult = await clients.reviewer.review_pr( + diff=diff, pr_title=pr_title, repo_name=repo_name, cwd=cwd + ) + # Post review comments to Azure DevOps PR + pr_id = int(pr_info.get("pr_id", 0)) + if pr_id and repo_name: + await _post_review_to_pr(clients, repo_name, pr_id, review) + return {"review_result": review.model_dump(mode="json")} + except Exception as exc: + return {"errors": [f"run_code_review failed: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: evaluate_review +# --------------------------------------------------------------------------- + +async def evaluate_review(state: dict[str, Any], config: dict) -> dict: + """Evaluate the review result and set review_approved. + + Approved only when verdict=="approve" and no blockers present. + """ + review = state.get("review_result") or {} + verdict = review.get("verdict", "request_changes") + has_blockers = review.get("has_blockers", False) + # Also check issues list for blocker severity + if not has_blockers: + issues = review.get("issues") or [] + has_blockers = any(i.get("severity") == "blocker" for i in issues) + approved = (verdict == "approve") and not has_blockers + return {"review_approved": approved} + + +# --------------------------------------------------------------------------- +# Node: interrupt_confirm_merge +# --------------------------------------------------------------------------- + +async def interrupt_confirm_merge(state: dict[str, Any], config: dict) -> dict: + """Interrupt to ask the operator to confirm the PR merge. + + Passes a human-readable summary string to interrupt(). + """ + pr_info = state.get("pr_info") or {} + review = state.get("review_result") or {} + summary = ( + f"PR #{pr_info.get('pr_id')} - {pr_info.get('pr_title')} " + f"[repo: {pr_info.get('repo_name')}]\n" + f"Review: {review.get('summary', 'No review summary')}\n" + f"Confirm merge? (yes/no)" + ) + interrupt(summary) + return {} + + +# --------------------------------------------------------------------------- +# Node: merge_pr_node +# --------------------------------------------------------------------------- + +async def merge_pr_node(state: dict[str, Any], config: dict) -> dict: + """Merge the PR via AzDo API. + + Critical node — re-raises ReleaseAgentError. + """ + clients = _get_clients(config) + pr_info = state.get("pr_info") or {} + pr_id = int(pr_info.get("pr_id", state.get("pr_id", 0))) + commit = state.get("last_merge_source_commit", "") + await clients.azdo.merge_pr(pr_id=pr_id, last_merge_source_commit=commit) + return {"messages": [f"PR #{pr_id} merged"]} + + +# --------------------------------------------------------------------------- +# Node: move_jira_ready_for_stage +# --------------------------------------------------------------------------- + +async def move_jira_ready_for_stage(state: dict[str, Any], config: dict) -> dict: + """Transition the Jira ticket to Ready for stage (2). + + Skipped if has_ticket is False. Non-critical: errors appended. + """ + if not state.get("has_ticket"): + return {} + clients = _get_clients(config) + ticket_id = state.get("ticket_id", "") + try: + await clients.jira.transition_issue(ticket_id, "Ready for stage (2)") + return {"messages": [f"Jira {ticket_id} moved to Ready for stage (2)"]} + except ReleaseAgentError as exc: + return {"errors": [f"move_jira_ready_for_stage failed: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: add_jira_pr_link +# --------------------------------------------------------------------------- + +async def add_jira_pr_link(state: dict[str, Any], config: dict) -> dict: + """Add the PR URL as a remote link on the Jira ticket. + + Skipped if has_ticket is False. Non-critical: errors appended. + """ + if not state.get("has_ticket"): + return {} + clients = _get_clients(config) + ticket_id = state.get("ticket_id", "") + pr_info = state.get("pr_info") or {} + pr_url = pr_info.get("pr_url", "") + pr_title = pr_info.get("pr_title", "PR") + pr_id = pr_info.get("pr_id", "") + try: + await clients.jira.add_remote_link( + ticket_id=ticket_id, + url=str(pr_url), + title=f"PR #{pr_id}: {pr_title}", + ) + return {"messages": [f"PR link added to Jira {ticket_id}"]} + except ReleaseAgentError as exc: + return {"errors": [f"add_jira_pr_link failed: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: calculate_version +# --------------------------------------------------------------------------- + +async def calculate_version(state: dict[str, Any], config: dict) -> dict: + """Calculate the next version for the repository using the staging store. + + Uses calculate_next_version from versioning module. + """ + repo_name = state.get("repo_name", "") + staging_store = _get_staging_store(config) + existing_versions: list[str] = [] + if staging_store is not None: + existing_versions = await staging_store.list_versions(repo_name) + version = calculate_next_version(repo_name, existing_versions) + return {"version": version} + + +# --------------------------------------------------------------------------- +# Node: update_staging +# --------------------------------------------------------------------------- + +async def update_staging(state: dict[str, Any], config: dict) -> dict: + """Add the PR's ticket to the staging release. + + If no staging exists, creates a new one. If has_ticket is False, skips + the ticket addition but still ensures a staging record exists. + """ + if not state.get("has_ticket"): + return {} + + clients = _get_clients(config) + staging_store = _get_staging_store(config) + repo_name = state.get("repo_name", "") + version = state.get("version", "v1.0.0") + ticket_id = state.get("ticket_id", "") + pr_info = state.get("pr_info") or {} + + # Fetch ticket summary from Jira + try: + issue = await clients.jira.get_issue(ticket_id) + summary = issue.summary + except Exception: + summary = ticket_id + + pr_url = pr_info.get("pr_url", "") + pr_id = pr_info.get("pr_id", "") + pr_title = pr_info.get("pr_title", "") + branch = pr_info.get("branch", "") + + ticket_entry = TicketEntry( + id=ticket_id, + summary=summary, + pr_id=str(pr_id), + pr_url=pr_url if pr_url.startswith("http") else f"https://example.com/{pr_id}", + pr_title=pr_title, + branch=branch, + merged_at=date.today(), + ) + + if staging_store is not None: + existing = await staging_store.load(repo_name) + if existing is None: + staging = StagingRelease( + version=version, + repo=repo_name, + started_at=date.today(), + tickets=[ticket_entry], + ) + else: + staging = existing.add_ticket(ticket_entry) + await staging_store.save(staging) + return {"staging": staging.model_dump(mode="json")} + + return {} + + +# --------------------------------------------------------------------------- +# Node: notify_request_changes +# --------------------------------------------------------------------------- + +async def notify_request_changes(state: dict[str, Any], config: dict) -> dict: + """Send a Slack notification when the review requests changes. + + Non-critical: errors appended. + """ + clients = _get_clients(config) + pr_info = state.get("pr_info") or {} + review = state.get("review_result") or {} + pr_id = pr_info.get("pr_id", "?") + pr_title = pr_info.get("pr_title", "") + repo = pr_info.get("repo_name", "") + summary = review.get("summary", "Review requested changes") + issues = review.get("issues") or [] + issues_text = "\n".join( + f"- [{i.get('severity', 'info')}] {i.get('description', '')}" for i in issues + ) + details = f"PR #{pr_id}: {pr_title} [{repo}]\n{summary}\n{issues_text}" + try: + await clients.slack.send_approval_request( + action="Request Changes", + details=details, + approval_url="", + ) + return {"messages": [f"Slack notified: request changes for PR #{pr_id}"]} + except ReleaseAgentError as exc: + return {"errors": [f"notify_request_changes failed: {exc}"]} + + +# --------------------------------------------------------------------------- +# Graph builder +# --------------------------------------------------------------------------- + +def build_pr_completed_graph(): + """Assemble and compile the PR Completed StateGraph. + + Returns the compiled graph ready for invocation. + + Routing after fetch_pr_details: + - "merged" -> calculate_version (PR already merged) + - "active_with_ticket" -> move_jira_code_review (PR active, has Jira ticket) + - "active_no_ticket" -> auto_create_ticket -> move_jira_code_review (no ticket) + """ + graph = StateGraph(ReleaseState) + + # Add all nodes + graph.add_node("parse_webhook", parse_webhook) + graph.add_node("fetch_pr_details", fetch_pr_details) + graph.add_node("auto_create_ticket", auto_create_ticket) + graph.add_node("move_jira_code_review", move_jira_code_review) + graph.add_node("run_code_review", run_code_review) + graph.add_node("evaluate_review", evaluate_review) + graph.add_node("interrupt_confirm_merge", interrupt_confirm_merge) + graph.add_node("merge_pr_node", merge_pr_node) + graph.add_node("move_jira_ready_for_stage", move_jira_ready_for_stage) + graph.add_node("add_jira_pr_link", add_jira_pr_link) + graph.add_node("calculate_version", calculate_version) + graph.add_node("update_staging", update_staging) + graph.add_node("notify_request_changes", notify_request_changes) + graph.add_node("trigger_ci_build", trigger_ci_build) + graph.add_node("poll_ci_build", poll_ci_build) + graph.add_node("notify_ci_result", notify_ci_result) + + # Entry + graph.add_edge(START, "parse_webhook") + graph.add_edge("parse_webhook", "fetch_pr_details") + + # Three-way route: merged, active_with_ticket, active_no_ticket + graph.add_conditional_edges( + "fetch_pr_details", + route_after_fetch, + { + "merged": "calculate_version", + "active_with_ticket": "move_jira_code_review", + "active_no_ticket": "auto_create_ticket", + }, + ) + + # auto_create_ticket -> move_jira_code_review (ticket now exists after creation) + graph.add_edge("auto_create_ticket", "move_jira_code_review") + + # Active PR path: code review + graph.add_edge("move_jira_code_review", "run_code_review") + graph.add_edge("run_code_review", "evaluate_review") + + # Route based on review outcome + graph.add_conditional_edges( + "evaluate_review", + is_review_approved, + { + "approve": "interrupt_confirm_merge", + "request_changes": "notify_request_changes", + }, + ) + + graph.add_edge("interrupt_confirm_merge", "merge_pr_node") + graph.add_edge("merge_pr_node", "move_jira_ready_for_stage") + graph.add_edge("move_jira_ready_for_stage", "add_jira_pr_link") + graph.add_edge("add_jira_pr_link", "calculate_version") + + # calculate_version -> update_staging -> trigger_ci_build -> poll_ci_build -> notify_ci_result -> END + graph.add_edge("calculate_version", "update_staging") + graph.add_edge("update_staging", "trigger_ci_build") + graph.add_edge("trigger_ci_build", "poll_ci_build") + graph.add_edge("poll_ci_build", "notify_ci_result") + graph.add_edge("notify_ci_result", END) + + # notify_request_changes -> END + graph.add_edge("notify_request_changes", END) + + return graph.compile() diff --git a/src/release_agent/graph/release.py b/src/release_agent/graph/release.py new file mode 100644 index 0000000..65ff858 --- /dev/null +++ b/src/release_agent/graph/release.py @@ -0,0 +1,627 @@ +"""Node functions for the Release subgraph. + +Each node is an async function (state, config) -> dict. +Nodes never mutate state — they return new dicts. +External clients are accessed via config["configurable"]["clients"]. + +Release flow (Phase 5): + merge_release_pr + -> trigger_ci_build_main + -> poll_ci_build_main + -> route_ci_result + -> ci_passed: wait_for_cd_release + -> poll_release_approvals + -> route_approval_stage + -> sandbox_pending: interrupt_sandbox_approval + -> execute_sandbox_approval + -> poll_release_approvals (loop) + -> prod_pending: interrupt_prod_approval + -> execute_prod_approval + -> poll_release_approvals (loop) + -> all_deployed: move_tickets_to_done + -> send_release_notification + -> archive_release + -> END + -> ci_failed: notify_ci_failure -> END +""" + +from datetime import date +from typing import Any + +from langgraph.graph import END, START, StateGraph +from langgraph.types import interrupt + +from release_agent.exceptions import ReleaseAgentError +from release_agent.graph.ci_nodes import poll_ci_build, trigger_ci_build +from release_agent.graph.dependencies import ToolClients +from release_agent.graph.routing import route_approval_stage, route_ci_result +from release_agent.models.release import StagingRelease +from release_agent.models.ticket import TicketEntry +from release_agent.state import ReleaseState +from release_agent.tools.slack import _build_ci_status_blocks + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _get_clients(config: dict) -> ToolClients: + return config["configurable"]["clients"] + + +def _get_staging_store(config: dict): + return config["configurable"].get("staging_store") + + +# --------------------------------------------------------------------------- +# Node: load_staging +# --------------------------------------------------------------------------- + +async def load_staging(state: dict[str, Any], config: dict) -> dict: + """Load the current staging release from the store.""" + repo_name = state.get("repo_name", "") + staging_store = _get_staging_store(config) + if staging_store is None: + return {"staging": None} + staging = await staging_store.load(repo_name) + if staging is None: + return {"staging": None} + return {"staging": staging.model_dump(mode="json")} + + +# --------------------------------------------------------------------------- +# Node: interrupt_confirm_release +# --------------------------------------------------------------------------- + +async def interrupt_confirm_release(state: dict[str, Any], config: dict) -> dict: + """Interrupt to ask the operator to confirm starting the release.""" + repo_name = state.get("repo_name", "") + staging_dict = state.get("staging") or {} + version = staging_dict.get("version", "unknown") + tickets = staging_dict.get("tickets") or [] + ticket_lines = "\n".join(f" - {t.get('id')}: {t.get('summary', '')}" for t in tickets) + summary = ( + f"Release {version} for repo '{repo_name}'\n" + f"Tickets ({len(tickets)}):\n{ticket_lines}\n" + f"Confirm release? (yes/no)" + ) + interrupt(summary) + return {} + + +# --------------------------------------------------------------------------- +# Node: create_release_pr +# --------------------------------------------------------------------------- + +async def create_release_pr(state: dict[str, Any], config: dict) -> dict: + """Create a release PR in AzDo from a release branch to main.""" + clients = _get_clients(config) + repo_name = state.get("repo_name", "") + version = state.get("version", "") + staging_dict = state.get("staging") or {} + tickets = staging_dict.get("tickets") or [] + ticket_ids = ", ".join(t.get("id", "") for t in tickets) + + source_branch = f"refs/heads/release/{version}" + target_branch = "refs/heads/main" + title = f"Release {version}" + description = f"Release {version} for {repo_name}\n\nTickets: {ticket_ids}" + + data = await clients.azdo.create_pr( + repo=repo_name, + source=source_branch, + target=target_branch, + title=title, + description=description, + ) + + pr_id = str(data.get("pullRequestId", "")) + commit = (data.get("lastMergeSourceCommit") or {}).get("commitId", "") + + return { + "release_pr_id": pr_id, + "release_pr_commit": commit, + "messages": [f"Release PR #{pr_id} created for {version}"], + } + + +# --------------------------------------------------------------------------- +# Node: interrupt_confirm_merge_release +# --------------------------------------------------------------------------- + +async def interrupt_confirm_merge_release(state: dict[str, Any], config: dict) -> dict: + """Interrupt to ask the operator to confirm merging the release PR.""" + release_pr_id = state.get("release_pr_id", "?") + version = state.get("version", "") + repo_name = state.get("repo_name", "") + summary = ( + f"Merge release PR #{release_pr_id} ({version}) into main for '{repo_name}'?\n" + f"Confirm merge? (yes/no)" + ) + interrupt(summary) + return {} + + +# --------------------------------------------------------------------------- +# Node: merge_release_pr +# --------------------------------------------------------------------------- + +async def merge_release_pr(state: dict[str, Any], config: dict) -> dict: + """Merge the release PR via AzDo API.""" + clients = _get_clients(config) + pr_id = int(state.get("release_pr_id", 0)) + commit = state.get("release_pr_commit", "") + await clients.azdo.merge_pr(pr_id=pr_id, last_merge_source_commit=commit) + return {"messages": [f"Release PR #{pr_id} merged"]} + + +# --------------------------------------------------------------------------- +# Node: trigger_ci_build_main (delegates to ci_nodes.trigger_ci_build) +# --------------------------------------------------------------------------- + +async def trigger_ci_build_main(state: dict[str, Any], config: dict) -> dict: + """Trigger CI build on main after the release PR is merged.""" + return await trigger_ci_build(state, config) + + +# --------------------------------------------------------------------------- +# Node: poll_ci_build_main (delegates to ci_nodes.poll_ci_build) +# --------------------------------------------------------------------------- + +async def poll_ci_build_main(state: dict[str, Any], config: dict) -> dict: + """Poll the main branch CI build until completion.""" + return await poll_ci_build(state, config) + + +# --------------------------------------------------------------------------- +# Node: notify_ci_failure +# --------------------------------------------------------------------------- + +async def notify_ci_failure(state: dict[str, Any], config: dict) -> dict: + """Send a Slack notification that the CI build failed on main. + + Non-critical: errors appended. + """ + clients = _get_clients(config) + repo_name = state.get("repo_name", "unknown") + build_result = state.get("ci_build_result", "failed") + build_url = state.get("ci_build_url") + version = state.get("version", "") + + blocks = _build_ci_status_blocks( + repo=repo_name, + branch=f"main (release {version})" if version else "main", + status=build_result, + build_url=build_url, + ) + text = f"CI build failed on main for {repo_name} release {version}" + + try: + await clients.slack.send_notification(text=text, blocks=blocks) + return {"messages": [text]} + except Exception as exc: + return {"errors": [f"notify_ci_failure: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: wait_for_cd_release +# --------------------------------------------------------------------------- + +async def wait_for_cd_release(state: dict[str, Any], config: dict) -> dict: + """Wait for the CD pipeline to create a release after CI passes. + + Fetches the latest release for the configured release definition. + Non-critical: errors appended if not found. + """ + clients = _get_clients(config) + definition_id = state.get("release_definition_id") + + if not definition_id: + return {"errors": ["wait_for_cd_release: release_definition_id not set in state"]} + + try: + release = await clients.azdo.get_latest_release(definition_id=definition_id) + if not release: + return {"errors": ["wait_for_cd_release: no release found for definition"]} + return { + "release_id": release["id"], + "messages": [f"CD release found: {release.get('name', release['id'])}"], + } + except ReleaseAgentError as exc: + return {"errors": [f"wait_for_cd_release: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: poll_release_approvals +# --------------------------------------------------------------------------- + +async def poll_release_approvals(state: dict[str, Any], config: dict) -> dict: + """Poll AzDo for pending release environment approvals. + + Non-critical: errors appended on failure. + """ + clients = _get_clients(config) + release_id = state.get("release_id") + + if not release_id: + return {"pending_approvals": [], "errors": ["poll_release_approvals: release_id not set"]} + + try: + approvals = await clients.azdo.get_release_approvals(release_id=release_id) + pending = [ + { + "approval_id": a.approval_id, + "stage_name": a.stage_name, + "status": a.status, + "release_id": a.release_id, + } + for a in approvals + if a.status == "pending" + ] + return {"pending_approvals": pending} + except ReleaseAgentError as exc: + return {"errors": [f"poll_release_approvals: {exc}"], "pending_approvals": []} + + +# --------------------------------------------------------------------------- +# Node: interrupt_sandbox_approval +# --------------------------------------------------------------------------- + +async def interrupt_sandbox_approval(state: dict[str, Any], config: dict) -> dict: + """Interrupt to ask the operator to approve the sandbox deployment.""" + approvals = state.get("pending_approvals") or [] + version = state.get("version", "") + stage_names = ", ".join(a.get("stage_name", "?") for a in approvals) + summary = ( + f"Approve sandbox deployment for release {version}?\n" + f"Stages: {stage_names}\n" + f"Approve? (yes/no)" + ) + interrupt(summary) + return {"current_stage": "sandbox_pending"} + + +# --------------------------------------------------------------------------- +# Node: execute_sandbox_approval +# --------------------------------------------------------------------------- + +async def execute_sandbox_approval(state: dict[str, Any], config: dict) -> dict: + """Approve all pending sandbox stage approvals via AzDo VSRM. + + Non-critical per approval: errors appended on individual failures. + """ + clients = _get_clients(config) + approvals = state.get("pending_approvals") or [] + errors: list[str] = [] + + for approval in approvals: + approval_id = approval.get("approval_id", "") + try: + await clients.azdo.approve_release( + approval_id=approval_id, + comment="Sandbox approved by release agent via Slack", + ) + except ReleaseAgentError as exc: + errors.append(f"execute_sandbox_approval failed for {approval_id}: {exc}") + + if errors: + return {"errors": errors} + return {} + + +# --------------------------------------------------------------------------- +# Node: interrupt_prod_approval +# --------------------------------------------------------------------------- + +async def interrupt_prod_approval(state: dict[str, Any], config: dict) -> dict: + """Interrupt to ask the operator to approve the production deployment.""" + approvals = state.get("pending_approvals") or [] + version = state.get("version", "") + stage_names = ", ".join(a.get("stage_name", "?") for a in approvals) + summary = ( + f"Approve production deployment for release {version}?\n" + f"Stages: {stage_names}\n" + f"Approve? (yes/no)" + ) + interrupt(summary) + return {"current_stage": "prod_pending"} + + +# --------------------------------------------------------------------------- +# Node: execute_prod_approval +# --------------------------------------------------------------------------- + +async def execute_prod_approval(state: dict[str, Any], config: dict) -> dict: + """Approve all pending production stage approvals via AzDo VSRM. + + Non-critical per approval: errors appended on individual failures. + """ + clients = _get_clients(config) + approvals = state.get("pending_approvals") or [] + errors: list[str] = [] + + for approval in approvals: + approval_id = approval.get("approval_id", "") + try: + await clients.azdo.approve_release( + approval_id=approval_id, + comment="Production approved by release agent via Slack", + ) + except ReleaseAgentError as exc: + errors.append(f"execute_prod_approval failed for {approval_id}: {exc}") + + if errors: + return {"errors": errors} + return {} + + +# --------------------------------------------------------------------------- +# Node: move_tickets_to_done +# --------------------------------------------------------------------------- + +async def move_tickets_to_done(state: dict[str, Any], config: dict) -> dict: + """Transition all staging tickets to Done/Released in Jira.""" + clients = _get_clients(config) + staging_dict = state.get("staging") or {} + tickets = staging_dict.get("tickets") or [] + errors: list[str] = [] + + for ticket_raw in tickets: + ticket_id = ticket_raw.get("id", "") + try: + await clients.jira.transition_issue(ticket_id, "Done") + except ReleaseAgentError as exc: + errors.append(f"move_tickets_to_done failed for {ticket_id}: {exc}") + + if errors: + return {"errors": errors} + return {} + + +# --------------------------------------------------------------------------- +# Node: send_slack_notification (send_release_notification) +# --------------------------------------------------------------------------- + +async def send_slack_notification(state: dict[str, Any], config: dict) -> dict: + """Send a release notification to Slack.""" + clients = _get_clients(config) + repo_name = state.get("repo_name", "") + version = state.get("version", "") + staging_dict = state.get("staging") or {} + tickets_raw = staging_dict.get("tickets") or [] + + try: + ticket_entries = [TicketEntry.model_validate(t) for t in tickets_raw] + await clients.slack.send_release_notification( + repo=repo_name, + version=version, + release_date=date.today(), + tickets=ticket_entries, + ) + return {"messages": [f"Slack notified: release {version} for {repo_name}"]} + except ReleaseAgentError as exc: + return {"errors": [f"send_slack_notification failed: {exc}"]} + + +# --------------------------------------------------------------------------- +# Node: archive_release +# --------------------------------------------------------------------------- + +async def archive_release(state: dict[str, Any], config: dict) -> dict: + """Archive the staging release in the store.""" + staging_store = _get_staging_store(config) + staging_dict = state.get("staging") or {} + if staging_store is None or not staging_dict: + return {} + staging = StagingRelease.model_validate(staging_dict) + await staging_store.archive(staging, date.today()) + return {"messages": [f"Release {staging.version} archived"]} + + +# --------------------------------------------------------------------------- +# Legacy nodes kept for backward compatibility +# --------------------------------------------------------------------------- + +async def list_pipelines(state: dict[str, Any], config: dict) -> dict: + """List build pipelines for the repository via AzDo.""" + clients = _get_clients(config) + repo_name = state.get("repo_name", "") + try: + pipelines = await clients.azdo.list_build_pipelines(repo=repo_name) + return {"pipelines": [p.model_dump() for p in pipelines]} + except ReleaseAgentError as exc: + return {"errors": [f"list_pipelines failed: {exc}"], "pipelines": []} + + +async def interrupt_confirm_trigger(state: dict[str, Any], config: dict) -> dict: + """Interrupt to ask the operator to confirm triggering pipelines.""" + repo_name = state.get("repo_name", "") + version = state.get("version", "") + pipelines = state.get("pipelines") or [] + pipeline_names = ", ".join(p.get("name", str(p.get("id", ""))) for p in pipelines) + summary = ( + f"Trigger pipelines for {repo_name} {version}?\n" + f"Pipelines: {pipeline_names}\n" + f"Confirm? (yes/no)" + ) + interrupt(summary) + return {} + + +async def trigger_pipelines(state: dict[str, Any], config: dict) -> dict: + """Trigger all listed pipelines for the release branch.""" + clients = _get_clients(config) + repo_name = state.get("repo_name", "") + version = state.get("version", "") + pipelines = state.get("pipelines") or [] + branch = f"refs/heads/release/{version}" + + triggered: list[dict] = [] + errors: list[str] = [] + + for pipeline in pipelines: + pipeline_id = pipeline.get("id") + try: + result = await clients.azdo.trigger_pipeline( + pipeline_id=pipeline_id, branch=branch + ) + triggered.append(result) + except ReleaseAgentError as exc: + errors.append(f"trigger_pipelines failed for pipeline {pipeline_id}: {exc}") + + result_dict: dict = {"triggered_builds": triggered} + if errors: + result_dict["errors"] = errors + return result_dict + + +async def check_release_approvals(state: dict[str, Any], config: dict) -> dict: + """Check triggered builds for pending stage approvals.""" + clients = _get_clients(config) + triggered_builds = state.get("triggered_builds") or [] + pending: list[dict] = [] + errors: list[str] = [] + + for build in triggered_builds: + build_id = build.get("id") + try: + await clients.azdo.get_build_status(build_id=build_id) + except ReleaseAgentError as exc: + errors.append(f"check_release_approvals failed for build {build_id}: {exc}") + + result_dict: dict = {"pending_approvals": pending} + if errors: + result_dict["errors"] = errors + return result_dict + + +async def interrupt_confirm_approve(state: dict[str, Any], config: dict) -> dict: + """Interrupt to ask the operator to confirm approving pipeline stages.""" + approvals = state.get("pending_approvals") or [] + version = state.get("version", "") + stage_names = ", ".join(a.get("stage_name", a.get("approval_id", "?")) for a in approvals) + summary = ( + f"Approve pipeline stages for release {version}?\n" + f"Stages: {stage_names}\n" + f"Confirm? (yes/no)" + ) + interrupt(summary) + return {} + + +async def approve_stage(state: dict[str, Any], config: dict) -> dict: + """Approve all pending pipeline stage approvals via AzDo VSRM.""" + clients = _get_clients(config) + approvals = state.get("pending_approvals") or [] + errors: list[str] = [] + + for approval in approvals: + approval_id = approval.get("approval_id", "") + try: + await clients.azdo.approve_release( + approval_id=approval_id, + comment="Approved by release agent", + ) + except ReleaseAgentError as exc: + errors.append(f"approve_stage failed for {approval_id}: {exc}") + + if errors: + return {"errors": errors} + return {} + + +# --------------------------------------------------------------------------- +# Graph builder +# --------------------------------------------------------------------------- + +def build_release_graph(): + """Assemble and compile the Release StateGraph (Phase 5 architecture). + + Flow: + load_staging -> interrupt_confirm_release -> create_release_pr + -> interrupt_confirm_merge_release -> merge_release_pr + -> trigger_ci_build_main -> poll_ci_build_main + -> route_ci_result: + ci_passed: wait_for_cd_release -> poll_release_approvals + -> route_approval_stage: + sandbox_pending: interrupt_sandbox_approval + -> execute_sandbox_approval + -> poll_release_approvals + prod_pending: interrupt_prod_approval + -> execute_prod_approval + -> poll_release_approvals + all_deployed: move_tickets_to_done + -> send_slack_notification + -> archive_release -> END + ci_failed: notify_ci_failure -> END + """ + graph = StateGraph(ReleaseState) + + # Core release PR nodes + graph.add_node("load_staging", load_staging) + graph.add_node("interrupt_confirm_release", interrupt_confirm_release) + graph.add_node("create_release_pr", create_release_pr) + graph.add_node("interrupt_confirm_merge_release", interrupt_confirm_merge_release) + graph.add_node("merge_release_pr", merge_release_pr) + + # CI nodes + graph.add_node("trigger_ci_build_main", trigger_ci_build_main) + graph.add_node("poll_ci_build_main", poll_ci_build_main) + graph.add_node("notify_ci_failure", notify_ci_failure) + + # CD approval nodes + graph.add_node("wait_for_cd_release", wait_for_cd_release) + graph.add_node("poll_release_approvals", poll_release_approvals) + graph.add_node("interrupt_sandbox_approval", interrupt_sandbox_approval) + graph.add_node("execute_sandbox_approval", execute_sandbox_approval) + graph.add_node("interrupt_prod_approval", interrupt_prod_approval) + graph.add_node("execute_prod_approval", execute_prod_approval) + + # Completion nodes + graph.add_node("move_tickets_to_done", move_tickets_to_done) + graph.add_node("send_release_notification", send_slack_notification) + graph.add_node("archive_release", archive_release) + + # Main release flow up to merge + graph.add_edge(START, "load_staging") + graph.add_edge("load_staging", "interrupt_confirm_release") + graph.add_edge("interrupt_confirm_release", "create_release_pr") + graph.add_edge("create_release_pr", "interrupt_confirm_merge_release") + graph.add_edge("interrupt_confirm_merge_release", "merge_release_pr") + + # CI pipeline after merge + graph.add_edge("merge_release_pr", "trigger_ci_build_main") + graph.add_edge("trigger_ci_build_main", "poll_ci_build_main") + graph.add_conditional_edges( + "poll_ci_build_main", + route_ci_result, + {"ci_passed": "wait_for_cd_release", "ci_failed": "notify_ci_failure"}, + ) + graph.add_edge("notify_ci_failure", END) + + # CD approval flow + graph.add_edge("wait_for_cd_release", "poll_release_approvals") + graph.add_conditional_edges( + "poll_release_approvals", + route_approval_stage, + { + "sandbox_pending": "interrupt_sandbox_approval", + "prod_pending": "interrupt_prod_approval", + "all_deployed": "move_tickets_to_done", + }, + ) + + # Sandbox approval loop + graph.add_edge("interrupt_sandbox_approval", "execute_sandbox_approval") + graph.add_edge("execute_sandbox_approval", "poll_release_approvals") + + # Prod approval loop + graph.add_edge("interrupt_prod_approval", "execute_prod_approval") + graph.add_edge("execute_prod_approval", "poll_release_approvals") + + # Completion + graph.add_edge("move_tickets_to_done", "send_release_notification") + graph.add_edge("send_release_notification", "archive_release") + graph.add_edge("archive_release", END) + + return graph.compile() diff --git a/src/release_agent/graph/routing.py b/src/release_agent/graph/routing.py new file mode 100644 index 0000000..616cc31 --- /dev/null +++ b/src/release_agent/graph/routing.py @@ -0,0 +1,131 @@ +"""Pure routing functions for LangGraph conditional edges. + +Each function takes a state dict and returns a routing string. +No side effects, no mutation, no I/O. +""" + +from typing import Any + + +def is_pr_already_merged(state: dict[str, Any]) -> str: + """Route based on whether the PR was already merged before processing. + + Returns: + "merged" if state["pr_already_merged"] is True. + "active" otherwise (including missing or None). + """ + if state.get("pr_already_merged"): + return "merged" + return "active" + + +def route_after_fetch(state: dict[str, Any]) -> str: + """Three-way route after fetch_pr_details based on merge status and ticket presence. + + Returns: + "merged" if state["pr_already_merged"] is True. + "active_with_ticket" if PR is active and state["has_ticket"] is True. + "active_no_ticket" if PR is active and state["has_ticket"] is falsy. + """ + if state.get("pr_already_merged"): + return "merged" + if state.get("has_ticket"): + return "active_with_ticket" + return "active_no_ticket" + + +def is_review_approved(state: dict[str, Any]) -> str: + """Route based on the review approval result. + + Returns: + "approve" if state["review_approved"] is True. + "request_changes" otherwise (including missing or None). + """ + if state.get("review_approved"): + return "approve" + return "request_changes" + + +def has_ticket(state: dict[str, Any]) -> str: + """Route based on whether the PR branch contains a Jira ticket ID. + + Returns: + "yes" if state["has_ticket"] is True. + "no" otherwise (including missing or None). + """ + if state.get("has_ticket"): + return "yes" + return "no" + + +def should_continue_to_release(state: dict[str, Any]) -> str: + """Route based on whether the operator chose to proceed to the release flow. + + Returns: + "yes" if state["continue_to_release"] is True and no errors accumulated. + "no" otherwise (including missing, None, or if errors are present). + """ + if state.get("continue_to_release") and not state.get("errors"): + return "yes" + return "no" + + +def has_pipelines(state: dict[str, Any]) -> str: + """Route based on whether any pipelines are defined for the repository. + + Returns: + "yes" if state["pipelines"] is a non-empty list. + "no" otherwise (including missing, None, or empty list). + """ + pipelines = state.get("pipelines") + if pipelines: + return "yes" + return "no" + + +def has_pending_approvals(state: dict[str, Any]) -> str: + """Route based on whether there are pipeline stage approvals pending. + + Returns: + "yes" if state["pending_approvals"] is a non-empty list. + "no" otherwise (including missing, None, or empty list). + """ + approvals = state.get("pending_approvals") + if approvals: + return "yes" + return "no" + + +def route_ci_result(state: dict[str, Any]) -> str: + """Route based on the CI build result. + + Returns: + "ci_passed" if state["ci_build_result"] == "succeeded". + "ci_failed" otherwise (including missing, None, or any other value). + """ + if state.get("ci_build_result") == "succeeded": + return "ci_passed" + return "ci_failed" + + +def route_approval_stage(state: dict[str, Any]) -> str: + """Route based on the current CD approval stage. + + Uses state["current_stage"] if present. Falls back to "sandbox_pending" + when approvals exist but no stage is set (first stage assumption). + + Returns: + "sandbox_pending" if state["current_stage"] == "sandbox_pending" or + approvals exist with no explicit stage (default first stage). + "prod_pending" if state["current_stage"] == "prod_pending". + "all_deployed" if state["pending_approvals"] is empty or missing. + """ + approvals = state.get("pending_approvals") + if not approvals: + return "all_deployed" + + current_stage = state.get("current_stage", "") + if current_stage == "prod_pending": + return "prod_pending" + # Default: sandbox_pending (covers both explicit sandbox_pending and unknown) + return "sandbox_pending" diff --git a/src/release_agent/main.py b/src/release_agent/main.py new file mode 100644 index 0000000..ace4e5c --- /dev/null +++ b/src/release_agent/main.py @@ -0,0 +1,361 @@ +"""FastAPI application entry point. + +Responsibilities: +- Create the FastAPI app with lifespan handler +- Startup: compile graphs, open DB pool, build tool clients, ensure schema +- Shutdown: wait for background tasks (30s timeout) +- Register routers and global exception handlers +- Expose schedule_graph and run_graph_in_background helpers +""" + +import asyncio +import logging +import uuid +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from pathlib import Path + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +from release_agent.api.approvals import router as approvals_router +from release_agent.api.models import ErrorResponse +from release_agent.api.slack_interactions import router as slack_interactions_router +from release_agent.api.status import router as status_router +from release_agent.api.webhooks import router as webhook_router +from release_agent.config import Settings +from release_agent.exceptions import ReleaseAgentError +from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients +from release_agent.graph.postgres_staging_store import PostgresStagingStore +from release_agent.graph.pr_completed import build_pr_completed_graph +from release_agent.graph.release import build_release_graph +from release_agent.services.pr_poller import run_pr_poll_loop +from release_agent.tools.azdo import AzDoClient +from release_agent.tools.claude_review import ClaudeReviewer +from release_agent.tools.jira import JiraClient +from release_agent.tools.slack import SlackClient + +try: + from psycopg_pool import AsyncConnectionPool +except ImportError: # pragma: no cover + AsyncConnectionPool = None # type: ignore[assignment,misc] + +logger = logging.getLogger(__name__) + +_SHUTDOWN_TIMEOUT_SECONDS = 30 + + +# --------------------------------------------------------------------------- +# Startup helpers +# --------------------------------------------------------------------------- + +def _create_tool_clients(settings: Settings) -> tuple[ToolClients, list]: + """Build and return a ToolClients instance and closeable HTTP clients.""" + http_client = httpx.AsyncClient(timeout=30.0) + vsrm_http_client = httpx.AsyncClient(timeout=30.0) + azdo = AzDoClient( + http_client=http_client, + vsrm_http_client=vsrm_http_client, + base_url=settings.azdo_api_url, + vsrm_base_url=settings.azdo_vsrm_api_url, + pat=settings.azdo_pat.get_secret_value(), + ) + jira = JiraClient( + http_client=http_client, + base_url=settings.jira_base_url, + email=settings.jira_email, + api_token=settings.jira_api_token.get_secret_value(), + ) + slack = SlackClient( + http_client=http_client, + webhook_url=settings.slack_webhook_url.get_secret_value(), + bot_token=settings.slack_bot_token.get_secret_value(), + channel_id=settings.slack_channel_id, + ) + reviewer = ClaudeReviewer() + clients = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer) + return clients, [http_client, vsrm_http_client] + + +def _create_staging_store(pool=None) -> PostgresStagingStore | JsonFileStagingStore: + """Return a StagingStore instance. + + When a PostgreSQL pool is provided, returns a PostgresStagingStore. + Falls back to JsonFileStagingStore for local development without a DB. + """ + if pool is not None: + return PostgresStagingStore(pool=pool) + return JsonFileStagingStore(directory=Path("data/staging")) + + +async def _ensure_db_schema(pool) -> None: + """Create all required tables if they do not already exist.""" + statements = [ + """ + CREATE TABLE IF NOT EXISTS agent_threads ( + thread_id TEXT PRIMARY KEY, + graph_name TEXT, + status TEXT NOT NULL DEFAULT 'running', + interrupt_value TEXT, + state JSONB, + repo_name TEXT, + pr_id TEXT, + version TEXT, + slack_message_ts TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """, + """ + ALTER TABLE agent_threads + ADD COLUMN IF NOT EXISTS slack_message_ts TEXT + """, + """ + CREATE TABLE IF NOT EXISTS staging_releases ( + repo TEXT PRIMARY KEY, + version TEXT NOT NULL, + started_at DATE NOT NULL, + tickets JSONB NOT NULL DEFAULT '[]', + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """, + """ + CREATE TABLE IF NOT EXISTS archived_releases ( + repo TEXT NOT NULL, + version TEXT NOT NULL, + started_at DATE NOT NULL, + tickets JSONB NOT NULL DEFAULT '[]', + released_at DATE NOT NULL, + PRIMARY KEY (repo, version) + ) + """, + ] + async with pool.connection() as conn: + async with conn.cursor() as cur: + for sql in statements: + await cur.execute(sql) + + +# --------------------------------------------------------------------------- +# Background task helpers +# --------------------------------------------------------------------------- + +def schedule_graph( + *, + app: FastAPI, + graph, + initial_state: dict, + thread_id: str | None = None, + tool_clients: ToolClients | None = None, + db_pool=None, +) -> str: + """Schedule a graph run as a background asyncio task. + + Creates a new thread_id if not provided. Adds the task to + app.state.background_tasks for graceful shutdown tracking. + + Returns the thread_id used for this execution. + """ + if thread_id is None: + thread_id = str(uuid.uuid4()) + + if tool_clients is None and hasattr(app.state, "tool_clients"): + tool_clients = app.state.tool_clients + + if db_pool is None and hasattr(app.state, "db_pool"): + db_pool = app.state.db_pool + + # Extract settings for config propagation + settings = getattr(app.state, "settings", None) + repos_base_dir = getattr(settings, "repos_base_dir", "") if settings else "" + default_jira_project = getattr(settings, "default_jira_project", "ALLPOST") if settings else "ALLPOST" + + task = asyncio.create_task( + run_graph_in_background( + graph=graph, + initial_state=initial_state, + thread_id=thread_id, + tool_clients=tool_clients, + db_pool=db_pool, + repos_base_dir=repos_base_dir, + default_jira_project=default_jira_project, + ) + ) + app.state.background_tasks.add(task) + task.add_done_callback(app.state.background_tasks.discard) + return thread_id + + +async def run_graph_in_background( + *, + graph, + initial_state: dict, + thread_id: str, + tool_clients: ToolClients | None = None, + db_pool=None, + repos_base_dir: str = "", + default_jira_project: str = "ALLPOST", +) -> None: + """Execute the graph and update thread status in the database.""" + from release_agent.api.webhooks import _upsert_thread + + config = { + "configurable": { + "thread_id": thread_id, + "clients": tool_clients, + "repos_base_dir": repos_base_dir, + "default_jira_project": default_jira_project, + } + } + try: + if db_pool is not None: + await _upsert_thread(db_pool, thread_id=thread_id, thread_status="running", state=initial_state) + result = await graph.ainvoke(initial_state, config=config) + if db_pool is not None: + await _upsert_thread(db_pool, thread_id=thread_id, thread_status="completed", state=result or {}) + except Exception as exc: + logger.exception("Graph execution failed for thread %s: %s", thread_id, exc) + if db_pool is not None: + await _upsert_thread( + db_pool, + thread_id=thread_id, + thread_status="failed", + state={"errors": [str(exc)]}, + ) + + +# --------------------------------------------------------------------------- +# Lifespan +# --------------------------------------------------------------------------- + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage startup and graceful shutdown of shared resources.""" + settings = Settings() + app.state.settings = settings + app.state.background_tasks = set() + app.state.started_at = datetime.now(tz=timezone.utc) + + # Compile graphs once at startup + logger.info("Compiling graphs...") + app.state.graphs = { + "pr_completed": build_pr_completed_graph(), + "release": build_release_graph(), + } + + # Build tool clients + logger.info("Building tool clients...") + app.state.tool_clients, app.state._http_clients = _create_tool_clients(settings) + + # Open PostgreSQL connection pool + logger.info("Opening DB connection pool...") + pool = AsyncConnectionPool( + conninfo=settings.postgres_dsn.get_secret_value(), + open=False, + ) + await pool.open() + app.state.db_pool = pool + + # Ensure schema exists + await _ensure_db_schema(pool) + + # Build staging store (backed by PostgreSQL) + app.state.staging_store = _create_staging_store(pool=pool) + + # Start PR polling background task if enabled + if settings.pr_poll_enabled: + logger.info("Starting PR polling background task...") + poll_task = asyncio.create_task( + run_pr_poll_loop( + azdo_client=app.state.tool_clients.azdo, + db_pool=pool, + watched_repos=settings.watched_repos_list, + target_branch=settings.pr_poll_target_branch, + interval_seconds=settings.pr_poll_interval_seconds, + schedule_fn=lambda *, initial_state: schedule_graph( + app=app, + graph=app.state.graphs["pr_completed"], + initial_state=initial_state, + ), + ) + ) + app.state.background_tasks.add(poll_task) + poll_task.add_done_callback(app.state.background_tasks.discard) + + logger.info("Startup complete.") + yield + + # Graceful shutdown: wait for background tasks + tasks = list(app.state.background_tasks) + if tasks: + logger.info("Waiting for %d background task(s) to complete...", len(tasks)) + done, pending = await asyncio.wait(tasks, timeout=_SHUTDOWN_TIMEOUT_SECONDS) + for task in pending: + logger.warning("Cancelling timed-out background task %s", task) + task.cancel() + + # Close HTTP clients + for client in getattr(app.state, "_http_clients", []): + await client.aclose() + + await pool.close() + logger.info("Shutdown complete.") + + +# --------------------------------------------------------------------------- +# Exception handlers +# --------------------------------------------------------------------------- + +async def _release_agent_error_handler(request: Request, exc: ReleaseAgentError) -> JSONResponse: + return JSONResponse( + status_code=500, + content=ErrorResponse( + error=type(exc).__name__, + detail=str(exc), + ).model_dump(), + ) + + +async def _generic_error_handler(request: Request, exc: Exception) -> JSONResponse: + logger.exception("Unhandled exception: %s", exc) + return JSONResponse( + status_code=500, + content=ErrorResponse( + error="InternalServerError", + detail="An unexpected error occurred", + ).model_dump(), + ) + + +# --------------------------------------------------------------------------- +# App factory +# --------------------------------------------------------------------------- + +def create_app() -> FastAPI: + """Create and configure the FastAPI application.""" + app = FastAPI( + title="Billo Release Agent", + version="0.1.0", + description="LangGraph-based release automation agent", + lifespan=lifespan, + ) + + # Register routers + app.include_router(webhook_router) + app.include_router(approvals_router) + app.include_router(status_router) + app.include_router(slack_interactions_router) + + # Register exception handlers + app.add_exception_handler(ReleaseAgentError, _release_agent_error_handler) # type: ignore[arg-type] + app.add_exception_handler(Exception, _generic_error_handler) + + return app + + +# --------------------------------------------------------------------------- +# ASGI entry point +# --------------------------------------------------------------------------- + +app = create_app() diff --git a/src/release_agent/models/__init__.py b/src/release_agent/models/__init__.py new file mode 100644 index 0000000..5fd015a --- /dev/null +++ b/src/release_agent/models/__init__.py @@ -0,0 +1 @@ +"""Pydantic models for the release agent.""" diff --git a/src/release_agent/models/build.py b/src/release_agent/models/build.py new file mode 100644 index 0000000..073735f --- /dev/null +++ b/src/release_agent/models/build.py @@ -0,0 +1,43 @@ +"""Build and approval dataclasses for CI/CD tracking. + +Immutable (frozen) dataclasses used across graph nodes, routing functions, +and API endpoints. +""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class BuildStatus: + """Immutable snapshot of an Azure DevOps build's current state. + + Attributes: + status: AzDo build status string (e.g. "notStarted", "inProgress", + "completed", "cancelling"). + result: AzDo build result string (e.g. "succeeded", "failed", + "canceled", "partiallySucceeded"). None when not completed. + build_url: Direct URL to the build results page, or None if unknown. + """ + + status: str + result: str | None + build_url: str | None + + +@dataclass(frozen=True) +class ApprovalRecord: + """Immutable record representing a single release pipeline approval gate. + + Attributes: + approval_id: Unique identifier of the approval record in AzDo. + stage_name: Human-readable name of the deployment stage (e.g. + "Sandbox", "Production"). + status: Current approval status (e.g. "pending", "approved", + "rejected"). + release_id: Numeric ID of the AzDo release containing this approval. + """ + + approval_id: str + stage_name: str + status: str + release_id: int diff --git a/src/release_agent/models/jira.py b/src/release_agent/models/jira.py new file mode 100644 index 0000000..0569702 --- /dev/null +++ b/src/release_agent/models/jira.py @@ -0,0 +1,33 @@ +"""Jira domain models.""" + +from pydantic import BaseModel, ConfigDict + + +class JiraTransition(BaseModel): + """A Jira workflow transition. + + Attributes: + id: Transition identifier as returned by the Jira API. + name: Human-readable transition name (e.g. "Done", "Released"). + """ + + model_config = ConfigDict(frozen=True) + + id: str + name: str + + +class JiraIssue(BaseModel): + """A Jira issue. + + Attributes: + key: Issue key (e.g. "ALLPOST-100"). + summary: Issue summary / title. + status: Current workflow status name (e.g. "In Progress"). + """ + + model_config = ConfigDict(frozen=True) + + key: str + summary: str + status: str diff --git a/src/release_agent/models/pipeline.py b/src/release_agent/models/pipeline.py new file mode 100644 index 0000000..3ac7434 --- /dev/null +++ b/src/release_agent/models/pipeline.py @@ -0,0 +1,31 @@ +"""Pipeline-related models for Azure DevOps pipeline tracking.""" + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class PipelineInfo(BaseModel): + """An immutable record identifying an Azure DevOps pipeline.""" + + model_config = ConfigDict(frozen=True) + + id: int + name: str + repo: str + + +class ReleasePipelineStage(BaseModel): + """An immutable record representing a single stage in a release pipeline.""" + + model_config = ConfigDict(frozen=True) + + name: str + rank: int = Field(ge=0) + requires_approval: bool + approval_id: str | None = None + + @model_validator(mode="after") + def validate_approval_consistency(self) -> "ReleasePipelineStage": + """Ensure approval_id is consistent with requires_approval.""" + if not self.requires_approval and self.approval_id is not None: + raise ValueError("approval_id must be None when requires_approval is False") + return self diff --git a/src/release_agent/models/pr.py b/src/release_agent/models/pr.py new file mode 100644 index 0000000..5ab7cef --- /dev/null +++ b/src/release_agent/models/pr.py @@ -0,0 +1,38 @@ +"""PRInfo model representing a pull request with ticket extraction.""" + +from typing import Literal + +from pydantic import BaseModel, ConfigDict, HttpUrl, model_validator + +from release_agent.branch_parser import parse_branch + + +class PRInfo(BaseModel): + """An immutable representation of a pull request. + + The ticket_id and has_ticket fields are automatically derived + from the branch name during model construction. + """ + + model_config = ConfigDict(frozen=True) + + pr_id: str + pr_url: HttpUrl + repo_name: str + branch: str + pr_title: str + pr_status: Literal["active", "completed", "abandoned"] + ticket_id: str | None = None + has_ticket: bool = False + + @model_validator(mode="before") + @classmethod + def extract_ticket_from_branch(cls, values: dict) -> dict: + """Derive ticket_id and has_ticket from the branch name.""" + branch = values.get("branch", "") + ticket_id, has_ticket = parse_branch(branch) + # ticket_id and has_ticket are always derived from branch parsing. + # Any caller-provided values are overridden to keep the model consistent. + values["ticket_id"] = ticket_id + values["has_ticket"] = has_ticket + return values diff --git a/src/release_agent/models/release.py b/src/release_agent/models/release.py new file mode 100644 index 0000000..0316bb3 --- /dev/null +++ b/src/release_agent/models/release.py @@ -0,0 +1,65 @@ +"""StagingRelease and ArchivedRelease models for tracking release state.""" + +import re +from datetime import date + +from pydantic import BaseModel, ConfigDict, field_validator, model_validator + +from release_agent.models.ticket import TicketEntry + +_VERSION_PATTERN = re.compile(r"^v\d+\.\d+\.\d+$") + + +class StagingRelease(BaseModel): + """An immutable model representing a release currently in staging. + + All mutation operations return new instances rather than modifying + the existing one. + """ + + model_config = ConfigDict(frozen=True) + + version: str + repo: str + started_at: date + tickets: list[TicketEntry] + + @field_validator("version") + @classmethod + def validate_version_format(cls, value: str) -> str: + """Validate version string matches vMAJOR.MINOR.PATCH.""" + if not _VERSION_PATTERN.match(value): + raise ValueError( + f"Invalid version format: {value!r}. Must match ^v\\d+\\.\\d+\\.\\d+$" + ) + return value + + def add_ticket(self, ticket: TicketEntry) -> "StagingRelease": + """Return a new StagingRelease with the given ticket appended. + + Does not mutate this instance. + """ + return self.model_copy(update={"tickets": [*self.tickets, ticket]}) + + def has_ticket(self, ticket_id: str) -> bool: + """Return True if a ticket with the given ID exists in this release.""" + return any(t.id == ticket_id for t in self.tickets) + + +class ArchivedRelease(StagingRelease): + """An immutable model representing a completed/released release. + + Extends StagingRelease with a released_at date that must be + on or after started_at. + """ + + released_at: date + + @model_validator(mode="after") + def validate_released_after_started(self) -> "ArchivedRelease": + """Validate that released_at is not before started_at.""" + if self.released_at < self.started_at: + raise ValueError( + f"released_at ({self.released_at}) must be >= started_at ({self.started_at})" + ) + return self diff --git a/src/release_agent/models/review.py b/src/release_agent/models/review.py new file mode 100644 index 0000000..433bd1b --- /dev/null +++ b/src/release_agent/models/review.py @@ -0,0 +1,48 @@ +"""Review models for Claude PR review structured output.""" + +from typing import Literal + +from pydantic import BaseModel, ConfigDict, computed_field + + +class ReviewIssue(BaseModel): + """A single issue identified during PR review. + + Attributes: + severity: One of "blocker", "error", "warning", or "info". + description: Human-readable description of the issue. + file_path: Optional path to the affected file. + suggestion: Optional remediation suggestion. + """ + + model_config = ConfigDict(frozen=True) + + severity: Literal["blocker", "error", "warning", "info"] + description: str + file_path: str | None = None + line_start: int | None = None + line_end: int | None = None + suggestion: str | None = None + + +class ReviewResult(BaseModel): + """Structured output from a Claude PR review. + + Attributes: + verdict: Either "approve" or "request_changes". + summary: A human-readable summary of the review. + issues: List of issues identified during review. + has_blockers: Computed field - True if any issue has severity "blocker". + """ + + model_config = ConfigDict(frozen=True) + + verdict: Literal["approve", "request_changes"] + summary: str + issues: tuple[ReviewIssue, ...] = () + + @computed_field # type: ignore[misc] + @property + def has_blockers(self) -> bool: + """Return True if any issue has blocker severity.""" + return any(issue.severity == "blocker" for issue in self.issues) diff --git a/src/release_agent/models/ticket.py b/src/release_agent/models/ticket.py new file mode 100644 index 0000000..a10927c --- /dev/null +++ b/src/release_agent/models/ticket.py @@ -0,0 +1,33 @@ +"""TicketEntry model representing a Jira ticket linked to a merged PR.""" + +import re +from datetime import date + +from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator + +_JIRA_ID_PATTERN = re.compile(r"^[A-Z][A-Z0-9]*-\d+$") + + +class TicketEntry(BaseModel): + """An immutable record of a Jira ticket merged into a release.""" + + model_config = ConfigDict(frozen=True) + + id: str + summary: str + pr_id: str + pr_url: HttpUrl + pr_title: str + branch: str + merged_at: date + + @field_validator("id") + @classmethod + def validate_jira_id(cls, value: str) -> str: + """Validate that the ID matches the Jira ticket format.""" + if not _JIRA_ID_PATTERN.match(value): + raise ValueError( + f"Invalid Jira ticket ID: {value!r}. " + "Must match pattern ^[A-Z][A-Z0-9]*-\\d+$" + ) + return value diff --git a/src/release_agent/models/webhook.py b/src/release_agent/models/webhook.py new file mode 100644 index 0000000..472124f --- /dev/null +++ b/src/release_agent/models/webhook.py @@ -0,0 +1,40 @@ +"""WebhookPayload models for Azure DevOps webhook events.""" + +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, ConfigDict, HttpUrl + + +class WebhookRepository(BaseModel): + """An immutable record of a repository as reported in a webhook event.""" + + model_config = ConfigDict(frozen=True) + + id: str + name: str + web_url: HttpUrl + + +class WebhookResource(BaseModel): + """An immutable record of the resource object in a webhook payload.""" + + model_config = ConfigDict(frozen=True) + + repository: WebhookRepository + pull_request_id: int + title: str + source_ref_name: str + target_ref_name: str + status: Literal["active", "completed", "abandoned"] + closed_date: datetime | None + + +class WebhookPayload(BaseModel): + """An immutable model representing an Azure DevOps webhook event payload.""" + + model_config = ConfigDict(frozen=True) + + subscription_id: str + event_type: str + resource: WebhookResource diff --git a/src/release_agent/services/__init__.py b/src/release_agent/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/release_agent/services/pr_dedup.py b/src/release_agent/services/pr_dedup.py new file mode 100644 index 0000000..8e1d7a9 --- /dev/null +++ b/src/release_agent/services/pr_dedup.py @@ -0,0 +1,46 @@ +"""PR deduplication service. + +Queries the agent_threads table to find which PRs from a given list have +not yet been processed. This prevents the PR poller from re-triggering +graph runs for PRs that already have an existing thread. +""" + +from release_agent.models.pr import PRInfo + +_QUERY = """ + SELECT pr_id, repo_name + FROM agent_threads + WHERE (pr_id, repo_name) IN ( + SELECT unnest(%s::text[]), unnest(%s::text[]) + ) +""" + + +async def find_unprocessed_prs(pool, prs: list[PRInfo]) -> list[PRInfo]: + """Return the subset of prs that have no existing agent_threads record. + + A PR is considered already-processed if there exists a row in agent_threads + with matching (pr_id, repo_name) pair. The SQL uses unnest to enforce + pair-wise matching (not independent ANY on each column). + + Args: + pool: Async psycopg connection pool. + prs: List of PRInfo objects to check. + + Returns: + A new list containing only PRs with no existing thread. + Original list is not mutated. + """ + if not prs: + return [] + + pr_ids = [p.pr_id for p in prs] + repo_names = [p.repo_name for p in prs] + + async with pool.connection() as conn: + async with conn.cursor() as cur: + await cur.execute(_QUERY, (pr_ids, repo_names)) + rows = await cur.fetchall() + + processed = {(str(row[0]), str(row[1])) for row in rows} + return [p for p in prs if (p.pr_id, p.repo_name) not in processed] diff --git a/src/release_agent/services/pr_poller.py b/src/release_agent/services/pr_poller.py new file mode 100644 index 0000000..4961ae2 --- /dev/null +++ b/src/release_agent/services/pr_poller.py @@ -0,0 +1,108 @@ +"""PR polling service. + +Polls Azure DevOps for active pull requests targeting a configured branch, +deduplicates against already-processed PRs, and schedules graph runs for +newly discovered PRs by synthesizing a fake webhook payload. +""" + +import asyncio +import logging +from collections.abc import Callable, Coroutine +from typing import Any + +from release_agent.models.pr import PRInfo +from release_agent.services.pr_dedup import find_unprocessed_prs + +logger = logging.getLogger(__name__) + + +def _synthesize_webhook_payload(pr: PRInfo) -> dict: + """Build a fake webhook payload dict from a PRInfo. + + Produces a dict that matches the structure expected by the parse_webhook + node so the PR can be processed as if it arrived via the real webhook. + + Args: + pr: The PRInfo object representing the active pull request. + + Returns: + A dict with event_type, subscription_id, and resource fields. + """ + return { + "event_type": "git.pullrequest.updated", + "subscription_id": f"polled-{pr.repo_name}-{pr.pr_id}", + "resource": { + "pull_request_id": int(pr.pr_id), + "title": pr.pr_title, + "status": pr.pr_status, + "source_ref_name": pr.branch, + "target_ref_name": "", + "repository": { + "id": f"{pr.repo_name}-id", + "name": pr.repo_name, + "web_url": str(pr.pr_url).rsplit("/pullrequest/", 1)[0], + }, + }, + } + + +async def run_pr_poll_loop( + *, + azdo_client, + db_pool, + watched_repos: list[str], + target_branch: str, + interval_seconds: int, + schedule_fn: Callable[..., Any], + sleep_fn: Callable[[float], Coroutine] | None = None, +) -> None: + """Poll Azure DevOps for active PRs and schedule unprocessed ones. + + Runs indefinitely until cancelled. For each polling interval: + 1. For each watched repo, list active PRs targeting target_branch. + 2. Find PRs not yet recorded in agent_threads. + 3. For each unprocessed PR, call schedule_fn with the synthesized payload. + 4. Sleep for interval_seconds. + + Args: + azdo_client: AzDoClient instance. + db_pool: Async psycopg connection pool. + watched_repos: List of repository names to poll. + target_branch: Target branch ref to filter PRs (e.g. "refs/heads/develop"). + interval_seconds: Seconds to sleep between polling iterations. + schedule_fn: Callable to schedule a graph run; called with keyword arg: + initial_state (dict). + sleep_fn: Async sleep callable (default: asyncio.sleep). Injectable for testing. + """ + if sleep_fn is None: + sleep_fn = asyncio.sleep # type: ignore[assignment] + + while True: + all_prs: list[PRInfo] = [] + + for repo in watched_repos: + try: + prs = await azdo_client.list_active_prs(repo, target_branch) + all_prs.extend(prs) + except Exception as exc: + logger.warning("PR polling failed for repo %s: %s", repo, exc) + + try: + unprocessed = await find_unprocessed_prs(db_pool, all_prs) + except Exception as exc: + logger.warning("PR dedup query failed: %s", exc) + unprocessed = [] + + for pr in unprocessed: + webhook_payload = _synthesize_webhook_payload(pr) + initial_state = { + "webhook_payload": webhook_payload, + "pr_id": pr.pr_id, + "repo_name": pr.repo_name, + } + try: + schedule_fn(initial_state=initial_state) + except Exception as exc: + logger.warning("Failed to schedule PR %s/%s: %s", pr.repo_name, pr.pr_id, exc) + + await sleep_fn(interval_seconds) # type: ignore[misc] diff --git a/src/release_agent/state.py b/src/release_agent/state.py new file mode 100644 index 0000000..920eefb --- /dev/null +++ b/src/release_agent/state.py @@ -0,0 +1,91 @@ +"""LangGraph state definition for the release agent. + +Uses TypedDict with total=False so all fields are optional, enabling +partial state updates throughout the graph execution. + +Reducers follow immutable patterns - they always return new lists. +""" + +from typing import Annotated + +from typing_extensions import TypedDict + + +def add_messages(existing: list[str], new: list[str]) -> list[str]: + """Accumulate messages without mutating the existing list. + + Returns a new list containing all existing messages followed by new ones. + """ + return [*existing, *new] + + +def add_errors(existing: list[str], new: list[str]) -> list[str]: + """Accumulate error messages without mutating the existing list. + + Returns a new list containing all existing errors followed by new ones. + """ + return [*existing, *new] + + +class ReleaseState(TypedDict, total=False): + """The mutable state passed between nodes in the release agent graph. + + All fields are optional (total=False). Each field can be updated + independently by graph nodes. List fields use reducers that accumulate + values rather than replacing them. + """ + + # Core identifiers + repo_name: str + pr_id: str + ticket_id: str + version: str + + # Accumulated lists + messages: Annotated[list[str], add_messages] + errors: Annotated[list[str], add_errors] + + # Phase 3: Webhook / PR fields + webhook_payload: dict + pr_info: dict + pr_diff: str + last_merge_source_commit: str + + # Phase 3: Jira / ticket fields + ticket_summary: str + has_ticket: bool + + # Phase 3: Review fields + review_result: dict + review_approved: bool + + # Phase 3: Staging fields + staging: dict + pr_already_merged: bool + + # Phase 3: Release PR fields + release_pr_id: str + release_pr_commit: str + + # Phase 3: Pipeline / approval fields + pipelines: list[dict] + triggered_builds: list[dict] + pending_approvals: list[dict] + + # Phase 3: Flow control + continue_to_release: bool + + # Phase 5: CI build tracking + ci_build_id: int + ci_build_status: str + ci_build_result: str + ci_build_url: str + + # Phase 5: CD release tracking + release_definition_id: int + release_id: int + current_stage: str + + # Phase 5: Slack interactive message tracking + approval_message_ts: str + slack_message_ts: str diff --git a/src/release_agent/tools/__init__.py b/src/release_agent/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/release_agent/tools/_http.py b/src/release_agent/tools/_http.py new file mode 100644 index 0000000..03e2d96 --- /dev/null +++ b/src/release_agent/tools/_http.py @@ -0,0 +1,103 @@ +"""Shared HTTP helpers for service clients. + +Provides a unified response-to-exception mapper and a Basic auth header builder. +""" + +import base64 + +import httpx + +from release_agent.exceptions import ( + AuthenticationError, + NotFoundError, + RateLimitError, + ServiceError, + ServiceUnavailableError, +) + +# Status codes that should not raise exceptions +_OK_RANGE_START = 200 +_OK_RANGE_END = 299 +_REDIRECT_RANGE_END = 399 + + +def raise_for_status(response: httpx.Response, *, service: str) -> None: + """Raise an appropriate exception based on the HTTP status code. + + 2xx and 3xx responses do not raise. 4xx/5xx responses raise typed exceptions + that carry the service name for debugging. + + Args: + response: The httpx response to check. + service: Name of the service that returned this response. + + Raises: + AuthenticationError: For 401 or 403 status codes. + NotFoundError: For 404 status codes. + RateLimitError: For 429 status codes. + ServiceUnavailableError: For 503 status codes. + ServiceError: For all other 4xx/5xx status codes. + """ + code = response.status_code + + if code <= _REDIRECT_RANGE_END: + return + + if code in (401, 403): + raise AuthenticationError( + service=service, status_code=code, detail=_extract_detail(response) + ) + + if code == 404: + raise NotFoundError(service=service, detail=_extract_detail(response)) + + if code == 429: + retry_after = _parse_retry_after(response) + raise RateLimitError(service=service, retry_after=retry_after) + + if code == 503: + raise ServiceUnavailableError(service=service, detail=_extract_detail(response)) + + raise ServiceError(service=service, status_code=code, detail=_extract_detail(response)) + + +def build_auth_header(username: str, password: str) -> dict[str, str]: + """Build a Basic authentication header. + + Args: + username: The username (may be empty string for PAT-only auth). + password: The password or token. + + Returns: + A dict with a single "Authorization" key containing the Basic auth value. + """ + credentials = f"{username}:{password}" + encoded = base64.b64encode(credentials.encode()).decode() + return {"Authorization": f"Basic {encoded}"} + + +def _extract_detail(response: httpx.Response) -> str | None: + """Try to extract a human-readable detail string from the response body.""" + try: + body = response.json() + if isinstance(body, dict): + return ( + body.get("message") + or body.get("detail") + or ("; ".join(body["errorMessages"]) if body.get("errorMessages") else None) + or str(body) + ) + return str(body) + except (ValueError, KeyError): + return response.text or None + + +def _parse_retry_after(response: httpx.Response) -> int | None: + """Parse the Retry-After header value as an integer number of seconds.""" + value = response.headers.get("Retry-After") + if value is None: + return None + try: + return int(value) + except ValueError: + return None diff --git a/src/release_agent/tools/_retry.py b/src/release_agent/tools/_retry.py new file mode 100644 index 0000000..46e7e03 --- /dev/null +++ b/src/release_agent/tools/_retry.py @@ -0,0 +1,74 @@ +"""Async retry decorator with exponential backoff. + +Retries only on RateLimitError and ServiceUnavailableError. +Respects the Retry-After header when present in RateLimitError. +""" + +import asyncio +import functools +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar + +from release_agent.exceptions import RateLimitError, ServiceUnavailableError + +_T = TypeVar("_T") + +_DEFAULT_MAX_ATTEMPTS = 3 +_DEFAULT_BASE_DELAY = 1.0 +_BACKOFF_MULTIPLIER = 2.0 + +# Types that trigger a retry +_RETRYABLE = (RateLimitError, ServiceUnavailableError) + + +def with_retry( + max_attempts: int = _DEFAULT_MAX_ATTEMPTS, + base_delay: float = _DEFAULT_BASE_DELAY, + sleep_fn: Callable[[float], Awaitable[None]] | None = None, +) -> Callable[[Callable[..., Awaitable[_T]]], Callable[..., Awaitable[_T]]]: + """Decorator factory that retries an async function on retryable errors. + + Args: + max_attempts: Maximum total attempts (including the first call). + base_delay: Initial delay in seconds between retries. Doubles each retry. + sleep_fn: Async sleep function (defaults to asyncio.sleep). Injected for testing. + + Returns: + A decorator that wraps an async function with retry logic. + """ + if max_attempts < 1: + raise ValueError(f"max_attempts must be >= 1, got {max_attempts}") + + _sleep = sleep_fn if sleep_fn is not None else asyncio.sleep + + def decorator(fn: Callable[..., Awaitable[_T]]) -> Callable[..., Awaitable[_T]]: + @functools.wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> _T: + delay = base_delay + last_exc: Exception | None = None + for attempt in range(1, max_attempts + 1): + try: + return await fn(*args, **kwargs) + except _RETRYABLE as exc: + last_exc = exc + if attempt >= max_attempts: + raise + wait = _resolve_wait(exc, delay) + await _sleep(wait) + delay *= _BACKOFF_MULTIPLIER + raise RuntimeError("Retry loop exited unexpectedly") from last_exc + + return wrapper + + return decorator + + +def _resolve_wait(exc: Exception, fallback_delay: float) -> float: + """Return the number of seconds to wait before the next attempt. + + For RateLimitError with a retry_after value, that value takes precedence. + Otherwise the exponential backoff delay is used. + """ + if isinstance(exc, RateLimitError) and exc.retry_after is not None: + return float(exc.retry_after) + return fallback_delay diff --git a/src/release_agent/tools/azdo.py b/src/release_agent/tools/azdo.py new file mode 100644 index 0000000..d9c8ee3 --- /dev/null +++ b/src/release_agent/tools/azdo.py @@ -0,0 +1,513 @@ +"""Azure DevOps service client. + +Uses two httpx.AsyncClient instances: + - http_client: main AzDo REST API (dev.azure.com) + - vsrm_http_client: VSRM release management API (vsrm.dev.azure.com) + +Both clients are injected via constructor for testability. +""" + +from types import TracebackType +from typing import Self + +import httpx + +from release_agent.models.build import ApprovalRecord, BuildStatus +from release_agent.models.pipeline import PipelineInfo +from release_agent.models.pr import PRInfo +from release_agent.tools._http import build_auth_header, raise_for_status + +_API_VERSION = "7.1" + + +class AzDoClient: + """Client for the Azure DevOps REST API. + + Args: + base_url: Main AzDo project API base URL. + vsrm_base_url: VSRM release management API base URL. + pat: Personal Access Token for authentication. + http_client: Injected httpx.AsyncClient for the main API. + vsrm_http_client: Injected httpx.AsyncClient for the VSRM API. + """ + + def __init__( + self, + *, + base_url: str, + vsrm_base_url: str, + pat: str, + http_client: httpx.AsyncClient, + vsrm_http_client: httpx.AsyncClient, + ) -> None: + self._base_url = base_url.rstrip("/") + self._vsrm_base_url = vsrm_base_url.rstrip("/") + self._auth = build_auth_header("", pat) + self._http = http_client + self._vsrm = vsrm_http_client + + async def close(self) -> None: + """Close both underlying HTTP clients.""" + await self._http.aclose() + await self._vsrm.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + # ------------------------------------------------------------------ + # Pull requests + # ------------------------------------------------------------------ + + async def get_pr(self, pr_id: int) -> PRInfo: + """Fetch a pull request by ID. + + Args: + pr_id: Numeric pull request ID. + + Returns: + A PRInfo model populated from the API response. + + Raises: + NotFoundError: If the PR does not exist. + AuthenticationError: If authentication fails. + ServiceError: For other HTTP errors. + """ + url = f"{self._base_url}/git/pullRequests/{pr_id}" + response = await self._http.get(url, headers=self._auth, params={"api-version": _API_VERSION}) + raise_for_status(response, service="azdo") + data = response.json() + return _parse_pr(data) + + async def list_active_prs(self, repo: str, target_branch: str) -> list[PRInfo]: + """List active pull requests for a repository filtered by target branch. + + Args: + repo: Repository name. + target_branch: Target branch ref name (e.g. "refs/heads/develop"). + + Returns: + List of PRInfo models for active PRs targeting the given branch. + + Raises: + NotFoundError: If the repository does not exist. + AuthenticationError: If authentication fails. + ServiceError: For other HTTP errors. + """ + url = f"{self._base_url}/git/repositories/{repo}/pullRequests" + response = await self._http.get( + url, + headers=self._auth, + params={ + "api-version": _API_VERSION, + "status": "active", + "targetRefName": target_branch, + }, + ) + raise_for_status(response, service="azdo") + data = response.json() + return [_parse_pr(item) for item in data.get("value", [])] + + async def get_pr_diff(self, pr_id: int) -> str: + """Return a diff-like string for the given pull request. + + Fetches the PR to determine the repository, then retrieves the + diffs endpoint and formats a summary string suitable for code review. + + Args: + pr_id: Numeric pull request ID. + + Returns: + A text string describing changed files. + + Raises: + NotFoundError: If the PR does not exist. + ServiceError: For other HTTP errors. + """ + pr = await self.get_pr(pr_id) + repo = pr.repo_name + url = f"{self._base_url}/git/repositories/{repo}/pullRequests/{pr_id}/diffs" + response = await self._http.get( + url, + headers=self._auth, + params={"api-version": _API_VERSION, "baseVersionDescriptor.versionType": "commit"}, + ) + raise_for_status(response, service="azdo") + data = response.json() + return _format_diff(data) + + async def merge_pr(self, *, pr_id: int, last_merge_source_commit: str) -> bool: + """Complete (merge) a pull request. + + Args: + pr_id: Numeric pull request ID. + last_merge_source_commit: The commit ID to use as the merge source. + + Returns: + True on success. + + Raises: + NotFoundError: If the PR does not exist. + ServiceError: For other HTTP errors including merge conflicts. + """ + url = f"{self._base_url}/git/pullRequests/{pr_id}" + payload = { + "status": "completed", + "lastMergeSourceCommit": {"commitId": last_merge_source_commit}, + "completionOptions": {"mergeStrategy": "squash"}, + } + response = await self._http.patch( + url, + headers={**self._auth, "Content-Type": "application/json"}, + json=payload, + params={"api-version": _API_VERSION}, + ) + raise_for_status(response, service="azdo") + return True + + async def create_pr( + self, + *, + repo: str, + source: str, + target: str, + title: str, + description: str, + ) -> dict: + """Create a new pull request. + + Args: + repo: Repository name. + source: Source branch ref name (e.g. "refs/heads/release/v1.2.0"). + target: Target branch ref name (e.g. "refs/heads/main"). + title: PR title. + description: PR description. + + Returns: + Raw dict from the API response. + + Raises: + ServiceError: For HTTP errors. + """ + url = f"{self._base_url}/git/repositories/{repo}/pullRequests" + payload = { + "sourceRefName": source, + "targetRefName": target, + "title": title, + "description": description, + } + response = await self._http.post( + url, + headers={**self._auth, "Content-Type": "application/json"}, + json=payload, + params={"api-version": _API_VERSION}, + ) + raise_for_status(response, service="azdo") + return response.json() + + # ------------------------------------------------------------------ + # Pipelines + # ------------------------------------------------------------------ + + async def list_build_pipelines(self, *, repo: str) -> list[PipelineInfo]: + """List build pipelines for a repository. + + Args: + repo: Repository name (used to filter pipelines). + + Returns: + List of PipelineInfo models. + + Raises: + ServiceError: For HTTP errors. + """ + url = f"{self._base_url}/pipelines" + response = await self._http.get( + url, + headers=self._auth, + params={"api-version": _API_VERSION}, + ) + raise_for_status(response, service="azdo") + data = response.json() + return [ + PipelineInfo(id=item["id"], name=item["name"], repo=repo) + for item in data.get("value", []) + ] + + async def trigger_pipeline(self, *, pipeline_id: int, branch: str) -> dict: + """Trigger a pipeline run. + + Args: + pipeline_id: Numeric pipeline definition ID. + branch: Branch ref to build (e.g. "refs/heads/main"). + + Returns: + Raw dict from the API response containing the build/run ID. + + Raises: + NotFoundError: If the pipeline does not exist. + ServiceError: For other HTTP errors. + """ + url = f"{self._base_url}/pipelines/{pipeline_id}/runs" + payload = {"resources": {"repositories": {"self": {"refName": branch}}}} + response = await self._http.post( + url, + headers={**self._auth, "Content-Type": "application/json"}, + json=payload, + params={"api-version": _API_VERSION}, + ) + raise_for_status(response, service="azdo") + return response.json() + + async def get_build_status(self, *, build_id: int) -> BuildStatus: + """Get the status of a build. + + Args: + build_id: Numeric build ID. + + Returns: + BuildStatus dataclass with status, result, and build_url fields. + + Raises: + NotFoundError: If the build does not exist. + ServiceError: For other HTTP errors. + """ + url = f"{self._base_url}/build/builds/{build_id}" + response = await self._http.get( + url, + headers=self._auth, + params={"api-version": _API_VERSION}, + ) + raise_for_status(response, service="azdo") + data = response.json() + links = data.get("_links") or {} + web_link = links.get("web") or {} + build_url = web_link.get("href") or data.get("url") + return BuildStatus( + status=data["status"], + result=data.get("result"), + build_url=build_url, + ) + + async def get_release_approvals(self, *, release_id: int) -> list[ApprovalRecord]: + """Get pending approvals for a release. + + Args: + release_id: Numeric release ID. + + Returns: + List of ApprovalRecord dataclasses for this release. + + Raises: + NotFoundError: If the release does not exist. + ServiceError: For other HTTP errors. + """ + url = f"{self._vsrm_base_url}/release/approvals" + response = await self._vsrm.get( + url, + headers=self._auth, + params={"api-version": _API_VERSION, "releaseId": release_id}, + ) + raise_for_status(response, service="azdo") + data = response.json() + records: list[ApprovalRecord] = [] + for item in data.get("value", []): + env = item.get("releaseEnvironment") or {} + rel = env.get("release") or {} + records.append( + ApprovalRecord( + approval_id=str(item["id"]), + stage_name=env.get("name", ""), + status=item.get("status", "pending"), + release_id=rel.get("id", release_id), + ) + ) + return records + + async def get_latest_release(self, *, definition_id: int) -> dict: + """Get the latest release for a release definition. + + Args: + definition_id: Numeric release definition ID. + + Returns: + Raw dict from the API response for the latest release, + or an empty dict if no releases exist. + + Raises: + NotFoundError: If the definition does not exist. + ServiceError: For other HTTP errors. + """ + url = f"{self._vsrm_base_url}/release/releases" + response = await self._vsrm.get( + url, + headers=self._auth, + params={ + "api-version": _API_VERSION, + "definitionId": definition_id, + "$top": 1, + "$orderby": "id desc", + }, + ) + raise_for_status(response, service="azdo") + data = response.json() + values = data.get("value", []) + return values[0] if values else {} + + # ------------------------------------------------------------------ + # Release approvals (VSRM) + # ------------------------------------------------------------------ + + async def approve_release(self, *, approval_id: str, comment: str) -> dict: + """Approve a release pipeline stage approval. + + Uses the VSRM API endpoint. + + Args: + approval_id: The approval record ID. + comment: Comment to attach to the approval. + + Returns: + Raw dict from the API response. + + Raises: + NotFoundError: If the approval does not exist. + ServiceError: For other HTTP errors. + """ + url = f"{self._vsrm_base_url}/release/approvals/{approval_id}" + payload = {"status": "approved", "comments": comment} + response = await self._vsrm.patch( + url, + headers={**self._auth, "Content-Type": "application/json"}, + json=payload, + params={"api-version": _API_VERSION}, + ) + raise_for_status(response, service="azdo") + return response.json() + + + # ------------------------------------------------------------------- + # PR Comment methods + # ------------------------------------------------------------------- + + async def add_pr_comment( + self, + *, + repo: str, + pr_id: int, + content: str, + ) -> dict: + """Add a general comment thread to a PR (not file-specific). + + Args: + repo: Repository name. + pr_id: Pull request ID. + content: Markdown content for the comment. + + Returns: + Raw dict from the API response. + """ + url = f"{self._base_url}/git/repositories/{repo}/pullRequests/{pr_id}/threads" + payload = { + "comments": [{"parentCommentId": 0, "content": content, "commentType": 1}], + "status": "active", + } + response = await self._http.post( + url, + headers={**self._auth, "Content-Type": "application/json"}, + json=payload, + params={"api-version": _API_VERSION}, + ) + raise_for_status(response, service="azdo") + return response.json() + + async def add_pr_inline_comment( + self, + *, + repo: str, + pr_id: int, + content: str, + file_path: str, + line_start: int, + line_end: int | None = None, + ) -> dict: + """Add an inline comment thread to a specific file and line in a PR. + + Args: + repo: Repository name. + pr_id: Pull request ID. + content: Markdown content for the comment. + file_path: Path to the file (e.g., "/src/Handlers/InvoiceHandler.cs"). + line_start: Starting line number. + line_end: Ending line number (defaults to line_start). + + Returns: + Raw dict from the API response. + """ + effective_end = line_end if line_end is not None else line_start + url = f"{self._base_url}/git/repositories/{repo}/pullRequests/{pr_id}/threads" + payload = { + "comments": [{"parentCommentId": 0, "content": content, "commentType": 1}], + "threadContext": { + "filePath": file_path if file_path.startswith("/") else f"/{file_path}", + "rightFileStart": {"line": line_start, "offset": 1}, + "rightFileEnd": {"line": effective_end, "offset": 1}, + }, + "status": "active", + } + response = await self._http.post( + url, + headers={**self._auth, "Content-Type": "application/json"}, + json=payload, + params={"api-version": _API_VERSION}, + ) + raise_for_status(response, service="azdo") + return response.json() + + +# --------------------------------------------------------------------------- +# Private parsing helpers +# --------------------------------------------------------------------------- + +def _parse_pr(data: dict) -> PRInfo: + """Map a raw AzDo PR API response to a PRInfo model.""" + pr_id = str(data["pullRequestId"]) + repo = data["repository"]["name"] + branch = data.get("sourceRefName", "") + url = data.get("url", "") + if not url: + repo_url = data["repository"].get("remoteUrl", "") + url = f"{repo_url}/pullrequest/{pr_id}" + return PRInfo( + pr_id=pr_id, + pr_url=url, + repo_name=repo, + branch=branch, + pr_title=data.get("title", ""), + pr_status=_map_pr_status(data.get("status", "active")), + ) + + +def _map_pr_status(raw: str) -> str: + """Normalise AzDo PR status to a value accepted by PRInfo.""" + mapping = {"active": "active", "completed": "completed", "abandoned": "abandoned"} + return mapping.get(raw, "active") + + +def _format_diff(data: dict) -> str: + """Format the diffs API response as a text string.""" + changes = data.get("changes", []) + lines = [] + for change in changes: + item = change.get("item", {}) + path = item.get("path", "unknown") + change_type = change.get("changeType", "edit") + lines.append(f"{change_type}: {path}") + return "\n".join(lines) diff --git a/src/release_agent/tools/claude_review.py b/src/release_agent/tools/claude_review.py new file mode 100644 index 0000000..89d91a0 --- /dev/null +++ b/src/release_agent/tools/claude_review.py @@ -0,0 +1,335 @@ +"""Claude PR reviewer using Claude Code CLI for code review. + +Uses `claude -p` (print mode) with `--output-format json` to leverage the +Claude Code subscription instead of API key billing. Claude Code can +autonomously read files, grep code, and explore the codebase for deeper reviews. +""" + +import asyncio +import json +import logging +from pathlib import Path + +from release_agent.models.review import ReviewIssue, ReviewResult + +logger = logging.getLogger(__name__) + +_MAX_DIFF_CHARS = 100_000 +_TRUNCATION_NOTE = "\n\n[DIFF TRUNCATED: content exceeded 100,000 characters]" +_CLI_TIMEOUT_SECONDS = 300 + +# JSON schema for structured output from Claude Code CLI +_REVIEW_JSON_SCHEMA = json.dumps({ + "type": "object", + "properties": { + "verdict": { + "type": "string", + "enum": ["approve", "request_changes"], + }, + "summary": { + "type": "string", + }, + "issues": { + "type": "array", + "items": { + "type": "object", + "properties": { + "severity": { + "type": "string", + "enum": ["blocker", "error", "warning", "info"], + }, + "description": {"type": "string"}, + "file_path": {"type": "string"}, + "line_start": {"type": "integer", "description": "Starting line number in the file"}, + "line_end": {"type": "integer", "description": "Ending line number in the file"}, + "suggestion": {"type": "string"}, + }, + "required": ["severity", "description"], + }, + }, + }, + "required": ["verdict", "summary", "issues"], +}) + +_SYSTEM_PROMPT = ( + "You are a senior code reviewer for .NET C# projects. " + "Review the PR for: backward compatibility (DB migration, serialization, integration events), " + "null handling, business logic correctness, test coverage, hardcoded values, and security. " + "You have access to the full codebase via Read, Glob, and Grep tools. " + "Use them to understand the context around the changed files. " + "For each issue, include the file_path and line_start/line_end where the issue occurs." +) + +# JSON schema for structured ticket content output +_TICKET_JSON_SCHEMA = json.dumps({ + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": "Short one-line summary of the work done (max 100 chars)", + }, + "description": { + "type": "string", + "description": "Detailed description of the changes and business value", + }, + }, + "required": ["summary", "description"], +}) + +_TICKET_SYSTEM_PROMPT = ( + "You are a Jira ticket writer for a software development team. " + "Given a PR diff and title, produce a concise Jira Story ticket. " + "The summary should be a clear, action-oriented title (max 100 chars). " + "The description should explain what was changed and why, suitable for a product backlog." +) + + +class ClaudeReviewer: + """Reviews pull request diffs using Claude Code CLI. + + Uses the user's Claude Code subscription (not API key) by invoking + `claude -p` as a subprocess. This allows Claude to autonomously read + files and explore the codebase for more thorough reviews. + + Args: + claude_cmd: Path to the claude CLI binary (default: "claude"). + timeout: Maximum seconds to wait for the review (default: 300). + run_subprocess: Async callable for running subprocesses (injected for testing). + """ + + def __init__( + self, + *, + claude_cmd: str = "claude", + timeout: int = _CLI_TIMEOUT_SECONDS, + run_subprocess: object | None = None, + ) -> None: + self._claude_cmd = claude_cmd + self._timeout = timeout + self._run_subprocess = run_subprocess or _default_run_subprocess + + async def review_pr( + self, + *, + diff: str, + pr_title: str, + repo_name: str, + cwd: str | Path | None = None, + ) -> ReviewResult: + """Review a pull request and return a structured result. + + Args: + diff: The raw diff text of the pull request. + pr_title: Title of the pull request. + repo_name: Name of the repository. + cwd: Working directory for Claude Code (e.g., git worktree path). + If provided, Claude can read files in this directory. + + Returns: + A ReviewResult with verdict, summary, and list of issues. + """ + truncated_diff = _truncate_diff(diff) + prompt = _build_prompt( + diff=truncated_diff, + pr_title=pr_title, + repo_name=repo_name, + ) + + cmd = [ + self._claude_cmd, "-p", prompt, + "--output-format", "json", + "--json-schema", _REVIEW_JSON_SCHEMA, + "--allowedTools", "Read,Glob,Grep", + "--system-prompt", _SYSTEM_PROMPT, + ] + + stdout, stderr, returncode = await self._run_subprocess( + cmd=cmd, + cwd=str(cwd) if cwd else None, + timeout=self._timeout, + ) + + if returncode != 0: + logger.error( + "Claude CLI failed (exit %d): %s", returncode, stderr[:500] + ) + raise RuntimeError(f"Claude CLI exited with code {returncode}: {stderr[:200]}") + + return _parse_cli_output(stdout) + + async def generate_ticket_content( + self, + *, + diff: str, + pr_title: str, + repo_name: str, + cwd: str | Path | None = None, + ) -> tuple[str, str]: + """Generate Jira ticket summary and description from a PR diff. + + Args: + diff: The raw diff text of the pull request. + pr_title: Title of the pull request. + repo_name: Name of the repository. + cwd: Working directory for Claude Code. + + Returns: + A tuple of (summary, description) strings suitable for Jira. + + Raises: + RuntimeError: If the Claude CLI fails or times out. + ValueError: If the output cannot be parsed. + """ + truncated_diff = _truncate_diff(diff) + prompt = ( + f"Generate a Jira Story ticket for this pull request from repository '{repo_name}'.\n\n" + f"PR Title: {pr_title}\n\n" + f"Diff:\n```\n{truncated_diff}\n```\n\n" + "Produce a concise summary and description for a Jira Story ticket." + ) + + cmd = [ + self._claude_cmd, "-p", prompt, + "--output-format", "json", + "--json-schema", _TICKET_JSON_SCHEMA, + "--allowedTools", "Read,Glob,Grep", + "--system-prompt", _TICKET_SYSTEM_PROMPT, + ] + + stdout, stderr, returncode = await self._run_subprocess( + cmd=cmd, + cwd=str(cwd) if cwd else None, + timeout=self._timeout, + ) + + if returncode != 0: + logger.error( + "Claude CLI failed (exit %d): %s", returncode, stderr[:500] + ) + raise RuntimeError(f"Claude CLI exited with code {returncode}: {stderr[:200]}") + + return _parse_ticket_output(stdout) + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + +def _truncate_diff(diff: str) -> str: + """Truncate diff to at most _MAX_DIFF_CHARS characters.""" + if len(diff) <= _MAX_DIFF_CHARS: + return diff + return diff[:_MAX_DIFF_CHARS] + _TRUNCATION_NOTE + + +def _build_prompt(*, diff: str, pr_title: str, repo_name: str) -> str: + """Build the user prompt for the review request.""" + return ( + f"Review this pull request from repository '{repo_name}'.\n\n" + f"PR Title: {pr_title}\n\n" + f"Diff:\n```\n{diff}\n```\n\n" + "Read the related source files to understand context. " + "Provide your review as structured JSON output." + ) + + +def _parse_cli_output(stdout: str) -> ReviewResult: + """Parse the JSON output from Claude Code CLI into a ReviewResult.""" + try: + data = json.loads(stdout) + except json.JSONDecodeError as exc: + raise ValueError(f"Failed to parse Claude CLI output as JSON: {exc}") from exc + + # Claude Code --output-format json wraps result in {"result": "...", ...} + # The structured_output field contains our schema-conforming data + structured = data.get("structured_output") or data.get("result") + + if structured is None: + raise ValueError("No structured_output or result in Claude CLI response") + + # If structured is a string (result field), try to parse it as JSON + if isinstance(structured, str): + try: + structured = json.loads(structured) + except json.JSONDecodeError: + raise ValueError( + "Claude CLI result is not valid JSON for review schema" + ) + + if not isinstance(structured, dict): + raise ValueError(f"Expected dict, got {type(structured).__name__}") + + issues = [ + ReviewIssue( + severity=issue["severity"], + description=issue["description"], + file_path=issue.get("file_path"), + line_start=issue.get("line_start"), + line_end=issue.get("line_end"), + suggestion=issue.get("suggestion"), + ) + for issue in structured.get("issues", []) + ] + + return ReviewResult( + verdict=structured["verdict"], + summary=structured["summary"], + issues=tuple(issues), + ) + + +def _parse_ticket_output(stdout: str) -> tuple[str, str]: + """Parse the JSON output from Claude Code CLI into (summary, description).""" + try: + data = json.loads(stdout) + except json.JSONDecodeError as exc: + raise ValueError(f"Failed to parse Claude CLI ticket output as JSON: {exc}") from exc + + structured = data.get("structured_output") or data.get("result") + + if structured is None: + raise ValueError("No structured_output or result in Claude CLI response") + + if isinstance(structured, str): + try: + structured = json.loads(structured) + except json.JSONDecodeError: + raise ValueError("Claude CLI result is not valid JSON for ticket schema") + + if not isinstance(structured, dict): + raise ValueError(f"Expected dict, got {type(structured).__name__}") + + summary = structured.get("summary", "") + description = structured.get("description", "") + if not summary: + raise ValueError("Claude ticket output missing required 'summary' field") + return (summary, description) + + +async def _default_run_subprocess( + *, + cmd: list[str], + cwd: str | None, + timeout: int, +) -> tuple[str, str, int]: + """Run a subprocess and return (stdout, stderr, returncode).""" + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + ) + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + process.communicate(), timeout=timeout + ) + except asyncio.TimeoutError: + process.kill() + await process.wait() + raise RuntimeError(f"Claude CLI timed out after {timeout} seconds") + + return ( + stdout_bytes.decode("utf-8", errors="replace"), + stderr_bytes.decode("utf-8", errors="replace"), + process.returncode or 0, + ) diff --git a/src/release_agent/tools/jira.py b/src/release_agent/tools/jira.py new file mode 100644 index 0000000..6fc6612 --- /dev/null +++ b/src/release_agent/tools/jira.py @@ -0,0 +1,269 @@ +"""Jira service client. + +Uses Basic auth (email:api_token) and the Jira REST API v3. +The httpx.AsyncClient is injected via constructor for testability. +""" + +from types import TracebackType +from typing import Self + +import httpx + +from release_agent.models.jira import JiraIssue, JiraTransition +from release_agent.tools._http import build_auth_header, raise_for_status + +_API_VERSION = "rest/api/3" +_DEV_IN_PROGRESS_TRANSITION = "Dev in Progress" + + +class JiraClient: + """Client for the Jira REST API. + + Args: + base_url: Jira instance base URL (e.g. "https://billolife.atlassian.net"). + email: Account email address for Basic auth. + api_token: Jira API token for Basic auth. + http_client: Injected httpx.AsyncClient. + """ + + def __init__( + self, + *, + base_url: str, + email: str, + api_token: str, + http_client: httpx.AsyncClient, + ) -> None: + self._base_url = base_url.rstrip("/") + self._auth = build_auth_header(email, api_token) + self._http = http_client + + async def close(self) -> None: + """Close the underlying HTTP client.""" + await self._http.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + # ------------------------------------------------------------------ + # Issue operations + # ------------------------------------------------------------------ + + async def get_issue(self, ticket_id: str) -> JiraIssue: + """Fetch a Jira issue by key. + + Args: + ticket_id: Jira issue key (e.g. "ALLPOST-100"). + + Returns: + A JiraIssue model populated from the API response. + + Raises: + NotFoundError: If the issue does not exist. + AuthenticationError: If authentication fails. + ServiceError: For other HTTP errors. + """ + url = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}" + response = await self._http.get(url, headers=self._auth) + raise_for_status(response, service="jira") + data = response.json() + return _parse_issue(data) + + async def get_transitions(self, ticket_id: str) -> list[JiraTransition]: + """Fetch available workflow transitions for a Jira issue. + + Args: + ticket_id: Jira issue key. + + Returns: + List of available JiraTransition models. + + Raises: + NotFoundError: If the issue does not exist. + ServiceError: For other HTTP errors. + """ + url = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}/transitions" + response = await self._http.get(url, headers=self._auth) + raise_for_status(response, service="jira") + data = response.json() + return [ + JiraTransition(id=t["id"], name=t["name"]) + for t in data.get("transitions", []) + ] + + async def transition_issue(self, ticket_id: str, transition_name: str) -> bool: + """Move a Jira issue through a named workflow transition. + + If the target transition is not currently available, a two-step + fallback is attempted: first transition to "Dev in Progress", then + retry the target transition. Returns False if the transition is still + unavailable after the fallback attempt. + + Args: + ticket_id: Jira issue key. + transition_name: Name of the target transition (e.g. "Released"). + + Returns: + True if the transition succeeded, False if unavailable. + + Raises: + ServiceError: For HTTP errors during transition execution. + """ + transitions = await self.get_transitions(ticket_id) + transition_id = _find_transition_id(transitions, transition_name) + + if transition_id is None: + # Two-step fallback: move to "Dev in Progress" first, then retry. + dev_id = _find_transition_id(transitions, _DEV_IN_PROGRESS_TRANSITION) + if dev_id is None: + return False + await self._do_transition(ticket_id, dev_id) + + # Retry target transition after the fallback step. + transitions = await self.get_transitions(ticket_id) + transition_id = _find_transition_id(transitions, transition_name) + if transition_id is None: + return False + + await self._do_transition(ticket_id, transition_id) + return True + + async def create_issue( + self, + *, + project: str, + summary: str, + description: str, + issue_type: str = "Story", + ) -> str: + """Create a new Jira issue and return its ticket key. + + Args: + project: Jira project key (e.g. "ALLPOST"). + summary: Issue summary / title. + description: Plain-text description, converted to ADF format. + issue_type: Issue type name (default: "Story"). + + Returns: + The ticket key of the created issue (e.g. "ALLPOST-42"). + + Raises: + AuthenticationError: If authentication fails. + ServiceError: For other HTTP errors. + """ + url = f"{self._base_url}/{_API_VERSION}/issue" + payload = { + "fields": { + "project": {"key": project}, + "summary": summary, + "description": _text_to_adf(description), + "issuetype": {"name": issue_type}, + } + } + response = await self._http.post( + url, + headers={**self._auth, "Content-Type": "application/json"}, + json=payload, + ) + raise_for_status(response, service="jira") + data = response.json() + return data["key"] + + async def add_remote_link(self, *, ticket_id: str, url: str, title: str) -> bool: + """Add a remote link (e.g. a PR URL) to a Jira issue. + + Args: + ticket_id: Jira issue key. + url: URL of the remote resource. + title: Display title of the link. + + Returns: + True on success. + + Raises: + NotFoundError: If the issue does not exist. + ServiceError: For other HTTP errors. + """ + endpoint = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}/remotelink" + payload = { + "object": { + "url": url, + "title": title, + } + } + response = await self._http.post( + endpoint, + headers={**self._auth, "Content-Type": "application/json"}, + json=payload, + ) + raise_for_status(response, service="jira") + return True + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + async def _do_transition(self, ticket_id: str, transition_id: str) -> None: + """Execute a transition by ID.""" + url = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}/transitions" + payload = {"transition": {"id": transition_id}} + response = await self._http.post( + url, + headers={**self._auth, "Content-Type": "application/json"}, + json=payload, + ) + raise_for_status(response, service="jira") + + +# --------------------------------------------------------------------------- +# Private parsing helpers +# --------------------------------------------------------------------------- + +def _text_to_adf(text: str) -> dict: + """Convert plain text to Atlassian Document Format (ADF). + + Each non-empty line becomes a paragraph node. Empty string produces + an empty document (no content nodes). + + Args: + text: Plain text to convert. + + Returns: + ADF dict conforming to the Jira REST API v3 description format. + """ + content = [] + for line in text.splitlines(): + stripped = line.strip() + if stripped: + content.append({ + "type": "paragraph", + "content": [{"type": "text", "text": stripped}], + }) + return {"version": 1, "type": "doc", "content": content} + + +def _parse_issue(data: dict) -> JiraIssue: + """Map a raw Jira issue API response to a JiraIssue model.""" + fields = data.get("fields", {}) + status = fields.get("status", {}).get("name", "Unknown") + return JiraIssue( + key=data["key"], + summary=fields.get("summary", ""), + status=status, + ) + + +def _find_transition_id(transitions: list[JiraTransition], name: str) -> str | None: + """Return the ID of the first transition matching the given name, or None.""" + for transition in transitions: + if transition.name == name: + return transition.id + return None diff --git a/src/release_agent/tools/slack.py b/src/release_agent/tools/slack.py new file mode 100644 index 0000000..0754f95 --- /dev/null +++ b/src/release_agent/tools/slack.py @@ -0,0 +1,500 @@ +"""Slack service client — dual-mode: webhook fallback + Web API. + +Supports two modes: +1. Webhook mode: uses webhook_url for fire-and-forget notifications. +2. Web API mode: uses bot_token + channel_id for interactive messages, + message updates, and retrieving message timestamps. + +Block Kit payloads are built by pure functions that can be tested independently. +The httpx.AsyncClient is injected via constructor for testability. +""" + +from datetime import date +from types import TracebackType +from typing import Self + +import httpx + +from release_agent.models.ticket import TicketEntry +from release_agent.tools._http import raise_for_status + +_SLACK_API_BASE = "https://slack.com/api" + + +class SlackClient: + """Client for sending messages to Slack. + + Supports two modes determined by which parameters are provided: + - Webhook mode: provide webhook_url. + - Web API mode: provide bot_token and channel_id. + + Both modes can be configured simultaneously; Web API takes priority + for methods that require it (send_interactive_approval, update_message). + Webhook is used as fallback for simple notifications. + + Args: + webhook_url: The Slack incoming webhook URL (optional). + bot_token: Slack bot OAuth token (xoxb-...) for Web API (optional). + channel_id: Slack channel ID for Web API messages (optional). + http_client: Injected httpx.AsyncClient. + """ + + def __init__( + self, + *, + webhook_url: str = "", + bot_token: str = "", + channel_id: str = "", + http_client: httpx.AsyncClient, + ) -> None: + self._webhook_url = webhook_url + self._bot_token = bot_token + self._channel_id = channel_id + self._http = http_client + + async def close(self) -> None: + """Close the underlying HTTP client.""" + await self._http.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + # ------------------------------------------------------------------ + # Public methods: fire-and-forget notifications + # ------------------------------------------------------------------ + + async def send_release_notification( + self, + *, + repo: str, + version: str, + release_date: date, + tickets: list[TicketEntry], + ) -> bool: + """Send a release notification to Slack. + + Args: + repo: Repository name. + version: Release version string (e.g. "v1.2.0"). + release_date: Date of the release. + tickets: List of tickets included in the release. + + Returns: + True on success. + + Raises: + ServiceError: If the Slack webhook returns an error response. + """ + blocks = _build_release_blocks( + repo=repo, + version=version, + release_date=release_date, + tickets=tickets, + ) + return await self._post_blocks(blocks) + + async def send_approval_request( + self, + *, + action: str, + details: str, + approval_url: str, + ) -> bool: + """Send an approval request notification to Slack. + + Args: + action: Description of the action requiring approval. + details: Additional context for the approval. + approval_url: URL where the approver can approve the action. + + Returns: + True on success. + + Raises: + ServiceError: If the Slack webhook returns an error response. + """ + blocks = _build_approval_blocks( + action=action, + details=details, + approval_url=approval_url, + ) + return await self._post_blocks(blocks) + + async def send_notification(self, *, text: str, blocks: list[dict]) -> bool: + """Send a plain notification to Slack. + + Uses Web API if bot_token is configured, otherwise falls back to + the webhook URL. + + Args: + text: Fallback text for notifications. + blocks: Block Kit blocks to include in the message. + + Returns: + True on success, False on error. + """ + if self._bot_token and self._channel_id: + ts = await self._post_via_web_api(text=text, blocks=blocks) + return ts != "" + if self._webhook_url: + try: + return await self._post_blocks(blocks or [{"type": "section", "text": {"type": "mrkdwn", "text": text}}]) + except Exception: + return False + return False + + # ------------------------------------------------------------------ + # Public methods: interactive messages (Web API only) + # ------------------------------------------------------------------ + + async def send_interactive_approval( + self, + *, + thread_id: str, + action: str, + details: str, + buttons: list[dict], + ) -> str: + """Send an interactive approval message with clickable buttons. + + Requires bot_token and channel_id. Returns the message timestamp + (ts) which can be used to update the message later. + + Args: + thread_id: The graph thread_id, embedded in button payloads + so the interactions endpoint can resume the correct thread. + action: Human-readable description of the action (e.g. "Deploy to Sandbox"). + details: Additional context shown in the message body. + buttons: List of button dicts with "text" and "value" keys. + + Returns: + Message timestamp string (e.g. "1234567890.123456") on success, + or empty string on failure. + """ + blocks = _build_interactive_approval_blocks( + thread_id=thread_id, + action=action, + details=details, + buttons=buttons, + ) + return await self._post_via_web_api(text=action, blocks=blocks) + + async def update_message( + self, + *, + message_ts: str, + text: str, + blocks: list[dict], + ) -> bool: + """Update an existing Slack message (Web API chat.update). + + Args: + message_ts: Timestamp of the message to update. + text: New fallback text. + blocks: New Block Kit blocks. + + Returns: + True on success, False on failure. + """ + url = f"{_SLACK_API_BASE}/chat.update" + payload: dict = { + "channel": self._channel_id, + "ts": message_ts, + "text": text, + } + if blocks: + payload["blocks"] = blocks + + response = await self._http.post( + url, + json=payload, + headers=self._web_api_headers(), + ) + if response.status_code != 200: + return False + data = response.json() + return bool(data.get("ok")) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + async def _post_blocks(self, blocks: list[dict]) -> bool: + """POST a Block Kit payload to the webhook URL.""" + response = await self._http.post( + self._webhook_url, + json={"blocks": blocks}, + headers={"Content-Type": "application/json"}, + ) + raise_for_status(response, service="slack") + return True + + async def _post_via_web_api(self, *, text: str, blocks: list[dict]) -> str: + """POST a message via Slack Web API chat.postMessage. + + Returns the message ts on success, or empty string on failure. + """ + url = f"{_SLACK_API_BASE}/chat.postMessage" + payload: dict = { + "channel": self._channel_id, + "text": text, + } + if blocks: + payload["blocks"] = blocks + + response = await self._http.post( + url, + json=payload, + headers=self._web_api_headers(), + ) + if response.status_code != 200: + return "" + data = response.json() + if not data.get("ok"): + return "" + return data.get("ts", "") + + def _web_api_headers(self) -> dict: + """Return headers for Slack Web API requests.""" + return { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {self._bot_token}", + } + + +# --------------------------------------------------------------------------- +# Pure Block Kit builder functions +# --------------------------------------------------------------------------- + +def _build_release_blocks( + *, + repo: str, + version: str, + release_date: date, + tickets: list[TicketEntry], +) -> list[dict]: + """Build Slack Block Kit blocks for a release notification. + + This is a pure function with no side effects, allowing it to be tested + independently of HTTP calls. + + Args: + repo: Repository name. + version: Release version string. + release_date: Date of the release. + tickets: Tickets included in the release. + + Returns: + List of Slack Block Kit block dicts. + """ + header_block = { + "type": "header", + "text": { + "type": "plain_text", + "text": f"Release {version} - {repo}", + }, + } + date_block = { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"*Release date:* {release_date.isoformat()}", + }, + } + blocks: list[dict] = [header_block, date_block] + + if tickets: + ticket_lines = [f"- *{t.id}*: {t.summary}" for t in tickets] + ticket_text = "\n".join(ticket_lines) + ticket_block = { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"*Tickets:*\n{ticket_text}", + }, + } + blocks.append(ticket_block) + + return blocks + + +def _build_approval_blocks( + *, + action: str, + details: str, + approval_url: str, +) -> list[dict]: + """Build Slack Block Kit blocks for an approval request. + + This is a pure function with no side effects, allowing it to be tested + independently of HTTP calls. + + Args: + action: Description of the action requiring approval. + details: Additional context for the approval. + approval_url: URL to the approval page. + + Returns: + List of Slack Block Kit block dicts. + """ + header_block = { + "type": "header", + "text": { + "type": "plain_text", + "text": f"Approval Required: {action}", + }, + } + details_block = { + "type": "section", + "text": { + "type": "mrkdwn", + "text": details, + }, + } + actions_block = { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"<{approval_url}|Review and Approve>", + }, + } + return [header_block, details_block, actions_block] + + +def _build_interactive_approval_blocks( + *, + thread_id: str, + action: str, + details: str, + buttons: list[dict], +) -> list[dict]: + """Build Slack Block Kit blocks with interactive approval buttons. + + The thread_id is embedded in each button's action_id and value so the + /slack/interactions endpoint can resume the correct graph thread when + a button is clicked. + + Args: + thread_id: Graph thread ID to embed in button payloads. + action: Human-readable action description for the header. + details: Contextual details shown in the message body. + buttons: List of dicts with "text" (str) and "value" (str) keys. + + Returns: + List of Slack Block Kit block dicts. + """ + header_block = { + "type": "header", + "text": { + "type": "plain_text", + "text": f"Approval Required: {action}", + }, + } + details_block = { + "type": "section", + "text": { + "type": "mrkdwn", + "text": details, + }, + } + blocks: list[dict] = [header_block, details_block] + + if buttons: + button_elements = [ + { + "type": "button", + "text": {"type": "plain_text", "text": btn["text"]}, + "value": f"{thread_id}:{btn['value']}", + "action_id": f"approval_{btn['value']}_{thread_id}", + } + for btn in buttons + ] + actions_block = { + "type": "actions", + "elements": button_elements, + } + blocks.append(actions_block) + + return blocks + + +def _build_ci_status_blocks( + *, + repo: str, + branch: str, + status: str, + build_url: str | None, +) -> list[dict]: + """Build Slack Block Kit blocks for a CI build status notification. + + Args: + repo: Repository name. + branch: Branch that was built. + status: Build result status (e.g. "succeeded", "failed"). + build_url: Direct URL to the build results page, or None. + + Returns: + List of Slack Block Kit block dicts. + """ + header_block = { + "type": "header", + "text": { + "type": "plain_text", + "text": f"CI Build: {repo}", + }, + } + detail_text = f"*Branch:* {branch}\n*Status:* {status}" + if build_url: + detail_text += f"\n<{build_url}|View Build>" + + details_block = { + "type": "section", + "text": { + "type": "mrkdwn", + "text": detail_text, + }, + } + return [header_block, details_block] + + +def _build_resolved_approval_blocks( + *, + action: str, + outcome: str, + user: str, +) -> list[dict]: + """Build Slack Block Kit blocks for a resolved approval message. + + Replaces the interactive approval message after the decision is made. + + Args: + action: The original action that was approved or rejected. + outcome: The approval outcome (e.g. "approved", "rejected"). + user: The Slack user who made the decision. + + Returns: + List of Slack Block Kit block dicts. + """ + header_block = { + "type": "header", + "text": { + "type": "plain_text", + "text": f"Approval {outcome.capitalize()}: {action}", + }, + } + details_block = { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"Decision: *{outcome}* by {user}", + }, + } + return [header_block, details_block] diff --git a/src/release_agent/versioning.py b/src/release_agent/versioning.py new file mode 100644 index 0000000..02f78fb --- /dev/null +++ b/src/release_agent/versioning.py @@ -0,0 +1,70 @@ +"""Version calculation utilities for release management. + +Pure functions only - no side effects, no mutation. +All version strings must include a 'v' prefix (e.g. "v1.2.3"). +""" + +import re + +_VERSION_PATTERN = re.compile(r"^v(\d+)\.(\d+)\.(\d+)$") + + +def parse_version(version_str: str) -> tuple[int, int, int]: + """Parse a version string into a (major, minor, patch) tuple. + + Accepts strings with or without a leading 'v' prefix. + + Raises: + ValueError: If the string does not match the expected format. + """ + normalized = version_str if version_str.startswith("v") else f"v{version_str}" + match = _VERSION_PATTERN.match(normalized) + if not match: + raise ValueError(f"Invalid version string: {version_str!r}") + return int(match.group(1)), int(match.group(2)), int(match.group(3)) + + +def format_version(major: int, minor: int, patch: int) -> str: + """Format a (major, minor, patch) tuple into a version string with 'v' prefix. + + Example: format_version(1, 0, 3) -> "v1.0.3" + """ + return f"v{major}.{minor}.{patch}" + + +def calculate_next_version(repo_name: str, existing_versions: list[str]) -> str: + """Calculate the next patch version given a list of existing version strings. + + Filters out any malformed version strings (those that do not match + vMAJOR.MINOR.PATCH). Finds the highest valid version and increments + the patch component by one. If no valid versions exist, returns "v1.0.0". + + The repo_name parameter is accepted for extensibility but does not + affect the calculation. + + Args: + repo_name: The repository name (unused in calculation). + existing_versions: A list of version strings, may include malformed ones. + + Returns: + The next version string in "vX.Y.Z" format. + """ + valid_versions: list[tuple[int, int, int]] = [] + + for version_str in existing_versions: + # Only consider strings that already have the 'v' prefix to match spec: + # "Versions without 'v' prefix are treated as malformed." + if not version_str.startswith("v"): + continue + try: + parsed = parse_version(version_str) + valid_versions.append(parsed) + except ValueError: + continue + + if not valid_versions: + return "v1.0.0" + + highest = max(valid_versions) + major, minor, patch = highest + return format_version(major, minor, patch + 1) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/test_approvals.py b/tests/api/test_approvals.py new file mode 100644 index 0000000..4af5c61 --- /dev/null +++ b/tests/api/test_approvals.py @@ -0,0 +1,259 @@ +"""Tests for approvals endpoint. Written FIRST (TDD RED phase).""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from release_agent.api.approvals import router as approvals_router + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_test_app( + *, + interrupted_threads: list[dict] | None = None, + graph_resume_result: dict | None = None, +) -> FastAPI: + """Return a FastAPI app with mocked state for approvals tests.""" + app = FastAPI() + app.include_router(approvals_router) + + if interrupted_threads is None: + interrupted_threads = [] + + mock_settings = MagicMock() + mock_settings.operator_token.get_secret_value.return_value = "" + mock_graphs = { + "pr_completed": MagicMock(), + "release": MagicMock(), + } + mock_clients = MagicMock() + + # Mock pool that returns interrupted threads from DB + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + rows = [ + ( + t["thread_id"], + t.get("graph_name", "pr_completed"), + t.get("interrupt_value", "Confirm?"), + t.get("created_at", datetime.now(tz=timezone.utc)), + t.get("repo_name"), + t.get("pr_id"), + t.get("version"), + ) + for t in interrupted_threads + ] + mock_cursor.fetchall = AsyncMock(return_value=rows) + mock_cursor.fetchone = AsyncMock(return_value=("pr_completed",)) + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + app.state.settings = mock_settings + app.state.graphs = mock_graphs + app.state.tool_clients = mock_clients + app.state.db_pool = mock_pool + app.state.background_tasks = set() + + return app + + +# --------------------------------------------------------------------------- +# POST /approvals/{thread_id} +# --------------------------------------------------------------------------- + +class TestPostApproval: + def test_valid_merge_decision_returns_200(self) -> None: + app = _make_test_app() + mock_graph = MagicMock() + mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]}) + app.state.graphs["pr_completed"] = mock_graph + + with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: + mock_resume.return_value = {"messages": ["resumed"]} + with TestClient(app) as client: + response = client.post( + "/approvals/thread-123", + json={"decision": "merge"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["thread_id"] == "thread-123" + assert "status" in data + assert "message" in data + + def test_valid_cancel_decision_returns_200(self) -> None: + app = _make_test_app() + with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: + mock_resume.return_value = {"messages": ["cancelled"]} + with TestClient(app) as client: + response = client.post( + "/approvals/thread-456", + json={"decision": "cancel"}, + ) + assert response.status_code == 200 + + def test_invalid_decision_returns_422(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.post( + "/approvals/thread-123", + json={"decision": "invalid_decision"}, + ) + assert response.status_code == 422 + + def test_missing_decision_returns_422(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.post( + "/approvals/thread-123", + json={}, + ) + assert response.status_code == 422 + + def test_response_contains_thread_id(self) -> None: + app = _make_test_app() + with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: + mock_resume.return_value = {} + with TestClient(app) as client: + response = client.post( + "/approvals/my-thread-id", + json={"decision": "approve"}, + ) + assert response.json()["thread_id"] == "my-thread-id" + + def test_approve_decision_returns_200(self) -> None: + app = _make_test_app() + with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: + mock_resume.return_value = {} + with TestClient(app) as client: + response = client.post( + "/approvals/t1", + json={"decision": "approve"}, + ) + assert response.status_code == 200 + + def test_skip_decision_returns_200(self) -> None: + app = _make_test_app() + with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: + mock_resume.return_value = {} + with TestClient(app) as client: + response = client.post( + "/approvals/t1", + json={"decision": "skip"}, + ) + assert response.status_code == 200 + + def test_trigger_decision_returns_200(self) -> None: + app = _make_test_app() + with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume: + mock_resume.return_value = {} + with TestClient(app) as client: + response = client.post( + "/approvals/t1", + json={"decision": "trigger"}, + ) + assert response.status_code == 200 + + +# --------------------------------------------------------------------------- +# GET /approvals/pending +# --------------------------------------------------------------------------- + +class TestGetPendingApprovals: + def test_empty_pending_returns_200(self) -> None: + app = _make_test_app(interrupted_threads=[]) + with TestClient(app) as client: + response = client.get("/approvals/pending") + assert response.status_code == 200 + data = response.json() + assert data["count"] == 0 + assert data["items"] == [] + + def test_pending_approvals_list_structure(self) -> None: + now = datetime.now(tz=timezone.utc) + threads = [ + { + "thread_id": "t1", + "graph_name": "pr_completed", + "interrupt_value": "Confirm merge?", + "created_at": now, + "repo_name": "my-repo", + "pr_id": "42", + "version": "v1.0.0", + } + ] + app = _make_test_app(interrupted_threads=threads) + with TestClient(app) as client: + response = client.get("/approvals/pending") + assert response.status_code == 200 + data = response.json() + assert data["count"] == 1 + assert data["items"][0]["thread_id"] == "t1" + assert data["items"][0]["graph_name"] == "pr_completed" + + def test_multiple_pending_approvals(self) -> None: + now = datetime.now(tz=timezone.utc) + threads = [ + { + "thread_id": f"t{i}", + "graph_name": "pr_completed", + "interrupt_value": "Confirm?", + "created_at": now, + "repo_name": None, + "pr_id": None, + "version": None, + } + for i in range(3) + ] + app = _make_test_app(interrupted_threads=threads) + with TestClient(app) as client: + response = client.get("/approvals/pending") + assert response.status_code == 200 + data = response.json() + assert data["count"] == 3 + assert len(data["items"]) == 3 + + def test_pending_approval_optional_fields_nullable(self) -> None: + now = datetime.now(tz=timezone.utc) + threads = [ + { + "thread_id": "t1", + "graph_name": "release", + "interrupt_value": "Run release?", + "created_at": now, + "repo_name": None, + "pr_id": None, + "version": None, + } + ] + app = _make_test_app(interrupted_threads=threads) + with TestClient(app) as client: + response = client.get("/approvals/pending") + item = response.json()["items"][0] + assert item["repo_name"] is None + assert item["pr_id"] is None + assert item["version"] is None + + +# --------------------------------------------------------------------------- +# _resume_graph helper function tests +# --------------------------------------------------------------------------- + +class TestResumeGraph: + def test_resume_graph_callable(self) -> None: + from release_agent.api.approvals import _resume_graph + import inspect + assert inspect.iscoroutinefunction(_resume_graph) diff --git a/tests/api/test_approvals_with_auth.py b/tests/api/test_approvals_with_auth.py new file mode 100644 index 0000000..888f473 --- /dev/null +++ b/tests/api/test_approvals_with_auth.py @@ -0,0 +1,139 @@ +"""Tests for approvals endpoints with operator token authentication. + +Phase 5 - Step 3: Verifies that POST /approvals/{thread_id} and +GET /approvals/pending require operator token when configured. +Written FIRST (TDD RED phase). +""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from release_agent.api.approvals import router as approvals_router + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_test_app(operator_token: str = "") -> FastAPI: + """Return a FastAPI app with approvals router and configurable operator token.""" + app = FastAPI() + app.include_router(approvals_router) + + mock_settings = MagicMock() + mock_settings.operator_token.get_secret_value.return_value = operator_token + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.fetchall = AsyncMock(return_value=[]) + mock_cursor.fetchone = AsyncMock(return_value=("pr_completed",)) + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + app.state.settings = mock_settings + app.state.graphs = { + "pr_completed": MagicMock(), + "release": MagicMock(), + } + app.state.tool_clients = MagicMock() + app.state.db_pool = mock_pool + app.state.background_tasks = set() + + return app + + +# --------------------------------------------------------------------------- +# POST /approvals/{thread_id} with auth +# --------------------------------------------------------------------------- + +class TestPostApprovalWithAuth: + def test_valid_token_allows_post_approval(self) -> None: + app = _make_test_app(operator_token="secret-token") + with patch( + "release_agent.api.approvals._resume_graph", new_callable=AsyncMock + ) as mock_resume: + mock_resume.return_value = {} + with TestClient(app) as client: + response = client.post( + "/approvals/thread-123", + json={"decision": "merge"}, + headers={"X-Operator-Token": "secret-token"}, + ) + assert response.status_code == 200 + + def test_missing_token_rejects_post_approval(self) -> None: + app = _make_test_app(operator_token="secret-token") + with TestClient(app) as client: + response = client.post( + "/approvals/thread-123", + json={"decision": "merge"}, + ) + assert response.status_code == 401 + + def test_wrong_token_rejects_post_approval(self) -> None: + app = _make_test_app(operator_token="secret-token") + with TestClient(app) as client: + response = client.post( + "/approvals/thread-123", + json={"decision": "merge"}, + headers={"X-Operator-Token": "wrong-token"}, + ) + assert response.status_code == 401 + + def test_no_auth_required_when_token_not_configured(self) -> None: + app = _make_test_app(operator_token="") + with patch( + "release_agent.api.approvals._resume_graph", new_callable=AsyncMock + ) as mock_resume: + mock_resume.return_value = {} + with TestClient(app) as client: + response = client.post( + "/approvals/thread-123", + json={"decision": "merge"}, + ) + assert response.status_code == 200 + + +# --------------------------------------------------------------------------- +# GET /approvals/pending with auth +# --------------------------------------------------------------------------- + +class TestGetPendingApprovalsWithAuth: + def test_valid_token_allows_get_pending(self) -> None: + app = _make_test_app(operator_token="secret-token") + with TestClient(app) as client: + response = client.get( + "/approvals/pending", + headers={"X-Operator-Token": "secret-token"}, + ) + assert response.status_code == 200 + + def test_missing_token_rejects_get_pending(self) -> None: + app = _make_test_app(operator_token="secret-token") + with TestClient(app) as client: + response = client.get("/approvals/pending") + assert response.status_code == 401 + + def test_wrong_token_rejects_get_pending(self) -> None: + app = _make_test_app(operator_token="secret-token") + with TestClient(app) as client: + response = client.get( + "/approvals/pending", + headers={"X-Operator-Token": "wrong"}, + ) + assert response.status_code == 401 + + def test_no_auth_required_when_token_not_configured(self) -> None: + app = _make_test_app(operator_token="") + with TestClient(app) as client: + response = client.get("/approvals/pending") + assert response.status_code == 200 diff --git a/tests/api/test_dependencies.py b/tests/api/test_dependencies.py new file mode 100644 index 0000000..71fbbbb --- /dev/null +++ b/tests/api/test_dependencies.py @@ -0,0 +1,149 @@ +"""Tests for API FastAPI dependencies. Written FIRST (TDD RED phase).""" + +from unittest.mock import MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from release_agent.api.dependencies import ( + get_db_pool, + get_graphs, + get_settings, + get_staging_store, + get_tool_clients, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_app_with_state(**state_kwargs) -> FastAPI: + """Return a minimal FastAPI app with app.state attributes set.""" + app = FastAPI() + for key, value in state_kwargs.items(): + setattr(app.state, key, value) + return app + + +# --------------------------------------------------------------------------- +# get_settings +# --------------------------------------------------------------------------- + +class TestGetSettings: + def test_returns_settings_from_state(self) -> None: + mock_settings = MagicMock() + app = _make_app_with_state(settings=mock_settings) + + with TestClient(app) as client: + # We test the dependency directly by simulating a request + request = MagicMock() + request.app = app + result = get_settings(request) + assert result is mock_settings + + def test_raises_when_settings_missing(self) -> None: + app = FastAPI() # no state.settings + + request = MagicMock() + request.app = app + + with pytest.raises(AttributeError): + get_settings(request) + + +# --------------------------------------------------------------------------- +# get_graphs +# --------------------------------------------------------------------------- + +class TestGetGraphs: + def test_returns_graphs_from_state(self) -> None: + mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()} + app = _make_app_with_state(graphs=mock_graphs) + + request = MagicMock() + request.app = app + result = get_graphs(request) + assert result is mock_graphs + + def test_raises_when_graphs_missing(self) -> None: + app = FastAPI() + + request = MagicMock() + request.app = app + + with pytest.raises(AttributeError): + get_graphs(request) + + +# --------------------------------------------------------------------------- +# get_tool_clients +# --------------------------------------------------------------------------- + +class TestGetToolClients: + def test_returns_tool_clients_from_state(self) -> None: + mock_clients = MagicMock() + app = _make_app_with_state(tool_clients=mock_clients) + + request = MagicMock() + request.app = app + result = get_tool_clients(request) + assert result is mock_clients + + def test_raises_when_tool_clients_missing(self) -> None: + app = FastAPI() + + request = MagicMock() + request.app = app + + with pytest.raises(AttributeError): + get_tool_clients(request) + + +# --------------------------------------------------------------------------- +# get_staging_store +# --------------------------------------------------------------------------- + +class TestGetStagingStore: + def test_returns_staging_store_from_state(self) -> None: + mock_store = MagicMock() + app = _make_app_with_state(staging_store=mock_store) + + request = MagicMock() + request.app = app + result = get_staging_store(request) + assert result is mock_store + + def test_raises_when_staging_store_missing(self) -> None: + app = FastAPI() + + request = MagicMock() + request.app = app + + with pytest.raises(AttributeError): + get_staging_store(request) + + +# --------------------------------------------------------------------------- +# get_db_pool +# --------------------------------------------------------------------------- + +class TestGetDbPool: + def test_returns_db_pool_from_state(self) -> None: + mock_pool = MagicMock() + app = _make_app_with_state(db_pool=mock_pool) + + request = MagicMock() + request.app = app + result = get_db_pool(request) + assert result is mock_pool + + def test_raises_when_db_pool_missing(self) -> None: + app = FastAPI() + + request = MagicMock() + request.app = app + + with pytest.raises(AttributeError): + get_db_pool(request) diff --git a/tests/api/test_internals.py b/tests/api/test_internals.py new file mode 100644 index 0000000..fd92268 --- /dev/null +++ b/tests/api/test_internals.py @@ -0,0 +1,446 @@ +"""Tests for internal async helper functions. + +Tests _run_graph, _upsert_thread, _resume_graph, and exception handlers. +Written FIRST then verified (TDD GREEN phase for internal helpers). +""" + +import json +from unittest.mock import AsyncMock, MagicMock, call, patch + +import pytest + + +# --------------------------------------------------------------------------- +# _upsert_thread tests +# --------------------------------------------------------------------------- + +class TestUpsertThread: + @pytest.mark.asyncio + async def test_upsert_thread_executes_sql(self) -> None: + from release_agent.api.webhooks import _upsert_thread + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _upsert_thread( + mock_pool, + thread_id="t1", + thread_status="running", + state={"repo_name": "my-repo"}, + ) + + mock_cursor.execute.assert_called_once() + args = mock_cursor.execute.call_args[0] + assert "agent_threads" in args[0] + assert args[1][0] == "t1" + assert args[1][4] == "running" + # state is JSON-encoded + state_json = json.loads(args[1][5]) + assert state_json["repo_name"] == "my-repo" + + @pytest.mark.asyncio + async def test_upsert_thread_completed_status(self) -> None: + from release_agent.api.webhooks import _upsert_thread + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _upsert_thread( + mock_pool, + thread_id="t2", + thread_status="completed", + state={}, + ) + + args = mock_cursor.execute.call_args[0] + assert args[1][4] == "completed" + + @pytest.mark.asyncio + async def test_upsert_thread_failed_status(self) -> None: + from release_agent.api.webhooks import _upsert_thread + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _upsert_thread( + mock_pool, + thread_id="t3", + thread_status="failed", + state={"errors": ["something went wrong"]}, + ) + + args = mock_cursor.execute.call_args[0] + assert args[1][4] == "failed" + state_json = json.loads(args[1][5]) + assert state_json["errors"] == ["something went wrong"] + + +# --------------------------------------------------------------------------- +# _run_graph tests +# --------------------------------------------------------------------------- + +class TestRunGraph: + @pytest.mark.asyncio + async def test_run_graph_success_upserts_completed(self) -> None: + from release_agent.api.webhooks import _run_graph + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]}) + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _run_graph( + graph=mock_graph, + initial_state={"repo_name": "test"}, + thread_id="t1", + tool_clients=MagicMock(), + db_pool=mock_pool, + ) + + # Should have been called with "running" then "completed" + calls = mock_cursor.execute.call_args_list + assert len(calls) == 2 + # First call: "running", second call: "completed" + assert calls[0][0][1][4] == "running" + assert calls[1][0][1][4] == "completed" + + @pytest.mark.asyncio + async def test_run_graph_failure_upserts_failed(self) -> None: + from release_agent.api.webhooks import _run_graph + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("graph crashed")) + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _run_graph( + graph=mock_graph, + initial_state={"repo_name": "test"}, + thread_id="t-fail", + tool_clients=MagicMock(), + db_pool=mock_pool, + ) + + calls = mock_cursor.execute.call_args_list + # First call: "running", second call: "failed" + assert calls[0][0][1][4] == "running" + assert calls[1][0][1][4] == "failed" + # State should contain errors + failed_state = json.loads(calls[1][0][1][5]) + assert "errors" in failed_state + + @pytest.mark.asyncio + async def test_run_graph_invokes_with_correct_config(self) -> None: + from release_agent.api.webhooks import _run_graph + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value={}) + mock_clients = MagicMock() + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _run_graph( + graph=mock_graph, + initial_state={"repo_name": "test"}, + thread_id="t-config", + tool_clients=mock_clients, + db_pool=mock_pool, + ) + + call_args = mock_graph.ainvoke.call_args + config = call_args[1]["config"] + assert config["configurable"]["thread_id"] == "t-config" + assert config["configurable"]["clients"] is mock_clients + + +# --------------------------------------------------------------------------- +# _resume_graph tests +# --------------------------------------------------------------------------- + +class TestResumeGraphInternal: + @pytest.mark.asyncio + async def test_resume_graph_success(self) -> None: + from release_agent.api.approvals import _resume_graph + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value={"result": "ok"}) + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + result = await _resume_graph( + graph=mock_graph, + thread_id="t1", + decision="merge", + tool_clients=MagicMock(), + db_pool=mock_pool, + ) + + assert result == {"result": "ok"} + mock_graph.ainvoke.assert_called_once() + # Verify the decision was passed + call_args = mock_graph.ainvoke.call_args + assert call_args[0][0]["decision"] == "merge" + + @pytest.mark.asyncio + async def test_resume_graph_failure_re_raises(self) -> None: + from release_agent.api.approvals import _resume_graph + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("resume failed")) + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + with pytest.raises(RuntimeError, match="resume failed"): + await _resume_graph( + graph=mock_graph, + thread_id="t1", + decision="cancel", + tool_clients=MagicMock(), + db_pool=mock_pool, + ) + + @pytest.mark.asyncio + async def test_resume_graph_upserts_completed_on_success(self) -> None: + from release_agent.api.approvals import _resume_graph + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]}) + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _resume_graph( + graph=mock_graph, + thread_id="t-success", + decision="approve", + tool_clients=MagicMock(), + db_pool=mock_pool, + ) + + # The last execute call should be "completed" + last_call = mock_cursor.execute.call_args_list[-1] + assert last_call[0][1][4] == "completed" + + @pytest.mark.asyncio + async def test_resume_graph_upserts_failed_on_exception(self) -> None: + from release_agent.api.approvals import _resume_graph + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(side_effect=ValueError("bad")) + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + with pytest.raises(ValueError): + await _resume_graph( + graph=mock_graph, + thread_id="t-fail", + decision="skip", + tool_clients=MagicMock(), + db_pool=mock_pool, + ) + + last_call = mock_cursor.execute.call_args_list[-1] + assert last_call[0][1][4] == "failed" + + +# --------------------------------------------------------------------------- +# run_graph_in_background tests (main.py) +# --------------------------------------------------------------------------- + +class TestRunGraphInBackground: + @pytest.mark.asyncio + async def test_success_with_db_pool(self) -> None: + from release_agent.main import run_graph_in_background + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value={"done": True}) + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await run_graph_in_background( + graph=mock_graph, + initial_state={"repo_name": "test"}, + thread_id="t-bg", + db_pool=mock_pool, + ) + + calls = mock_cursor.execute.call_args_list + assert calls[0][0][1][4] == "running" + assert calls[1][0][1][4] == "completed" + + @pytest.mark.asyncio + async def test_failure_with_db_pool(self) -> None: + from release_agent.main import run_graph_in_background + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("bg failed")) + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await run_graph_in_background( + graph=mock_graph, + initial_state={}, + thread_id="t-bg-fail", + db_pool=mock_pool, + ) + + last_call = mock_cursor.execute.call_args_list[-1] + assert last_call[0][1][4] == "failed" + + @pytest.mark.asyncio + async def test_success_without_db_pool(self) -> None: + """run_graph_in_background works even without a db_pool.""" + from release_agent.main import run_graph_in_background + + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value={}) + + # Should not raise even with no db_pool + await run_graph_in_background( + graph=mock_graph, + initial_state={}, + thread_id="t-no-pool", + db_pool=None, + ) + + mock_graph.ainvoke.assert_called_once() + + +# --------------------------------------------------------------------------- +# Exception handler tests (main.py) +# --------------------------------------------------------------------------- + +class TestExceptionHandlerFunctions: + @pytest.mark.asyncio + async def test_release_agent_error_handler_returns_500(self) -> None: + from release_agent.main import _release_agent_error_handler + from release_agent.exceptions import ServiceError + + request = MagicMock() + exc = ServiceError(service="azdo", status_code=503, detail="unavailable") + + response = await _release_agent_error_handler(request, exc) + + assert response.status_code == 500 + body = json.loads(response.body) + assert body["error"] == "ServiceError" + assert "unavailable" in body["detail"] + + @pytest.mark.asyncio + async def test_generic_error_handler_returns_500(self) -> None: + from release_agent.main import _generic_error_handler + + request = MagicMock() + exc = ValueError("something generic") + + response = await _generic_error_handler(request, exc) + + assert response.status_code == 500 + body = json.loads(response.body) + assert body["error"] == "InternalServerError" + assert "An unexpected error occurred" in body["detail"] diff --git a/tests/api/test_models.py b/tests/api/test_models.py new file mode 100644 index 0000000..36bd270 --- /dev/null +++ b/tests/api/test_models.py @@ -0,0 +1,294 @@ +"""Tests for API request/response models. Written FIRST (TDD RED phase).""" + +from datetime import datetime, timezone + +import pytest +from pydantic import ValidationError + +from release_agent.api.models import ( + ApprovalDecision, + ApprovalResponse, + ErrorResponse, + HealthResponse, + ManualReleaseRequest, + ManualTriggerResponse, + PendingApproval, + PendingApprovalsResponse, + ReleaseVersionListResponse, + StagingResponse, + WebhookResponse, +) + + +# --------------------------------------------------------------------------- +# WebhookResponse +# --------------------------------------------------------------------------- + +class TestWebhookResponse: + def test_valid_construction(self) -> None: + resp = WebhookResponse(thread_id="thread-123", message="scheduled") + assert resp.thread_id == "thread-123" + assert resp.message == "scheduled" + + def test_frozen_immutable(self) -> None: + resp = WebhookResponse(thread_id="t1", message="ok") + with pytest.raises((TypeError, ValidationError)): + resp.thread_id = "other" # type: ignore[misc] + + def test_missing_thread_id_raises(self) -> None: + with pytest.raises(ValidationError): + WebhookResponse(message="ok") # type: ignore[call-arg] + + def test_missing_message_raises(self) -> None: + with pytest.raises(ValidationError): + WebhookResponse(thread_id="t1") # type: ignore[call-arg] + + +# --------------------------------------------------------------------------- +# ApprovalDecision +# --------------------------------------------------------------------------- + +class TestApprovalDecision: + def test_merge_decision(self) -> None: + d = ApprovalDecision(decision="merge") + assert d.decision == "merge" + + def test_cancel_decision(self) -> None: + d = ApprovalDecision(decision="cancel") + assert d.decision == "cancel" + + def test_approve_decision(self) -> None: + d = ApprovalDecision(decision="approve") + assert d.decision == "approve" + + def test_skip_decision(self) -> None: + d = ApprovalDecision(decision="skip") + assert d.decision == "skip" + + def test_trigger_decision(self) -> None: + d = ApprovalDecision(decision="trigger") + assert d.decision == "trigger" + + def test_invalid_decision_raises(self) -> None: + with pytest.raises(ValidationError): + ApprovalDecision(decision="invalid") # type: ignore[arg-type] + + def test_frozen_immutable(self) -> None: + d = ApprovalDecision(decision="merge") + with pytest.raises((TypeError, ValidationError)): + d.decision = "cancel" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# ApprovalResponse +# --------------------------------------------------------------------------- + +class TestApprovalResponse: + def test_valid_construction(self) -> None: + resp = ApprovalResponse( + thread_id="t1", status="resumed", message="Graph resumed" + ) + assert resp.thread_id == "t1" + assert resp.status == "resumed" + assert resp.message == "Graph resumed" + + def test_frozen_immutable(self) -> None: + resp = ApprovalResponse(thread_id="t1", status="ok", message="m") + with pytest.raises((TypeError, ValidationError)): + resp.status = "bad" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# PendingApproval +# --------------------------------------------------------------------------- + +class TestPendingApproval: + def test_full_construction(self) -> None: + now = datetime.now(tz=timezone.utc) + pa = PendingApproval( + thread_id="t1", + graph_name="pr_completed", + interrupt_value="Confirm merge?", + created_at=now, + repo_name="my-repo", + pr_id="42", + version="v1.2.3", + ) + assert pa.thread_id == "t1" + assert pa.graph_name == "pr_completed" + assert pa.repo_name == "my-repo" + assert pa.pr_id == "42" + assert pa.version == "v1.2.3" + + def test_optional_fields_none(self) -> None: + now = datetime.now(tz=timezone.utc) + pa = PendingApproval( + thread_id="t1", + graph_name="release", + interrupt_value="Confirm?", + created_at=now, + ) + assert pa.repo_name is None + assert pa.pr_id is None + assert pa.version is None + + def test_frozen_immutable(self) -> None: + now = datetime.now(tz=timezone.utc) + pa = PendingApproval( + thread_id="t1", + graph_name="g", + interrupt_value="v", + created_at=now, + ) + with pytest.raises((TypeError, ValidationError)): + pa.thread_id = "other" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# PendingApprovalsResponse +# --------------------------------------------------------------------------- + +class TestPendingApprovalsResponse: + def test_empty_list(self) -> None: + resp = PendingApprovalsResponse(items=[], count=0) + assert resp.items == [] + assert resp.count == 0 + + def test_with_items(self) -> None: + now = datetime.now(tz=timezone.utc) + item = PendingApproval( + thread_id="t1", + graph_name="g", + interrupt_value="v", + created_at=now, + ) + resp = PendingApprovalsResponse(items=[item], count=1) + assert resp.count == 1 + assert len(resp.items) == 1 + + def test_frozen_immutable(self) -> None: + resp = PendingApprovalsResponse(items=[], count=0) + with pytest.raises((TypeError, ValidationError)): + resp.count = 5 # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# HealthResponse +# --------------------------------------------------------------------------- + +class TestHealthResponse: + def test_ok_status(self) -> None: + resp = HealthResponse(status="ok", version="0.1.0", uptime_seconds=123.4) + assert resp.status == "ok" + assert resp.version == "0.1.0" + assert resp.uptime_seconds == pytest.approx(123.4) + + def test_degraded_status(self) -> None: + resp = HealthResponse(status="degraded", version="0.1.0", uptime_seconds=0.0) + assert resp.status == "degraded" + + def test_invalid_status_raises(self) -> None: + with pytest.raises(ValidationError): + HealthResponse(status="unknown", version="0.1.0", uptime_seconds=0.0) # type: ignore[arg-type] + + def test_frozen_immutable(self) -> None: + resp = HealthResponse(status="ok", version="0.1.0", uptime_seconds=1.0) + with pytest.raises((TypeError, ValidationError)): + resp.status = "degraded" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# ReleaseVersionListResponse +# --------------------------------------------------------------------------- + +class TestReleaseVersionListResponse: + def test_valid_construction(self) -> None: + resp = ReleaseVersionListResponse(repo="my-repo", versions=["v1.0.0", "v1.1.0"]) + assert resp.repo == "my-repo" + assert resp.versions == ["v1.0.0", "v1.1.0"] + + def test_empty_versions(self) -> None: + resp = ReleaseVersionListResponse(repo="r", versions=[]) + assert resp.versions == [] + + def test_frozen_immutable(self) -> None: + resp = ReleaseVersionListResponse(repo="r", versions=[]) + with pytest.raises((TypeError, ValidationError)): + resp.repo = "other" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# StagingResponse +# --------------------------------------------------------------------------- + +class TestStagingResponse: + def test_with_staging(self) -> None: + staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []} + resp = StagingResponse(repo="my-repo", staging=staging_data) + assert resp.repo == "my-repo" + assert resp.staging is not None + assert resp.staging["version"] == "v1.0.0" + + def test_without_staging(self) -> None: + resp = StagingResponse(repo="my-repo", staging=None) + assert resp.staging is None + + def test_frozen_immutable(self) -> None: + resp = StagingResponse(repo="r", staging=None) + with pytest.raises((TypeError, ValidationError)): + resp.repo = "other" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# ManualTriggerResponse +# --------------------------------------------------------------------------- + +class TestManualTriggerResponse: + def test_valid_construction(self) -> None: + resp = ManualTriggerResponse(thread_id="t1", message="triggered") + assert resp.thread_id == "t1" + assert resp.message == "triggered" + + def test_frozen_immutable(self) -> None: + resp = ManualTriggerResponse(thread_id="t1", message="m") + with pytest.raises((TypeError, ValidationError)): + resp.thread_id = "other" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# ManualReleaseRequest +# --------------------------------------------------------------------------- + +class TestManualReleaseRequest: + def test_valid_construction(self) -> None: + req = ManualReleaseRequest(repo="my-repo") + assert req.repo == "my-repo" + + def test_missing_repo_raises(self) -> None: + with pytest.raises(ValidationError): + ManualReleaseRequest() # type: ignore[call-arg] + + def test_frozen_immutable(self) -> None: + req = ManualReleaseRequest(repo="r") + with pytest.raises((TypeError, ValidationError)): + req.repo = "other" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# ErrorResponse +# --------------------------------------------------------------------------- + +class TestErrorResponse: + def test_error_only(self) -> None: + resp = ErrorResponse(error="Something went wrong") + assert resp.error == "Something went wrong" + assert resp.detail is None + + def test_error_with_detail(self) -> None: + resp = ErrorResponse(error="Not found", detail="Thread t1 not found") + assert resp.detail == "Thread t1 not found" + + def test_frozen_immutable(self) -> None: + resp = ErrorResponse(error="e") + with pytest.raises((TypeError, ValidationError)): + resp.error = "other" # type: ignore[misc] diff --git a/tests/api/test_operator_auth.py b/tests/api/test_operator_auth.py new file mode 100644 index 0000000..f4514dc --- /dev/null +++ b/tests/api/test_operator_auth.py @@ -0,0 +1,111 @@ +"""Tests for operator token authentication dependency. + +Phase 5 - Step 3: require_operator_token FastAPI dependency. +Written FIRST (TDD RED phase). +""" + +from unittest.mock import MagicMock + +import pytest +from fastapi import FastAPI, Depends, HTTPException +from fastapi.testclient import TestClient + +from release_agent.api.dependencies import require_operator_token + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_app_with_token(operator_token: str = "") -> FastAPI: + """Return a minimal app with a protected route and the given token config.""" + app = FastAPI() + + mock_settings = MagicMock() + mock_settings.operator_token.get_secret_value.return_value = operator_token + app.state.settings = mock_settings + + @app.get("/protected") + async def protected_route(_: None = Depends(require_operator_token)): + return {"ok": True} + + return app + + +# --------------------------------------------------------------------------- +# require_operator_token tests +# --------------------------------------------------------------------------- + +class TestRequireOperatorToken: + def test_valid_token_allows_access(self) -> None: + app = _make_app_with_token("super-secret-token") + with TestClient(app) as client: + response = client.get( + "/protected", + headers={"X-Operator-Token": "super-secret-token"}, + ) + assert response.status_code == 200 + + def test_missing_token_header_returns_401_when_token_configured(self) -> None: + app = _make_app_with_token("super-secret-token") + with TestClient(app) as client: + response = client.get("/protected") + assert response.status_code == 401 + + def test_wrong_token_returns_401(self) -> None: + app = _make_app_with_token("super-secret-token") + with TestClient(app) as client: + response = client.get( + "/protected", + headers={"X-Operator-Token": "wrong-token"}, + ) + assert response.status_code == 401 + + def test_empty_operator_token_config_skips_auth(self) -> None: + """When operator_token is not configured (empty), all requests pass.""" + app = _make_app_with_token("") + with TestClient(app) as client: + response = client.get("/protected") + assert response.status_code == 200 + + def test_empty_operator_token_config_passes_even_without_header(self) -> None: + app = _make_app_with_token("") + with TestClient(app) as client: + response = client.get("/protected", headers={}) + assert response.status_code == 200 + + def test_token_comparison_is_constant_time(self) -> None: + """Verify hmac.compare_digest is used (not == operator) — tested by checking + that the function still works correctly, not timing (which we can't test here).""" + app = _make_app_with_token("my-token") + with TestClient(app) as client: + response = client.get( + "/protected", + headers={"X-Operator-Token": "my-token"}, + ) + assert response.status_code == 200 + + def test_empty_string_token_header_rejected_when_token_configured(self) -> None: + app = _make_app_with_token("configured-token") + with TestClient(app) as client: + response = client.get( + "/protected", + headers={"X-Operator-Token": ""}, + ) + assert response.status_code == 401 + + def test_401_response_has_detail_field(self) -> None: + app = _make_app_with_token("secret") + with TestClient(app) as client: + response = client.get("/protected") + data = response.json() + assert "detail" in data + + def test_valid_token_returns_correct_response_body(self) -> None: + app = _make_app_with_token("token123") + with TestClient(app) as client: + response = client.get( + "/protected", + headers={"X-Operator-Token": "token123"}, + ) + assert response.json() == {"ok": True} diff --git a/tests/api/test_slack_interactions.py b/tests/api/test_slack_interactions.py new file mode 100644 index 0000000..f61b076 --- /dev/null +++ b/tests/api/test_slack_interactions.py @@ -0,0 +1,473 @@ +"""Tests for api/slack_interactions.py endpoint. + +Written FIRST (TDD RED phase). + +Tests cover: +- Signature verification (HMAC-SHA256) +- Payload parsing +- Button routing +- _resume_graph invocation +- Error handling +""" + +import hashlib +import hmac +import json +import time +import urllib.parse +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from release_agent.api.slack_interactions import router as slack_interactions_router +from release_agent.api.slack_interactions import _verify_slack_signature + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_TEST_SIGNING_SECRET = "test-signing-secret-abc" + + +def _make_slack_signature(*, signing_secret: str, timestamp: str, body: str) -> str: + """Compute a valid Slack signing signature.""" + base_string = f"v0:{timestamp}:{body}" + sig = hmac.new( + signing_secret.encode(), + base_string.encode(), + hashlib.sha256, + ).hexdigest() + return f"v0={sig}" + + +def _make_test_app( + *, + signing_secret: str = _TEST_SIGNING_SECRET, + thread_graph_name: str | None = "release", + graph_result: dict | None = None, +) -> FastAPI: + """Return a FastAPI test app with mocked state for slack interactions.""" + app = FastAPI() + app.include_router(slack_interactions_router) + + mock_settings = MagicMock() + mock_settings.slack_signing_secret.get_secret_value.return_value = signing_secret + mock_settings.operator_token.get_secret_value.return_value = "" + + mock_graph = MagicMock() + mock_graph.ainvoke = AsyncMock(return_value=graph_result or {"messages": ["done"]}) + + mock_graphs = { + "release": mock_graph, + "pr_completed": MagicMock(), + } + mock_clients = MagicMock() + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.fetchone = AsyncMock( + return_value=(thread_graph_name,) if thread_graph_name else None + ) + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + app.state.settings = mock_settings + app.state.graphs = mock_graphs + app.state.tool_clients = mock_clients + app.state.db_pool = mock_pool + app.state.background_tasks = set() + + return app + + +def _make_button_payload( + *, + thread_id: str = "test-thread-123", + value: str = "approve", + user_id: str = "U12345", + user_name: str = "alice", +) -> str: + """Build a URL-encoded Slack button action payload.""" + payload = { + "type": "block_actions", + "user": {"id": user_id, "name": user_name}, + "actions": [ + { + "type": "button", + "value": f"{thread_id}:{value}", + "action_id": f"approval_{value}_{thread_id}", + } + ], + } + return urllib.parse.urlencode({"payload": json.dumps(payload)}) + + +# --------------------------------------------------------------------------- +# _verify_slack_signature pure function tests +# --------------------------------------------------------------------------- + +class TestVerifySlackSignature: + """Tests for the _verify_slack_signature pure function.""" + + def test_returns_true_for_valid_signature(self) -> None: + timestamp = str(int(time.time())) + body = "test=body&data=here" + sig = _make_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + ) + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + signature=sig, + ) is True + + def test_returns_false_for_wrong_secret(self) -> None: + timestamp = str(int(time.time())) + body = "test=body" + sig = _make_slack_signature( + signing_secret="wrong-secret", + timestamp=timestamp, + body=body, + ) + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + signature=sig, + ) is False + + def test_returns_false_for_tampered_body(self) -> None: + timestamp = str(int(time.time())) + original_body = "original=body" + sig = _make_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=original_body, + ) + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body="tampered=body", + signature=sig, + ) is False + + def test_returns_false_for_wrong_timestamp(self) -> None: + body = "test=body" + sig = _make_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp="1000000", + body=body, + ) + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp="9999999", + body=body, + signature=sig, + ) is False + + def test_returns_false_for_malformed_signature(self) -> None: + timestamp = str(int(time.time())) + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body="body", + signature="not-a-valid-sig", + ) is False + + def test_returns_false_for_empty_signature(self) -> None: + timestamp = str(int(time.time())) + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body="body", + signature="", + ) is False + + def test_uses_hmac_sha256(self) -> None: + timestamp = "1234567890" + body = "payload=data" + base = f"v0:{timestamp}:{body}" + expected_hash = hmac.new( + _TEST_SIGNING_SECRET.encode(), + base.encode(), + hashlib.sha256, + ).hexdigest() + sig = f"v0={expected_hash}" + + # Inject current_time matching timestamp to bypass replay prevention + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + signature=sig, + current_time=1234567890.0, + ) is True + + def test_rejects_stale_timestamp(self) -> None: + old_timestamp = "1000000000" # year 2001 + body = "payload=data" + base = f"v0:{old_timestamp}:{body}" + expected_hash = hmac.new( + _TEST_SIGNING_SECRET.encode(), + base.encode(), + hashlib.sha256, + ).hexdigest() + sig = f"v0={expected_hash}" + + # Valid signature but timestamp too old + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=old_timestamp, + body=body, + signature=sig, + ) is False + + def test_rejects_non_integer_timestamp(self) -> None: + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp="not-a-number", + body="body", + signature="v0=abc", + ) is False + + def test_signature_prefix_must_be_v0(self) -> None: + timestamp = "1234567890" + body = "payload=data" + base = f"v0:{timestamp}:{body}" + hash_val = hmac.new( + _TEST_SIGNING_SECRET.encode(), + base.encode(), + hashlib.sha256, + ).hexdigest() + wrong_prefix_sig = f"v1={hash_val}" + + assert _verify_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + signature=wrong_prefix_sig, + ) is False + + +# --------------------------------------------------------------------------- +# POST /slack/interactions endpoint tests +# --------------------------------------------------------------------------- + +class TestSlackInteractionsEndpoint: + """Tests for POST /slack/interactions.""" + + def test_returns_200_for_valid_request(self) -> None: + app = _make_test_app() + client = TestClient(app, raise_server_exceptions=False) + + timestamp = str(int(time.time())) + body = _make_button_payload(thread_id="t-abc", value="approve") + sig = _make_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + ) + + response = client.post( + "/slack/interactions", + content=body, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "X-Slack-Request-Timestamp": timestamp, + "X-Slack-Signature": sig, + }, + ) + + assert response.status_code == 200 + + def test_returns_403_for_invalid_signature(self) -> None: + app = _make_test_app() + client = TestClient(app, raise_server_exceptions=False) + + timestamp = str(int(time.time())) + body = _make_button_payload() + + response = client.post( + "/slack/interactions", + content=body, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "X-Slack-Request-Timestamp": timestamp, + "X-Slack-Signature": "v0=invalid_signature", + }, + ) + + assert response.status_code == 403 + + def test_returns_400_when_missing_timestamp_header(self) -> None: + app = _make_test_app() + client = TestClient(app, raise_server_exceptions=False) + + body = _make_button_payload() + + response = client.post( + "/slack/interactions", + content=body, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "X-Slack-Signature": "v0=something", + }, + ) + + assert response.status_code in (400, 403, 422) + + def test_rejects_when_signing_secret_not_configured(self) -> None: + app = _make_test_app(signing_secret="") + client = TestClient(app, raise_server_exceptions=False) + + timestamp = str(int(time.time())) + body = _make_button_payload(thread_id="t-abc", value="approve") + + response = client.post( + "/slack/interactions", + content=body, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "X-Slack-Request-Timestamp": timestamp, + "X-Slack-Signature": "v0=any_sig", + }, + ) + + assert response.status_code == 503 + + def test_returns_200_with_approve_action(self) -> None: + app = _make_test_app() + client = TestClient(app, raise_server_exceptions=False) + + timestamp = str(int(time.time())) + body = _make_button_payload(thread_id="thread-1", value="approve") + sig = _make_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + ) + + response = client.post( + "/slack/interactions", + content=body, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "X-Slack-Request-Timestamp": timestamp, + "X-Slack-Signature": sig, + }, + ) + + assert response.status_code == 200 + + def test_returns_200_with_reject_action(self) -> None: + app = _make_test_app() + client = TestClient(app, raise_server_exceptions=False) + + timestamp = str(int(time.time())) + body = _make_button_payload(thread_id="thread-2", value="reject") + sig = _make_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + ) + + response = client.post( + "/slack/interactions", + content=body, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "X-Slack-Request-Timestamp": timestamp, + "X-Slack-Signature": sig, + }, + ) + + assert response.status_code == 200 + + def test_schedules_graph_resume_in_background(self) -> None: + app = _make_test_app() + client = TestClient(app, raise_server_exceptions=False) + + timestamp = str(int(time.time())) + body = _make_button_payload(thread_id="t-bg", value="approve") + sig = _make_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + ) + + response = client.post( + "/slack/interactions", + content=body, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "X-Slack-Request-Timestamp": timestamp, + "X-Slack-Signature": sig, + }, + ) + + assert response.status_code == 200 + + def test_returns_404_for_unknown_thread(self) -> None: + app = _make_test_app(thread_graph_name=None) + client = TestClient(app, raise_server_exceptions=False) + + timestamp = str(int(time.time())) + body = _make_button_payload(thread_id="unknown-thread", value="approve") + sig = _make_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + ) + + response = client.post( + "/slack/interactions", + content=body, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "X-Slack-Request-Timestamp": timestamp, + "X-Slack-Signature": sig, + }, + ) + + # Should return 200 immediately (Slack requires immediate 200) + # but the background task may log an error + assert response.status_code == 200 + + def test_response_body_is_empty_or_ok(self) -> None: + app = _make_test_app() + client = TestClient(app, raise_server_exceptions=False) + + timestamp = str(int(time.time())) + body = _make_button_payload(thread_id="t-ok", value="approve") + sig = _make_slack_signature( + signing_secret=_TEST_SIGNING_SECRET, + timestamp=timestamp, + body=body, + ) + + response = client.post( + "/slack/interactions", + content=body, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "X-Slack-Request-Timestamp": timestamp, + "X-Slack-Signature": sig, + }, + ) + + assert response.status_code == 200 + # Body may be empty or a simple JSON with ok=True + if response.content: + data = response.json() + assert data.get("ok") is True or "ok" not in data diff --git a/tests/api/test_status.py b/tests/api/test_status.py new file mode 100644 index 0000000..6f55783 --- /dev/null +++ b/tests/api/test_status.py @@ -0,0 +1,270 @@ +"""Tests for status, releases, staging, and manual trigger endpoints. +Written FIRST (TDD RED phase). +""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from release_agent.api.status import router as status_router + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_test_app( + *, + versions: list[str] | None = None, + staging_data: dict | None = None, +) -> FastAPI: + """Return a FastAPI app with mocked state for status tests.""" + app = FastAPI() + app.include_router(status_router) + + mock_settings = MagicMock() + mock_settings.operator_token.get_secret_value.return_value = "" + mock_graphs = { + "pr_completed": MagicMock(), + "release": MagicMock(), + } + mock_clients = MagicMock() + + mock_staging_store = MagicMock() + mock_staging_store.list_versions = AsyncMock(return_value=versions or []) + + # staging store returns StagingRelease-like or None + if staging_data is not None: + mock_staging_obj = MagicMock() + mock_staging_obj.model_dump = MagicMock(return_value=staging_data) + mock_staging_store.load = AsyncMock(return_value=mock_staging_obj) + else: + mock_staging_store.load = AsyncMock(return_value=None) + + mock_pool = MagicMock() + + app.state.settings = mock_settings + app.state.graphs = mock_graphs + app.state.tool_clients = mock_clients + app.state.staging_store = mock_staging_store + app.state.db_pool = mock_pool + app.state.background_tasks = set() + app.state.started_at = datetime.now(tz=timezone.utc) + + return app + + +# --------------------------------------------------------------------------- +# GET /status +# --------------------------------------------------------------------------- + +class TestGetStatus: + def test_returns_200(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.get("/status") + assert response.status_code == 200 + + def test_response_has_status_field(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.get("/status") + data = response.json() + assert "status" in data + assert data["status"] in ("ok", "degraded") + + def test_response_has_version_field(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.get("/status") + data = response.json() + assert "version" in data + assert isinstance(data["version"], str) + + def test_response_has_uptime_seconds(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.get("/status") + data = response.json() + assert "uptime_seconds" in data + assert data["uptime_seconds"] >= 0.0 + + def test_status_is_ok_when_healthy(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.get("/status") + assert response.json()["status"] == "ok" + + +# --------------------------------------------------------------------------- +# GET /releases/{repo} +# --------------------------------------------------------------------------- + +class TestGetReleaseVersions: + def test_returns_200(self) -> None: + app = _make_test_app(versions=["v1.0.0", "v1.1.0"]) + with TestClient(app) as client: + response = client.get("/releases/my-repo") + assert response.status_code == 200 + + def test_response_has_repo_and_versions(self) -> None: + app = _make_test_app(versions=["v1.0.0", "v1.1.0"]) + with TestClient(app) as client: + response = client.get("/releases/my-repo") + data = response.json() + assert data["repo"] == "my-repo" + assert data["versions"] == ["v1.0.0", "v1.1.0"] + + def test_empty_versions_list(self) -> None: + app = _make_test_app(versions=[]) + with TestClient(app) as client: + response = client.get("/releases/unknown-repo") + data = response.json() + assert data["versions"] == [] + + def test_repo_name_in_path_used(self) -> None: + mock_staging_store = MagicMock() + mock_staging_store.list_versions = AsyncMock(return_value=[]) + app = _make_test_app() + app.state.staging_store = mock_staging_store + + with TestClient(app) as client: + client.get("/releases/specific-repo") + + mock_staging_store.list_versions.assert_called_once_with("specific-repo") + + +# --------------------------------------------------------------------------- +# GET /staging +# --------------------------------------------------------------------------- + +class TestGetStaging: + def test_returns_200_with_staging(self) -> None: + staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []} + app = _make_test_app(staging_data=staging_data) + with TestClient(app) as client: + response = client.get("/staging?repo=my-repo") + assert response.status_code == 200 + + def test_response_has_repo_and_staging(self) -> None: + staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []} + app = _make_test_app(staging_data=staging_data) + with TestClient(app) as client: + response = client.get("/staging?repo=my-repo") + data = response.json() + assert data["repo"] == "my-repo" + assert data["staging"] is not None + assert data["staging"]["version"] == "v1.0.0" + + def test_returns_null_staging_when_not_found(self) -> None: + app = _make_test_app(staging_data=None) + with TestClient(app) as client: + response = client.get("/staging?repo=no-staging-repo") + assert response.status_code == 200 + data = response.json() + assert data["staging"] is None + + def test_missing_repo_query_returns_422(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.get("/staging") + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# POST /manual/pr/{pr_id} +# --------------------------------------------------------------------------- + +class TestManualPrTrigger: + def test_returns_202(self) -> None: + app = _make_test_app() + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post("/manual/pr/42") + assert response.status_code == 202 + + def test_response_has_thread_id(self) -> None: + app = _make_test_app() + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post("/manual/pr/42") + data = response.json() + assert "thread_id" in data + assert isinstance(data["thread_id"], str) + + def test_response_has_message(self) -> None: + app = _make_test_app() + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post("/manual/pr/42") + assert "message" in response.json() + + def test_schedules_background_task(self) -> None: + app = _make_test_app() + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ) as mock_create: + with TestClient(app) as client: + client.post("/manual/pr/99") + mock_create.assert_called_once() + + +# --------------------------------------------------------------------------- +# POST /manual/release +# --------------------------------------------------------------------------- + +class TestManualReleaseTrigger: + def test_returns_202(self) -> None: + app = _make_test_app() + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post( + "/manual/release", + json={"repo": "my-repo"}, + ) + assert response.status_code == 202 + + def test_response_has_thread_id(self) -> None: + app = _make_test_app() + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post( + "/manual/release", + json={"repo": "my-repo"}, + ) + data = response.json() + assert "thread_id" in data + + def test_missing_repo_returns_422(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.post( + "/manual/release", + json={}, + ) + assert response.status_code == 422 + + def test_schedules_background_task(self) -> None: + app = _make_test_app() + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ) as mock_create: + with TestClient(app) as client: + client.post( + "/manual/release", + json={"repo": "my-repo"}, + ) + mock_create.assert_called_once() diff --git a/tests/api/test_status_with_auth.py b/tests/api/test_status_with_auth.py new file mode 100644 index 0000000..ebeca7e --- /dev/null +++ b/tests/api/test_status_with_auth.py @@ -0,0 +1,166 @@ +"""Tests for status/manual endpoints with operator token authentication. + +Phase 5 - Step 3: Verifies that POST /manual/* require operator token +when configured. GET endpoints are not protected. +Written FIRST (TDD RED phase). +""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from release_agent.api.status import router as status_router + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_test_app(operator_token: str = "") -> FastAPI: + """Return a FastAPI app with status router and configurable operator token.""" + app = FastAPI() + app.include_router(status_router) + + mock_settings = MagicMock() + mock_settings.operator_token.get_secret_value.return_value = operator_token + + mock_staging_store = MagicMock() + mock_staging_store.list_versions = AsyncMock(return_value=[]) + mock_staging_store.load = AsyncMock(return_value=None) + + mock_pool = MagicMock() + + app.state.settings = mock_settings + app.state.graphs = { + "pr_completed": MagicMock(), + "release": MagicMock(), + } + app.state.tool_clients = MagicMock() + app.state.staging_store = mock_staging_store + app.state.db_pool = mock_pool + app.state.background_tasks = set() + app.state.started_at = datetime.now(tz=timezone.utc) + + return app + + +# --------------------------------------------------------------------------- +# POST /manual/pr/{pr_id} with auth +# --------------------------------------------------------------------------- + +class TestManualPrTriggerWithAuth: + def test_valid_token_allows_manual_pr(self) -> None: + app = _make_test_app(operator_token="secure-token") + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post( + "/manual/pr/42", + headers={"X-Operator-Token": "secure-token"}, + ) + assert response.status_code == 202 + + def test_missing_token_rejects_manual_pr(self) -> None: + app = _make_test_app(operator_token="secure-token") + with TestClient(app) as client: + response = client.post("/manual/pr/42") + assert response.status_code == 401 + + def test_wrong_token_rejects_manual_pr(self) -> None: + app = _make_test_app(operator_token="secure-token") + with TestClient(app) as client: + response = client.post( + "/manual/pr/42", + headers={"X-Operator-Token": "bad-token"}, + ) + assert response.status_code == 401 + + def test_no_auth_required_when_token_not_configured(self) -> None: + app = _make_test_app(operator_token="") + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post("/manual/pr/42") + assert response.status_code == 202 + + +# --------------------------------------------------------------------------- +# POST /manual/release with auth +# --------------------------------------------------------------------------- + +class TestManualReleaseTriggerWithAuth: + def test_valid_token_allows_manual_release(self) -> None: + app = _make_test_app(operator_token="secure-token") + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post( + "/manual/release", + json={"repo": "my-repo"}, + headers={"X-Operator-Token": "secure-token"}, + ) + assert response.status_code == 202 + + def test_missing_token_rejects_manual_release(self) -> None: + app = _make_test_app(operator_token="secure-token") + with TestClient(app) as client: + response = client.post( + "/manual/release", + json={"repo": "my-repo"}, + ) + assert response.status_code == 401 + + def test_wrong_token_rejects_manual_release(self) -> None: + app = _make_test_app(operator_token="secure-token") + with TestClient(app) as client: + response = client.post( + "/manual/release", + json={"repo": "my-repo"}, + headers={"X-Operator-Token": "wrong"}, + ) + assert response.status_code == 401 + + def test_no_auth_required_when_token_not_configured(self) -> None: + app = _make_test_app(operator_token="") + with patch( + "release_agent.api.status.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post( + "/manual/release", + json={"repo": "my-repo"}, + ) + assert response.status_code == 202 + + +# --------------------------------------------------------------------------- +# GET /status, /releases, /staging do NOT require auth +# --------------------------------------------------------------------------- + +class TestReadEndpointsNoAuth: + def test_get_status_no_token_needed(self) -> None: + """GET /status should never require auth.""" + app = _make_test_app(operator_token="super-secret") + with TestClient(app) as client: + response = client.get("/status") + assert response.status_code == 200 + + def test_get_releases_no_token_needed(self) -> None: + app = _make_test_app(operator_token="super-secret") + with TestClient(app) as client: + response = client.get("/releases/my-repo") + assert response.status_code == 200 + + def test_get_staging_no_token_needed(self) -> None: + app = _make_test_app(operator_token="super-secret") + with TestClient(app) as client: + response = client.get("/staging?repo=my-repo") + assert response.status_code == 200 diff --git a/tests/api/test_webhooks.py b/tests/api/test_webhooks.py new file mode 100644 index 0000000..eb5c8a0 --- /dev/null +++ b/tests/api/test_webhooks.py @@ -0,0 +1,218 @@ +"""Tests for webhook endpoint. Written FIRST (TDD RED phase).""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from release_agent.api.webhooks import ( + _validate_webhook_secret, + router as webhook_router, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +VALID_SECRET = "super-secret-webhook-key" + +_COMPLETED_PR_PAYLOAD = { + "subscription_id": "sub-1", + "event_type": "git.pullrequest.updated", + "resource": { + "repository": { + "id": "repo-1", + "name": "my-repo", + "web_url": "https://dev.azure.com/org/project/_git/my-repo", + }, + "pull_request_id": 42, + "title": "feat: add feature", + "source_ref_name": "refs/heads/feature/BILL-123-add-feature", + "target_ref_name": "refs/heads/main", + "status": "completed", + "closed_date": "2024-01-15T10:00:00Z", + }, +} + +_ACTIVE_PR_PAYLOAD = { + "subscription_id": "sub-2", + "event_type": "git.pullrequest.updated", + "resource": { + "repository": { + "id": "repo-1", + "name": "my-repo", + "web_url": "https://dev.azure.com/org/project/_git/my-repo", + }, + "pull_request_id": 43, + "title": "WIP: work in progress", + "source_ref_name": "refs/heads/feature/BILL-456", + "target_ref_name": "refs/heads/main", + "status": "active", + "closed_date": None, + }, +} + + +def _make_test_app(webhook_secret: str = VALID_SECRET) -> FastAPI: + """Return a FastAPI app with mocked state for webhook tests.""" + app = FastAPI() + app.include_router(webhook_router) + + mock_settings = MagicMock() + mock_settings.webhook_secret.get_secret_value.return_value = webhook_secret + + mock_graphs = { + "pr_completed": MagicMock(), + "release": MagicMock(), + } + + mock_clients = MagicMock() + mock_pool = MagicMock() + + # background_tasks set tracked on state + app.state.settings = mock_settings + app.state.graphs = mock_graphs + app.state.tool_clients = mock_clients + app.state.db_pool = mock_pool + app.state.background_tasks = set() + + return app + + +# --------------------------------------------------------------------------- +# _validate_webhook_secret (unit tests, pure function) +# --------------------------------------------------------------------------- + +class TestValidateWebhookSecret: + def test_valid_secret_returns_true(self) -> None: + assert _validate_webhook_secret("mysecret", "mysecret") is True + + def test_wrong_secret_returns_false(self) -> None: + assert _validate_webhook_secret("wrong", "mysecret") is False + + def test_empty_header_returns_false(self) -> None: + assert _validate_webhook_secret("", "mysecret") is False + + def test_none_header_returns_false(self) -> None: + assert _validate_webhook_secret(None, "mysecret") is False # type: ignore[arg-type] + + def test_uses_constant_time_comparison(self) -> None: + # Should not raise even for very different lengths + assert _validate_webhook_secret("a", "very-long-secret-value") is False + + def test_empty_expected_rejects_all(self) -> None: + # Empty expected secret = auth misconfigured, reject everything + assert _validate_webhook_secret("", "") is False + assert _validate_webhook_secret("any-value", "") is False + assert _validate_webhook_secret(None, "") is False + + +# --------------------------------------------------------------------------- +# POST /webhooks/azdo — integration tests via TestClient +# --------------------------------------------------------------------------- + +class TestWebhookEndpoint: + def test_valid_completed_pr_returns_202(self) -> None: + app = _make_test_app() + with patch( + "release_agent.api.webhooks.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app, raise_server_exceptions=True) as client: + response = client.post( + "/webhooks/azdo", + json=_COMPLETED_PR_PAYLOAD, + headers={"X-Webhook-Secret": VALID_SECRET}, + ) + assert response.status_code == 202 + data = response.json() + assert "thread_id" in data + assert "message" in data + + def test_missing_secret_header_returns_401(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.post( + "/webhooks/azdo", + json=_COMPLETED_PR_PAYLOAD, + ) + assert response.status_code == 401 + + def test_wrong_secret_header_returns_401(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.post( + "/webhooks/azdo", + json=_COMPLETED_PR_PAYLOAD, + headers={"X-Webhook-Secret": "wrong-secret"}, + ) + assert response.status_code == 401 + + def test_invalid_payload_returns_422(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.post( + "/webhooks/azdo", + json={"invalid": "payload"}, + headers={"X-Webhook-Secret": VALID_SECRET}, + ) + assert response.status_code == 422 + + def test_active_pr_event_returns_200_ignored(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.post( + "/webhooks/azdo", + json=_ACTIVE_PR_PAYLOAD, + headers={"X-Webhook-Secret": VALID_SECRET}, + ) + assert response.status_code == 200 + data = response.json() + assert "ignored" in data.get("message", "").lower() or "ignored" in str(data).lower() + + def test_completed_pr_thread_id_is_string(self) -> None: + app = _make_test_app() + with patch( + "release_agent.api.webhooks.asyncio.create_task", return_value=MagicMock() + ): + with TestClient(app) as client: + response = client.post( + "/webhooks/azdo", + json=_COMPLETED_PR_PAYLOAD, + headers={"X-Webhook-Secret": VALID_SECRET}, + ) + assert response.status_code == 202 + assert isinstance(response.json()["thread_id"], str) + + def test_completed_pr_schedules_background_task(self) -> None: + app = _make_test_app() + task_mock = MagicMock() + with patch( + "release_agent.api.webhooks.asyncio.create_task", return_value=task_mock + ) as mock_create: + with TestClient(app) as client: + client.post( + "/webhooks/azdo", + json=_COMPLETED_PR_PAYLOAD, + headers={"X-Webhook-Secret": VALID_SECRET}, + ) + mock_create.assert_called_once() + + +# --------------------------------------------------------------------------- +# Error response shape +# --------------------------------------------------------------------------- + +class TestWebhookErrorShape: + def test_401_has_detail_field(self) -> None: + app = _make_test_app() + with TestClient(app) as client: + response = client.post( + "/webhooks/azdo", + json=_COMPLETED_PR_PAYLOAD, + ) + assert response.status_code == 401 + # FastAPI HTTPException returns {"detail": ...} + assert "detail" in response.json() diff --git a/tests/graph/__init__.py b/tests/graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/graph/conftest.py b/tests/graph/conftest.py new file mode 100644 index 0000000..13e1f89 --- /dev/null +++ b/tests/graph/conftest.py @@ -0,0 +1,44 @@ +"""Shared fixtures for graph tests. + +Provides build_mock_clients() to create ToolClients with AsyncMock fields +so individual node functions can be tested without compiling the full graph. +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from release_agent.graph.dependencies import ToolClients + + +def build_mock_clients() -> ToolClients: + """Return a ToolClients instance whose fields are all AsyncMock/MagicMock.""" + azdo = AsyncMock() + jira = AsyncMock() + slack = AsyncMock() + reviewer = AsyncMock() + return ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer) + + +def build_config(clients: ToolClients | None = None, staging_store=None) -> dict: + """Return a LangGraph-style config dict with clients and staging_store.""" + if clients is None: + clients = build_mock_clients() + return { + "configurable": { + "clients": clients, + "staging_store": staging_store, + } + } + + +@pytest.fixture() +def mock_clients() -> ToolClients: + """Pytest fixture returning fresh mock ToolClients.""" + return build_mock_clients() + + +@pytest.fixture() +def config(mock_clients: ToolClients): + """Pytest fixture returning a config dict with mock clients.""" + return build_config(mock_clients) diff --git a/tests/graph/test_ci_nodes.py b/tests/graph/test_ci_nodes.py new file mode 100644 index 0000000..42b79dd --- /dev/null +++ b/tests/graph/test_ci_nodes.py @@ -0,0 +1,294 @@ +"""Tests for graph/ci_nodes.py. + +Written FIRST (TDD RED phase). + +All external calls (azdo, slack, poll_until) are mocked. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from release_agent.graph.ci_nodes import notify_ci_result, poll_ci_build, trigger_ci_build +from release_agent.models.build import BuildStatus +from release_agent.models.pipeline import PipelineInfo +from tests.graph.conftest import build_config, build_mock_clients + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_pipeline(pipeline_id: int = 10, name: str = "CI-build") -> dict: + return {"id": pipeline_id, "name": name, "repo": "my-repo"} + + +# --------------------------------------------------------------------------- +# trigger_ci_build +# --------------------------------------------------------------------------- + +class TestTriggerCiBuild: + """Tests for trigger_ci_build node.""" + + async def test_triggers_pipeline_on_branch(self) -> None: + clients = build_mock_clients() + clients.azdo.list_build_pipelines.return_value = [ + PipelineInfo(id=10, name="CI", repo="my-repo") + ] + clients.azdo.trigger_pipeline.return_value = {"id": 555, "state": "inProgress"} + config = build_config(clients) + state = {"repo_name": "my-repo", "version": "v1.0.0"} + + result = await trigger_ci_build(state, config) + + clients.azdo.trigger_pipeline.assert_called_once() + assert "ci_build_id" in result + assert result["ci_build_id"] == 555 + + async def test_returns_ci_build_id(self) -> None: + clients = build_mock_clients() + clients.azdo.list_build_pipelines.return_value = [ + PipelineInfo(id=20, name="build-and-test", repo="my-repo") + ] + clients.azdo.trigger_pipeline.return_value = {"id": 999} + config = build_config(clients) + state = {"repo_name": "my-repo", "version": "v2.0.0"} + + result = await trigger_ci_build(state, config) + assert result["ci_build_id"] == 999 + + async def test_appends_error_when_no_pipelines_found(self) -> None: + clients = build_mock_clients() + clients.azdo.list_build_pipelines.return_value = [] + config = build_config(clients) + state = {"repo_name": "my-repo", "version": "v1.0.0"} + + result = await trigger_ci_build(state, config) + + assert "errors" in result + assert len(result["errors"]) >= 1 + + async def test_appends_error_on_trigger_failure(self) -> None: + from release_agent.exceptions import ServiceError + + clients = build_mock_clients() + clients.azdo.list_build_pipelines.return_value = [ + PipelineInfo(id=10, name="CI", repo="my-repo") + ] + clients.azdo.trigger_pipeline.side_effect = ServiceError( + service="azdo", status_code=500, detail="Internal error" + ) + config = build_config(clients) + state = {"repo_name": "my-repo", "version": "v1.0.0"} + + result = await trigger_ci_build(state, config) + + assert "errors" in result + + async def test_uses_main_branch_when_no_version(self) -> None: + clients = build_mock_clients() + clients.azdo.list_build_pipelines.return_value = [ + PipelineInfo(id=10, name="CI", repo="my-repo") + ] + clients.azdo.trigger_pipeline.return_value = {"id": 1} + config = build_config(clients) + state = {"repo_name": "my-repo"} + + result = await trigger_ci_build(state, config) + + call_kwargs = clients.azdo.trigger_pipeline.call_args[1] + branch = call_kwargs.get("branch", "") + assert "main" in branch or "refs/heads" in branch + + async def test_appends_message_on_success(self) -> None: + clients = build_mock_clients() + clients.azdo.list_build_pipelines.return_value = [ + PipelineInfo(id=10, name="CI", repo="my-repo") + ] + clients.azdo.trigger_pipeline.return_value = {"id": 123} + config = build_config(clients) + state = {"repo_name": "my-repo", "version": "v1.0.0"} + + result = await trigger_ci_build(state, config) + + assert "messages" in result + assert len(result["messages"]) >= 1 + + +# --------------------------------------------------------------------------- +# poll_ci_build +# --------------------------------------------------------------------------- + +class TestPollCiBuild: + """Tests for poll_ci_build node.""" + + async def test_returns_ci_build_status_and_result_on_completion(self) -> None: + clients = build_mock_clients() + completed_status = BuildStatus(status="completed", result="succeeded", build_url="https://build/1") + config = build_config(clients) + state = {"ci_build_id": 42, "repo_name": "my-repo"} + + with patch( + "release_agent.graph.ci_nodes.poll_until", + return_value=(completed_status, True), + ): + result = await poll_ci_build(state, config) + + assert result["ci_build_status"] == "completed" + assert result["ci_build_result"] == "succeeded" + + async def test_returns_build_url(self) -> None: + clients = build_mock_clients() + completed_status = BuildStatus( + status="completed", + result="succeeded", + build_url="https://dev.azure.com/build/42", + ) + config = build_config(clients) + state = {"ci_build_id": 42, "repo_name": "my-repo"} + + with patch( + "release_agent.graph.ci_nodes.poll_until", + return_value=(completed_status, True), + ): + result = await poll_ci_build(state, config) + + assert result.get("ci_build_url") == "https://dev.azure.com/build/42" + + async def test_appends_error_on_timeout(self) -> None: + clients = build_mock_clients() + running_status = BuildStatus(status="inProgress", result=None, build_url=None) + config = build_config(clients) + state = {"ci_build_id": 42, "repo_name": "my-repo"} + + with patch( + "release_agent.graph.ci_nodes.poll_until", + return_value=(running_status, False), + ): + result = await poll_ci_build(state, config) + + assert "errors" in result + + async def test_appends_error_when_build_id_missing(self) -> None: + clients = build_mock_clients() + config = build_config(clients) + state = {"repo_name": "my-repo"} # no ci_build_id + + result = await poll_ci_build(state, config) + + assert "errors" in result + + async def test_passes_correct_build_id_to_poll_fn(self) -> None: + clients = build_mock_clients() + clients.azdo.get_build_status.return_value = BuildStatus( + status="completed", result="succeeded", build_url=None + ) + config = build_config(clients) + state = {"ci_build_id": 77, "repo_name": "my-repo"} + + async def fake_poll_until(*, poll_fn, is_done, interval_seconds, max_wait_seconds, sleep_fn=None): + result = await poll_fn() + return result, True + + with patch("release_agent.graph.ci_nodes.poll_until", side_effect=fake_poll_until): + await poll_ci_build(state, config) + + clients.azdo.get_build_status.assert_called_once_with(build_id=77) + + async def test_result_none_when_poll_returns_none(self) -> None: + clients = build_mock_clients() + config = build_config(clients) + state = {"ci_build_id": 42, "repo_name": "my-repo"} + + with patch( + "release_agent.graph.ci_nodes.poll_until", + return_value=(None, False), + ): + result = await poll_ci_build(state, config) + + assert "errors" in result + + +# --------------------------------------------------------------------------- +# notify_ci_result +# --------------------------------------------------------------------------- + +class TestNotifyCiResult: + """Tests for notify_ci_result node.""" + + async def test_sends_notification_on_success(self) -> None: + clients = build_mock_clients() + clients.slack.send_notification.return_value = True + config = build_config(clients) + state = { + "repo_name": "my-repo", + "ci_build_status": "completed", + "ci_build_result": "succeeded", + "ci_build_url": "https://build/99", + } + + result = await notify_ci_result(state, config) + + clients.slack.send_notification.assert_called_once() + assert "messages" in result + + async def test_sends_notification_on_failure(self) -> None: + clients = build_mock_clients() + clients.slack.send_notification.return_value = True + config = build_config(clients) + state = { + "repo_name": "my-repo", + "ci_build_status": "completed", + "ci_build_result": "failed", + "ci_build_url": None, + } + + result = await notify_ci_result(state, config) + + clients.slack.send_notification.assert_called_once() + + async def test_handles_slack_error_gracefully(self) -> None: + from release_agent.exceptions import ServiceError + + clients = build_mock_clients() + clients.slack.send_notification.side_effect = ServiceError( + service="slack", status_code=500, detail="Slack error" + ) + config = build_config(clients) + state = { + "repo_name": "my-repo", + "ci_build_result": "succeeded", + "ci_build_url": None, + } + + result = await notify_ci_result(state, config) + + # Should not re-raise; should append error + assert "errors" in result + + async def test_includes_repo_name_in_message(self) -> None: + clients = build_mock_clients() + clients.slack.send_notification.return_value = True + config = build_config(clients) + state = { + "repo_name": "super-service", + "ci_build_result": "succeeded", + "ci_build_url": None, + } + + await notify_ci_result(state, config) + + call_kwargs = clients.slack.send_notification.call_args[1] + text_or_blocks = str(call_kwargs) + assert "super-service" in text_or_blocks + + async def test_returns_empty_dict_when_state_has_no_data(self) -> None: + clients = build_mock_clients() + clients.slack.send_notification.return_value = True + config = build_config(clients) + state = {} + + result = await notify_ci_result(state, config) + + # Should not crash; may return messages or empty dict + assert isinstance(result, dict) diff --git a/tests/graph/test_dependencies.py b/tests/graph/test_dependencies.py new file mode 100644 index 0000000..0c6dcc3 --- /dev/null +++ b/tests/graph/test_dependencies.py @@ -0,0 +1,283 @@ +"""Tests for graph/dependencies.py. Written FIRST (TDD RED phase). + +Covers: +- ToolClients frozen dataclass +- StagingStore Protocol (structural check) +- JsonFileStagingStore file I/O operations +""" + +import json +from datetime import date +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest + +from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients +from release_agent.models.release import ArchivedRelease, StagingRelease +from release_agent.models.ticket import TicketEntry + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry: + return TicketEntry( + id=ticket_id, + summary="Fix something", + pr_id="42", + pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42", + pr_title="Fix: something", + branch=f"feature/{ticket_id}-fix", + merged_at=date(2025, 1, 15), + ) + + +def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease: + return StagingRelease( + version=version, + repo=repo, + started_at=date(2025, 1, 1), + tickets=[], + ) + + +# --------------------------------------------------------------------------- +# ToolClients tests +# --------------------------------------------------------------------------- + +class TestToolClients: + """Tests for the ToolClients frozen dataclass.""" + + def test_can_be_constructed_with_all_fields(self) -> None: + azdo = AsyncMock() + jira = AsyncMock() + slack = AsyncMock() + reviewer = AsyncMock() + clients = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer) + assert clients.azdo is azdo + assert clients.jira is jira + assert clients.slack is slack + assert clients.reviewer is reviewer + + def test_is_frozen_cannot_reassign_field(self) -> None: + clients = ToolClients( + azdo=AsyncMock(), jira=AsyncMock(), slack=AsyncMock(), reviewer=AsyncMock() + ) + with pytest.raises((AttributeError, TypeError)): + clients.azdo = AsyncMock() # type: ignore[misc] + + def test_fields_are_accessible_by_name(self) -> None: + azdo = object() + clients = ToolClients( + azdo=azdo, jira=object(), slack=object(), reviewer=object() + ) + assert clients.azdo is azdo + + def test_equality_for_same_instances(self) -> None: + azdo = AsyncMock() + jira = AsyncMock() + slack = AsyncMock() + reviewer = AsyncMock() + c1 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer) + c2 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer) + assert c1 == c2 + + +# --------------------------------------------------------------------------- +# StagingStore Protocol structural tests +# --------------------------------------------------------------------------- + +class TestStagingStoreProtocol: + """Verify that the Protocol is structurally correct.""" + + def test_json_file_store_satisfies_protocol(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + # runtime_checkable would need @runtime_checkable; check duck-typing instead + assert hasattr(store, "load") + assert hasattr(store, "save") + assert hasattr(store, "archive") + assert hasattr(store, "list_versions") + + def test_protocol_is_importable(self) -> None: + # Just import-level check + assert StagingStore is not None + + +# --------------------------------------------------------------------------- +# JsonFileStagingStore tests +# --------------------------------------------------------------------------- + +class TestJsonFileStagingStore: + """Tests for JsonFileStagingStore using tmp_path for file I/O.""" + + # ------------------------------------------------------------------ + # load + # ------------------------------------------------------------------ + + async def test_load_returns_none_when_file_missing(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + result = await store.load("nonexistent-repo") + assert result is None + + async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + loaded = await store.load("my-repo") + assert loaded is not None + assert loaded.version == "v1.0.0" + assert loaded.repo == "my-repo" + + async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging().add_ticket(_make_ticket("BILL-10")) + await store.save(staging) + loaded = await store.load("my-repo") + assert loaded is not None + assert len(loaded.tickets) == 1 + assert loaded.tickets[0].id == "BILL-10" + + async def test_load_is_read_only_does_not_mutate_stored(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + loaded1 = await store.load("my-repo") + loaded2 = await store.load("my-repo") + assert loaded1 is not loaded2 # fresh objects each time + + # ------------------------------------------------------------------ + # save + # ------------------------------------------------------------------ + + async def test_save_creates_file_in_directory(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging(repo="api-service") + await store.save(staging) + expected_path = tmp_path / "api-service.json" + assert expected_path.exists() + + async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging_v1 = _make_staging(version="v1.0.0") + staging_v2 = _make_staging(version="v1.0.1") + await store.save(staging_v1) + await store.save(staging_v2) + loaded = await store.load("my-repo") + assert loaded is not None + assert loaded.version == "v1.0.1" + + async def test_save_writes_valid_json(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + raw = (tmp_path / "my-repo.json").read_text() + data = json.loads(raw) + assert data["version"] == "v1.0.0" + assert data["repo"] == "my-repo" + + async def test_save_does_not_mutate_staging_release(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + original_tickets = list(staging.tickets) + await store.save(staging) + assert list(staging.tickets) == original_tickets + + # ------------------------------------------------------------------ + # archive + # ------------------------------------------------------------------ + + async def test_archive_removes_staging_file(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + await store.archive(staging, date(2025, 6, 1)) + assert await store.load("my-repo") is None + + async def test_archive_creates_archive_file(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging(repo="my-repo", version="v1.0.0") + await store.save(staging) + await store.archive(staging, date(2025, 6, 1)) + archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json" + assert archive_path.exists() + + async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + release_date = date(2025, 6, 1) + await store.archive(staging, release_date) + archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json" + data = json.loads(archive_path.read_text()) + assert data["released_at"] == "2025-06-01" + + async def test_archive_without_prior_save_creates_archive(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.archive(staging, date(2025, 6, 1)) + archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json" + assert archive_path.exists() + + # ------------------------------------------------------------------ + # list_versions + # ------------------------------------------------------------------ + + async def test_list_versions_empty_directory(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + versions = await store.list_versions("my-repo") + assert versions == [] + + async def test_list_versions_returns_version_from_staging_file(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + await store.save(_make_staging(version="v2.1.0")) + versions = await store.list_versions("my-repo") + assert "v2.1.0" in versions + + async def test_list_versions_includes_archived_versions(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging(version="v1.5.0") + await store.save(staging) + await store.archive(staging, date(2025, 3, 1)) + # Now save a new staging for the same repo + await store.save(_make_staging(version="v1.6.0")) + versions = await store.list_versions("my-repo") + assert "v1.5.0" in versions + assert "v1.6.0" in versions + + async def test_list_versions_only_returns_versions_for_given_repo(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + await store.save(_make_staging(repo="repo-a", version="v1.0.0")) + await store.save(_make_staging(repo="repo-b", version="v2.0.0")) + versions_a = await store.list_versions("repo-a") + assert "v1.0.0" in versions_a + # repo-b version should not appear in repo-a's list + assert "v2.0.0" not in versions_a + + async def test_list_versions_no_duplicates(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + await store.save(_make_staging(version="v1.0.0")) + versions = await store.list_versions("my-repo") + assert len(versions) == len(set(versions)) + + async def test_list_versions_multiple_archives(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + for i in range(3): + staging = _make_staging(version=f"v1.0.{i}") + await store.archive(staging, date(2025, 1, i + 1)) + versions = await store.list_versions("my-repo") + assert len(versions) == 3 + assert "v1.0.0" in versions + assert "v1.0.1" in versions + assert "v1.0.2" in versions + + # ------------------------------------------------------------------ + # directory creation + # ------------------------------------------------------------------ + + def test_store_creates_directory_if_not_exists(self, tmp_path: Path) -> None: + new_dir = tmp_path / "staging_data" + assert not new_dir.exists() + JsonFileStagingStore(directory=new_dir) + assert new_dir.exists() diff --git a/tests/graph/test_dependencies_async.py b/tests/graph/test_dependencies_async.py new file mode 100644 index 0000000..f4b2353 --- /dev/null +++ b/tests/graph/test_dependencies_async.py @@ -0,0 +1,177 @@ +"""Tests for async StagingStore protocol and async JsonFileStagingStore. + +Phase 5 - Step 1: All StagingStore methods become async def. +Written FIRST (TDD RED phase). +""" + +import json +from datetime import date +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest + +from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients +from release_agent.models.release import StagingRelease +from release_agent.models.ticket import TicketEntry + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry: + return TicketEntry( + id=ticket_id, + summary="Fix something", + pr_id="42", + pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42", + pr_title="Fix: something", + branch=f"feature/{ticket_id}-fix", + merged_at=date(2025, 1, 15), + ) + + +def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease: + return StagingRelease( + version=version, + repo=repo, + started_at=date(2025, 1, 1), + tickets=[], + ) + + +# --------------------------------------------------------------------------- +# Protocol: all methods must be async +# --------------------------------------------------------------------------- + +class TestStagingStoreProtocolIsAsync: + """Verify that StagingStore protocol methods are async-compatible.""" + + def test_protocol_has_load_method(self) -> None: + assert hasattr(StagingStore, "load") + + def test_protocol_has_save_method(self) -> None: + assert hasattr(StagingStore, "save") + + def test_protocol_has_archive_method(self) -> None: + assert hasattr(StagingStore, "archive") + + def test_protocol_has_list_versions_method(self) -> None: + assert hasattr(StagingStore, "list_versions") + + +# --------------------------------------------------------------------------- +# JsonFileStagingStore async interface +# --------------------------------------------------------------------------- + +class TestJsonFileStagingStoreAsync: + """Verify that JsonFileStagingStore methods are awaitable (async def).""" + + async def test_load_is_awaitable(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + result = await store.load("nonexistent-repo") + assert result is None + + async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + loaded = await store.load("my-repo") + assert loaded is not None + assert loaded.version == "v1.0.0" + assert loaded.repo == "my-repo" + + async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging().add_ticket(_make_ticket("BILL-10")) + await store.save(staging) + loaded = await store.load("my-repo") + assert loaded is not None + assert len(loaded.tickets) == 1 + assert loaded.tickets[0].id == "BILL-10" + + async def test_load_returns_fresh_objects(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + loaded1 = await store.load("my-repo") + loaded2 = await store.load("my-repo") + assert loaded1 is not loaded2 + + async def test_save_is_awaitable(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging(repo="api-service") + await store.save(staging) + expected_path = tmp_path / "api-service.json" + assert expected_path.exists() + + async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + await store.save(_make_staging(version="v1.0.0")) + await store.save(_make_staging(version="v1.0.1")) + loaded = await store.load("my-repo") + assert loaded is not None + assert loaded.version == "v1.0.1" + + async def test_save_writes_valid_json(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + raw = (tmp_path / "my-repo.json").read_text() + data = json.loads(raw) + assert data["version"] == "v1.0.0" + assert data["repo"] == "my-repo" + + async def test_archive_is_awaitable(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + await store.archive(staging, date(2025, 6, 1)) + assert await store.load("my-repo") is None + + async def test_archive_creates_archive_file(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging(repo="my-repo", version="v1.0.0") + await store.save(staging) + await store.archive(staging, date(2025, 6, 1)) + archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json" + assert archive_path.exists() + + async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await store.save(staging) + await store.archive(staging, date(2025, 6, 1)) + archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json" + data = json.loads(archive_path.read_text()) + assert data["released_at"] == "2025-06-01" + + async def test_list_versions_is_awaitable(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + versions = await store.list_versions("my-repo") + assert versions == [] + + async def test_list_versions_returns_staging_version(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + await store.save(_make_staging(version="v2.1.0")) + versions = await store.list_versions("my-repo") + assert "v2.1.0" in versions + + async def test_list_versions_includes_archived(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging(version="v1.5.0") + await store.save(staging) + await store.archive(staging, date(2025, 3, 1)) + await store.save(_make_staging(version="v1.6.0")) + versions = await store.list_versions("my-repo") + assert "v1.5.0" in versions + assert "v1.6.0" in versions + + async def test_list_versions_only_for_given_repo(self, tmp_path: Path) -> None: + store = JsonFileStagingStore(directory=tmp_path) + await store.save(_make_staging(repo="repo-a", version="v1.0.0")) + await store.save(_make_staging(repo="repo-b", version="v2.0.0")) + versions_a = await store.list_versions("repo-a") + assert "v1.0.0" in versions_a + assert "v2.0.0" not in versions_a diff --git a/tests/graph/test_full_cycle.py b/tests/graph/test_full_cycle.py new file mode 100644 index 0000000..28436f6 --- /dev/null +++ b/tests/graph/test_full_cycle.py @@ -0,0 +1,53 @@ +"""Tests for graph/full_cycle.py. + +Tests that the full cycle graph composes pr_completed and release subgraphs +correctly, and that the routing conditional edge works as expected. +""" + +from release_agent.graph.full_cycle import build_full_cycle_graph +from release_agent.graph.routing import should_continue_to_release + + +class TestBuildFullCycleGraph: + def test_returns_compiled_graph(self) -> None: + graph = build_full_cycle_graph() + assert graph is not None + + def test_graph_can_be_built_multiple_times(self) -> None: + graph1 = build_full_cycle_graph() + graph2 = build_full_cycle_graph() + assert graph1 is not None + assert graph2 is not None + + def test_graph_has_get_graph_method(self) -> None: + graph = build_full_cycle_graph() + assert hasattr(graph, "get_graph") or hasattr(graph, "nodes") + + +class TestFullCycleRouting: + """Test that the routing function used by full_cycle correctly + determines whether to continue to the release subgraph.""" + + def test_continue_when_flag_true_and_no_errors(self) -> None: + state = {"continue_to_release": True, "errors": []} + assert should_continue_to_release(state) == "yes" + + def test_stop_when_flag_false(self) -> None: + state = {"continue_to_release": False} + assert should_continue_to_release(state) == "no" + + def test_stop_when_flag_missing(self) -> None: + state = {} + assert should_continue_to_release(state) == "no" + + def test_stop_when_errors_present(self) -> None: + state = {"continue_to_release": True, "errors": ["some error"]} + assert should_continue_to_release(state) == "no" + + def test_stop_when_flag_true_but_errors_present(self) -> None: + state = {"continue_to_release": True, "errors": ["critical failure"]} + assert should_continue_to_release(state) == "no" + + def test_continue_when_errors_empty_list(self) -> None: + state = {"continue_to_release": True, "errors": []} + assert should_continue_to_release(state) == "yes" diff --git a/tests/graph/test_polling.py b/tests/graph/test_polling.py new file mode 100644 index 0000000..191dd67 --- /dev/null +++ b/tests/graph/test_polling.py @@ -0,0 +1,356 @@ +"""Tests for graph/polling.py — poll_until async utility. + +Written FIRST (TDD RED phase). + +All tests inject a fake_sleep_fn that returns immediately to avoid real waits. +""" + +import asyncio +from unittest.mock import AsyncMock, call + +import pytest + +from release_agent.graph.polling import poll_until + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +async def _immediate_sleep(seconds: float) -> None: + """Drop-in replacement for asyncio.sleep that returns immediately.""" + return + + +# --------------------------------------------------------------------------- +# Success path tests +# --------------------------------------------------------------------------- + +class TestPollUntilSuccess: + """Tests for the happy path where poll_fn succeeds before timeout.""" + + async def test_returns_tuple_of_result_and_completed_true(self) -> None: + calls = iter(["running", "running", "completed"]) + async def poll_fn(): + return next(calls) + + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "completed", + interval_seconds=1, + max_wait_seconds=60, + sleep_fn=_immediate_sleep, + ) + + assert result == "completed" + assert completed is True + + async def test_returns_immediately_when_already_done(self) -> None: + async def poll_fn(): + return "completed" + + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "completed", + interval_seconds=1, + max_wait_seconds=60, + sleep_fn=_immediate_sleep, + ) + + assert result == "completed" + assert completed is True + + async def test_polls_multiple_times_before_done(self) -> None: + call_count = 0 + + async def poll_fn(): + nonlocal call_count + call_count += 1 + return "done" if call_count >= 3 else "pending" + + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + interval_seconds=1, + max_wait_seconds=60, + sleep_fn=_immediate_sleep, + ) + + assert result == "done" + assert completed is True + assert call_count == 3 + + async def test_sleep_called_between_polls(self) -> None: + call_count = 0 + sleep_calls: list[float] = [] + + async def poll_fn(): + nonlocal call_count + call_count += 1 + return "done" if call_count >= 2 else "pending" + + async def tracking_sleep(seconds: float) -> None: + sleep_calls.append(seconds) + + await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + interval_seconds=15, + max_wait_seconds=60, + sleep_fn=tracking_sleep, + ) + + assert len(sleep_calls) >= 1 + assert all(s == 15 for s in sleep_calls) + + async def test_no_sleep_on_first_successful_poll(self) -> None: + sleep_calls: list[float] = [] + + async def poll_fn(): + return "done" + + async def tracking_sleep(seconds: float) -> None: + sleep_calls.append(seconds) + + await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + interval_seconds=10, + max_wait_seconds=60, + sleep_fn=tracking_sleep, + ) + + assert sleep_calls == [] + + async def test_works_with_dict_results(self) -> None: + responses = iter([ + {"status": "inProgress"}, + {"status": "completed", "result": "succeeded"}, + ]) + + async def poll_fn(): + return next(responses) + + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r["status"] == "completed", + interval_seconds=1, + max_wait_seconds=60, + sleep_fn=_immediate_sleep, + ) + + assert result["result"] == "succeeded" + assert completed is True + + +# --------------------------------------------------------------------------- +# Timeout tests +# --------------------------------------------------------------------------- + +class TestPollUntilTimeout: + """Tests for timeout behavior.""" + + async def test_returns_last_result_and_completed_false_on_timeout(self) -> None: + async def poll_fn(): + return "still_running" + + # With interval=10, max_wait=5, it should time out after one poll + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + interval_seconds=10, + max_wait_seconds=5, + sleep_fn=_immediate_sleep, + ) + + assert result == "still_running" + assert completed is False + + async def test_at_least_one_poll_happens_before_timeout(self) -> None: + call_count = 0 + + async def poll_fn(): + nonlocal call_count + call_count += 1 + return "running" + + await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + interval_seconds=100, + max_wait_seconds=1, + sleep_fn=_immediate_sleep, + ) + + assert call_count >= 1 + + async def test_max_polls_bounded_by_max_wait_over_interval(self) -> None: + call_count = 0 + + async def poll_fn(): + nonlocal call_count + call_count += 1 + return "running" + + await poll_until( + poll_fn=poll_fn, + is_done=lambda r: False, + interval_seconds=10, + max_wait_seconds=30, + sleep_fn=_immediate_sleep, + ) + + # With interval=10, max_wait=30: should poll at most ceil(30/10)+1 = 4 times + assert call_count <= 5 + + +# --------------------------------------------------------------------------- +# Error handling tests +# --------------------------------------------------------------------------- + +class TestPollUntilErrorHandling: + """Tests for error/exception handling in poll_until.""" + + async def test_continues_after_transient_exception(self) -> None: + call_count = 0 + + async def poll_fn(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise RuntimeError("Transient error") + return "done" + + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + interval_seconds=1, + max_wait_seconds=60, + sleep_fn=_immediate_sleep, + ) + + assert result == "done" + assert completed is True + + async def test_aborts_after_three_consecutive_failures(self) -> None: + call_count = 0 + + async def poll_fn(): + nonlocal call_count + call_count += 1 + raise RuntimeError("Persistent error") + + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: True, + interval_seconds=1, + max_wait_seconds=60, + sleep_fn=_immediate_sleep, + ) + + # Should abort after 3 consecutive failures + assert call_count == 3 + assert completed is False + assert result is None + + async def test_resets_consecutive_failure_count_on_success(self) -> None: + call_count = 0 + + async def poll_fn(): + nonlocal call_count + call_count += 1 + # Fail twice, succeed once, fail twice, succeed (done) + if call_count in (1, 2): + raise RuntimeError("fail") + if call_count == 3: + return "running" + if call_count in (4, 5): + raise RuntimeError("fail again") + return "done" + + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + interval_seconds=1, + max_wait_seconds=120, + sleep_fn=_immediate_sleep, + ) + + assert completed is True + assert result == "done" + + async def test_single_exception_does_not_abort(self) -> None: + call_count = 0 + + async def poll_fn(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("one error") + return "done" + + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + interval_seconds=1, + max_wait_seconds=60, + sleep_fn=_immediate_sleep, + ) + + assert completed is True + assert result == "done" + + async def test_two_consecutive_failures_do_not_abort(self) -> None: + call_count = 0 + + async def poll_fn(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise ConnectionError("two errors") + return "done" + + result, completed = await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + interval_seconds=1, + max_wait_seconds=60, + sleep_fn=_immediate_sleep, + ) + + assert completed is True + assert result == "done" + + +# --------------------------------------------------------------------------- +# Default parameter tests +# --------------------------------------------------------------------------- + +class TestPollUntilDefaults: + """Tests that default parameters match the spec.""" + + async def test_default_interval_is_30_seconds(self) -> None: + sleep_calls: list[float] = [] + + async def poll_fn(): + return "done" if len(sleep_calls) >= 1 else "running" + + async def tracking_sleep(seconds: float) -> None: + sleep_calls.append(seconds) + + await poll_until( + poll_fn=poll_fn, + is_done=lambda r: r == "done", + sleep_fn=tracking_sleep, + ) + + if sleep_calls: + assert sleep_calls[0] == 30 + + async def test_poll_fn_and_is_done_are_keyword_only(self) -> None: + """poll_fn and is_done must be passed as keyword arguments.""" + async def poll_fn(): + return "done" + + with pytest.raises(TypeError): + await poll_until(poll_fn, lambda r: r == "done") # type: ignore[call-arg] diff --git a/tests/graph/test_postgres_staging_store.py b/tests/graph/test_postgres_staging_store.py new file mode 100644 index 0000000..f3a334b --- /dev/null +++ b/tests/graph/test_postgres_staging_store.py @@ -0,0 +1,414 @@ +"""Tests for PostgresStagingStore. + +Phase 5 - Step 2: PostgreSQL-backed StagingStore using async pool. +Written FIRST (TDD RED phase). + +All tests use FakeAsyncPool — no real PostgreSQL required. +""" + +import json +from datetime import date +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from release_agent.graph.postgres_staging_store import PostgresStagingStore +from release_agent.models.release import ArchivedRelease, StagingRelease +from release_agent.models.ticket import TicketEntry + + +# --------------------------------------------------------------------------- +# Fake pool infrastructure +# --------------------------------------------------------------------------- + +class FakeAsyncCursor: + """Records SQL calls and returns configured results.""" + + def __init__(self) -> None: + self.executed: list[tuple[str, tuple]] = [] + self._fetchone_result: tuple | None = None + self._fetchall_result: list[tuple] = [] + + def set_fetchone(self, row: tuple | None) -> None: + self._fetchone_result = row + + def set_fetchall(self, rows: list[tuple]) -> None: + self._fetchall_result = rows + + async def execute(self, sql: str, params: tuple = ()) -> None: + self.executed.append((sql, params)) + + async def fetchone(self) -> tuple | None: + return self._fetchone_result + + async def fetchall(self) -> list[tuple]: + return self._fetchall_result + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +class FakeAsyncTransaction: + """Fake async transaction context manager (no-op).""" + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +class FakeAsyncConnection: + """Async context manager returning a FakeAsyncCursor.""" + + def __init__(self, cursor: FakeAsyncCursor) -> None: + self._cursor = cursor + + def cursor(self): + return self._cursor + + def transaction(self): + return FakeAsyncTransaction() + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +class FakeAsyncPool: + """Records all SQL executed through it.""" + + def __init__(self, cursor: FakeAsyncCursor) -> None: + self._cursor = cursor + self._conn = FakeAsyncConnection(cursor) + + def connection(self): + return self._conn + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry: + return TicketEntry( + id=ticket_id, + summary="Fix something", + pr_id="42", + pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42", + pr_title="Fix: something", + branch=f"feature/{ticket_id}-fix", + merged_at=date(2025, 1, 15), + ) + + +def _make_staging( + repo: str = "my-repo", + version: str = "v1.0.0", + tickets: list | None = None, +) -> StagingRelease: + t = tickets if tickets is not None else [] + return StagingRelease( + version=version, + repo=repo, + started_at=date(2025, 1, 1), + tickets=t, + ) + + +def _staging_row(staging: StagingRelease) -> tuple: + """Return (repo, version, started_at, tickets_json) as DB would store it.""" + return ( + staging.repo, + staging.version, + staging.started_at.isoformat(), + json.dumps([t.model_dump(mode="json") for t in staging.tickets]), + ) + + +# --------------------------------------------------------------------------- +# load() +# --------------------------------------------------------------------------- + +class TestPostgresStagingStoreLoad: + async def test_load_returns_none_when_no_row(self) -> None: + cursor = FakeAsyncCursor() + cursor.set_fetchone(None) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + result = await store.load("nonexistent-repo") + + assert result is None + + async def test_load_returns_staging_release_when_row_exists(self) -> None: + staging = _make_staging(repo="api-service", version="v2.0.0") + cursor = FakeAsyncCursor() + cursor.set_fetchone(( + staging.repo, + staging.version, + staging.started_at.isoformat(), + json.dumps([]), + )) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + result = await store.load("api-service") + + assert result is not None + assert isinstance(result, StagingRelease) + assert result.repo == "api-service" + assert result.version == "v2.0.0" + + async def test_load_returns_staging_with_tickets(self) -> None: + ticket = _make_ticket("BILL-42") + staging = _make_staging(tickets=[ticket]) + cursor = FakeAsyncCursor() + cursor.set_fetchone(( + staging.repo, + staging.version, + staging.started_at.isoformat(), + json.dumps([ticket.model_dump(mode="json")]), + )) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + result = await store.load("my-repo") + + assert result is not None + assert len(result.tickets) == 1 + assert result.tickets[0].id == "BILL-42" + + async def test_load_executes_select_with_correct_repo(self) -> None: + cursor = FakeAsyncCursor() + cursor.set_fetchone(None) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + await store.load("target-repo") + + assert len(cursor.executed) >= 1 + sql, params = cursor.executed[-1] + assert "SELECT" in sql.upper() + assert "target-repo" in params + + async def test_load_queries_staging_releases_table(self) -> None: + cursor = FakeAsyncCursor() + cursor.set_fetchone(None) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + await store.load("my-repo") + + sql, _ = cursor.executed[-1] + assert "staging_releases" in sql + + +# --------------------------------------------------------------------------- +# save() +# --------------------------------------------------------------------------- + +class TestPostgresStagingStoreSave: + async def test_save_executes_upsert(self) -> None: + cursor = FakeAsyncCursor() + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + staging = _make_staging() + + await store.save(staging) + + assert len(cursor.executed) >= 1 + sql, _ = cursor.executed[-1] + # Should be an INSERT ... ON CONFLICT ... or UPSERT + assert "INSERT" in sql.upper() or "UPSERT" in sql.upper() + + async def test_save_passes_repo_to_upsert(self) -> None: + cursor = FakeAsyncCursor() + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + staging = _make_staging(repo="payment-service") + + await store.save(staging) + + _, params = cursor.executed[-1] + assert "payment-service" in params + + async def test_save_passes_version_to_upsert(self) -> None: + cursor = FakeAsyncCursor() + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + staging = _make_staging(version="v3.1.0") + + await store.save(staging) + + _, params = cursor.executed[-1] + assert "v3.1.0" in params + + async def test_save_targets_staging_releases_table(self) -> None: + cursor = FakeAsyncCursor() + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + staging = _make_staging() + + await store.save(staging) + + sql, _ = cursor.executed[-1] + assert "staging_releases" in sql + + async def test_save_serializes_tickets_as_json(self) -> None: + cursor = FakeAsyncCursor() + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + staging = _make_staging(tickets=[_make_ticket("ALLPOST-99")]) + + await store.save(staging) + + _, params = cursor.executed[-1] + # tickets param should be a JSON string containing the ticket id + tickets_json = next(p for p in params if isinstance(p, str) and "ALLPOST-99" in p) + parsed = json.loads(tickets_json) + assert parsed[0]["id"] == "ALLPOST-99" + + +# --------------------------------------------------------------------------- +# archive() +# --------------------------------------------------------------------------- + +class TestPostgresStagingStoreArchive: + async def test_archive_inserts_into_archived_releases(self) -> None: + cursor = FakeAsyncCursor() + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + staging = _make_staging() + + await store.archive(staging, date(2025, 6, 1)) + + sql_statements = [sql for sql, _ in cursor.executed] + assert any("archived_releases" in sql for sql in sql_statements) + + async def test_archive_deletes_from_staging_releases(self) -> None: + cursor = FakeAsyncCursor() + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + staging = _make_staging(repo="my-repo") + + await store.archive(staging, date(2025, 6, 1)) + + sql_statements = [sql for sql, _ in cursor.executed] + assert any("DELETE" in sql.upper() and "staging_releases" in sql for sql in sql_statements) + + async def test_archive_passes_released_at_date(self) -> None: + cursor = FakeAsyncCursor() + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + staging = _make_staging() + release_date = date(2025, 12, 31) + + await store.archive(staging, release_date) + + all_params = [params for _, params in cursor.executed] + all_values = [v for params in all_params for v in params] + assert "2025-12-31" in all_values or release_date.isoformat() in all_values + + async def test_archive_passes_repo_to_delete(self) -> None: + cursor = FakeAsyncCursor() + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + staging = _make_staging(repo="payment-service") + + await store.archive(staging, date(2025, 6, 1)) + + delete_calls = [(sql, params) for sql, params in cursor.executed if "DELETE" in sql.upper()] + assert len(delete_calls) >= 1 + _, params = delete_calls[0] + assert "payment-service" in params + + +# --------------------------------------------------------------------------- +# list_versions() +# --------------------------------------------------------------------------- + +class TestPostgresStagingStoreListVersions: + async def test_list_versions_returns_empty_when_no_data(self) -> None: + cursor = FakeAsyncCursor() + cursor.set_fetchone(None) + cursor.set_fetchall([]) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + versions = await store.list_versions("my-repo") + + assert versions == [] + + async def test_list_versions_includes_staging_version(self) -> None: + cursor = FakeAsyncCursor() + # fetchone returns staging row + cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]")) + # fetchall returns archived rows + cursor.set_fetchall([]) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + versions = await store.list_versions("my-repo") + + assert "v1.0.0" in versions + + async def test_list_versions_includes_archived_versions(self) -> None: + cursor = FakeAsyncCursor() + cursor.set_fetchone(None) + cursor.set_fetchall([ + ("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-06-01"), + ("my-repo", "v1.1.0", "2025-02-01", "[]", "2025-07-01"), + ]) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + versions = await store.list_versions("my-repo") + + assert "v1.0.0" in versions + assert "v1.1.0" in versions + + async def test_list_versions_combines_staging_and_archived(self) -> None: + cursor = FakeAsyncCursor() + cursor.set_fetchone(("my-repo", "v2.0.0", "2025-03-01", "[]")) + cursor.set_fetchall([ + ("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"), + ]) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + versions = await store.list_versions("my-repo") + + assert "v2.0.0" in versions + assert "v1.0.0" in versions + + async def test_list_versions_no_duplicates(self) -> None: + cursor = FakeAsyncCursor() + cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]")) + cursor.set_fetchall([ + ("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"), + ]) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + versions = await store.list_versions("my-repo") + + assert len(versions) == len(set(versions)) + + async def test_list_versions_executes_queries_for_correct_repo(self) -> None: + cursor = FakeAsyncCursor() + cursor.set_fetchone(None) + cursor.set_fetchall([]) + pool = FakeAsyncPool(cursor) + store = PostgresStagingStore(pool=pool) + + await store.list_versions("target-repo") + + all_params = [params for _, params in cursor.executed] + all_values = [v for params in all_params for v in params] + assert "target-repo" in all_values diff --git a/tests/graph/test_pr_completed.py b/tests/graph/test_pr_completed.py new file mode 100644 index 0000000..e2c5dff --- /dev/null +++ b/tests/graph/test_pr_completed.py @@ -0,0 +1,956 @@ +"""Tests for graph/pr_completed.py node functions. Written FIRST (TDD RED phase). + +Each node is an async function (state, config) -> dict. +Tests call nodes directly with a state dict and config dict — no graph compilation. +""" + +from datetime import date, datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients +from release_agent.graph.pr_completed import ( + _post_review_to_pr, + add_jira_pr_link, + auto_create_ticket, + calculate_version, + evaluate_review, + fetch_pr_details, + interrupt_confirm_merge, + merge_pr_node, + move_jira_code_review, + move_jira_ready_for_stage, + notify_request_changes, + parse_webhook, + run_code_review, + update_staging, + build_pr_completed_graph, +) +from release_agent.models.review import ReviewIssue +from release_agent.models.jira import JiraIssue +from release_agent.models.pr import PRInfo +from release_agent.models.review import ReviewResult +from tests.graph.conftest import build_config, build_mock_clients + + +# --------------------------------------------------------------------------- +# Webhook payload fixtures +# --------------------------------------------------------------------------- + +def _make_webhook_payload( + *, + repo_name: str = "my-repo", + pr_id: int = 42, + source_ref: str = "refs/heads/feature/ALLPOST-100_fix-bug", + target_ref: str = "refs/heads/main", + status: str = "completed", + title: str = "Fix: bug", + closed_date: str | None = "2025-01-15T10:00:00Z", +) -> dict: + # Uses snake_case keys to match WebhookPayload Pydantic model field names + return { + "subscription_id": "sub-1", + "event_type": "git.pullrequest.merged", + "resource": { + "repository": { + "id": "repo-id-1", + "name": repo_name, + "web_url": "https://dev.azure.com/org/proj/_git/my-repo", + }, + "pull_request_id": pr_id, + "title": title, + "source_ref_name": source_ref, + "target_ref_name": target_ref, + "status": status, + "closed_date": closed_date, + }, + } + + +def _make_pr_info( + *, + pr_id: str = "42", + repo_name: str = "my-repo", + branch: str = "refs/heads/feature/ALLPOST-100-fix-bug", + status: str = "completed", +) -> PRInfo: + return PRInfo( + pr_id=pr_id, + pr_url="https://dev.azure.com/org/proj/_git/my-repo/pullrequest/42", + repo_name=repo_name, + branch=branch, + pr_title="Fix: bug", + pr_status=status, + ) + + +def _make_approve_review() -> dict: + return { + "verdict": "approve", + "summary": "Looks good", + "issues": [], + "has_blockers": False, + } + + +def _make_request_changes_review() -> dict: + return { + "verdict": "request_changes", + "summary": "Needs work", + "issues": [{"severity": "blocker", "description": "Missing tests"}], + "has_blockers": True, + } + + +# --------------------------------------------------------------------------- +# parse_webhook +# --------------------------------------------------------------------------- + +class TestParseWebhook: + async def test_extracts_pr_info_from_payload(self) -> None: + state = {"webhook_payload": _make_webhook_payload()} + config = build_config() + result = await parse_webhook(state, config) + assert "pr_info" in result + pr = result["pr_info"] + assert pr["pr_id"] == "42" + assert pr["repo_name"] == "my-repo" + + async def test_extracts_ticket_from_branch(self) -> None: + state = {"webhook_payload": _make_webhook_payload( + source_ref="refs/heads/feature/ALLPOST-100_fix-bug" + )} + config = build_config() + result = await parse_webhook(state, config) + assert result["ticket_id"] == "ALLPOST-100" + assert result["has_ticket"] is True + + async def test_no_ticket_when_branch_has_none(self) -> None: + state = {"webhook_payload": _make_webhook_payload( + source_ref="refs/heads/bugfix/generic_fix" + )} + config = build_config() + result = await parse_webhook(state, config) + assert result["has_ticket"] is False + assert result["ticket_id"] is None + + async def test_sets_repo_name(self) -> None: + state = {"webhook_payload": _make_webhook_payload(repo_name="backend-api")} + config = build_config() + result = await parse_webhook(state, config) + assert result["repo_name"] == "backend-api" + + async def test_sets_pr_id_as_string(self) -> None: + state = {"webhook_payload": _make_webhook_payload(pr_id=99)} + config = build_config() + result = await parse_webhook(state, config) + assert result["pr_info"]["pr_id"] == "99" + + async def test_invalid_payload_adds_error(self) -> None: + state = {"webhook_payload": {"bad": "data"}} + config = build_config() + result = await parse_webhook(state, config) + assert "errors" in result + assert len(result["errors"]) > 0 + + +# --------------------------------------------------------------------------- +# fetch_pr_details +# --------------------------------------------------------------------------- + +class TestFetchPrDetails: + async def test_fetches_pr_and_sets_pr_already_merged_false(self) -> None: + clients = build_mock_clients() + pr = _make_pr_info(status="active") + clients.azdo.get_pr = AsyncMock(return_value=pr) + clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py") + config = build_config(clients) + state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "active"}} + result = await fetch_pr_details(state, config) + assert result["pr_already_merged"] is False + assert result["pr_diff"] == "edit: main.py" + + async def test_sets_pr_already_merged_true_when_completed(self) -> None: + clients = build_mock_clients() + pr = _make_pr_info(status="completed") + clients.azdo.get_pr = AsyncMock(return_value=pr) + clients.azdo.get_pr_diff = AsyncMock(return_value="") + config = build_config(clients) + state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "completed"}} + result = await fetch_pr_details(state, config) + assert result["pr_already_merged"] is True + + async def test_stores_last_merge_source_commit(self) -> None: + clients = build_mock_clients() + pr = _make_pr_info(status="active") + clients.azdo.get_pr = AsyncMock(return_value=pr) + clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py") + config = build_config(clients) + state = {"pr_id": "42", "pr_info": {"pr_id": "42"}} + result = await fetch_pr_details(state, config) + # last_merge_source_commit may be None if pr doesn't have it, but key must be present + assert "last_merge_source_commit" in result + + async def test_adds_error_on_service_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.azdo.get_pr = AsyncMock(side_effect=ServiceError( + service="azdo", status_code=500, detail="Server error" + )) + config = build_config(clients) + state = {"pr_id": "42"} + result = await fetch_pr_details(state, config) + assert "errors" in result + assert len(result["errors"]) > 0 + + +# --------------------------------------------------------------------------- +# move_jira_code_review +# --------------------------------------------------------------------------- + +class TestMoveJiraCodeReview: + async def test_transitions_ticket_when_has_ticket(self) -> None: + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock(return_value=True) + config = build_config(clients) + state = {"ticket_id": "ALLPOST-100", "has_ticket": True} + result = await move_jira_code_review(state, config) + clients.jira.transition_issue.assert_called_once_with("ALLPOST-100", "code review") + assert result == {} or "messages" in result + + async def test_skips_when_no_ticket(self) -> None: + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock(return_value=True) + config = build_config(clients) + state = {"has_ticket": False, "ticket_id": None} + result = await move_jira_code_review(state, config) + clients.jira.transition_issue.assert_not_called() + + async def test_appends_error_on_jira_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock(side_effect=ServiceError( + service="jira", status_code=500, detail="Jira down" + )) + config = build_config(clients) + state = {"ticket_id": "ALLPOST-100", "has_ticket": True} + result = await move_jira_code_review(state, config) + assert "errors" in result + assert len(result["errors"]) > 0 + + +# --------------------------------------------------------------------------- +# run_code_review +# --------------------------------------------------------------------------- + +class TestRunCodeReview: + async def test_calls_reviewer_with_diff(self) -> None: + clients = build_mock_clients() + review = ReviewResult(verdict="approve", summary="LGTM", issues=()) + clients.reviewer.review_pr = AsyncMock(return_value=review) + config = build_config(clients) + state = { + "pr_diff": "edit: main.py", + "pr_info": {"pr_title": "Fix: bug", "repo_name": "my-repo"}, + } + result = await run_code_review(state, config) + clients.reviewer.review_pr.assert_called_once() + assert "review_result" in result + + async def test_stores_review_result_as_dict(self) -> None: + clients = build_mock_clients() + review = ReviewResult(verdict="approve", summary="Clean code", issues=()) + clients.reviewer.review_pr = AsyncMock(return_value=review) + config = build_config(clients) + state = { + "pr_diff": "edit: main.py", + "pr_info": {"pr_title": "Fix", "repo_name": "repo"}, + } + result = await run_code_review(state, config) + assert result["review_result"]["verdict"] == "approve" + + async def test_adds_error_on_reviewer_failure(self) -> None: + clients = build_mock_clients() + clients.reviewer.review_pr = AsyncMock(side_effect=Exception("API error")) + config = build_config(clients) + state = { + "pr_diff": "edit: main.py", + "pr_info": {"pr_title": "Fix", "repo_name": "repo"}, + } + result = await run_code_review(state, config) + assert "errors" in result + + +# --------------------------------------------------------------------------- +# _post_review_to_pr +# --------------------------------------------------------------------------- + +class TestPostReviewToPr: + async def test_posts_summary_comment(self) -> None: + clients = build_mock_clients() + clients.azdo.add_pr_comment = AsyncMock() + clients.azdo.add_pr_inline_comment = AsyncMock() + review = ReviewResult(verdict="approve", summary="LGTM", issues=()) + await _post_review_to_pr(clients, "my-repo", 42, review) + clients.azdo.add_pr_comment.assert_called_once() + call_kwargs = clients.azdo.add_pr_comment.call_args + assert "APPROVE" in call_kwargs.kwargs["content"] + + async def test_posts_inline_comment_for_issue_with_file_and_line(self) -> None: + clients = build_mock_clients() + clients.azdo.add_pr_comment = AsyncMock() + clients.azdo.add_pr_inline_comment = AsyncMock() + issue = ReviewIssue( + severity="error", description="Null check missing", + file_path="src/Foo.cs", line_start=42, suggestion="Add null guard", + ) + review = ReviewResult(verdict="request_changes", summary="Issues", issues=(issue,)) + await _post_review_to_pr(clients, "my-repo", 42, review) + clients.azdo.add_pr_inline_comment.assert_called_once() + call_kwargs = clients.azdo.add_pr_inline_comment.call_args.kwargs + assert call_kwargs["file_path"] == "src/Foo.cs" + assert call_kwargs["line_start"] == 42 + assert "Null check missing" in call_kwargs["content"] + assert "Add null guard" in call_kwargs["content"] + + async def test_skips_inline_for_issue_without_line(self) -> None: + clients = build_mock_clients() + clients.azdo.add_pr_comment = AsyncMock() + clients.azdo.add_pr_inline_comment = AsyncMock() + issue = ReviewIssue(severity="warning", description="Style issue", file_path="src/Foo.cs") + review = ReviewResult(verdict="approve", summary="OK", issues=(issue,)) + await _post_review_to_pr(clients, "my-repo", 42, review) + clients.azdo.add_pr_inline_comment.assert_not_called() + + async def test_skips_inline_for_issue_without_file(self) -> None: + clients = build_mock_clients() + clients.azdo.add_pr_comment = AsyncMock() + clients.azdo.add_pr_inline_comment = AsyncMock() + issue = ReviewIssue(severity="info", description="General note", line_start=10) + review = ReviewResult(verdict="approve", summary="OK", issues=(issue,)) + await _post_review_to_pr(clients, "my-repo", 42, review) + clients.azdo.add_pr_inline_comment.assert_not_called() + + async def test_inline_failure_does_not_prevent_summary(self) -> None: + clients = build_mock_clients() + clients.azdo.add_pr_comment = AsyncMock() + clients.azdo.add_pr_inline_comment = AsyncMock(side_effect=Exception("API error")) + issue = ReviewIssue( + severity="blocker", description="Critical", file_path="a.cs", line_start=1 + ) + review = ReviewResult(verdict="request_changes", summary="Bad", issues=(issue,)) + await _post_review_to_pr(clients, "my-repo", 42, review) + # Summary should still be posted even though inline failed + clients.azdo.add_pr_comment.assert_called_once() + + async def test_summary_failure_does_not_raise(self) -> None: + clients = build_mock_clients() + clients.azdo.add_pr_comment = AsyncMock(side_effect=Exception("Network error")) + clients.azdo.add_pr_inline_comment = AsyncMock() + review = ReviewResult(verdict="approve", summary="LGTM", issues=()) + # Should not raise + await _post_review_to_pr(clients, "my-repo", 42, review) + + async def test_summary_contains_issue_count(self) -> None: + clients = build_mock_clients() + clients.azdo.add_pr_comment = AsyncMock() + clients.azdo.add_pr_inline_comment = AsyncMock() + issues = ( + ReviewIssue(severity="warning", description="Issue 1"), + ReviewIssue(severity="error", description="Issue 2"), + ) + review = ReviewResult(verdict="request_changes", summary="Problems", issues=issues) + await _post_review_to_pr(clients, "my-repo", 42, review) + content = clients.azdo.add_pr_comment.call_args.kwargs["content"] + assert "2 issue(s)" in content + + async def test_run_code_review_calls_post_review(self) -> None: + """Integration: run_code_review posts comments when pr_id and repo_name present.""" + clients = build_mock_clients() + review = ReviewResult(verdict="approve", summary="LGTM", issues=()) + clients.reviewer.review_pr = AsyncMock(return_value=review) + clients.azdo.add_pr_comment = AsyncMock() + clients.azdo.add_pr_inline_comment = AsyncMock() + config = build_config(clients) + state = { + "pr_diff": "edit: main.py", + "pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "my-repo"}, + } + await run_code_review(state, config) + clients.azdo.add_pr_comment.assert_called_once() + + +# --------------------------------------------------------------------------- +# evaluate_review +# --------------------------------------------------------------------------- + +class TestEvaluateReview: + async def test_sets_review_approved_true_for_approve_verdict(self) -> None: + config = build_config() + state = {"review_result": _make_approve_review()} + result = await evaluate_review(state, config) + assert result["review_approved"] is True + + async def test_sets_review_approved_false_for_request_changes(self) -> None: + config = build_config() + state = {"review_result": _make_request_changes_review()} + result = await evaluate_review(state, config) + assert result["review_approved"] is False + + async def test_sets_false_when_review_result_missing(self) -> None: + config = build_config() + state = {} + result = await evaluate_review(state, config) + assert result["review_approved"] is False + + async def test_sets_false_when_has_blockers(self) -> None: + config = build_config() + state = { + "review_result": { + "verdict": "approve", + "summary": "Approve with blocker?", + "issues": [{"severity": "blocker", "description": "Problem"}], + "has_blockers": True, + } + } + result = await evaluate_review(state, config) + assert result["review_approved"] is False + + +# --------------------------------------------------------------------------- +# interrupt_confirm_merge +# --------------------------------------------------------------------------- + +class TestInterruptConfirmMerge: + async def test_calls_interrupt_with_summary_string(self) -> None: + config = build_config() + state = { + "pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"}, + "review_result": {"summary": "LGTM"}, + } + with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt: + mock_interrupt.return_value = "confirm" + await interrupt_confirm_merge(state, config) + mock_interrupt.assert_called_once() + call_arg = mock_interrupt.call_args[0][0] + assert isinstance(call_arg, str) + assert len(call_arg) > 0 + + async def test_interrupt_value_contains_pr_info(self) -> None: + config = build_config() + state = { + "pr_info": {"pr_id": "42", "pr_title": "Fix: auth bug", "repo_name": "backend"}, + "review_result": {"summary": "All good"}, + } + with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt: + mock_interrupt.return_value = "confirm" + await interrupt_confirm_merge(state, config) + call_arg = mock_interrupt.call_args[0][0] + assert "42" in call_arg or "Fix: auth bug" in call_arg or "backend" in call_arg + + +# --------------------------------------------------------------------------- +# merge_pr_node +# --------------------------------------------------------------------------- + +class TestMergePrNode: + async def test_calls_azdo_merge_pr(self) -> None: + clients = build_mock_clients() + clients.azdo.merge_pr = AsyncMock(return_value=True) + config = build_config(clients) + state = { + "pr_info": {"pr_id": "42"}, + "last_merge_source_commit": "abc123", + } + result = await merge_pr_node(state, config) + clients.azdo.merge_pr.assert_called_once_with( + pr_id=42, last_merge_source_commit="abc123" + ) + + async def test_returns_message_on_success(self) -> None: + clients = build_mock_clients() + clients.azdo.merge_pr = AsyncMock(return_value=True) + config = build_config(clients) + state = { + "pr_info": {"pr_id": "42"}, + "last_merge_source_commit": "abc123", + } + result = await merge_pr_node(state, config) + assert "messages" in result + + async def test_re_raises_on_service_error(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError( + service="azdo", status_code=409, detail="Conflict" + )) + config = build_config(clients) + state = { + "pr_info": {"pr_id": "42"}, + "last_merge_source_commit": "abc123", + } + with pytest.raises(ServiceError): + await merge_pr_node(state, config) + + +# --------------------------------------------------------------------------- +# move_jira_ready_for_stage +# --------------------------------------------------------------------------- + +class TestMoveJiraReadyForStage: + async def test_transitions_ticket(self) -> None: + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock(return_value=True) + config = build_config(clients) + state = {"ticket_id": "ALLPOST-100", "has_ticket": True} + result = await move_jira_ready_for_stage(state, config) + clients.jira.transition_issue.assert_called_once_with( + "ALLPOST-100", "Ready for stage (2)" + ) + + async def test_skips_when_no_ticket(self) -> None: + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock() + config = build_config(clients) + state = {"has_ticket": False} + await move_jira_ready_for_stage(state, config) + clients.jira.transition_issue.assert_not_called() + + async def test_appends_error_on_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock(side_effect=ServiceError( + service="jira", status_code=500, detail="Error" + )) + config = build_config(clients) + state = {"ticket_id": "ALLPOST-100", "has_ticket": True} + result = await move_jira_ready_for_stage(state, config) + assert "errors" in result + + +# --------------------------------------------------------------------------- +# add_jira_pr_link +# --------------------------------------------------------------------------- + +class TestAddJiraPrLink: + async def test_calls_add_remote_link(self) -> None: + clients = build_mock_clients() + clients.jira.add_remote_link = AsyncMock(return_value=True) + config = build_config(clients) + state = { + "ticket_id": "ALLPOST-100", + "has_ticket": True, + "pr_info": { + "pr_id": "42", + "pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42", + "pr_title": "Fix: bug", + }, + } + result = await add_jira_pr_link(state, config) + clients.jira.add_remote_link.assert_called_once() + + async def test_skips_when_no_ticket(self) -> None: + clients = build_mock_clients() + clients.jira.add_remote_link = AsyncMock() + config = build_config(clients) + state = {"has_ticket": False} + await add_jira_pr_link(state, config) + clients.jira.add_remote_link.assert_not_called() + + async def test_appends_error_on_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.jira.add_remote_link = AsyncMock(side_effect=ServiceError( + service="jira", status_code=500, detail="Error" + )) + config = build_config(clients) + state = { + "ticket_id": "ALLPOST-100", + "has_ticket": True, + "pr_info": { + "pr_id": "42", + "pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42", + "pr_title": "Fix", + }, + } + result = await add_jira_pr_link(state, config) + assert "errors" in result + + +# --------------------------------------------------------------------------- +# calculate_version +# --------------------------------------------------------------------------- + +class TestCalculateVersion: + async def test_returns_v1_0_0_for_empty_store(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + clients = build_mock_clients() + config = build_config(clients, staging_store=staging_store) + state = {"repo_name": "my-repo"} + result = await calculate_version(state, config) + assert result["version"] == "v1.0.0" + + async def test_increments_patch_version(self, tmp_path) -> None: + from release_agent.models.release import StagingRelease + staging_store = JsonFileStagingStore(directory=tmp_path) + # Pre-populate with an existing version + staging = StagingRelease( + version="v1.0.5", + repo="my-repo", + started_at=date(2025, 1, 1), + tickets=[], + ) + await staging_store.archive(staging, date(2025, 1, 10)) + clients = build_mock_clients() + config = build_config(clients, staging_store=staging_store) + state = {"repo_name": "my-repo"} + result = await calculate_version(state, config) + assert result["version"] == "v1.0.6" + + async def test_sets_version_in_state(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + clients = build_mock_clients() + config = build_config(clients, staging_store=staging_store) + state = {"repo_name": "new-repo"} + result = await calculate_version(state, config) + assert "version" in result + assert result["version"].startswith("v") + + +# --------------------------------------------------------------------------- +# update_staging +# --------------------------------------------------------------------------- + +class TestUpdateStaging: + async def test_creates_new_staging_when_none_exists(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + clients = build_mock_clients() + # Jira get_issue returns a summary + from release_agent.models.jira import JiraIssue + clients.jira.get_issue = AsyncMock(return_value=JiraIssue( + key="ALLPOST-100", summary="Fix auth bug", status="Ready for stage (2)" + )) + config = build_config(clients, staging_store=staging_store) + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "ticket_id": "ALLPOST-100", + "has_ticket": True, + "pr_info": { + "pr_id": "42", + "pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42", + "pr_title": "Fix: auth bug", + "branch": "feature/ALLPOST-100-fix", + }, + } + result = await update_staging(state, config) + loaded = await staging_store.load("my-repo") + assert loaded is not None + assert loaded.has_ticket("ALLPOST-100") + + async def test_appends_ticket_to_existing_staging(self, tmp_path) -> None: + from datetime import date + from release_agent.models.release import StagingRelease + from release_agent.models.jira import JiraIssue + staging_store = JsonFileStagingStore(directory=tmp_path) + existing = StagingRelease( + version="v1.0.0", repo="my-repo", + started_at=date(2025, 1, 1), tickets=[] + ) + await staging_store.save(existing) + clients = build_mock_clients() + clients.jira.get_issue = AsyncMock(return_value=JiraIssue( + key="BILL-99", summary="New feature", status="Ready" + )) + config = build_config(clients, staging_store=staging_store) + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "ticket_id": "BILL-99", + "has_ticket": True, + "pr_info": { + "pr_id": "55", + "pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/55", + "pr_title": "Feat: new feature", + "branch": "feature/BILL-99-feat", + }, + } + await update_staging(state, config) + loaded = await staging_store.load("my-repo") + assert loaded is not None + assert loaded.has_ticket("BILL-99") + + async def test_skips_ticket_add_when_no_ticket(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + clients = build_mock_clients() + config = build_config(clients, staging_store=staging_store) + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "has_ticket": False, + } + await update_staging(state, config) + # No staging file should be created for ticket-less PR if no existing staging + # (or staging exists without new ticket added) + clients.jira.get_issue.assert_not_called() + + async def test_returns_empty_dict_when_no_staging_store(self) -> None: + from release_agent.models.jira import JiraIssue + clients = build_mock_clients() + clients.jira.get_issue = AsyncMock(return_value=JiraIssue( + key="ALLPOST-1", summary="Fix", status="Ready" + )) + config = build_config(clients, staging_store=None) + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "ticket_id": "ALLPOST-1", + "has_ticket": True, + "pr_info": { + "pr_id": "1", + "pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1", + "pr_title": "Fix", + "branch": "feature/ALLPOST-1", + }, + } + result = await update_staging(state, config) + assert result == {} + + async def test_uses_ticket_id_as_summary_on_jira_failure(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + clients = build_mock_clients() + clients.jira.get_issue = AsyncMock(side_effect=Exception("Jira unavailable")) + config = build_config(clients, staging_store=staging_store) + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "ticket_id": "ALLPOST-99", + "has_ticket": True, + "pr_info": { + "pr_id": "5", + "pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/5", + "pr_title": "Fix something", + "branch": "feature/ALLPOST-99_fix", + }, + } + result = await update_staging(state, config) + loaded = await staging_store.load("my-repo") + assert loaded is not None + assert loaded.tickets[0].id == "ALLPOST-99" + assert loaded.tickets[0].summary == "ALLPOST-99" + + async def test_sets_staging_dict_in_result(self, tmp_path) -> None: + from release_agent.models.jira import JiraIssue + staging_store = JsonFileStagingStore(directory=tmp_path) + clients = build_mock_clients() + clients.jira.get_issue = AsyncMock(return_value=JiraIssue( + key="ALLPOST-1", summary="S", status="Ready" + )) + config = build_config(clients, staging_store=staging_store) + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "ticket_id": "ALLPOST-1", + "has_ticket": True, + "pr_info": { + "pr_id": "1", + "pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1", + "pr_title": "Fix", + "branch": "feature/ALLPOST-1", + }, + } + result = await update_staging(state, config) + assert "staging" in result + assert isinstance(result["staging"], dict) + + +# --------------------------------------------------------------------------- +# notify_request_changes +# --------------------------------------------------------------------------- + +class TestNotifyRequestChanges: + async def test_calls_slack_send_approval_request(self) -> None: + clients = build_mock_clients() + clients.slack.send_approval_request = AsyncMock(return_value=True) + config = build_config(clients) + state = { + "pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"}, + "review_result": { + "verdict": "request_changes", + "summary": "Too many issues", + "issues": [{"severity": "blocker", "description": "No tests"}], + }, + } + result = await notify_request_changes(state, config) + clients.slack.send_approval_request.assert_called_once() + + async def test_appends_error_on_slack_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.slack.send_approval_request = AsyncMock(side_effect=ServiceError( + service="slack", status_code=500, detail="Webhook error" + )) + config = build_config(clients) + state = { + "pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "repo"}, + "review_result": {"summary": "Issues found", "issues": []}, + } + result = await notify_request_changes(state, config) + assert "errors" in result + + +# --------------------------------------------------------------------------- +# build_pr_completed_graph +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# auto_create_ticket node +# --------------------------------------------------------------------------- + +class TestAutoCreateTicket: + """Tests for the auto_create_ticket node.""" + + def _make_config_with_jira_project( + self, jira_project: str = "ALLPOST" + ): + clients = build_mock_clients() + clients.jira.create_issue = AsyncMock(return_value="ALLPOST-99") + clients.reviewer.generate_ticket_content = AsyncMock( + return_value=("My summary", "My description") + ) + config = build_config(clients) + config["configurable"]["default_jira_project"] = jira_project + return config, clients + + async def test_creates_jira_issue_and_returns_ticket_id(self) -> None: + config, clients = self._make_config_with_jira_project() + state = { + "pr_diff": "edit: main.py", + "pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"}, + } + + result = await auto_create_ticket(state, config) + + assert result.get("ticket_id") == "ALLPOST-99" + + async def test_sets_has_ticket_true(self) -> None: + config, clients = self._make_config_with_jira_project() + state = { + "pr_diff": "edit: main.py", + "pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"}, + } + + result = await auto_create_ticket(state, config) + + assert result.get("has_ticket") is True + + async def test_calls_generate_ticket_content(self) -> None: + config, clients = self._make_config_with_jira_project() + state = { + "pr_diff": "edit: main.py", + "pr_info": {"pr_title": "Fix login", "repo_name": "auth-service"}, + } + + await auto_create_ticket(state, config) + + clients.reviewer.generate_ticket_content.assert_awaited_once() + + async def test_calls_create_issue_with_project_key(self) -> None: + config, clients = self._make_config_with_jira_project(jira_project="MYPROJ") + clients.jira.create_issue = AsyncMock(return_value="MYPROJ-5") + config["configurable"]["default_jira_project"] = "MYPROJ" + state = { + "pr_diff": "d", + "pr_info": {"pr_title": "t", "repo_name": "r"}, + } + + await auto_create_ticket(state, config) + + call_kwargs = clients.jira.create_issue.call_args.kwargs + assert call_kwargs["project"] == "MYPROJ" + + async def test_appends_message_on_success(self) -> None: + config, _ = self._make_config_with_jira_project() + state = { + "pr_diff": "d", + "pr_info": {"pr_title": "t", "repo_name": "r"}, + } + + result = await auto_create_ticket(state, config) + + assert "messages" in result + assert len(result["messages"]) > 0 + + async def test_appends_error_on_create_issue_failure(self) -> None: + config, clients = self._make_config_with_jira_project() + clients.jira.create_issue = AsyncMock(side_effect=Exception("Jira down")) + state = { + "pr_diff": "d", + "pr_info": {"pr_title": "t", "repo_name": "r"}, + } + + result = await auto_create_ticket(state, config) + + assert "errors" in result + assert len(result["errors"]) > 0 + + async def test_appends_error_on_generate_content_failure(self) -> None: + config, clients = self._make_config_with_jira_project() + clients.reviewer.generate_ticket_content = AsyncMock(side_effect=RuntimeError("CLI fail")) + state = { + "pr_diff": "d", + "pr_info": {"pr_title": "t", "repo_name": "r"}, + } + + result = await auto_create_ticket(state, config) + + assert "errors" in result + + async def test_uses_default_project_from_config(self) -> None: + config, clients = self._make_config_with_jira_project(jira_project="TEAM") + clients.jira.create_issue = AsyncMock(return_value="TEAM-1") + state = { + "pr_diff": "d", + "pr_info": {"pr_title": "t", "repo_name": "r"}, + } + + result = await auto_create_ticket(state, config) + + assert result["ticket_id"] == "TEAM-1" + + +# --------------------------------------------------------------------------- +# build_pr_completed_graph +# --------------------------------------------------------------------------- + +class TestBuildPrCompletedGraph: + def test_returns_compiled_graph(self) -> None: + graph = build_pr_completed_graph() + assert graph is not None + + def test_graph_has_nodes(self) -> None: + graph = build_pr_completed_graph() + # The compiled graph object should be truthy + assert graph is not None + + def test_graph_includes_trigger_ci_build_node(self) -> None: + graph = build_pr_completed_graph() + # Graph nodes should include CI pipeline nodes + graph_nodes = graph.get_graph().nodes + assert "trigger_ci_build" in graph_nodes + + def test_graph_includes_poll_ci_build_node(self) -> None: + graph = build_pr_completed_graph() + graph_nodes = graph.get_graph().nodes + assert "poll_ci_build" in graph_nodes + + def test_graph_includes_notify_ci_result_node(self) -> None: + graph = build_pr_completed_graph() + graph_nodes = graph.get_graph().nodes + assert "notify_ci_result" in graph_nodes + + def test_graph_includes_auto_create_ticket_node(self) -> None: + graph = build_pr_completed_graph() + graph_nodes = graph.get_graph().nodes + assert "auto_create_ticket" in graph_nodes diff --git a/tests/graph/test_release.py b/tests/graph/test_release.py new file mode 100644 index 0000000..f631838 --- /dev/null +++ b/tests/graph/test_release.py @@ -0,0 +1,866 @@ +"""Tests for graph/release.py node functions. Written FIRST (TDD RED phase). + +Each node is an async function (state, config) -> dict. +Tests call nodes directly — no graph compilation required. +""" + +from datetime import date +from unittest.mock import AsyncMock, patch + +import pytest + +from release_agent.graph.dependencies import JsonFileStagingStore +from release_agent.graph.release import ( + approve_stage, + archive_release, + check_release_approvals, + create_release_pr, + interrupt_confirm_approve, + interrupt_confirm_merge_release, + interrupt_confirm_release, + interrupt_confirm_trigger, + list_pipelines, + load_staging, + merge_release_pr, + move_tickets_to_done, + send_slack_notification, + trigger_pipelines, + build_release_graph, +) +from release_agent.models.pipeline import PipelineInfo, ReleasePipelineStage +from release_agent.models.release import StagingRelease +from release_agent.models.ticket import TicketEntry +from tests.graph.conftest import build_config, build_mock_clients + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry: + return TicketEntry( + id=ticket_id, + summary="Fix something", + pr_id="42", + pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42", + pr_title="Fix: something", + branch=f"feature/{ticket_id}-fix", + merged_at=date(2025, 1, 15), + ) + + +def _make_staging( + *, + repo: str = "my-repo", + version: str = "v1.0.0", + tickets: list | None = None, +) -> StagingRelease: + t = tickets if tickets is not None else [_make_ticket()] + return StagingRelease( + version=version, + repo=repo, + started_at=date(2025, 1, 1), + tickets=t, + ) + + +def _staging_dict(staging: StagingRelease) -> dict: + return staging.model_dump(mode="json") + + +# --------------------------------------------------------------------------- +# load_staging +# --------------------------------------------------------------------------- + +class TestLoadStaging: + async def test_loads_staging_from_store(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await staging_store.save(staging) + clients = build_mock_clients() + config = build_config(clients, staging_store=staging_store) + state = {"repo_name": "my-repo"} + result = await load_staging(state, config) + assert "staging" in result + assert result["staging"]["version"] == "v1.0.0" + + async def test_returns_none_when_no_staging(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + clients = build_mock_clients() + config = build_config(clients, staging_store=staging_store) + state = {"repo_name": "nonexistent"} + result = await load_staging(state, config) + assert result.get("staging") is None + + async def test_staging_includes_tickets(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging(tickets=[_make_ticket("BILL-10"), _make_ticket("BILL-11")]) + await staging_store.save(staging) + clients = build_mock_clients() + config = build_config(clients, staging_store=staging_store) + state = {"repo_name": "my-repo"} + result = await load_staging(state, config) + assert len(result["staging"]["tickets"]) == 2 + + +# --------------------------------------------------------------------------- +# interrupt_confirm_release +# --------------------------------------------------------------------------- + +class TestInterruptConfirmRelease: + async def test_calls_interrupt_with_staging_summary(self) -> None: + config = build_config() + staging = _make_staging() + state = { + "repo_name": "my-repo", + "staging": _staging_dict(staging), + } + with patch("release_agent.graph.release.interrupt") as mock_interrupt: + mock_interrupt.return_value = "confirm" + await interrupt_confirm_release(state, config) + mock_interrupt.assert_called_once() + call_arg = mock_interrupt.call_args[0][0] + assert isinstance(call_arg, str) + + async def test_interrupt_contains_version_and_repo(self) -> None: + config = build_config() + staging = _make_staging(version="v2.5.0", repo="backend") + state = { + "repo_name": "backend", + "staging": _staging_dict(staging), + } + with patch("release_agent.graph.release.interrupt") as mock_interrupt: + mock_interrupt.return_value = "confirm" + await interrupt_confirm_release(state, config) + call_arg = mock_interrupt.call_args[0][0] + assert "v2.5.0" in call_arg or "backend" in call_arg + + +# --------------------------------------------------------------------------- +# create_release_pr +# --------------------------------------------------------------------------- + +class TestCreateReleasePr: + async def test_calls_azdo_create_pr(self) -> None: + clients = build_mock_clients() + clients.azdo.create_pr = AsyncMock(return_value={ + "pullRequestId": 99, + "lastMergeSourceCommit": {"commitId": "deadbeef"}, + }) + config = build_config(clients) + staging = _make_staging(version="v1.2.0") + state = { + "repo_name": "my-repo", + "version": "v1.2.0", + "staging": _staging_dict(staging), + } + result = await create_release_pr(state, config) + clients.azdo.create_pr.assert_called_once() + call_kwargs = clients.azdo.create_pr.call_args.kwargs + assert call_kwargs["repo"] == "my-repo" + + async def test_sets_release_pr_id(self) -> None: + clients = build_mock_clients() + clients.azdo.create_pr = AsyncMock(return_value={ + "pullRequestId": 77, + "lastMergeSourceCommit": {"commitId": "cafe1234"}, + }) + config = build_config(clients) + staging = _make_staging(version="v1.0.3") + state = { + "repo_name": "my-repo", + "version": "v1.0.3", + "staging": _staging_dict(staging), + } + result = await create_release_pr(state, config) + assert result["release_pr_id"] == "77" + + async def test_sets_release_pr_commit(self) -> None: + clients = build_mock_clients() + clients.azdo.create_pr = AsyncMock(return_value={ + "pullRequestId": 77, + "lastMergeSourceCommit": {"commitId": "cafe1234"}, + }) + config = build_config(clients) + staging = _make_staging() + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "staging": _staging_dict(staging), + } + result = await create_release_pr(state, config) + assert result["release_pr_commit"] == "cafe1234" + + async def test_re_raises_on_service_error(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.azdo.create_pr = AsyncMock(side_effect=ServiceError( + service="azdo", status_code=422, detail="Invalid branch" + )) + config = build_config(clients) + staging = _make_staging() + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "staging": _staging_dict(staging), + } + with pytest.raises(ServiceError): + await create_release_pr(state, config) + + +# --------------------------------------------------------------------------- +# interrupt_confirm_merge_release +# --------------------------------------------------------------------------- + +class TestInterruptConfirmMergeRelease: + async def test_calls_interrupt_with_pr_info(self) -> None: + config = build_config() + state = { + "release_pr_id": "99", + "version": "v1.0.0", + "repo_name": "my-repo", + } + with patch("release_agent.graph.release.interrupt") as mock_interrupt: + mock_interrupt.return_value = "confirm" + await interrupt_confirm_merge_release(state, config) + mock_interrupt.assert_called_once() + call_arg = mock_interrupt.call_args[0][0] + assert isinstance(call_arg, str) + assert len(call_arg) > 0 + + +# --------------------------------------------------------------------------- +# merge_release_pr +# --------------------------------------------------------------------------- + +class TestMergeReleasePr: + async def test_calls_azdo_merge_pr(self) -> None: + clients = build_mock_clients() + clients.azdo.merge_pr = AsyncMock(return_value=True) + config = build_config(clients) + state = { + "release_pr_id": "99", + "release_pr_commit": "abc123", + } + await merge_release_pr(state, config) + clients.azdo.merge_pr.assert_called_once_with( + pr_id=99, last_merge_source_commit="abc123" + ) + + async def test_re_raises_on_service_error(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError( + service="azdo", status_code=409, detail="Conflict" + )) + config = build_config(clients) + state = {"release_pr_id": "99", "release_pr_commit": "abc"} + with pytest.raises(ServiceError): + await merge_release_pr(state, config) + + +# --------------------------------------------------------------------------- +# move_tickets_to_done +# --------------------------------------------------------------------------- + +class TestMoveTicketsToDone: + async def test_transitions_all_tickets(self) -> None: + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock(return_value=True) + config = build_config(clients) + staging = _make_staging(tickets=[_make_ticket("BILL-1"), _make_ticket("BILL-2")]) + state = {"staging": _staging_dict(staging)} + await move_tickets_to_done(state, config) + assert clients.jira.transition_issue.call_count == 2 + + async def test_calls_transition_with_done_name(self) -> None: + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock(return_value=True) + config = build_config(clients) + staging = _make_staging(tickets=[_make_ticket("BILL-1")]) + state = {"staging": _staging_dict(staging)} + await move_tickets_to_done(state, config) + call_args = clients.jira.transition_issue.call_args_list[0] + ticket_id, transition = call_args[0] + assert ticket_id == "BILL-1" + assert "done" in transition.lower() or "released" in transition.lower() + + async def test_appends_error_on_jira_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock(side_effect=ServiceError( + service="jira", status_code=500, detail="Error" + )) + config = build_config(clients) + staging = _make_staging(tickets=[_make_ticket()]) + state = {"staging": _staging_dict(staging)} + result = await move_tickets_to_done(state, config) + assert "errors" in result + + async def test_empty_tickets_no_calls(self) -> None: + clients = build_mock_clients() + clients.jira.transition_issue = AsyncMock() + config = build_config(clients) + staging = _make_staging(tickets=[]) + state = {"staging": _staging_dict(staging)} + await move_tickets_to_done(state, config) + clients.jira.transition_issue.assert_not_called() + + +# --------------------------------------------------------------------------- +# send_slack_notification +# --------------------------------------------------------------------------- + +class TestSendSlackNotification: + async def test_calls_slack_send_release_notification(self) -> None: + clients = build_mock_clients() + clients.slack.send_release_notification = AsyncMock(return_value=True) + config = build_config(clients) + staging = _make_staging() + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "staging": _staging_dict(staging), + } + result = await send_slack_notification(state, config) + clients.slack.send_release_notification.assert_called_once() + + async def test_appends_error_on_slack_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.slack.send_release_notification = AsyncMock(side_effect=ServiceError( + service="slack", status_code=500, detail="Webhook error" + )) + config = build_config(clients) + staging = _make_staging() + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "staging": _staging_dict(staging), + } + result = await send_slack_notification(state, config) + assert "errors" in result + + +# --------------------------------------------------------------------------- +# archive_release +# --------------------------------------------------------------------------- + +class TestArchiveRelease: + async def test_archives_staging_to_store(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging() + await staging_store.save(staging) + clients = build_mock_clients() + config = build_config(clients, staging_store=staging_store) + state = { + "repo_name": "my-repo", + "staging": _staging_dict(staging), + } + await archive_release(state, config) + # Staging should be gone now + assert await staging_store.load("my-repo") is None + + async def test_archive_file_created_in_store(self, tmp_path) -> None: + staging_store = JsonFileStagingStore(directory=tmp_path) + staging = _make_staging(version="v3.0.0") + await staging_store.save(staging) + clients = build_mock_clients() + config = build_config(clients, staging_store=staging_store) + state = { + "repo_name": "my-repo", + "staging": _staging_dict(staging), + } + await archive_release(state, config) + versions = await staging_store.list_versions("my-repo") + assert "v3.0.0" in versions + + +# --------------------------------------------------------------------------- +# list_pipelines +# --------------------------------------------------------------------------- + +class TestListPipelines: + async def test_fetches_pipelines_from_azdo(self) -> None: + clients = build_mock_clients() + pipelines = [PipelineInfo(id=1, name="build", repo="my-repo")] + clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines) + config = build_config(clients) + state = {"repo_name": "my-repo"} + result = await list_pipelines(state, config) + clients.azdo.list_build_pipelines.assert_called_once_with(repo="my-repo") + assert "pipelines" in result + assert len(result["pipelines"]) == 1 + + async def test_stores_pipelines_as_list_of_dicts(self) -> None: + clients = build_mock_clients() + pipelines = [ + PipelineInfo(id=1, name="build", repo="my-repo"), + PipelineInfo(id=2, name="deploy", repo="my-repo"), + ] + clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines) + config = build_config(clients) + state = {"repo_name": "my-repo"} + result = await list_pipelines(state, config) + assert len(result["pipelines"]) == 2 + assert result["pipelines"][0]["id"] == 1 + + async def test_empty_pipelines_stored_as_empty_list(self) -> None: + clients = build_mock_clients() + clients.azdo.list_build_pipelines = AsyncMock(return_value=[]) + config = build_config(clients) + state = {"repo_name": "my-repo"} + result = await list_pipelines(state, config) + assert result["pipelines"] == [] + + async def test_appends_error_on_service_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.azdo.list_build_pipelines = AsyncMock(side_effect=ServiceError( + service="azdo", status_code=500, detail="Error" + )) + config = build_config(clients) + state = {"repo_name": "my-repo"} + result = await list_pipelines(state, config) + assert "errors" in result + + +# --------------------------------------------------------------------------- +# interrupt_confirm_trigger +# --------------------------------------------------------------------------- + +class TestInterruptConfirmTrigger: + async def test_calls_interrupt_with_pipelines_summary(self) -> None: + config = build_config() + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}], + } + with patch("release_agent.graph.release.interrupt") as mock_interrupt: + mock_interrupt.return_value = "confirm" + await interrupt_confirm_trigger(state, config) + mock_interrupt.assert_called_once() + call_arg = mock_interrupt.call_args[0][0] + assert isinstance(call_arg, str) + + +# --------------------------------------------------------------------------- +# trigger_pipelines +# --------------------------------------------------------------------------- + +class TestTriggerPipelines: + async def test_triggers_each_pipeline(self) -> None: + clients = build_mock_clients() + clients.azdo.trigger_pipeline = AsyncMock(return_value={"id": 1001}) + config = build_config(clients) + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "pipelines": [ + {"id": 1, "name": "build", "repo": "my-repo"}, + {"id": 2, "name": "deploy", "repo": "my-repo"}, + ], + } + result = await trigger_pipelines(state, config) + assert clients.azdo.trigger_pipeline.call_count == 2 + assert "triggered_builds" in result + assert len(result["triggered_builds"]) == 2 + + async def test_no_pipelines_no_calls(self) -> None: + clients = build_mock_clients() + clients.azdo.trigger_pipeline = AsyncMock() + config = build_config(clients) + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "pipelines": [], + } + result = await trigger_pipelines(state, config) + clients.azdo.trigger_pipeline.assert_not_called() + assert result["triggered_builds"] == [] + + async def test_appends_error_on_trigger_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.azdo.trigger_pipeline = AsyncMock(side_effect=ServiceError( + service="azdo", status_code=500, detail="Error" + )) + config = build_config(clients) + state = { + "repo_name": "my-repo", + "version": "v1.0.0", + "pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}], + } + result = await trigger_pipelines(state, config) + assert "errors" in result + + +# --------------------------------------------------------------------------- +# check_release_approvals +# --------------------------------------------------------------------------- + +class TestCheckReleaseApprovals: + async def test_fetches_pending_approvals_from_builds(self) -> None: + clients = build_mock_clients() + clients.azdo.get_build_status = AsyncMock(return_value="completed") + config = build_config(clients) + state = { + "triggered_builds": [{"id": 1001}], + } + result = await check_release_approvals(state, config) + assert "pending_approvals" in result + + async def test_empty_builds_means_no_approvals(self) -> None: + clients = build_mock_clients() + config = build_config(clients) + state = {"triggered_builds": []} + result = await check_release_approvals(state, config) + assert result["pending_approvals"] == [] + + async def test_appends_error_on_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.azdo.get_build_status = AsyncMock(side_effect=ServiceError( + service="azdo", status_code=500, detail="Error" + )) + config = build_config(clients) + state = {"triggered_builds": [{"id": 1001}]} + result = await check_release_approvals(state, config) + assert "errors" in result + + +# --------------------------------------------------------------------------- +# interrupt_confirm_approve +# --------------------------------------------------------------------------- + +class TestInterruptConfirmApprove: + async def test_calls_interrupt_with_approvals_summary(self) -> None: + config = build_config() + state = { + "pending_approvals": [{"approval_id": "aaa", "stage_name": "Production"}], + "version": "v1.0.0", + } + with patch("release_agent.graph.release.interrupt") as mock_interrupt: + mock_interrupt.return_value = "confirm" + await interrupt_confirm_approve(state, config) + mock_interrupt.assert_called_once() + + +# --------------------------------------------------------------------------- +# approve_stage +# --------------------------------------------------------------------------- + +class TestApproveStage: + async def test_approves_each_pending_approval(self) -> None: + clients = build_mock_clients() + clients.azdo.approve_release = AsyncMock(return_value={"status": "approved"}) + config = build_config(clients) + state = { + "pending_approvals": [ + {"approval_id": "aaa"}, + {"approval_id": "bbb"}, + ], + } + result = await approve_stage(state, config) + assert clients.azdo.approve_release.call_count == 2 + + async def test_no_approvals_no_calls(self) -> None: + clients = build_mock_clients() + clients.azdo.approve_release = AsyncMock() + config = build_config(clients) + state = {"pending_approvals": []} + await approve_stage(state, config) + clients.azdo.approve_release.assert_not_called() + + async def test_appends_error_on_failure(self) -> None: + from release_agent.exceptions import ServiceError + clients = build_mock_clients() + clients.azdo.approve_release = AsyncMock(side_effect=ServiceError( + service="azdo", status_code=500, detail="Error" + )) + config = build_config(clients) + state = {"pending_approvals": [{"approval_id": "aaa"}]} + result = await approve_stage(state, config) + assert "errors" in result + + +# --------------------------------------------------------------------------- +# build_release_graph +# --------------------------------------------------------------------------- + +class TestBuildReleaseGraph: + def test_returns_compiled_graph(self) -> None: + graph = build_release_graph() + assert graph is not None + + def test_graph_includes_trigger_ci_build_main_node(self) -> None: + graph = build_release_graph() + graph_nodes = graph.get_graph().nodes + assert "trigger_ci_build_main" in graph_nodes + + def test_graph_includes_poll_ci_build_main_node(self) -> None: + graph = build_release_graph() + graph_nodes = graph.get_graph().nodes + assert "poll_ci_build_main" in graph_nodes + + def test_graph_includes_wait_for_cd_release_node(self) -> None: + graph = build_release_graph() + graph_nodes = graph.get_graph().nodes + assert "wait_for_cd_release" in graph_nodes + + def test_graph_includes_poll_release_approvals_node(self) -> None: + graph = build_release_graph() + graph_nodes = graph.get_graph().nodes + assert "poll_release_approvals" in graph_nodes + + def test_graph_includes_interrupt_sandbox_approval_node(self) -> None: + graph = build_release_graph() + graph_nodes = graph.get_graph().nodes + assert "interrupt_sandbox_approval" in graph_nodes + + def test_graph_includes_interrupt_prod_approval_node(self) -> None: + graph = build_release_graph() + graph_nodes = graph.get_graph().nodes + assert "interrupt_prod_approval" in graph_nodes + + def test_graph_includes_execute_sandbox_approval_node(self) -> None: + graph = build_release_graph() + graph_nodes = graph.get_graph().nodes + assert "execute_sandbox_approval" in graph_nodes + + def test_graph_includes_execute_prod_approval_node(self) -> None: + graph = build_release_graph() + graph_nodes = graph.get_graph().nodes + assert "execute_prod_approval" in graph_nodes + + def test_graph_includes_notify_ci_failure_node(self) -> None: + graph = build_release_graph() + graph_nodes = graph.get_graph().nodes + assert "notify_ci_failure" in graph_nodes + + +# --------------------------------------------------------------------------- +# New release graph node: wait_for_cd_release +# --------------------------------------------------------------------------- + +class TestWaitForCdRelease: + """Tests for wait_for_cd_release node.""" + + async def test_sets_release_id_when_found(self) -> None: + from release_agent.graph.release import wait_for_cd_release + clients = build_mock_clients() + clients.azdo.get_latest_release.return_value = {"id": 100, "name": "Release-100"} + config = build_config(clients) + state = {"release_definition_id": 5, "repo_name": "my-repo"} + + result = await wait_for_cd_release(state, config) + + assert "release_id" in result + assert result["release_id"] == 100 + + async def test_appends_error_when_no_release(self) -> None: + from release_agent.graph.release import wait_for_cd_release + clients = build_mock_clients() + clients.azdo.get_latest_release.return_value = {} + config = build_config(clients) + state = {"release_definition_id": 5, "repo_name": "my-repo"} + + result = await wait_for_cd_release(state, config) + + assert "errors" in result + + async def test_works_without_release_definition_id(self) -> None: + from release_agent.graph.release import wait_for_cd_release + clients = build_mock_clients() + config = build_config(clients) + state = {"repo_name": "my-repo"} + + result = await wait_for_cd_release(state, config) + + assert isinstance(result, dict) + + +# --------------------------------------------------------------------------- +# New release graph node: poll_release_approvals +# --------------------------------------------------------------------------- + +class TestPollReleaseApprovals: + """Tests for poll_release_approvals node.""" + + async def test_sets_pending_approvals_from_azdo(self) -> None: + from release_agent.graph.release import poll_release_approvals + from release_agent.models.build import ApprovalRecord + clients = build_mock_clients() + clients.azdo.get_release_approvals.return_value = [ + ApprovalRecord(approval_id="a1", stage_name="Sandbox", status="pending", release_id=10), + ] + config = build_config(clients) + state = {"release_id": 10} + + result = await poll_release_approvals(state, config) + + assert "pending_approvals" in result + assert len(result["pending_approvals"]) == 1 + + async def test_returns_empty_list_when_no_approvals(self) -> None: + from release_agent.graph.release import poll_release_approvals + clients = build_mock_clients() + clients.azdo.get_release_approvals.return_value = [] + config = build_config(clients) + state = {"release_id": 10} + + result = await poll_release_approvals(state, config) + + assert result.get("pending_approvals") == [] + + async def test_appends_error_on_failure(self) -> None: + from release_agent.exceptions import ServiceError + from release_agent.graph.release import poll_release_approvals + clients = build_mock_clients() + clients.azdo.get_release_approvals.side_effect = ServiceError( + service="azdo", status_code=500, detail="error" + ) + config = build_config(clients) + state = {"release_id": 10} + + result = await poll_release_approvals(state, config) + + assert "errors" in result + + +# --------------------------------------------------------------------------- +# New release graph node: interrupt_sandbox_approval +# --------------------------------------------------------------------------- + +class TestInterruptSandboxApproval: + async def test_calls_interrupt(self) -> None: + from release_agent.graph.release import interrupt_sandbox_approval + config = build_config() + state = { + "pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}], + "version": "v1.0.0", + } + with patch("release_agent.graph.release.interrupt") as mock_interrupt: + mock_interrupt.return_value = "confirm" + await interrupt_sandbox_approval(state, config) + mock_interrupt.assert_called_once() + + async def test_sets_current_stage_to_sandbox_pending(self) -> None: + from release_agent.graph.release import interrupt_sandbox_approval + config = build_config() + state = { + "pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}], + } + with patch("release_agent.graph.release.interrupt", return_value="yes"): + result = await interrupt_sandbox_approval(state, config) + assert result.get("current_stage") == "sandbox_pending" + + +# --------------------------------------------------------------------------- +# New release graph node: interrupt_prod_approval +# --------------------------------------------------------------------------- + +class TestInterruptProdApproval: + async def test_calls_interrupt(self) -> None: + from release_agent.graph.release import interrupt_prod_approval + config = build_config() + state = { + "pending_approvals": [{"approval_id": "y", "stage_name": "Production"}], + "version": "v1.0.0", + } + with patch("release_agent.graph.release.interrupt") as mock_interrupt: + mock_interrupt.return_value = "confirm" + await interrupt_prod_approval(state, config) + mock_interrupt.assert_called_once() + + async def test_sets_current_stage_to_prod_pending(self) -> None: + from release_agent.graph.release import interrupt_prod_approval + config = build_config() + state = { + "pending_approvals": [{"approval_id": "y", "stage_name": "Production"}], + } + with patch("release_agent.graph.release.interrupt", return_value="yes"): + result = await interrupt_prod_approval(state, config) + assert result.get("current_stage") == "prod_pending" + + +# --------------------------------------------------------------------------- +# New release graph node: execute_sandbox_approval +# --------------------------------------------------------------------------- + +class TestExecuteSandboxApproval: + async def test_approves_sandbox_approvals(self) -> None: + from release_agent.graph.release import execute_sandbox_approval + clients = build_mock_clients() + clients.azdo.approve_release.return_value = {"status": "approved"} + config = build_config(clients) + state = { + "pending_approvals": [{"approval_id": "sb1", "stage_name": "Sandbox"}], + } + + result = await execute_sandbox_approval(state, config) + + clients.azdo.approve_release.assert_called() + + async def test_returns_empty_dict_on_success(self) -> None: + from release_agent.graph.release import execute_sandbox_approval + clients = build_mock_clients() + clients.azdo.approve_release.return_value = {"status": "approved"} + config = build_config(clients) + state = {"pending_approvals": [{"approval_id": "sb1"}]} + + result = await execute_sandbox_approval(state, config) + + assert "errors" not in result or result["errors"] == [] + + +# --------------------------------------------------------------------------- +# New release graph node: execute_prod_approval +# --------------------------------------------------------------------------- + +class TestExecuteProdApproval: + async def test_approves_prod_approvals(self) -> None: + from release_agent.graph.release import execute_prod_approval + clients = build_mock_clients() + clients.azdo.approve_release.return_value = {"status": "approved"} + config = build_config(clients) + state = { + "pending_approvals": [{"approval_id": "pd1", "stage_name": "Production"}], + } + + result = await execute_prod_approval(state, config) + + clients.azdo.approve_release.assert_called() + + +# --------------------------------------------------------------------------- +# New release graph node: notify_ci_failure +# --------------------------------------------------------------------------- + +class TestNotifyCiFailure: + async def test_sends_slack_notification(self) -> None: + from release_agent.graph.release import notify_ci_failure + clients = build_mock_clients() + clients.slack.send_notification.return_value = True + config = build_config(clients) + state = { + "repo_name": "my-repo", + "ci_build_result": "failed", + "ci_build_url": "https://build/1", + } + + result = await notify_ci_failure(state, config) + + clients.slack.send_notification.assert_called_once() + + async def test_appends_message_on_success(self) -> None: + from release_agent.graph.release import notify_ci_failure + clients = build_mock_clients() + clients.slack.send_notification.return_value = True + config = build_config(clients) + state = {"repo_name": "my-repo", "ci_build_result": "failed"} + + result = await notify_ci_failure(state, config) + + assert "messages" in result or isinstance(result, dict) diff --git a/tests/graph/test_routing.py b/tests/graph/test_routing.py new file mode 100644 index 0000000..e64cb57 --- /dev/null +++ b/tests/graph/test_routing.py @@ -0,0 +1,302 @@ +"""Tests for graph/routing.py. Written FIRST (TDD RED phase). + +All routing functions are pure — they take a state dict and return a string. +Every branch is tested, including missing state fields (defaults to falsy). +""" + +import pytest + +from release_agent.graph.routing import ( + has_pending_approvals, + has_pipelines, + has_ticket, + is_pr_already_merged, + is_review_approved, + route_after_fetch, + route_approval_stage, + route_ci_result, + should_continue_to_release, +) + + +# --------------------------------------------------------------------------- +# is_pr_already_merged +# --------------------------------------------------------------------------- + +class TestIsPrAlreadyMerged: + def test_returns_merged_when_true(self) -> None: + state = {"pr_already_merged": True} + assert is_pr_already_merged(state) == "merged" + + def test_returns_active_when_false(self) -> None: + state = {"pr_already_merged": False} + assert is_pr_already_merged(state) == "active" + + def test_returns_active_when_field_missing(self) -> None: + state = {} + assert is_pr_already_merged(state) == "active" + + def test_returns_active_when_none(self) -> None: + state = {"pr_already_merged": None} + assert is_pr_already_merged(state) == "active" + + +# --------------------------------------------------------------------------- +# is_review_approved +# --------------------------------------------------------------------------- + +class TestIsReviewApproved: + def test_returns_approve_when_true(self) -> None: + state = {"review_approved": True} + assert is_review_approved(state) == "approve" + + def test_returns_request_changes_when_false(self) -> None: + state = {"review_approved": False} + assert is_review_approved(state) == "request_changes" + + def test_returns_request_changes_when_field_missing(self) -> None: + state = {} + assert is_review_approved(state) == "request_changes" + + def test_returns_request_changes_when_none(self) -> None: + state = {"review_approved": None} + assert is_review_approved(state) == "request_changes" + + +# --------------------------------------------------------------------------- +# has_ticket +# --------------------------------------------------------------------------- + +class TestHasTicket: + def test_returns_yes_when_true(self) -> None: + state = {"has_ticket": True} + assert has_ticket(state) == "yes" + + def test_returns_no_when_false(self) -> None: + state = {"has_ticket": False} + assert has_ticket(state) == "no" + + def test_returns_no_when_field_missing(self) -> None: + state = {} + assert has_ticket(state) == "no" + + def test_returns_no_when_none(self) -> None: + state = {"has_ticket": None} + assert has_ticket(state) == "no" + + +# --------------------------------------------------------------------------- +# should_continue_to_release +# --------------------------------------------------------------------------- + +class TestShouldContinueToRelease: + def test_returns_yes_when_true(self) -> None: + state = {"continue_to_release": True} + assert should_continue_to_release(state) == "yes" + + def test_returns_no_when_false(self) -> None: + state = {"continue_to_release": False} + assert should_continue_to_release(state) == "no" + + def test_returns_no_when_field_missing(self) -> None: + state = {} + assert should_continue_to_release(state) == "no" + + def test_returns_no_when_none(self) -> None: + state = {"continue_to_release": None} + assert should_continue_to_release(state) == "no" + + +# --------------------------------------------------------------------------- +# has_pipelines +# --------------------------------------------------------------------------- + +class TestHasPipelines: + def test_returns_yes_when_non_empty_list(self) -> None: + state = {"pipelines": [{"id": 1}]} + assert has_pipelines(state) == "yes" + + def test_returns_no_when_empty_list(self) -> None: + state = {"pipelines": []} + assert has_pipelines(state) == "no" + + def test_returns_no_when_field_missing(self) -> None: + state = {} + assert has_pipelines(state) == "no" + + def test_returns_no_when_none(self) -> None: + state = {"pipelines": None} + assert has_pipelines(state) == "no" + + def test_returns_yes_with_multiple_pipelines(self) -> None: + state = {"pipelines": [{"id": 1}, {"id": 2}]} + assert has_pipelines(state) == "yes" + + +# --------------------------------------------------------------------------- +# has_pending_approvals +# --------------------------------------------------------------------------- + +class TestHasPendingApprovals: + def test_returns_yes_when_non_empty_list(self) -> None: + state = {"pending_approvals": [{"approval_id": "abc"}]} + assert has_pending_approvals(state) == "yes" + + def test_returns_no_when_empty_list(self) -> None: + state = {"pending_approvals": []} + assert has_pending_approvals(state) == "no" + + def test_returns_no_when_field_missing(self) -> None: + state = {} + assert has_pending_approvals(state) == "no" + + def test_returns_no_when_none(self) -> None: + state = {"pending_approvals": None} + assert has_pending_approvals(state) == "no" + + def test_returns_yes_with_multiple_approvals(self) -> None: + state = {"pending_approvals": [{"approval_id": "a"}, {"approval_id": "b"}]} + assert has_pending_approvals(state) == "yes" + + +# --------------------------------------------------------------------------- +# route_ci_result +# --------------------------------------------------------------------------- + +class TestRouteCiResult: + """Tests for route_ci_result routing function.""" + + def test_returns_ci_passed_when_succeeded(self) -> None: + state = {"ci_build_result": "succeeded"} + assert route_ci_result(state) == "ci_passed" + + def test_returns_ci_failed_when_failed(self) -> None: + state = {"ci_build_result": "failed"} + assert route_ci_result(state) == "ci_failed" + + def test_returns_ci_failed_when_canceled(self) -> None: + state = {"ci_build_result": "canceled"} + assert route_ci_result(state) == "ci_failed" + + def test_returns_ci_failed_when_partially_succeeded(self) -> None: + state = {"ci_build_result": "partiallySucceeded"} + assert route_ci_result(state) == "ci_failed" + + def test_returns_ci_failed_when_field_missing(self) -> None: + state = {} + assert route_ci_result(state) == "ci_failed" + + def test_returns_ci_failed_when_none(self) -> None: + state = {"ci_build_result": None} + assert route_ci_result(state) == "ci_failed" + + def test_returns_ci_failed_when_empty_string(self) -> None: + state = {"ci_build_result": ""} + assert route_ci_result(state) == "ci_failed" + + def test_case_sensitive_succeeded(self) -> None: + # AzDo returns "succeeded" (lowercase) + state = {"ci_build_result": "succeeded"} + assert route_ci_result(state) == "ci_passed" + + +# --------------------------------------------------------------------------- +# route_approval_stage +# --------------------------------------------------------------------------- + +class TestRouteApprovalStage: + """Tests for route_approval_stage routing function.""" + + def test_returns_all_deployed_when_no_pending_approvals(self) -> None: + state = {"pending_approvals": []} + assert route_approval_stage(state) == "all_deployed" + + def test_returns_all_deployed_when_field_missing(self) -> None: + state = {} + assert route_approval_stage(state) == "all_deployed" + + def test_returns_all_deployed_when_none(self) -> None: + state = {"pending_approvals": None} + assert route_approval_stage(state) == "all_deployed" + + def test_returns_sandbox_pending_when_sandbox_approval_exists(self) -> None: + state = { + "current_stage": "sandbox_pending", + "pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}], + } + assert route_approval_stage(state) == "sandbox_pending" + + def test_returns_prod_pending_when_prod_approval_exists(self) -> None: + state = { + "current_stage": "prod_pending", + "pending_approvals": [{"approval_id": "y", "stage_name": "Production"}], + } + assert route_approval_stage(state) == "prod_pending" + + def test_uses_current_stage_field_when_present(self) -> None: + state = { + "current_stage": "sandbox_pending", + "pending_approvals": [{"approval_id": "z"}], + } + assert route_approval_stage(state) == "sandbox_pending" + + def test_returns_all_deployed_when_no_current_stage_and_has_approvals(self) -> None: + # When current_stage is missing but approvals exist, stage is unknown + # so we treat as sandbox by default (first stage) + state = { + "pending_approvals": [{"approval_id": "a"}], + } + # Must return either sandbox_pending or prod_pending (not all_deployed) + result = route_approval_stage(state) + assert result in ("sandbox_pending", "prod_pending") + + def test_sandbox_pending_from_current_stage(self) -> None: + state = {"current_stage": "sandbox_pending", "pending_approvals": [{"approval_id": "x"}]} + assert route_approval_stage(state) == "sandbox_pending" + + def test_prod_pending_from_current_stage(self) -> None: + state = {"current_stage": "prod_pending", "pending_approvals": [{"approval_id": "x"}]} + assert route_approval_stage(state) == "prod_pending" + + +# --------------------------------------------------------------------------- +# route_after_fetch +# --------------------------------------------------------------------------- + +class TestRouteAfterFetch: + """Tests for route_after_fetch — 3-way routing replacing is_pr_already_merged.""" + + def test_returns_merged_when_pr_already_merged(self) -> None: + state = {"pr_already_merged": True} + assert route_after_fetch(state) == "merged" + + def test_returns_active_with_ticket_when_active_and_has_ticket(self) -> None: + state = {"pr_already_merged": False, "has_ticket": True} + assert route_after_fetch(state) == "active_with_ticket" + + def test_returns_active_no_ticket_when_active_and_no_ticket(self) -> None: + state = {"pr_already_merged": False, "has_ticket": False} + assert route_after_fetch(state) == "active_no_ticket" + + def test_returns_active_no_ticket_when_has_ticket_missing(self) -> None: + state = {"pr_already_merged": False} + assert route_after_fetch(state) == "active_no_ticket" + + def test_returns_active_no_ticket_when_has_ticket_none(self) -> None: + state = {"pr_already_merged": False, "has_ticket": None} + assert route_after_fetch(state) == "active_no_ticket" + + def test_returns_active_no_ticket_when_all_fields_missing(self) -> None: + state = {} + assert route_after_fetch(state) == "active_no_ticket" + + def test_merged_takes_precedence_over_has_ticket(self) -> None: + # Even if has_ticket is True, merged PR should route to "merged" + state = {"pr_already_merged": True, "has_ticket": True} + assert route_after_fetch(state) == "merged" + + def test_returns_active_with_ticket_ignores_merged_false(self) -> None: + state = {"pr_already_merged": False, "has_ticket": True} + result = route_after_fetch(state) + assert result != "merged" + assert result == "active_with_ticket" diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/scripts/test_migrate.py b/tests/scripts/test_migrate.py new file mode 100644 index 0000000..43f55f8 --- /dev/null +++ b/tests/scripts/test_migrate.py @@ -0,0 +1,332 @@ +"""Tests for scripts/migrate_json_to_db.py. + +Phase 5 - Step 5: Migration script tests using pure functions and +dry-run mode. No real database required. +Written FIRST (TDD RED phase). +""" + +import json +from datetime import date +from pathlib import Path + +import pytest + +from scripts.migrate_json_to_db import ( + collect_json_files, + parse_staging_json, + parse_archived_json, + build_staging_insert_sql, + build_archived_insert_sql, + is_archived_filename, + is_staging_filename, + MigrationRecord, +) + + +# --------------------------------------------------------------------------- +# Fixture data (mirrors real JSON structure from release-workflow/releases/) +# --------------------------------------------------------------------------- + +STAGING_JSON = { + "version": "v1.0.0", + "repo": "Billo.Platform.Document", + "started_at": "2026-03-17", + "tickets": [ + { + "id": "ALLPOST-4219", + "summary": "Test release bot", + "pr_id": "10460", + "pr_url": "https://dev.azure.com/billodev/Billo%20App%20Platform/_git/Billo.Platform.Document/pullrequest/10460", + "pr_title": "chore: trigger release bot test", + "branch": "feature_ALLPOST-4219_test_release_bot", + "merged_at": "2026-03-17", + } + ], +} + +PAYMENT_STAGING_JSON = { + "version": "v1.0.1", + "repo": "Billo.Platform.Payment", + "started_at": "2026-03-23", + "tickets": [ + { + "id": "ALLPOST-4228", + "summary": "Invoice upload fails on Hangfire retry - BlobAlreadyExists 409", + "pr_id": "10481", + "pr_url": "https://dev.azure.com/billodev/Billo%20App%20Platform/_git/Billo.Platform.Payment/pullrequest/10481", + "pr_title": "Invoice upload fails on Hangfire retry - BlobAlreadyExists 409", + "branch": "bug/ALLPOST-4228_fix-invoice-upload-blob-already-exists", + "merged_at": "2026-03-23", + } + ], +} + +# Archived JSON has an additional released_at field +ARCHIVED_JSON = { + "version": "v1.0.0", + "repo": "Billo.Platform.Payment", + "started_at": "2026-01-01", + "tickets": [], + "released_at": "2026-01-15", +} + + +# --------------------------------------------------------------------------- +# is_staging_filename / is_archived_filename +# --------------------------------------------------------------------------- + +class TestFileNameClassification: + def test_staging_filename_identified(self) -> None: + assert is_staging_filename("Billo.Platform.Document.json") is True + + def test_archived_filename_identified(self) -> None: + assert is_archived_filename("Billo.Platform.Payment_v1.0.1_2026-03-23.json") is True + + def test_staging_filename_not_archived(self) -> None: + assert is_archived_filename("Billo.Platform.Document.json") is False + + def test_archived_filename_not_staging(self) -> None: + assert is_staging_filename("Billo.Platform.Payment_v1.0.1_2026-03-23.json") is False + + def test_non_json_file_is_not_staging(self) -> None: + assert is_staging_filename("README.md") is False + + def test_non_json_file_is_not_archived(self) -> None: + assert is_archived_filename("README.md") is False + + +# --------------------------------------------------------------------------- +# parse_staging_json +# --------------------------------------------------------------------------- + +class TestParseStagingJson: + def test_parse_returns_migration_record(self) -> None: + record = parse_staging_json(STAGING_JSON) + assert isinstance(record, MigrationRecord) + + def test_parse_extracts_repo(self) -> None: + record = parse_staging_json(STAGING_JSON) + assert record.repo == "Billo.Platform.Document" + + def test_parse_extracts_version(self) -> None: + record = parse_staging_json(STAGING_JSON) + assert record.version == "v1.0.0" + + def test_parse_extracts_started_at(self) -> None: + record = parse_staging_json(STAGING_JSON) + assert record.started_at == date(2026, 3, 17) + + def test_parse_extracts_tickets(self) -> None: + record = parse_staging_json(STAGING_JSON) + assert len(record.tickets) == 1 + assert record.tickets[0]["id"] == "ALLPOST-4219" + + def test_parse_staging_has_no_released_at(self) -> None: + record = parse_staging_json(STAGING_JSON) + assert record.released_at is None + + def test_parse_staging_with_multiple_tickets(self) -> None: + data = { + **STAGING_JSON, + "tickets": [ + {**STAGING_JSON["tickets"][0], "id": "ALLPOST-1"}, + {**STAGING_JSON["tickets"][0], "id": "ALLPOST-2"}, + ], + } + record = parse_staging_json(data) + assert len(record.tickets) == 2 + + def test_parse_staging_with_empty_tickets(self) -> None: + data = {**STAGING_JSON, "tickets": []} + record = parse_staging_json(data) + assert record.tickets == [] + + +# --------------------------------------------------------------------------- +# parse_archived_json +# --------------------------------------------------------------------------- + +class TestParseArchivedJson: + def test_parse_returns_migration_record(self) -> None: + record = parse_archived_json(ARCHIVED_JSON) + assert isinstance(record, MigrationRecord) + + def test_parse_extracts_released_at(self) -> None: + record = parse_archived_json(ARCHIVED_JSON) + assert record.released_at == date(2026, 1, 15) + + def test_parse_extracts_repo(self) -> None: + record = parse_archived_json(ARCHIVED_JSON) + assert record.repo == "Billo.Platform.Payment" + + def test_parse_extracts_version(self) -> None: + record = parse_archived_json(ARCHIVED_JSON) + assert record.version == "v1.0.0" + + +# --------------------------------------------------------------------------- +# build_staging_insert_sql +# --------------------------------------------------------------------------- + +class TestBuildStagingInsertSql: + def test_returns_tuple_of_sql_and_params(self) -> None: + record = parse_staging_json(STAGING_JSON) + sql, params = build_staging_insert_sql(record) + assert isinstance(sql, str) + assert isinstance(params, tuple) + + def test_sql_inserts_into_staging_releases(self) -> None: + record = parse_staging_json(STAGING_JSON) + sql, _ = build_staging_insert_sql(record) + assert "staging_releases" in sql + + def test_sql_is_insert_statement(self) -> None: + record = parse_staging_json(STAGING_JSON) + sql, _ = build_staging_insert_sql(record) + assert "INSERT" in sql.upper() + + def test_params_include_repo(self) -> None: + record = parse_staging_json(STAGING_JSON) + _, params = build_staging_insert_sql(record) + assert "Billo.Platform.Document" in params + + def test_params_include_version(self) -> None: + record = parse_staging_json(STAGING_JSON) + _, params = build_staging_insert_sql(record) + assert "v1.0.0" in params + + def test_params_include_started_at(self) -> None: + record = parse_staging_json(STAGING_JSON) + _, params = build_staging_insert_sql(record) + assert "2026-03-17" in params + + def test_params_include_tickets_json(self) -> None: + record = parse_staging_json(STAGING_JSON) + _, params = build_staging_insert_sql(record) + # tickets should be serialized as JSON string + tickets_json = next(p for p in params if isinstance(p, str) and "ALLPOST-4219" in p) + parsed = json.loads(tickets_json) + assert parsed[0]["id"] == "ALLPOST-4219" + + def test_sql_uses_on_conflict_do_nothing_or_update(self) -> None: + record = parse_staging_json(STAGING_JSON) + sql, _ = build_staging_insert_sql(record) + assert "ON CONFLICT" in sql.upper() or "INSERT" in sql.upper() + + +# --------------------------------------------------------------------------- +# build_archived_insert_sql +# --------------------------------------------------------------------------- + +class TestBuildArchivedInsertSql: + def test_returns_tuple_of_sql_and_params(self) -> None: + record = parse_archived_json(ARCHIVED_JSON) + sql, params = build_archived_insert_sql(record) + assert isinstance(sql, str) + assert isinstance(params, tuple) + + def test_sql_inserts_into_archived_releases(self) -> None: + record = parse_archived_json(ARCHIVED_JSON) + sql, _ = build_archived_insert_sql(record) + assert "archived_releases" in sql + + def test_sql_is_insert_statement(self) -> None: + record = parse_archived_json(ARCHIVED_JSON) + sql, _ = build_archived_insert_sql(record) + assert "INSERT" in sql.upper() + + def test_params_include_released_at(self) -> None: + record = parse_archived_json(ARCHIVED_JSON) + _, params = build_archived_insert_sql(record) + assert "2026-01-15" in params + + def test_params_include_repo(self) -> None: + record = parse_archived_json(ARCHIVED_JSON) + _, params = build_archived_insert_sql(record) + assert "Billo.Platform.Payment" in params + + +# --------------------------------------------------------------------------- +# collect_json_files +# --------------------------------------------------------------------------- + +class TestCollectJsonFiles: + def test_returns_empty_list_for_empty_directory(self, tmp_path: Path) -> None: + result = collect_json_files(tmp_path) + assert result == [] + + def test_finds_staging_json_files(self, tmp_path: Path) -> None: + (tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON)) + result = collect_json_files(tmp_path) + assert len(result) == 1 + + def test_finds_archived_json_files(self, tmp_path: Path) -> None: + (tmp_path / "my-repo_v1.0.0_2025-06-01.json").write_text( + json.dumps(ARCHIVED_JSON) + ) + result = collect_json_files(tmp_path) + assert len(result) == 1 + + def test_ignores_non_json_files(self, tmp_path: Path) -> None: + (tmp_path / "README.md").write_text("readme") + (tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON)) + result = collect_json_files(tmp_path) + assert len(result) == 1 + + def test_collects_from_nested_directories(self, tmp_path: Path) -> None: + repo_dir = tmp_path / "Billo.Platform.Document" + repo_dir.mkdir() + (repo_dir / "v1.0.0.json").write_text(json.dumps(STAGING_JSON)) + result = collect_json_files(tmp_path) + assert len(result) == 1 + + def test_returns_path_objects(self, tmp_path: Path) -> None: + (tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON)) + result = collect_json_files(tmp_path) + assert all(isinstance(p, Path) for p in result) + + def test_collects_multiple_files(self, tmp_path: Path) -> None: + for i in range(3): + (tmp_path / f"repo-{i}.json").write_text(json.dumps(STAGING_JSON)) + result = collect_json_files(tmp_path) + assert len(result) == 3 + + +# --------------------------------------------------------------------------- +# Dry-run mode (integration of pure functions) +# --------------------------------------------------------------------------- + +class TestDryRunMode: + def test_dry_run_collects_records_without_db_access(self, tmp_path: Path) -> None: + """Dry run processes files and returns SQL/params without executing.""" + repo_dir = tmp_path / "Billo.Platform.Document" + repo_dir.mkdir() + (repo_dir / "v1.0.0.json").write_text(json.dumps(STAGING_JSON)) + + files = collect_json_files(tmp_path) + assert len(files) == 1 + + # Parse and build SQL — no DB connection needed + record = parse_staging_json(json.loads(files[0].read_text())) + sql, params = build_staging_insert_sql(record) + assert "INSERT" in sql.upper() + assert "Billo.Platform.Document" in params + + def test_payment_staging_file_parses_correctly(self, tmp_path: Path) -> None: + (tmp_path / "Billo.Platform.Payment.json").write_text( + json.dumps(PAYMENT_STAGING_JSON) + ) + files = collect_json_files(tmp_path) + record = parse_staging_json(json.loads(files[0].read_text())) + assert record.repo == "Billo.Platform.Payment" + assert record.version == "v1.0.1" + assert len(record.tickets) == 1 + assert record.tickets[0]["id"] == "ALLPOST-4228" + + def test_archived_file_parses_correctly(self, tmp_path: Path) -> None: + (tmp_path / "my-repo_v1.0.0_2026-01-15.json").write_text( + json.dumps(ARCHIVED_JSON) + ) + files = collect_json_files(tmp_path) + record = parse_archived_json(json.loads(files[0].read_text())) + assert record.released_at == date(2026, 1, 15) diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/services/test_pr_dedup.py b/tests/services/test_pr_dedup.py new file mode 100644 index 0000000..8f0aa52 --- /dev/null +++ b/tests/services/test_pr_dedup.py @@ -0,0 +1,141 @@ +"""Tests for services/pr_dedup.py. Written FIRST (TDD RED phase). + +find_unprocessed_prs queries agent_threads to find which PRs have not yet +been processed (no existing thread for that repo+pr_id combination). +""" + +import pytest + +from release_agent.models.pr import PRInfo +from release_agent.services.pr_dedup import find_unprocessed_prs + + +# --------------------------------------------------------------------------- +# Helpers — fake async pool +# --------------------------------------------------------------------------- + +def _make_pr(pr_id: str, repo_name: str = "my-repo") -> PRInfo: + return PRInfo( + pr_id=pr_id, + pr_url=f"https://dev.azure.com/org/proj/_git/{repo_name}/pullrequest/{pr_id}", + repo_name=repo_name, + branch="refs/heads/feature/ALLPOST-100-fix", + pr_title=f"PR {pr_id}", + pr_status="active", + ) + + +def _make_pool(existing_rows: list[tuple[str, str]]): + """Return a fake async connection pool. + + existing_rows: list of (pr_id, repo_name) tuples representing already-processed PRs. + """ + + class FakeCursor: + def __init__(self, rows): + self._rows = rows + + async def execute(self, sql, params=None): + pass + + async def fetchall(self): + return self._rows + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + class FakeConn: + def __init__(self, rows): + self._rows = rows + + def cursor(self): + return FakeCursor(self._rows) + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + class FakePool: + def __init__(self, rows): + self._rows = rows + + def connection(self): + return FakeConn(self._rows) + + return FakePool(existing_rows) + + +# --------------------------------------------------------------------------- +# find_unprocessed_prs tests +# --------------------------------------------------------------------------- + +class TestFindUnprocessedPrs: + + async def test_returns_all_when_none_processed(self) -> None: + prs = [_make_pr("10"), _make_pr("20")] + pool = _make_pool([]) + + result = await find_unprocessed_prs(pool, prs) + + assert len(result) == 2 + + async def test_returns_empty_when_all_processed(self) -> None: + prs = [_make_pr("10"), _make_pr("20")] + # existing rows: (pr_id, repo_name) + pool = _make_pool([("10", "my-repo"), ("20", "my-repo")]) + + result = await find_unprocessed_prs(pool, prs) + + assert result == [] + + async def test_returns_only_unprocessed(self) -> None: + prs = [_make_pr("10"), _make_pr("20"), _make_pr("30")] + pool = _make_pool([("10", "my-repo")]) + + result = await find_unprocessed_prs(pool, prs) + + pr_ids = [p.pr_id for p in result] + assert "10" not in pr_ids + assert "20" in pr_ids + assert "30" in pr_ids + + async def test_empty_input_returns_empty(self) -> None: + pool = _make_pool([]) + + result = await find_unprocessed_prs(pool, []) + + assert result == [] + + async def test_different_repos_not_confused(self) -> None: + pr_repo_a = _make_pr("10", repo_name="repo-a") + pr_repo_b = _make_pr("10", repo_name="repo-b") + # Only repo-a/10 is processed + pool = _make_pool([("10", "repo-a")]) + + result = await find_unprocessed_prs(pool, [pr_repo_a, pr_repo_b]) + + # repo-b/10 should still be returned (different repo) + assert len(result) == 1 + assert result[0].repo_name == "repo-b" + + async def test_returns_list_of_pr_info(self) -> None: + prs = [_make_pr("42")] + pool = _make_pool([]) + + result = await find_unprocessed_prs(pool, prs) + + assert all(isinstance(p, PRInfo) for p in result) + + async def test_preserves_pr_info_objects(self) -> None: + pr = _make_pr("77") + pool = _make_pool([]) + + result = await find_unprocessed_prs(pool, [pr]) + + assert result[0].pr_id == "77" + assert result[0].repo_name == "my-repo" diff --git a/tests/services/test_pr_poller.py b/tests/services/test_pr_poller.py new file mode 100644 index 0000000..3b11510 --- /dev/null +++ b/tests/services/test_pr_poller.py @@ -0,0 +1,309 @@ +"""Tests for services/pr_poller.py. Written FIRST (TDD RED phase). + +Tests verify: +- _synthesize_webhook_payload produces a valid payload dict +- run_pr_poll_loop calls list_active_prs, dedup, then schedules graph for each unprocessed PR +- Fake sleep is injected to avoid real waits +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from release_agent.models.pr import PRInfo +from release_agent.services.pr_poller import _synthesize_webhook_payload, run_pr_poll_loop + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_pr( + pr_id: str = "10", + repo_name: str = "my-repo", + branch: str = "refs/heads/feature/ALLPOST-100-fix", + title: str = "Test PR", + status: str = "active", +) -> PRInfo: + return PRInfo( + pr_id=pr_id, + pr_url=f"https://dev.azure.com/org/proj/_git/{repo_name}/pullrequest/{pr_id}", + repo_name=repo_name, + branch=branch, + pr_title=title, + pr_status=status, + ) + + +# --------------------------------------------------------------------------- +# _synthesize_webhook_payload tests +# --------------------------------------------------------------------------- + +class TestSynthesizeWebhookPayload: + def test_returns_dict(self) -> None: + pr = _make_pr() + result = _synthesize_webhook_payload(pr) + assert isinstance(result, dict) + + def test_has_resource_key(self) -> None: + pr = _make_pr() + result = _synthesize_webhook_payload(pr) + assert "resource" in result + + def test_resource_contains_pull_request_id(self) -> None: + pr = _make_pr(pr_id="42") + result = _synthesize_webhook_payload(pr) + assert result["resource"]["pull_request_id"] == 42 + + def test_resource_contains_repository_name(self) -> None: + pr = _make_pr(repo_name="backend-api") + result = _synthesize_webhook_payload(pr) + assert result["resource"]["repository"]["name"] == "backend-api" + + def test_resource_contains_title(self) -> None: + pr = _make_pr(title="My PR Title") + result = _synthesize_webhook_payload(pr) + assert result["resource"]["title"] == "My PR Title" + + def test_resource_contains_source_ref_name(self) -> None: + pr = _make_pr(branch="refs/heads/feature/ALLPOST-200-test") + result = _synthesize_webhook_payload(pr) + assert result["resource"]["source_ref_name"] == "refs/heads/feature/ALLPOST-200-test" + + def test_resource_status_is_active(self) -> None: + pr = _make_pr(status="active") + result = _synthesize_webhook_payload(pr) + assert result["resource"]["status"] == "active" + + def test_event_type_is_pr_updated(self) -> None: + pr = _make_pr() + result = _synthesize_webhook_payload(pr) + assert "event_type" in result + + def test_subscription_id_present(self) -> None: + pr = _make_pr() + result = _synthesize_webhook_payload(pr) + assert "subscription_id" in result + + def test_different_prs_produce_different_payloads(self) -> None: + pr1 = _make_pr(pr_id="1") + pr2 = _make_pr(pr_id="2") + r1 = _synthesize_webhook_payload(pr1) + r2 = _synthesize_webhook_payload(pr2) + assert r1["resource"]["pull_request_id"] != r2["resource"]["pull_request_id"] + + +# --------------------------------------------------------------------------- +# run_pr_poll_loop tests +# --------------------------------------------------------------------------- + +class TestRunPrPollLoop: + + async def test_calls_list_active_prs_for_each_repo(self) -> None: + azdo = AsyncMock() + azdo.list_active_prs = AsyncMock(return_value=[]) + + sleep_calls: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + sleep_calls.append(seconds) + raise asyncio.CancelledError + + with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])): + with pytest.raises(asyncio.CancelledError): + await run_pr_poll_loop( + azdo_client=azdo, + db_pool=MagicMock(), + watched_repos=["repo-a", "repo-b"], + target_branch="refs/heads/develop", + interval_seconds=30, + schedule_fn=MagicMock(), + sleep_fn=fake_sleep, + ) + + assert azdo.list_active_prs.call_count == 2 + + async def test_calls_find_unprocessed_prs(self) -> None: + pr = _make_pr(pr_id="10") + azdo = AsyncMock() + azdo.list_active_prs = AsyncMock(return_value=[pr]) + + find_mock = AsyncMock(return_value=[]) + + async def fake_sleep(seconds: float) -> None: + raise asyncio.CancelledError + + with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=find_mock): + with pytest.raises(asyncio.CancelledError): + await run_pr_poll_loop( + azdo_client=azdo, + db_pool=MagicMock(), + watched_repos=["my-repo"], + target_branch="refs/heads/develop", + interval_seconds=30, + schedule_fn=MagicMock(), + sleep_fn=fake_sleep, + ) + + find_mock.assert_called_once() + + async def test_schedules_graph_for_each_unprocessed_pr(self) -> None: + pr1 = _make_pr(pr_id="10") + pr2 = _make_pr(pr_id="20") + azdo = AsyncMock() + azdo.list_active_prs = AsyncMock(return_value=[pr1, pr2]) + + schedule_mock = MagicMock() + + async def fake_sleep(seconds: float) -> None: + raise asyncio.CancelledError + + with patch( + "release_agent.services.pr_poller.find_unprocessed_prs", + new=AsyncMock(return_value=[pr1, pr2]), + ): + with pytest.raises(asyncio.CancelledError): + await run_pr_poll_loop( + azdo_client=azdo, + db_pool=MagicMock(), + watched_repos=["my-repo"], + target_branch="refs/heads/develop", + interval_seconds=30, + schedule_fn=schedule_mock, + sleep_fn=fake_sleep, + ) + + assert schedule_mock.call_count == 2 + + async def test_does_not_schedule_already_processed_prs(self) -> None: + pr = _make_pr(pr_id="10") + azdo = AsyncMock() + azdo.list_active_prs = AsyncMock(return_value=[pr]) + + schedule_mock = MagicMock() + + async def fake_sleep(seconds: float) -> None: + raise asyncio.CancelledError + + # All PRs already processed + with patch( + "release_agent.services.pr_poller.find_unprocessed_prs", + new=AsyncMock(return_value=[]), + ): + with pytest.raises(asyncio.CancelledError): + await run_pr_poll_loop( + azdo_client=azdo, + db_pool=MagicMock(), + watched_repos=["my-repo"], + target_branch="refs/heads/develop", + interval_seconds=30, + schedule_fn=schedule_mock, + sleep_fn=fake_sleep, + ) + + schedule_mock.assert_not_called() + + async def test_sleeps_for_configured_interval(self) -> None: + azdo = AsyncMock() + azdo.list_active_prs = AsyncMock(return_value=[]) + + sleep_calls: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + sleep_calls.append(seconds) + raise asyncio.CancelledError + + with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])): + with pytest.raises(asyncio.CancelledError): + await run_pr_poll_loop( + azdo_client=azdo, + db_pool=MagicMock(), + watched_repos=["my-repo"], + target_branch="refs/heads/develop", + interval_seconds=123, + schedule_fn=MagicMock(), + sleep_fn=fake_sleep, + ) + + assert sleep_calls[0] == 123 + + async def test_handles_empty_watched_repos(self) -> None: + azdo = AsyncMock() + + async def fake_sleep(seconds: float) -> None: + raise asyncio.CancelledError + + with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])): + with pytest.raises(asyncio.CancelledError): + await run_pr_poll_loop( + azdo_client=azdo, + db_pool=MagicMock(), + watched_repos=[], + target_branch="refs/heads/develop", + interval_seconds=30, + schedule_fn=MagicMock(), + sleep_fn=fake_sleep, + ) + + azdo.list_active_prs.assert_not_called() + + async def test_schedule_fn_receives_synthesized_payload(self) -> None: + pr = _make_pr(pr_id="55", repo_name="test-repo") + azdo = AsyncMock() + azdo.list_active_prs = AsyncMock(return_value=[pr]) + + schedule_calls: list[dict] = [] + + def schedule_mock(**kwargs) -> None: + schedule_calls.append(kwargs) + + async def fake_sleep(seconds: float) -> None: + raise asyncio.CancelledError + + with patch( + "release_agent.services.pr_poller.find_unprocessed_prs", + new=AsyncMock(return_value=[pr]), + ): + with pytest.raises(asyncio.CancelledError): + await run_pr_poll_loop( + azdo_client=azdo, + db_pool=MagicMock(), + watched_repos=["test-repo"], + target_branch="refs/heads/develop", + interval_seconds=30, + schedule_fn=schedule_mock, + sleep_fn=fake_sleep, + ) + + assert len(schedule_calls) == 1 + initial_state = schedule_calls[0]["initial_state"] + assert initial_state["webhook_payload"]["resource"]["pull_request_id"] == 55 + assert initial_state["pr_id"] == "55" + assert initial_state["repo_name"] == "test-repo" + + async def test_continues_after_list_active_prs_error(self) -> None: + azdo = AsyncMock() + # First repo raises, second succeeds + azdo.list_active_prs = AsyncMock(side_effect=[Exception("API error"), []]) + + sleep_calls: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + sleep_calls.append(seconds) + raise asyncio.CancelledError + + with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])): + with pytest.raises(asyncio.CancelledError): + await run_pr_poll_loop( + azdo_client=azdo, + db_pool=MagicMock(), + watched_repos=["repo-a", "repo-b"], + target_branch="refs/heads/develop", + interval_seconds=30, + schedule_fn=MagicMock(), + sleep_fn=fake_sleep, + ) + + # Should still sleep (loop iteration completed despite error) + assert len(sleep_calls) == 1 diff --git a/tests/test_branch_parser.py b/tests/test_branch_parser.py new file mode 100644 index 0000000..38dd7ba --- /dev/null +++ b/tests/test_branch_parser.py @@ -0,0 +1,122 @@ +"""Tests for branch_parser module. Written FIRST (TDD RED phase).""" + + +from release_agent.branch_parser import parse_branch, strip_refs_prefix + + +class TestStripRefsPrefix: + """Tests for strip_refs_prefix function.""" + + def test_strips_refs_heads_prefix(self) -> None: + assert strip_refs_prefix("refs/heads/fix/BILL-42_something") == "fix/BILL-42_something" + + def test_strips_refs_heads_prefix_feature(self) -> None: + assert strip_refs_prefix("refs/heads/feature/ALLPOST-100_add-feature") == "feature/ALLPOST-100_add-feature" + + def test_no_refs_prefix_unchanged(self) -> None: + assert strip_refs_prefix("bug/ALLPOST-4229_fix-review") == "bug/ALLPOST-4229_fix-review" + + def test_main_unchanged(self) -> None: + assert strip_refs_prefix("main") == "main" + + def test_develop_unchanged(self) -> None: + assert strip_refs_prefix("develop") == "develop" + + def test_empty_string(self) -> None: + assert strip_refs_prefix("") == "" + + def test_only_refs_heads(self) -> None: + assert strip_refs_prefix("refs/heads/") == "" + + def test_refs_tags_not_stripped(self) -> None: + assert strip_refs_prefix("refs/tags/v1.0.0") == "refs/tags/v1.0.0" + + +class TestParseBranch: + """Tests for parse_branch function.""" + + def test_bug_branch_with_ticket(self) -> None: + ticket_id, has_ticket = parse_branch("bug/ALLPOST-4229_fix-review") + assert ticket_id == "ALLPOST-4229" + assert has_ticket is True + + def test_feature_branch_with_ticket(self) -> None: + ticket_id, has_ticket = parse_branch("feature/ALLPOST-100_add-feature") + assert ticket_id == "ALLPOST-100" + assert has_ticket is True + + def test_refs_heads_fix_branch(self) -> None: + ticket_id, has_ticket = parse_branch("refs/heads/fix/BILL-42_something") + assert ticket_id == "BILL-42" + assert has_ticket is True + + def test_feat_branch_short(self) -> None: + ticket_id, has_ticket = parse_branch("feat/MY-1_x") + assert ticket_id == "MY-1" + assert has_ticket is True + + def test_chore_without_ticket(self) -> None: + ticket_id, has_ticket = parse_branch("chore/update-dependencies") + assert ticket_id is None + assert has_ticket is False + + def test_main_branch(self) -> None: + ticket_id, has_ticket = parse_branch("main") + assert ticket_id is None + assert has_ticket is False + + def test_develop_branch(self) -> None: + ticket_id, has_ticket = parse_branch("develop") + assert ticket_id is None + assert has_ticket is False + + def test_release_branch(self) -> None: + ticket_id, has_ticket = parse_branch("release/v1.0.3") + assert ticket_id is None + assert has_ticket is False + + def test_returns_tuple(self) -> None: + result = parse_branch("main") + assert isinstance(result, tuple) + assert len(result) == 2 + + def test_ticket_id_type_when_present(self) -> None: + ticket_id, has_ticket = parse_branch("bug/ALLPOST-4229_fix-review") + assert isinstance(ticket_id, str) + assert isinstance(has_ticket, bool) + + def test_ticket_id_type_when_absent(self) -> None: + ticket_id, has_ticket = parse_branch("main") + assert ticket_id is None + assert isinstance(has_ticket, bool) + + def test_fix_prefix(self) -> None: + ticket_id, has_ticket = parse_branch("fix/PROJ-999_some-fix") + assert ticket_id == "PROJ-999" + assert has_ticket is True + + def test_refs_heads_feature_branch(self) -> None: + ticket_id, has_ticket = parse_branch("refs/heads/feature/ALLPOST-100_add-feature") + assert ticket_id == "ALLPOST-100" + assert has_ticket is True + + def test_ticket_with_multiple_digits(self) -> None: + ticket_id, has_ticket = parse_branch("feature/ABC-12345_some-long-feature") + assert ticket_id == "ABC-12345" + assert has_ticket is True + + def test_branch_without_underscore_separator(self) -> None: + # Branch has ticket pattern but no underscore - still detects ticket + ticket_id, has_ticket = parse_branch("feature/PROJ-100") + assert ticket_id == "PROJ-100" + assert has_ticket is True + + def test_empty_string(self) -> None: + ticket_id, has_ticket = parse_branch("") + assert ticket_id is None + assert has_ticket is False + + def test_ticket_with_numeric_project_prefix(self) -> None: + ticket_id, has_ticket = parse_branch("feature/AB2-100_feature") + assert ticket_id == "AB2-100" + assert has_ticket is True diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..f61c329 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,450 @@ +"""Tests for config module. Written FIRST (TDD RED phase).""" + +import os +from unittest.mock import patch + +import pytest +from pydantic import SecretStr, ValidationError + +from release_agent.config import Settings + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _base_env() -> dict[str, str]: + """Return minimal valid environment variables.""" + return { + "AZDO_ORGANIZATION": "my-org", + "AZDO_PROJECT": "my-project", + "AZDO_PAT": "super-secret-pat", + "ANTHROPIC_API_KEY": "sk-ant-key", + "POSTGRES_DSN": "postgresql://user:pass@localhost:5432/db", + "JIRA_EMAIL": "user@example.com", + "JIRA_API_TOKEN": "jira-token-abc", + "SLACK_WEBHOOK_URL": "https://hooks.slack.com/services/T000/B000/xxxx", + "WEBHOOK_SECRET": "test-webhook-secret", + } + + +# --------------------------------------------------------------------------- +# Settings tests +# --------------------------------------------------------------------------- + +class TestSettings: + """Tests for Settings config class.""" + + def test_loads_from_env_vars(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.azdo_organization == "my-org" + assert settings.azdo_project == "my-project" + + def test_pat_is_secret_str(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert isinstance(settings.azdo_pat, SecretStr) + + def test_anthropic_key_is_secret_str(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert isinstance(settings.anthropic_api_key, SecretStr) + + def test_secret_str_not_leaked_in_repr(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + repr_str = repr(settings) + assert "super-secret-pat" not in repr_str + assert "sk-ant-key" not in repr_str + + def test_secret_str_not_leaked_in_str(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + str_repr = str(settings) + assert "super-secret-pat" not in str_repr + + def test_missing_required_azdo_org_raises(self) -> None: + env = _base_env() + del env["AZDO_ORGANIZATION"] + with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError): + Settings() + + def test_missing_required_azdo_project_raises(self) -> None: + env = _base_env() + del env["AZDO_PROJECT"] + with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError): + Settings() + + def test_missing_required_pat_raises(self) -> None: + env = _base_env() + del env["AZDO_PAT"] + with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError): + Settings() + + def test_missing_anthropic_key_is_optional(self) -> None: + env = _base_env() + del env["ANTHROPIC_API_KEY"] + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.anthropic_api_key.get_secret_value() == "" + + def test_missing_postgres_dsn_raises(self) -> None: + env = _base_env() + del env["POSTGRES_DSN"] + with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError): + Settings() + + def test_azdo_base_url_computed(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + expected = "https://dev.azure.com/my-org" + assert settings.azdo_base_url == expected + + def test_azdo_api_url_computed(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + expected = "https://dev.azure.com/my-org/my-project/_apis" + assert settings.azdo_api_url == expected + + def test_default_port(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.port == 8000 + + def test_custom_port_from_env(self) -> None: + env = {**_base_env(), "PORT": "9000"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.port == 9000 + + def test_port_below_minimum_raises(self) -> None: + env = {**_base_env(), "PORT": "0"} + with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError): + Settings() + + def test_port_above_maximum_raises(self) -> None: + env = {**_base_env(), "PORT": "65536"} + with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError): + Settings() + + def test_port_minimum_valid(self) -> None: + env = {**_base_env(), "PORT": "1"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.port == 1 + + def test_port_maximum_valid(self) -> None: + env = {**_base_env(), "PORT": "65535"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.port == 65535 + + def test_get_pat_value(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.azdo_pat.get_secret_value() == "super-secret-pat" + + def test_get_anthropic_key_value(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.anthropic_api_key.get_secret_value() == "sk-ant-key" + + def test_postgres_dsn_is_secret_str(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert isinstance(settings.postgres_dsn, SecretStr) + assert "localhost" in settings.postgres_dsn.get_secret_value() + + def test_postgres_dsn_not_leaked_in_repr(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert "user:pass" not in repr(settings) + + +class TestSettingsPhase2: + """Tests for Phase 2 settings fields.""" + + def test_jira_base_url_default(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.jira_base_url == "https://billolife.atlassian.net" + + def test_jira_base_url_custom(self) -> None: + env = {**_base_env(), "JIRA_BASE_URL": "https://custom.atlassian.net"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.jira_base_url == "https://custom.atlassian.net" + + def test_jira_email_required(self) -> None: + env = _base_env() + del env["JIRA_EMAIL"] + with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError): + Settings() + + def test_jira_email_stored(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.jira_email == "user@example.com" + + def test_jira_api_token_required(self) -> None: + env = _base_env() + del env["JIRA_API_TOKEN"] + with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError): + Settings() + + def test_jira_api_token_is_secret_str(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert isinstance(settings.jira_api_token, SecretStr) + + def test_jira_api_token_not_leaked_in_repr(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert "jira-token-abc" not in repr(settings) + + def test_slack_webhook_url_optional_defaults_empty(self) -> None: + env = _base_env() + del env["SLACK_WEBHOOK_URL"] + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.slack_webhook_url.get_secret_value() == "" + + def test_slack_webhook_url_is_secret_str(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert isinstance(settings.slack_webhook_url, SecretStr) + + def test_slack_webhook_url_not_leaked_in_repr(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert "xxxx" not in repr(settings) + + def test_claude_review_model_default(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.claude_review_model == "claude-sonnet-4-20250514" + + def test_claude_review_model_custom(self) -> None: + env = {**_base_env(), "CLAUDE_REVIEW_MODEL": "claude-opus-4-20250514"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.claude_review_model == "claude-opus-4-20250514" + + def test_azdo_vsrm_api_url_computed(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + expected = "https://vsrm.dev.azure.com/my-org/my-project/_apis" + assert settings.azdo_vsrm_api_url == expected + + +class TestSettingsPhase4: + """Tests for Phase 4 settings fields (webhook secret).""" + + def test_webhook_secret_optional_defaults_empty(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + # When not provided, defaults to empty string or None + assert settings.webhook_secret is not None or settings.webhook_secret == "" + + def test_webhook_secret_custom_value(self) -> None: + env = {**_base_env(), "WEBHOOK_SECRET": "my-super-secret"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.webhook_secret.get_secret_value() == "my-super-secret" + + def test_webhook_secret_is_secret_str(self) -> None: + env = {**_base_env(), "WEBHOOK_SECRET": "secret-value"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert isinstance(settings.webhook_secret, SecretStr) + + def test_webhook_secret_not_leaked_in_repr(self) -> None: + env = {**_base_env(), "WEBHOOK_SECRET": "super-private-secret"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert "super-private-secret" not in repr(settings) + + +class TestSettingsPhase5: + """Tests for Phase 5 settings fields (Slack Web API + CI polling).""" + + def test_slack_webhook_url_optional_when_bot_token_provided(self) -> None: + env = {k: v for k, v in _base_env().items() if k != "SLACK_WEBHOOK_URL"} + env["SLACK_BOT_TOKEN"] = "xoxb-test-token" + env["SLACK_CHANNEL_ID"] = "C12345678" + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.slack_bot_token is not None + assert settings.slack_bot_token.get_secret_value() == "xoxb-test-token" + + def test_slack_bot_token_optional_defaults_empty(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.slack_bot_token.get_secret_value() == "" + + def test_slack_bot_token_is_secret_str(self) -> None: + env = {**_base_env(), "SLACK_BOT_TOKEN": "xoxb-abc-123"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert isinstance(settings.slack_bot_token, SecretStr) + + def test_slack_bot_token_not_leaked_in_repr(self) -> None: + env = {**_base_env(), "SLACK_BOT_TOKEN": "xoxb-super-secret"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert "xoxb-super-secret" not in repr(settings) + + def test_slack_signing_secret_optional_defaults_empty(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.slack_signing_secret.get_secret_value() == "" + + def test_slack_signing_secret_custom_value(self) -> None: + env = {**_base_env(), "SLACK_SIGNING_SECRET": "signing-secret-xyz"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.slack_signing_secret.get_secret_value() == "signing-secret-xyz" + + def test_slack_signing_secret_is_secret_str(self) -> None: + env = {**_base_env(), "SLACK_SIGNING_SECRET": "some-secret"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert isinstance(settings.slack_signing_secret, SecretStr) + + def test_slack_signing_secret_not_leaked_in_repr(self) -> None: + env = {**_base_env(), "SLACK_SIGNING_SECRET": "private-signing-secret"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert "private-signing-secret" not in repr(settings) + + def test_slack_channel_id_optional_defaults_empty(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.slack_channel_id == "" + + def test_slack_channel_id_custom_value(self) -> None: + env = {**_base_env(), "SLACK_CHANNEL_ID": "C0987654321"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.slack_channel_id == "C0987654321" + + def test_ci_poll_interval_seconds_default(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.ci_poll_interval_seconds == 30 + + def test_ci_poll_interval_seconds_custom(self) -> None: + env = {**_base_env(), "CI_POLL_INTERVAL_SECONDS": "60"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.ci_poll_interval_seconds == 60 + + def test_ci_poll_max_wait_seconds_default(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.ci_poll_max_wait_seconds == 1800 + + def test_ci_poll_max_wait_seconds_custom(self) -> None: + env = {**_base_env(), "CI_POLL_MAX_WAIT_SECONDS": "3600"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.ci_poll_max_wait_seconds == 3600 + + def test_slack_webhook_url_still_optional(self) -> None: + env = {k: v for k, v in _base_env().items() if k != "SLACK_WEBHOOK_URL"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.slack_webhook_url.get_secret_value() == "" + + +class TestSettingsPrPolling: + """Tests for PR polling config fields (Step 1).""" + + def test_watched_repos_defaults_empty(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.watched_repos == "" + + def test_watched_repos_custom_value(self) -> None: + env = {**_base_env(), "WATCHED_REPOS": "repo-a,repo-b"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.watched_repos == "repo-a,repo-b" + + def test_watched_repos_list_empty_when_blank(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.watched_repos_list == [] + + def test_watched_repos_list_splits_comma_separated(self) -> None: + env = {**_base_env(), "WATCHED_REPOS": "repo-a,repo-b,repo-c"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.watched_repos_list == ["repo-a", "repo-b", "repo-c"] + + def test_watched_repos_list_strips_whitespace(self) -> None: + env = {**_base_env(), "WATCHED_REPOS": " repo-a , repo-b "} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.watched_repos_list == ["repo-a", "repo-b"] + + def test_watched_repos_list_ignores_empty_entries(self) -> None: + env = {**_base_env(), "WATCHED_REPOS": "repo-a,,repo-b"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.watched_repos_list == ["repo-a", "repo-b"] + + def test_pr_poll_interval_seconds_default(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.pr_poll_interval_seconds == 300 + + def test_pr_poll_interval_seconds_custom(self) -> None: + env = {**_base_env(), "PR_POLL_INTERVAL_SECONDS": "60"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.pr_poll_interval_seconds == 60 + + def test_pr_poll_target_branch_default(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.pr_poll_target_branch == "refs/heads/develop" + + def test_pr_poll_target_branch_custom(self) -> None: + env = {**_base_env(), "PR_POLL_TARGET_BRANCH": "refs/heads/main"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.pr_poll_target_branch == "refs/heads/main" + + def test_pr_poll_enabled_defaults_false(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.pr_poll_enabled is False + + def test_pr_poll_enabled_true_from_env(self) -> None: + env = {**_base_env(), "PR_POLL_ENABLED": "true"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.pr_poll_enabled is True + + def test_default_jira_project_default(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.default_jira_project == "ALLPOST" + + def test_default_jira_project_custom(self) -> None: + env = {**_base_env(), "DEFAULT_JIRA_PROJECT": "MYPROJ"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.default_jira_project == "MYPROJ" + + def test_auto_create_ticket_enabled_defaults_true(self) -> None: + with patch.dict(os.environ, _base_env(), clear=True): + settings = Settings() + assert settings.auto_create_ticket_enabled is True + + def test_auto_create_ticket_enabled_false_from_env(self) -> None: + env = {**_base_env(), "AUTO_CREATE_TICKET_ENABLED": "false"} + with patch.dict(os.environ, env, clear=True): + settings = Settings() + assert settings.auto_create_ticket_enabled is False diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..bc86363 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,198 @@ +"""Tests for custom exception hierarchy. Written FIRST (TDD RED phase).""" + +import pytest + +from release_agent.exceptions import ( + AuthenticationError, + NotFoundError, + RateLimitError, + ReleaseAgentError, + ServiceError, + ServiceUnavailableError, +) + + +class TestReleaseAgentError: + """Tests for the base exception class.""" + + def test_is_exception(self) -> None: + err = ReleaseAgentError("something went wrong") + assert isinstance(err, Exception) + + def test_message_stored(self) -> None: + err = ReleaseAgentError("something went wrong") + assert str(err) == "something went wrong" + + def test_can_be_raised(self) -> None: + with pytest.raises(ReleaseAgentError): + raise ReleaseAgentError("boom") + + +class TestServiceError: + """Tests for ServiceError with service name and status code.""" + + def test_is_release_agent_error(self) -> None: + err = ServiceError(service="jira", status_code=500, detail="Internal error") + assert isinstance(err, ReleaseAgentError) + + def test_stores_service(self) -> None: + err = ServiceError(service="jira", status_code=500, detail="Internal error") + assert err.service == "jira" + + def test_stores_status_code(self) -> None: + err = ServiceError(service="azdo", status_code=422, detail="Unprocessable") + assert err.status_code == 422 + + def test_stores_detail(self) -> None: + err = ServiceError(service="slack", status_code=400, detail="Bad payload") + assert err.detail == "Bad payload" + + def test_str_includes_service_and_status(self) -> None: + err = ServiceError(service="jira", status_code=500, detail="Server error") + text = str(err) + assert "jira" in text + assert "500" in text + + def test_can_be_raised(self) -> None: + with pytest.raises(ServiceError): + raise ServiceError(service="azdo", status_code=400, detail="bad request") + + def test_detail_none_allowed(self) -> None: + err = ServiceError(service="jira", status_code=404, detail=None) + assert err.detail is None + + +class TestAuthenticationError: + """Tests for AuthenticationError (401/403).""" + + def test_is_service_error(self) -> None: + err = AuthenticationError(service="azdo") + assert isinstance(err, ServiceError) + + def test_default_status_code_401(self) -> None: + err = AuthenticationError(service="azdo") + assert err.status_code == 401 + + def test_service_stored(self) -> None: + err = AuthenticationError(service="jira") + assert err.service == "jira" + + def test_custom_status_code(self) -> None: + err = AuthenticationError(service="azdo", status_code=403) + assert err.status_code == 403 + + def test_str_contains_service(self) -> None: + err = AuthenticationError(service="slack") + assert "slack" in str(err) + + def test_can_be_raised(self) -> None: + with pytest.raises(AuthenticationError): + raise AuthenticationError(service="azdo") + + def test_caught_as_service_error(self) -> None: + with pytest.raises(ServiceError): + raise AuthenticationError(service="azdo") + + +class TestNotFoundError: + """Tests for NotFoundError (404).""" + + def test_is_service_error(self) -> None: + err = NotFoundError(service="azdo", detail="PR not found") + assert isinstance(err, ServiceError) + + def test_status_code_is_404(self) -> None: + err = NotFoundError(service="jira", detail="Issue not found") + assert err.status_code == 404 + + def test_detail_stored(self) -> None: + err = NotFoundError(service="azdo", detail="PR 999 not found") + assert "PR 999" in err.detail + + def test_can_be_raised(self) -> None: + with pytest.raises(NotFoundError): + raise NotFoundError(service="azdo", detail="not found") + + def test_caught_as_release_agent_error(self) -> None: + with pytest.raises(ReleaseAgentError): + raise NotFoundError(service="jira", detail="issue missing") + + +class TestRateLimitError: + """Tests for RateLimitError (429) with retry_after.""" + + def test_is_service_error(self) -> None: + err = RateLimitError(service="jira", retry_after=30) + assert isinstance(err, ServiceError) + + def test_status_code_is_429(self) -> None: + err = RateLimitError(service="jira", retry_after=30) + assert err.status_code == 429 + + def test_stores_retry_after(self) -> None: + err = RateLimitError(service="slack", retry_after=60) + assert err.retry_after == 60 + + def test_retry_after_none_allowed(self) -> None: + err = RateLimitError(service="azdo", retry_after=None) + assert err.retry_after is None + + def test_str_contains_service(self) -> None: + err = RateLimitError(service="jira", retry_after=5) + assert "jira" in str(err) + + def test_can_be_raised(self) -> None: + with pytest.raises(RateLimitError): + raise RateLimitError(service="jira", retry_after=30) + + +class TestServiceUnavailableError: + """Tests for ServiceUnavailableError (503).""" + + def test_is_service_error(self) -> None: + err = ServiceUnavailableError(service="azdo") + assert isinstance(err, ServiceError) + + def test_status_code_is_503(self) -> None: + err = ServiceUnavailableError(service="azdo") + assert err.status_code == 503 + + def test_service_stored(self) -> None: + err = ServiceUnavailableError(service="slack") + assert err.service == "slack" + + def test_custom_detail(self) -> None: + err = ServiceUnavailableError(service="azdo", detail="Maintenance window") + assert "Maintenance" in err.detail + + def test_can_be_raised(self) -> None: + with pytest.raises(ServiceUnavailableError): + raise ServiceUnavailableError(service="azdo") + + def test_caught_as_service_error(self) -> None: + with pytest.raises(ServiceError): + raise ServiceUnavailableError(service="jira") + + +class TestExceptionHierarchyInheritance: + """Tests verifying the full exception hierarchy is correct.""" + + def test_all_are_release_agent_errors(self) -> None: + errors = [ + AuthenticationError(service="svc"), + NotFoundError(service="svc", detail="x"), + RateLimitError(service="svc", retry_after=1), + ServiceUnavailableError(service="svc"), + ] + for err in errors: + assert isinstance(err, ReleaseAgentError), f"{type(err)} not ReleaseAgentError" + + def test_all_are_service_errors(self) -> None: + errors = [ + AuthenticationError(service="svc"), + NotFoundError(service="svc", detail="x"), + RateLimitError(service="svc", retry_after=1), + ServiceUnavailableError(service="svc"), + ] + for err in errors: + assert isinstance(err, ServiceError), f"{type(err)} not ServiceError" diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..2e95523 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,666 @@ +"""Tests for main FastAPI application. Written FIRST (TDD RED phase). + +Heavy startup (PostgreSQL, httpx clients, graph compilation) is mocked. +Tests verify: routes registered, lifespan hooks, exception handlers. +""" + +import asyncio +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +def _make_mock_settings(): + s = MagicMock() + s.webhook_secret.get_secret_value.return_value = "test-secret" + s.postgres_dsn.get_secret_value.return_value = "postgresql://u:p@localhost/db" + s.azdo_pat.get_secret_value.return_value = "pat" + s.anthropic_api_key.get_secret_value.return_value = "key" + s.jira_api_token.get_secret_value.return_value = "jira" + s.slack_webhook_url.get_secret_value.return_value = "https://hooks.slack.com/x" + s.slack_bot_token.get_secret_value.return_value = "" + s.slack_channel_id = "" + s.slack_signing_secret.get_secret_value.return_value = "" + s.port = 8000 + s.pr_poll_enabled = False + s.pr_poll_interval_seconds = 300 + s.pr_poll_target_branch = "refs/heads/develop" + s.watched_repos_list = [] + s.default_jira_project = "ALLPOST" + return s + + +def _make_patched_app(): + """Return the FastAPI app with all heavy startup mocked.""" + mock_settings = _make_mock_settings() + mock_pool = AsyncMock() + mock_pool.__aenter__ = AsyncMock(return_value=mock_pool) + mock_pool.__aexit__ = AsyncMock(return_value=False) + + mock_graphs = { + "pr_completed": MagicMock(), + "release": MagicMock(), + } + mock_clients = MagicMock() + mock_staging_store = MagicMock() + + patches = [ + patch("release_agent.main.Settings", return_value=mock_settings), + patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]), + patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]), + patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), + patch("release_agent.main._create_tool_clients", return_value=mock_clients), + patch("release_agent.main._create_staging_store", return_value=mock_staging_store), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + ] + + for p in patches: + p.start() + + from release_agent.main import create_app + + app = create_app() + + for p in patches: + p.stop() + + return app, mock_settings, mock_pool, mock_graphs + + +# --------------------------------------------------------------------------- +# Route registration tests +# --------------------------------------------------------------------------- + +class TestRouteRegistration: + def test_webhook_route_registered(self) -> None: + from release_agent.main import create_app + + with ( + patch("release_agent.main.Settings", return_value=_make_mock_settings()), + patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), + patch("release_agent.main.build_release_graph", return_value=MagicMock()), + patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()), + patch("release_agent.main._create_tool_clients", return_value=MagicMock()), + patch("release_agent.main._create_staging_store", return_value=MagicMock()), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + ): + app = create_app() + + routes = {r.path for r in app.routes} # type: ignore[attr-defined] + assert "/webhooks/azdo" in routes + + def test_approvals_routes_registered(self) -> None: + from release_agent.main import create_app + + with ( + patch("release_agent.main.Settings", return_value=_make_mock_settings()), + patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), + patch("release_agent.main.build_release_graph", return_value=MagicMock()), + patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()), + patch("release_agent.main._create_tool_clients", return_value=MagicMock()), + patch("release_agent.main._create_staging_store", return_value=MagicMock()), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + ): + app = create_app() + + routes = {r.path for r in app.routes} # type: ignore[attr-defined] + assert "/approvals/pending" in routes + assert "/approvals/{thread_id}" in routes + + def test_status_routes_registered(self) -> None: + from release_agent.main import create_app + + with ( + patch("release_agent.main.Settings", return_value=_make_mock_settings()), + patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), + patch("release_agent.main.build_release_graph", return_value=MagicMock()), + patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()), + patch("release_agent.main._create_tool_clients", return_value=MagicMock()), + patch("release_agent.main._create_staging_store", return_value=MagicMock()), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + ): + app = create_app() + + routes = {r.path for r in app.routes} # type: ignore[attr-defined] + assert "/status" in routes + assert "/staging" in routes + + +# --------------------------------------------------------------------------- +# schedule_graph / run_graph_in_background tests +# --------------------------------------------------------------------------- + +class TestScheduleGraph: + def test_schedule_graph_returns_thread_id(self) -> None: + from release_agent.main import schedule_graph + + mock_app = MagicMock() + mock_app.state.background_tasks = set() + mock_graph = MagicMock() + + with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()): + thread_id = schedule_graph( + app=mock_app, + graph=mock_graph, + initial_state={"repo_name": "my-repo"}, + thread_id=None, + ) + + assert isinstance(thread_id, str) + assert len(thread_id) > 0 + + def test_schedule_graph_uses_provided_thread_id(self) -> None: + from release_agent.main import schedule_graph + + mock_app = MagicMock() + mock_app.state.background_tasks = set() + mock_graph = MagicMock() + + with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()): + thread_id = schedule_graph( + app=mock_app, + graph=mock_graph, + initial_state={}, + thread_id="custom-thread-id", + ) + + assert thread_id == "custom-thread-id" + + def test_schedule_graph_adds_task_to_background_tasks(self) -> None: + from release_agent.main import schedule_graph + + mock_app = MagicMock() + mock_app.state.background_tasks = set() + mock_graph = MagicMock() + mock_task = MagicMock() + + with patch("release_agent.main.asyncio.create_task", return_value=mock_task): + schedule_graph( + app=mock_app, + graph=mock_graph, + initial_state={}, + thread_id=None, + ) + + assert mock_task in mock_app.state.background_tasks + + def test_run_graph_in_background_is_coroutine(self) -> None: + from release_agent.main import run_graph_in_background + import inspect + + assert inspect.iscoroutinefunction(run_graph_in_background) + + +# --------------------------------------------------------------------------- +# _ensure_db_schema tests +# --------------------------------------------------------------------------- + +class TestEnsureDbSchema: + @pytest.mark.asyncio + async def test_ensure_db_schema_creates_table(self) -> None: + from release_agent.main import _ensure_db_schema + + mock_pool = AsyncMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _ensure_db_schema(mock_pool) + + # Phase 5: now executes multiple DDL statements (agent_threads + + # staging_releases + archived_releases), so called_once no longer holds. + assert mock_cursor.execute.call_count >= 1 + all_sql = " ".join( + call.args[0] for call in mock_cursor.execute.call_args_list + ) + assert "agent_threads" in all_sql + + +# --------------------------------------------------------------------------- +# _create_tool_clients tests +# --------------------------------------------------------------------------- + +class TestCreateToolClients: + def test_create_tool_clients_returns_tool_clients_instance(self) -> None: + from release_agent.main import _create_tool_clients + from release_agent.graph.dependencies import ToolClients + + mock_settings = _make_mock_settings() + + with ( + patch("release_agent.main.AzDoClient") as mock_azdo, + patch("release_agent.main.JiraClient") as mock_jira, + patch("release_agent.main.SlackClient") as mock_slack, + patch("release_agent.main.ClaudeReviewer") as mock_reviewer, + patch("release_agent.main.httpx.AsyncClient") as mock_httpx, + ): + clients, http_clients = _create_tool_clients(mock_settings) + + assert isinstance(clients, ToolClients) + + +# --------------------------------------------------------------------------- +# _create_staging_store tests +# --------------------------------------------------------------------------- + +class TestCreateStagingStore: + def test_create_staging_store_returns_store(self) -> None: + from release_agent.main import _create_staging_store + from release_agent.graph.dependencies import JsonFileStagingStore + + result = _create_staging_store() + assert isinstance(result, JsonFileStagingStore) + + +# --------------------------------------------------------------------------- +# Global exception handler tests +# --------------------------------------------------------------------------- + +class TestExceptionHandlers: + def test_app_has_exception_handlers(self) -> None: + from release_agent.main import create_app + + with ( + patch("release_agent.main.Settings", return_value=_make_mock_settings()), + patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), + patch("release_agent.main.build_release_graph", return_value=MagicMock()), + patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()), + patch("release_agent.main._create_tool_clients", return_value=MagicMock()), + patch("release_agent.main._create_staging_store", return_value=MagicMock()), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + ): + app = create_app() + + # FastAPI stores exception handlers in exception_handlers attribute + assert hasattr(app, "exception_handlers") + + +# --------------------------------------------------------------------------- +# Lifespan tests +# --------------------------------------------------------------------------- + +class TestGracefulShutdown: + @pytest.mark.asyncio + async def test_lifespan_cancels_timed_out_tasks(self) -> None: + """Verify the lifespan waits for tasks and cancels timed-out ones.""" + from release_agent.main import lifespan + + mock_pool = AsyncMock() + mock_pool.open = AsyncMock() + mock_pool.close = AsyncMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + from fastapi import FastAPI + + app = FastAPI() + app.state.background_tasks = set() + + mock_settings = _make_mock_settings() + mock_task = MagicMock() + mock_task.cancel = MagicMock() + + with ( + patch("release_agent.main.Settings", return_value=mock_settings), + patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), + patch("release_agent.main.build_release_graph", return_value=MagicMock()), + patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), + patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])), + patch("release_agent.main._create_staging_store", return_value=MagicMock()), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + patch( + "release_agent.main.asyncio.wait", + new_callable=AsyncMock, + return_value=(set(), {mock_task}), + ), + ): + ctx = lifespan(app) + await ctx.__aenter__() + # Add a fake task to background_tasks after startup + app.state.background_tasks.add(mock_task) + await ctx.__aexit__(None, None, None) + + # The pending task should have been cancelled + mock_task.cancel.assert_called_once() + + +class TestLifespan: + def test_app_state_set_after_lifespan(self) -> None: + """Verify app.state.graphs and app.state.settings are set during lifespan.""" + from release_agent.main import create_app + + mock_pool = AsyncMock() + mock_pool.__aenter__ = AsyncMock(return_value=mock_pool) + mock_pool.__aexit__ = AsyncMock(return_value=False) + mock_pool.open = AsyncMock() + mock_pool.close = AsyncMock() + + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + mock_settings = _make_mock_settings() + mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()} + mock_clients = MagicMock() + mock_staging_store = MagicMock() + + with ( + patch("release_agent.main.Settings", return_value=mock_settings), + patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]), + patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]), + patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), + patch("release_agent.main._create_tool_clients", return_value=(mock_clients, [])), + patch("release_agent.main._create_staging_store", return_value=mock_staging_store), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + ): + app = create_app() + with TestClient(app) as client: + # App is started; state should be accessible + response = client.get("/status") + # We just verify no crash + assert response.status_code in (200, 500) + + +# --------------------------------------------------------------------------- +# Phase 5: Slack interactions route + new config tests +# --------------------------------------------------------------------------- + +class TestPhase5Routes: + """Tests for Phase 5 additions to main.py.""" + + def _make_patches(self): + mock_settings = _make_mock_settings() + mock_pool = AsyncMock() + mock_pool.open = AsyncMock() + mock_pool.close = AsyncMock() + mock_pool.__aenter__ = AsyncMock(return_value=mock_pool) + mock_pool.__aexit__ = AsyncMock(return_value=False) + + return [ + patch("release_agent.main.Settings", return_value=mock_settings), + patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), + patch("release_agent.main.build_release_graph", return_value=MagicMock()), + patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), + patch("release_agent.main._create_tool_clients", return_value=MagicMock()), + patch("release_agent.main._create_staging_store", return_value=MagicMock()), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + ] + + def test_slack_interactions_route_registered(self) -> None: + from release_agent.main import create_app + + patches = self._make_patches() + for p in patches: + p.start() + + try: + app = create_app() + finally: + for p in patches: + p.stop() + + routes = {r.path for r in app.routes} # type: ignore[attr-defined] + assert "/slack/interactions" in routes + + def test_create_tool_clients_uses_bot_token(self) -> None: + from release_agent.main import _create_tool_clients + + mock_settings = _make_mock_settings() + mock_settings.slack_bot_token.get_secret_value.return_value = "xoxb-test" + mock_settings.slack_channel_id = "C12345" + mock_settings.slack_webhook_url.get_secret_value.return_value = "" + mock_settings.azdo_api_url = "https://dev.azure.com/org/proj/_apis" + mock_settings.azdo_vsrm_api_url = "https://vsrm.dev.azure.com/org/proj/_apis" + mock_settings.jira_base_url = "https://example.atlassian.net" + mock_settings.jira_email = "test@example.com" + + # Should not raise + clients, http_clients = _create_tool_clients(mock_settings) + assert clients is not None + # Clean up + for hc in http_clients: + asyncio.get_event_loop().run_until_complete(hc.aclose()) + + +class TestPhase5DbSchema: + """Tests that _ensure_db_schema adds the slack_message_ts column.""" + + async def test_ensure_db_schema_executes_sql_statements(self) -> None: + from release_agent.main import _ensure_db_schema + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + executed_sql: list[str] = [] + + async def capture_execute(sql, *args): + executed_sql.append(sql.strip()) + + mock_cursor.execute = capture_execute + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _ensure_db_schema(mock_pool) + + # Should have executed CREATE TABLE statements + assert len(executed_sql) >= 3 + combined = " ".join(executed_sql) + assert "agent_threads" in combined + assert "staging_releases" in combined + + async def test_ensure_db_schema_includes_slack_message_ts_column(self) -> None: + from release_agent.main import _ensure_db_schema + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + executed_sql: list[str] = [] + + async def capture_execute(sql, *args): + executed_sql.append(sql.strip()) + + mock_cursor.execute = capture_execute + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _ensure_db_schema(mock_pool) + + combined = " ".join(executed_sql) + assert "slack_message_ts" in combined + + +# --------------------------------------------------------------------------- +# PR polling lifespan integration tests +# --------------------------------------------------------------------------- + +class TestPrPollingLifespan: + """Tests for PR polling startup in the lifespan handler.""" + + def _make_polling_settings(self, *, pr_poll_enabled: bool = True) -> MagicMock: + s = _make_mock_settings() + s.pr_poll_enabled = pr_poll_enabled + s.pr_poll_interval_seconds = 30 + s.pr_poll_target_branch = "refs/heads/develop" + s.watched_repos_list = ["repo-a"] + s.default_jira_project = "ALLPOST" + return s + + async def test_poll_loop_started_when_pr_poll_enabled(self) -> None: + """When pr_poll_enabled=True, a background task for polling is created.""" + from release_agent.main import create_app + + mock_settings = self._make_polling_settings(pr_poll_enabled=True) + mock_pool = AsyncMock() + mock_pool.open = AsyncMock() + mock_pool.close = AsyncMock() + mock_pool.connection = MagicMock() + + poll_loop_started = [] + + async def fake_run_poll_loop(**kwargs): + poll_loop_started.append(True) + # Simulate an immediate cancellation to avoid infinite loop + raise asyncio.CancelledError + + with ( + patch("release_agent.main.Settings", return_value=mock_settings), + patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), + patch("release_agent.main.build_release_graph", return_value=MagicMock()), + patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), + patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])), + patch("release_agent.main._create_staging_store", return_value=MagicMock()), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop), + ): + app = create_app() + async with app.router.lifespan_context(app): + # Give the event loop a chance to start background tasks + await asyncio.sleep(0) + + assert len(poll_loop_started) > 0 + + async def test_poll_loop_not_started_when_pr_poll_disabled(self) -> None: + """When pr_poll_enabled=False, no polling background task is created.""" + from release_agent.main import create_app + + mock_settings = self._make_polling_settings(pr_poll_enabled=False) + mock_pool = AsyncMock() + mock_pool.open = AsyncMock() + mock_pool.close = AsyncMock() + mock_pool.connection = MagicMock() + + poll_loop_started = [] + + async def fake_run_poll_loop(**kwargs): + poll_loop_started.append(True) + + with ( + patch("release_agent.main.Settings", return_value=mock_settings), + patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()), + patch("release_agent.main.build_release_graph", return_value=MagicMock()), + patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool), + patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])), + patch("release_agent.main._create_staging_store", return_value=MagicMock()), + patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock), + patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop), + ): + app = create_app() + async with app.router.lifespan_context(app): + await asyncio.sleep(0) + + assert len(poll_loop_started) == 0 + + +# --------------------------------------------------------------------------- +# _run_graph default_jira_project injection tests +# --------------------------------------------------------------------------- + +class TestRunGraphJiraProjectInjection: + """Tests that _run_graph passes default_jira_project into the graph config.""" + + async def test_default_jira_project_passed_to_graph_config(self) -> None: + from release_agent.api.webhooks import _run_graph + + captured_configs: list[dict] = [] + + mock_graph = MagicMock() + + async def fake_ainvoke(state, config=None): + captured_configs.append(config or {}) + return {} + + mock_graph.ainvoke = fake_ainvoke + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _run_graph( + graph=mock_graph, + initial_state={"pr_id": "1", "repo_name": "r"}, + thread_id="tid-1", + tool_clients=MagicMock(), + db_pool=mock_pool, + repos_base_dir="", + graph_name="pr_completed", + default_jira_project="MYPROJ", + ) + + assert len(captured_configs) == 1 + configurable = captured_configs[0].get("configurable", {}) + assert configurable.get("default_jira_project") == "MYPROJ" + + async def test_default_jira_project_defaults_to_allpost(self) -> None: + from release_agent.api.webhooks import _run_graph + + captured_configs: list[dict] = [] + + mock_graph = MagicMock() + + async def fake_ainvoke(state, config=None): + captured_configs.append(config or {}) + return {} + + mock_graph.ainvoke = fake_ainvoke + + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _run_graph( + graph=mock_graph, + initial_state={"pr_id": "1", "repo_name": "r"}, + thread_id="tid-2", + tool_clients=MagicMock(), + db_pool=mock_pool, + ) + + configurable = captured_configs[0].get("configurable", {}) + assert configurable.get("default_jira_project") == "ALLPOST" diff --git a/tests/test_main_phase5.py b/tests/test_main_phase5.py new file mode 100644 index 0000000..17aba26 --- /dev/null +++ b/tests/test_main_phase5.py @@ -0,0 +1,147 @@ +"""Tests for main.py Phase 5 changes. + +Phase 5 - Step 4: _ensure_db_schema creates staging/archived tables, +and lifespan uses PostgresStagingStore. +Written FIRST (TDD RED phase). +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# _ensure_db_schema includes staging DDL +# --------------------------------------------------------------------------- + +class TestEnsureDbSchemaPhase5: + async def test_schema_creates_staging_releases_table(self) -> None: + from release_agent.main import _ensure_db_schema + + executed_sqls: list[str] = [] + + mock_cursor = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + + async def capture_execute(sql: str, *args) -> None: + executed_sqls.append(sql) + + mock_cursor.execute = capture_execute + + mock_conn = AsyncMock() + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + + mock_pool = MagicMock() + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _ensure_db_schema(mock_pool) + + all_sql = " ".join(executed_sqls) + assert "staging_releases" in all_sql + + async def test_schema_creates_archived_releases_table(self) -> None: + from release_agent.main import _ensure_db_schema + + executed_sqls: list[str] = [] + + mock_cursor = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + + async def capture_execute(sql: str, *args) -> None: + executed_sqls.append(sql) + + mock_cursor.execute = capture_execute + + mock_conn = AsyncMock() + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + + mock_pool = MagicMock() + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _ensure_db_schema(mock_pool) + + all_sql = " ".join(executed_sqls) + assert "archived_releases" in all_sql + + async def test_schema_still_creates_agent_threads_table(self) -> None: + from release_agent.main import _ensure_db_schema + + executed_sqls: list[str] = [] + + mock_cursor = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + + async def capture_execute(sql: str, *args) -> None: + executed_sqls.append(sql) + + mock_cursor.execute = capture_execute + + mock_conn = AsyncMock() + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + + mock_pool = MagicMock() + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _ensure_db_schema(mock_pool) + + all_sql = " ".join(executed_sqls) + assert "agent_threads" in all_sql + + async def test_schema_uses_if_not_exists(self) -> None: + from release_agent.main import _ensure_db_schema + + executed_sqls: list[str] = [] + + mock_cursor = AsyncMock() + mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor.__aexit__ = AsyncMock(return_value=False) + + async def capture_execute(sql: str, *args) -> None: + executed_sqls.append(sql) + + mock_cursor.execute = capture_execute + + mock_conn = AsyncMock() + mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock(return_value=False) + + mock_pool = MagicMock() + mock_pool.connection = MagicMock(return_value=mock_conn) + + await _ensure_db_schema(mock_pool) + + all_sql = " ".join(executed_sqls) + assert "IF NOT EXISTS" in all_sql.upper() + + +# --------------------------------------------------------------------------- +# Lifespan: PostgresStagingStore wired in +# --------------------------------------------------------------------------- + +class TestLifespanUsesPostgresStagingStore: + def test_lifespan_creates_postgres_staging_store(self) -> None: + """When PostgresStagingStore is imported in main, it is used in lifespan.""" + from release_agent.main import _create_staging_store + from release_agent.graph.postgres_staging_store import PostgresStagingStore + + mock_pool = MagicMock() + result = _create_staging_store(pool=mock_pool) + assert isinstance(result, PostgresStagingStore) + + def test_create_staging_store_without_pool_falls_back_to_json(self) -> None: + """Without a pool, falls back to JsonFileStagingStore for local dev.""" + from release_agent.main import _create_staging_store + from release_agent.graph.dependencies import JsonFileStagingStore + + result = _create_staging_store(pool=None) + assert isinstance(result, JsonFileStagingStore) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..a330267 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,635 @@ +"""Tests for Pydantic models. Written FIRST (TDD RED phase).""" + +from datetime import date, datetime + +import pytest +from pydantic import ValidationError + +from release_agent.models.jira import JiraIssue, JiraTransition +from release_agent.models.pipeline import PipelineInfo, ReleasePipelineStage +from release_agent.models.pr import PRInfo +from release_agent.models.release import ArchivedRelease, StagingRelease +from release_agent.models.review import ReviewIssue, ReviewResult +from release_agent.models.ticket import TicketEntry +from release_agent.models.webhook import WebhookPayload, WebhookRepository, WebhookResource + +# --------------------------------------------------------------------------- +# PRInfo tests +# --------------------------------------------------------------------------- + +class TestPRInfo: + """Tests for PRInfo model.""" + + def _make_pr(self, **kwargs) -> PRInfo: + defaults = { + "pr_id": "PR-1", + "pr_url": "https://dev.azure.com/org/project/_git/repo/pullrequest/1", + "repo_name": "my-repo", + "branch": "feature/ALLPOST-100_add-feature", + "pr_title": "Add new feature", + "pr_status": "active", + } + defaults.update(kwargs) + return PRInfo(**defaults) + + def test_ticket_id_extracted_from_branch(self) -> None: + pr = self._make_pr(branch="feature/ALLPOST-100_add-feature") + assert pr.ticket_id == "ALLPOST-100" + assert pr.has_ticket is True + + def test_branch_without_ticket(self) -> None: + pr = self._make_pr(branch="chore/update-dependencies") + assert pr.ticket_id is None + assert pr.has_ticket is False + + def test_main_branch_no_ticket(self) -> None: + pr = self._make_pr(branch="main") + assert pr.ticket_id is None + assert pr.has_ticket is False + + def test_refs_heads_branch_parsed(self) -> None: + pr = self._make_pr(branch="refs/heads/fix/BILL-42_fix-bug") + assert pr.ticket_id == "BILL-42" + assert pr.has_ticket is True + + def test_pr_status_active(self) -> None: + pr = self._make_pr(pr_status="active") + assert pr.pr_status == "active" + + def test_pr_status_completed(self) -> None: + pr = self._make_pr(pr_status="completed") + assert pr.pr_status == "completed" + + def test_pr_status_abandoned(self) -> None: + pr = self._make_pr(pr_status="abandoned") + assert pr.pr_status == "abandoned" + + def test_invalid_pr_status_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_pr(pr_status="unknown") + + def test_model_is_frozen(self) -> None: + pr = self._make_pr() + with pytest.raises(ValidationError): + pr.pr_id = "modified" # type: ignore[misc] + + def test_pr_url_is_valid_url(self) -> None: + pr = self._make_pr() + # HttpUrl should have been validated + assert "dev.azure.com" in str(pr.pr_url) + + def test_invalid_url_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_pr(pr_url="not-a-url") + + +# --------------------------------------------------------------------------- +# TicketEntry tests +# --------------------------------------------------------------------------- + +class TestTicketEntry: + """Tests for TicketEntry model.""" + + def _make_ticket(self, **kwargs) -> TicketEntry: + defaults = { + "id": "ALLPOST-4229", + "summary": "Fix review bug", + "pr_id": "PR-42", + "pr_url": "https://dev.azure.com/org/project/_git/repo/pullrequest/42", + "pr_title": "Fix review", + "branch": "bug/ALLPOST-4229_fix-review", + "merged_at": date(2024, 1, 15), + } + defaults.update(kwargs) + return TicketEntry(**defaults) + + def test_valid_ticket_entry(self) -> None: + ticket = self._make_ticket() + assert ticket.id == "ALLPOST-4229" + assert ticket.summary == "Fix review bug" + + def test_valid_jira_id_format(self) -> None: + ticket = self._make_ticket(id="BILL-42") + assert ticket.id == "BILL-42" + + def test_invalid_id_lowercase_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_ticket(id="allpost-4229") + + def test_invalid_id_no_number_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_ticket(id="ALLPOST-") + + def test_invalid_id_no_dash_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_ticket(id="ALLPOST4229") + + def test_invalid_id_starts_with_number_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_ticket(id="4ALLPOST-4229") + + def test_merged_at_is_date(self) -> None: + ticket = self._make_ticket() + assert isinstance(ticket.merged_at, date) + + def test_model_is_frozen(self) -> None: + ticket = self._make_ticket() + with pytest.raises(ValidationError): + ticket.id = "OTHER-1" # type: ignore[misc] + + def test_minimum_valid_id(self) -> None: + # Single uppercase letter prefix followed by dash and digits + ticket = self._make_ticket(id="A-1") + assert ticket.id == "A-1" + + def test_numeric_in_project_key(self) -> None: + ticket = self._make_ticket(id="AB2-100") + assert ticket.id == "AB2-100" + + +# --------------------------------------------------------------------------- +# StagingRelease tests +# --------------------------------------------------------------------------- + +class TestStagingRelease: + """Tests for StagingRelease model.""" + + def _make_ticket(self, ticket_id: str = "ALLPOST-1") -> TicketEntry: + return TicketEntry( + id=ticket_id, + summary="Some ticket", + pr_id="PR-1", + pr_url="https://dev.azure.com/org/project/_git/repo/pullrequest/1", + pr_title="Some PR", + branch=f"feature/{ticket_id}_some-feature", + merged_at=date(2024, 1, 15), + ) + + def _make_release(self, **kwargs) -> StagingRelease: + defaults = { + "version": "v1.0.0", + "repo": "my-repo", + "started_at": date(2024, 1, 1), + "tickets": [], + } + defaults.update(kwargs) + return StagingRelease(**defaults) + + def test_valid_release(self) -> None: + release = self._make_release() + assert release.version == "v1.0.0" + + def test_version_must_match_pattern(self) -> None: + with pytest.raises(ValidationError): + self._make_release(version="1.0.0") + + def test_version_missing_patch_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_release(version="v1.0") + + def test_version_extra_segments_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_release(version="v1.0.0.1") + + def test_version_letters_in_numbers_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_release(version="v1.a.0") + + def test_add_ticket_returns_new_instance(self) -> None: + release = self._make_release() + ticket = self._make_ticket("ALLPOST-1") + new_release = release.add_ticket(ticket) + assert new_release is not release + + def test_add_ticket_immutability(self) -> None: + release = self._make_release() + ticket = self._make_ticket("ALLPOST-1") + new_release = release.add_ticket(ticket) + assert len(release.tickets) == 0 + assert len(new_release.tickets) == 1 + + def test_add_ticket_contains_ticket(self) -> None: + release = self._make_release() + ticket = self._make_ticket("ALLPOST-1") + new_release = release.add_ticket(ticket) + assert ticket in new_release.tickets + + def test_has_ticket_true(self) -> None: + ticket = self._make_ticket("ALLPOST-1") + release = self._make_release(tickets=[ticket]) + assert release.has_ticket("ALLPOST-1") is True + + def test_has_ticket_false(self) -> None: + release = self._make_release() + assert release.has_ticket("ALLPOST-99") is False + + def test_has_ticket_after_add(self) -> None: + release = self._make_release() + ticket = self._make_ticket("ALLPOST-5") + new_release = release.add_ticket(ticket) + assert new_release.has_ticket("ALLPOST-5") is True + + def test_model_is_frozen(self) -> None: + release = self._make_release() + with pytest.raises(ValidationError): + release.version = "v2.0.0" # type: ignore[misc] + + def test_multiple_tickets(self) -> None: + t1 = self._make_ticket("ALLPOST-1") + t2 = self._make_ticket("ALLPOST-2") + release = self._make_release(tickets=[t1, t2]) + assert len(release.tickets) == 2 + + +# --------------------------------------------------------------------------- +# ArchivedRelease tests +# --------------------------------------------------------------------------- + +class TestArchivedRelease: + """Tests for ArchivedRelease model.""" + + def _make_archived(self, **kwargs) -> ArchivedRelease: + defaults = { + "version": "v1.0.0", + "repo": "my-repo", + "started_at": date(2024, 1, 1), + "tickets": [], + "released_at": date(2024, 1, 10), + } + defaults.update(kwargs) + return ArchivedRelease(**defaults) + + def test_valid_archived_release(self) -> None: + release = self._make_archived() + assert release.released_at == date(2024, 1, 10) + + def test_released_at_same_as_started_at_is_valid(self) -> None: + release = self._make_archived(started_at=date(2024, 1, 1), released_at=date(2024, 1, 1)) + assert release.released_at == release.started_at + + def test_released_at_before_started_at_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_archived( + started_at=date(2024, 1, 10), + released_at=date(2024, 1, 1), + ) + + def test_model_is_frozen(self) -> None: + release = self._make_archived() + with pytest.raises(ValidationError): + release.released_at = date(2024, 12, 31) # type: ignore[misc] + + def test_inherits_version_validation(self) -> None: + with pytest.raises(ValidationError): + self._make_archived(version="1.0.0") + + +# --------------------------------------------------------------------------- +# PipelineInfo tests +# --------------------------------------------------------------------------- + +class TestPipelineInfo: + """Tests for PipelineInfo model.""" + + def test_valid_pipeline_info(self) -> None: + pipeline = PipelineInfo(id=42, name="Release Pipeline", repo="my-repo") + assert pipeline.id == 42 + assert pipeline.name == "Release Pipeline" + assert pipeline.repo == "my-repo" + + def test_model_is_frozen(self) -> None: + pipeline = PipelineInfo(id=1, name="Test", repo="repo") + with pytest.raises(ValidationError): + pipeline.id = 2 # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# ReleasePipelineStage tests +# --------------------------------------------------------------------------- + +class TestReleasePipelineStage: + """Tests for ReleasePipelineStage model.""" + + def test_valid_stage_without_approval(self) -> None: + stage = ReleasePipelineStage( + name="Build", rank=0, requires_approval=False, approval_id=None + ) + assert stage.name == "Build" + assert stage.rank == 0 + + def test_valid_stage_with_approval(self) -> None: + stage = ReleasePipelineStage( + name="Production", rank=2, requires_approval=True, approval_id="approval-uuid-123" + ) + assert stage.requires_approval is True + assert stage.approval_id == "approval-uuid-123" + + def test_negative_rank_raises(self) -> None: + with pytest.raises(ValidationError): + ReleasePipelineStage( + name="Bad", rank=-1, requires_approval=False, approval_id=None + ) + + def test_requires_approval_false_with_approval_id_raises(self) -> None: + with pytest.raises(ValidationError): + ReleasePipelineStage( + name="Bad", rank=0, requires_approval=False, approval_id="some-id" + ) + + def test_requires_approval_true_without_approval_id_is_valid(self) -> None: + stage = ReleasePipelineStage( + name="Production", rank=2, requires_approval=True, approval_id=None + ) + assert stage.requires_approval is True + assert stage.approval_id is None + + def test_model_is_frozen(self) -> None: + stage = ReleasePipelineStage(name="Build", rank=0, requires_approval=False, approval_id=None) + with pytest.raises(ValidationError): + stage.name = "Changed" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# WebhookPayload tests +# --------------------------------------------------------------------------- + +class TestWebhookPayload: + """Tests for WebhookPayload and nested models.""" + + def _make_payload(self, **kwargs) -> WebhookPayload: + defaults = { + "subscription_id": "sub-123", + "event_type": "git.pullrequest.merged", + "resource": { + "repository": { + "id": "repo-uuid-456", + "name": "my-repo", + "web_url": "https://dev.azure.com/org/project/_git/my-repo", + }, + "pull_request_id": 42, + "title": "Fix the bug", + "source_ref_name": "refs/heads/bug/ALLPOST-4229_fix-review", + "target_ref_name": "refs/heads/main", + "status": "completed", + "closed_date": None, + }, + } + defaults.update(kwargs) + return WebhookPayload(**defaults) + + def test_valid_payload(self) -> None: + payload = self._make_payload() + assert payload.subscription_id == "sub-123" + assert payload.event_type == "git.pullrequest.merged" + + def test_resource_parsed(self) -> None: + payload = self._make_payload() + assert isinstance(payload.resource, WebhookResource) + assert payload.resource.pull_request_id == 42 + assert payload.resource.title == "Fix the bug" + + def test_repository_parsed(self) -> None: + payload = self._make_payload() + repo = payload.resource.repository + assert isinstance(repo, WebhookRepository) + assert repo.name == "my-repo" + + def test_repository_web_url(self) -> None: + payload = self._make_payload() + assert "dev.azure.com" in str(payload.resource.repository.web_url) + + def test_closed_date_none(self) -> None: + payload = self._make_payload() + assert payload.resource.closed_date is None + + def test_closed_date_populated(self) -> None: + payload_data = { + "subscription_id": "sub-123", + "event_type": "git.pullrequest.merged", + "resource": { + "repository": { + "id": "repo-uuid-456", + "name": "my-repo", + "web_url": "https://dev.azure.com/org/project/_git/my-repo", + }, + "pull_request_id": 42, + "title": "Fix the bug", + "source_ref_name": "refs/heads/bug/ALLPOST-4229_fix-review", + "target_ref_name": "refs/heads/main", + "status": "completed", + "closed_date": "2024-01-15T10:30:00Z", + }, + } + payload = WebhookPayload(**payload_data) + assert payload.resource.closed_date is not None + assert isinstance(payload.resource.closed_date, datetime) + + def test_model_is_frozen(self) -> None: + payload = self._make_payload() + with pytest.raises(ValidationError): + payload.subscription_id = "changed" # type: ignore[misc] + + def test_source_ref_name_preserved(self) -> None: + payload = self._make_payload() + assert payload.resource.source_ref_name == "refs/heads/bug/ALLPOST-4229_fix-review" + + +# --------------------------------------------------------------------------- +# ReviewIssue tests +# --------------------------------------------------------------------------- + +class TestReviewIssue: + """Tests for ReviewIssue model.""" + + def _make_issue(self, **kwargs) -> ReviewIssue: + defaults = { + "severity": "warning", + "description": "Variable name is not descriptive", + } + defaults.update(kwargs) + return ReviewIssue(**defaults) + + def test_valid_warning_issue(self) -> None: + issue = self._make_issue(severity="warning", description="Unclear variable") + assert issue.severity == "warning" + assert issue.description == "Unclear variable" + + def test_valid_error_issue(self) -> None: + issue = self._make_issue(severity="error", description="Null pointer risk") + assert issue.severity == "error" + + def test_valid_info_issue(self) -> None: + issue = self._make_issue(severity="info", description="Minor style note") + assert issue.severity == "info" + + def test_valid_blocker_issue(self) -> None: + issue = self._make_issue(severity="blocker", description="Security vulnerability") + assert issue.severity == "blocker" + + def test_invalid_severity_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_issue(severity="critical") + + def test_file_path_optional_none_by_default(self) -> None: + issue = self._make_issue() + assert issue.file_path is None + + def test_file_path_can_be_set(self) -> None: + issue = self._make_issue(file_path="src/foo.py") + assert issue.file_path == "src/foo.py" + + def test_suggestion_optional_none_by_default(self) -> None: + issue = self._make_issue() + assert issue.suggestion is None + + def test_suggestion_can_be_set(self) -> None: + issue = self._make_issue(suggestion="Rename to `user_count`") + assert issue.suggestion == "Rename to `user_count`" + + def test_model_is_frozen(self) -> None: + issue = self._make_issue() + with pytest.raises(ValidationError): + issue.severity = "error" # type: ignore[misc] + + def test_description_required(self) -> None: + with pytest.raises(ValidationError): + ReviewIssue(severity="warning") # type: ignore[call-arg] + + +# --------------------------------------------------------------------------- +# ReviewResult tests +# --------------------------------------------------------------------------- + +class TestReviewResult: + """Tests for ReviewResult model.""" + + def _make_blocker_issue(self) -> ReviewIssue: + return ReviewIssue(severity="blocker", description="Must fix this") + + def _make_warning_issue(self) -> ReviewIssue: + return ReviewIssue(severity="warning", description="Minor issue") + + def _make_result(self, **kwargs) -> ReviewResult: + defaults = { + "verdict": "approve", + "summary": "Looks good overall", + "issues": [], + } + defaults.update(kwargs) + return ReviewResult(**defaults) + + def test_valid_approve_verdict(self) -> None: + result = self._make_result(verdict="approve") + assert result.verdict == "approve" + + def test_valid_request_changes_verdict(self) -> None: + result = self._make_result(verdict="request_changes") + assert result.verdict == "request_changes" + + def test_invalid_verdict_raises(self) -> None: + with pytest.raises(ValidationError): + self._make_result(verdict="reject") + + def test_summary_stored(self) -> None: + result = self._make_result(summary="Great PR") + assert result.summary == "Great PR" + + def test_issues_empty_by_default(self) -> None: + result = self._make_result() + assert len(result.issues) == 0 + + def test_has_blockers_false_with_no_issues(self) -> None: + result = self._make_result(issues=[]) + assert result.has_blockers is False + + def test_has_blockers_false_with_only_warnings(self) -> None: + result = self._make_result(issues=[self._make_warning_issue()]) + assert result.has_blockers is False + + def test_has_blockers_true_with_blocker_issue(self) -> None: + result = self._make_result(issues=[self._make_blocker_issue()]) + assert result.has_blockers is True + + def test_has_blockers_true_mixed_issues(self) -> None: + result = self._make_result( + issues=[self._make_warning_issue(), self._make_blocker_issue()] + ) + assert result.has_blockers is True + + def test_model_is_frozen(self) -> None: + result = self._make_result() + with pytest.raises(ValidationError): + result.verdict = "request_changes" # type: ignore[misc] + + def test_multiple_issues_stored(self) -> None: + issues = [self._make_warning_issue(), self._make_blocker_issue()] + result = self._make_result(issues=issues) + assert len(result.issues) == 2 + + def test_has_blockers_is_computed(self) -> None: + # Verify has_blockers cannot be set directly (it's computed) + result = self._make_result(issues=[self._make_blocker_issue()]) + assert result.has_blockers is True + + +# --------------------------------------------------------------------------- +# JiraTransition tests +# --------------------------------------------------------------------------- + +class TestJiraTransition: + """Tests for JiraTransition model.""" + + def test_valid_transition(self) -> None: + transition = JiraTransition(id="11", name="To Do") + assert transition.id == "11" + assert transition.name == "To Do" + + def test_model_is_frozen(self) -> None: + transition = JiraTransition(id="11", name="To Do") + with pytest.raises(ValidationError): + transition.id = "22" # type: ignore[misc] + + def test_id_required(self) -> None: + with pytest.raises(ValidationError): + JiraTransition(name="To Do") # type: ignore[call-arg] + + def test_name_required(self) -> None: + with pytest.raises(ValidationError): + JiraTransition(id="11") # type: ignore[call-arg] + + +# --------------------------------------------------------------------------- +# JiraIssue tests +# --------------------------------------------------------------------------- + +class TestJiraIssue: + """Tests for JiraIssue model.""" + + def test_valid_issue(self) -> None: + issue = JiraIssue(key="ALLPOST-100", summary="Fix the bug", status="In Progress") + assert issue.key == "ALLPOST-100" + assert issue.summary == "Fix the bug" + assert issue.status == "In Progress" + + def test_model_is_frozen(self) -> None: + issue = JiraIssue(key="ALLPOST-100", summary="Fix the bug", status="In Progress") + with pytest.raises(ValidationError): + issue.key = "ALLPOST-200" # type: ignore[misc] + + def test_key_required(self) -> None: + with pytest.raises(ValidationError): + JiraIssue(summary="Fix the bug", status="In Progress") # type: ignore[call-arg] + + def test_summary_required(self) -> None: + with pytest.raises(ValidationError): + JiraIssue(key="ALLPOST-100", status="In Progress") # type: ignore[call-arg] + + def test_status_required(self) -> None: + with pytest.raises(ValidationError): + JiraIssue(key="ALLPOST-100", summary="Fix the bug") # type: ignore[call-arg] + + def test_various_statuses(self) -> None: + statuses = ["To Do", "In Progress", "Done", "Released"] + for status in statuses: + issue = JiraIssue(key="ALLPOST-1", summary="Test", status=status) + assert issue.status == status diff --git a/tests/test_models_build.py b/tests/test_models_build.py new file mode 100644 index 0000000..526c9fa --- /dev/null +++ b/tests/test_models_build.py @@ -0,0 +1,148 @@ +"""Tests for models/build.py — BuildStatus and ApprovalRecord. + +Written FIRST (TDD RED phase). +""" + +import pytest +from dataclasses import FrozenInstanceError + +from release_agent.models.build import ApprovalRecord, BuildStatus + + +# --------------------------------------------------------------------------- +# BuildStatus tests +# --------------------------------------------------------------------------- + +class TestBuildStatus: + """Tests for BuildStatus frozen dataclass.""" + + def test_can_be_created_with_all_fields(self) -> None: + bs = BuildStatus( + status="completed", + result="succeeded", + build_url="https://dev.azure.com/org/proj/_build/results?buildId=42", + ) + assert bs.status == "completed" + assert bs.result == "succeeded" + assert bs.build_url == "https://dev.azure.com/org/proj/_build/results?buildId=42" + + def test_result_can_be_none(self) -> None: + bs = BuildStatus( + status="inProgress", + result=None, + build_url="https://dev.azure.com/org/proj/_build/results?buildId=99", + ) + assert bs.result is None + + def test_build_url_can_be_none(self) -> None: + bs = BuildStatus(status="notStarted", result=None, build_url=None) + assert bs.build_url is None + + def test_is_frozen_status(self) -> None: + bs = BuildStatus(status="completed", result="succeeded", build_url=None) + with pytest.raises((FrozenInstanceError, AttributeError)): + bs.status = "inProgress" # type: ignore[misc] + + def test_is_frozen_result(self) -> None: + bs = BuildStatus(status="completed", result="succeeded", build_url=None) + with pytest.raises((FrozenInstanceError, AttributeError)): + bs.result = "failed" # type: ignore[misc] + + def test_equality(self) -> None: + a = BuildStatus(status="completed", result="succeeded", build_url="http://x") + b = BuildStatus(status="completed", result="succeeded", build_url="http://x") + assert a == b + + def test_inequality_on_status(self) -> None: + a = BuildStatus(status="completed", result="succeeded", build_url=None) + b = BuildStatus(status="inProgress", result="succeeded", build_url=None) + assert a != b + + def test_inequality_on_result(self) -> None: + a = BuildStatus(status="completed", result="succeeded", build_url=None) + b = BuildStatus(status="completed", result="failed", build_url=None) + assert a != b + + def test_repr_contains_status(self) -> None: + bs = BuildStatus(status="completed", result="succeeded", build_url=None) + assert "completed" in repr(bs) + + def test_status_values_typical(self) -> None: + for s in ("notStarted", "inProgress", "completed", "cancelling"): + bs = BuildStatus(status=s, result=None, build_url=None) + assert bs.status == s + + def test_result_values_typical(self) -> None: + for r in ("succeeded", "failed", "canceled", "partiallySucceeded"): + bs = BuildStatus(status="completed", result=r, build_url=None) + assert bs.result == r + + +# --------------------------------------------------------------------------- +# ApprovalRecord tests +# --------------------------------------------------------------------------- + +class TestApprovalRecord: + """Tests for ApprovalRecord frozen dataclass.""" + + def test_can_be_created_with_all_fields(self) -> None: + ar = ApprovalRecord( + approval_id="approval-abc-123", + stage_name="Sandbox", + status="pending", + release_id=42, + ) + assert ar.approval_id == "approval-abc-123" + assert ar.stage_name == "Sandbox" + assert ar.status == "pending" + assert ar.release_id == 42 + + def test_is_frozen_approval_id(self) -> None: + ar = ApprovalRecord( + approval_id="abc", + stage_name="Sandbox", + status="pending", + release_id=1, + ) + with pytest.raises((FrozenInstanceError, AttributeError)): + ar.approval_id = "xyz" # type: ignore[misc] + + def test_is_frozen_stage_name(self) -> None: + ar = ApprovalRecord( + approval_id="abc", + stage_name="Sandbox", + status="pending", + release_id=1, + ) + with pytest.raises((FrozenInstanceError, AttributeError)): + ar.stage_name = "Production" # type: ignore[misc] + + def test_equality(self) -> None: + a = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1) + b = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1) + assert a == b + + def test_inequality_on_approval_id(self) -> None: + a = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1) + b = ApprovalRecord(approval_id="y", stage_name="S", status="pending", release_id=1) + assert a != b + + def test_status_pending(self) -> None: + ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="pending", release_id=10) + assert ar.status == "pending" + + def test_status_approved(self) -> None: + ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="approved", release_id=10) + assert ar.status == "approved" + + def test_status_rejected(self) -> None: + ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="rejected", release_id=10) + assert ar.status == "rejected" + + def test_repr_contains_stage_name(self) -> None: + ar = ApprovalRecord(approval_id="a", stage_name="Production", status="pending", release_id=5) + assert "Production" in repr(ar) + + def test_release_id_is_int(self) -> None: + ar = ApprovalRecord(approval_id="a", stage_name="S", status="pending", release_id=999) + assert isinstance(ar.release_id, int) diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 0000000..d8c17c6 --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,241 @@ +"""Tests for LangGraph state module. Written FIRST (TDD RED phase).""" + +import json + +from release_agent.state import ReleaseState, add_errors, add_messages + +# --------------------------------------------------------------------------- +# ReleaseState tests +# --------------------------------------------------------------------------- + +class TestReleaseState: + """Tests for ReleaseState TypedDict.""" + + def test_empty_state_is_valid(self) -> None: + # total=False means all fields are optional + state: ReleaseState = {} + assert state == {} + + def test_partial_state_with_repo(self) -> None: + state: ReleaseState = {"repo_name": "my-repo"} + assert state["repo_name"] == "my-repo" + + def test_partial_state_with_messages(self) -> None: + state: ReleaseState = {"messages": ["Hello"]} + assert state["messages"] == ["Hello"] + + def test_partial_state_with_errors(self) -> None: + state: ReleaseState = {"errors": ["Something went wrong"]} + assert state["errors"] == ["Something went wrong"] + + def test_state_with_pr_id(self) -> None: + state: ReleaseState = {"pr_id": "PR-42"} + assert state["pr_id"] == "PR-42" + + def test_state_with_ticket_id(self) -> None: + state: ReleaseState = {"ticket_id": "ALLPOST-100"} + assert state["ticket_id"] == "ALLPOST-100" + + def test_state_with_version(self) -> None: + state: ReleaseState = {"version": "v1.0.1"} + assert state["version"] == "v1.0.1" + + # Phase 3 new fields + def test_state_with_webhook_payload(self) -> None: + state: ReleaseState = {"webhook_payload": {"event_type": "git.pullrequest.merged"}} + assert state["webhook_payload"]["event_type"] == "git.pullrequest.merged" + + def test_state_with_pr_info(self) -> None: + state: ReleaseState = {"pr_info": {"pr_id": "42", "repo_name": "my-repo"}} + assert state["pr_info"]["repo_name"] == "my-repo" + + def test_state_with_pr_diff(self) -> None: + state: ReleaseState = {"pr_diff": "edit: src/main.py"} + assert state["pr_diff"] == "edit: src/main.py" + + def test_state_with_last_merge_source_commit(self) -> None: + state: ReleaseState = {"last_merge_source_commit": "abc123"} + assert state["last_merge_source_commit"] == "abc123" + + def test_state_with_ticket_summary(self) -> None: + state: ReleaseState = {"ticket_summary": "Fix login bug"} + assert state["ticket_summary"] == "Fix login bug" + + def test_state_with_has_ticket(self) -> None: + state: ReleaseState = {"has_ticket": True} + assert state["has_ticket"] is True + + def test_state_with_review_result(self) -> None: + state: ReleaseState = {"review_result": {"verdict": "approve", "summary": "LGTM"}} + assert state["review_result"]["verdict"] == "approve" + + def test_state_with_review_approved(self) -> None: + state: ReleaseState = {"review_approved": True} + assert state["review_approved"] is True + + def test_state_with_staging(self) -> None: + state: ReleaseState = {"staging": {"version": "v1.0.0", "tickets": []}} + assert state["staging"]["version"] == "v1.0.0" + + def test_state_with_pr_already_merged(self) -> None: + state: ReleaseState = {"pr_already_merged": False} + assert state["pr_already_merged"] is False + + def test_state_with_release_pr_id(self) -> None: + state: ReleaseState = {"release_pr_id": "123"} + assert state["release_pr_id"] == "123" + + def test_state_with_release_pr_commit(self) -> None: + state: ReleaseState = {"release_pr_commit": "deadbeef"} + assert state["release_pr_commit"] == "deadbeef" + + def test_state_with_pipelines(self) -> None: + state: ReleaseState = {"pipelines": [{"id": 1, "name": "build"}]} + assert len(state["pipelines"]) == 1 + + def test_state_with_triggered_builds(self) -> None: + state: ReleaseState = {"triggered_builds": [{"id": 99}]} + assert state["triggered_builds"][0]["id"] == 99 + + def test_state_with_pending_approvals(self) -> None: + state: ReleaseState = {"pending_approvals": [{"approval_id": "aaa"}]} + assert state["pending_approvals"][0]["approval_id"] == "aaa" + + def test_state_with_continue_to_release(self) -> None: + state: ReleaseState = {"continue_to_release": True} + assert state["continue_to_release"] is True + + # Phase 5: CI/CD and approval fields + def test_state_with_ci_build_id(self) -> None: + state: ReleaseState = {"ci_build_id": 12345} + assert state["ci_build_id"] == 12345 + + def test_state_with_ci_build_status(self) -> None: + state: ReleaseState = {"ci_build_status": "inProgress"} + assert state["ci_build_status"] == "inProgress" + + def test_state_with_ci_build_result(self) -> None: + state: ReleaseState = {"ci_build_result": "succeeded"} + assert state["ci_build_result"] == "succeeded" + + def test_state_with_ci_build_url(self) -> None: + state: ReleaseState = {"ci_build_url": "https://dev.azure.com/org/proj/_build/results?buildId=99"} + assert "buildId=99" in state["ci_build_url"] + + def test_state_with_release_definition_id(self) -> None: + state: ReleaseState = {"release_definition_id": 7} + assert state["release_definition_id"] == 7 + + def test_state_with_release_id(self) -> None: + state: ReleaseState = {"release_id": 456} + assert state["release_id"] == 456 + + def test_state_with_current_stage(self) -> None: + state: ReleaseState = {"current_stage": "sandbox_pending"} + assert state["current_stage"] == "sandbox_pending" + + def test_state_with_approval_message_ts(self) -> None: + state: ReleaseState = {"approval_message_ts": "1234567890.123456"} + assert state["approval_message_ts"] == "1234567890.123456" + + def test_state_with_slack_message_ts(self) -> None: + state: ReleaseState = {"slack_message_ts": "9876543210.000001"} + assert state["slack_message_ts"] == "9876543210.000001" + + def test_state_json_serializable_empty(self) -> None: + state: ReleaseState = {} + serialized = json.dumps(state) + assert json.loads(serialized) == {} + + def test_state_json_serializable_with_strings(self) -> None: + state: ReleaseState = { + "repo_name": "my-repo", + "pr_id": "PR-1", + "ticket_id": "ALLPOST-1", + "version": "v1.0.0", + } + serialized = json.dumps(state) + loaded = json.loads(serialized) + assert loaded["repo_name"] == "my-repo" + assert loaded["pr_id"] == "PR-1" + + def test_state_json_serializable_with_lists(self) -> None: + state: ReleaseState = { + "messages": ["msg1", "msg2"], + "errors": ["err1"], + } + serialized = json.dumps(state) + loaded = json.loads(serialized) + assert loaded["messages"] == ["msg1", "msg2"] + assert loaded["errors"] == ["err1"] + + +# --------------------------------------------------------------------------- +# Reducer tests +# --------------------------------------------------------------------------- + +class TestAddMessages: + """Tests for add_messages reducer.""" + + def test_accumulates_to_empty(self) -> None: + result = add_messages([], ["Hello"]) + assert result == ["Hello"] + + def test_accumulates_to_existing(self) -> None: + result = add_messages(["Hello"], ["World"]) + assert result == ["Hello", "World"] + + def test_accumulates_multiple(self) -> None: + result = add_messages(["A", "B"], ["C", "D"]) + assert result == ["A", "B", "C", "D"] + + def test_existing_unchanged(self) -> None: + existing = ["Hello"] + add_messages(existing, ["World"]) + # Original should not be mutated + assert existing == ["Hello"] + + def test_empty_new_messages(self) -> None: + result = add_messages(["Hello"], []) + assert result == ["Hello"] + + def test_both_empty(self) -> None: + result = add_messages([], []) + assert result == [] + + def test_returns_new_list(self) -> None: + existing = ["Hello"] + new_msgs = ["World"] + result = add_messages(existing, new_msgs) + assert result is not existing + assert result is not new_msgs + + +class TestAddErrors: + """Tests for add_errors reducer.""" + + def test_accumulates_to_empty(self) -> None: + result = add_errors([], ["Error occurred"]) + assert result == ["Error occurred"] + + def test_accumulates_to_existing(self) -> None: + result = add_errors(["First error"], ["Second error"]) + assert result == ["First error", "Second error"] + + def test_existing_unchanged(self) -> None: + existing = ["First error"] + add_errors(existing, ["Second error"]) + assert existing == ["First error"] + + def test_empty_new_errors(self) -> None: + result = add_errors(["Existing"], []) + assert result == ["Existing"] + + def test_both_empty(self) -> None: + result = add_errors([], []) + assert result == [] + + def test_returns_new_list(self) -> None: + existing = ["Error"] + result = add_errors(existing, ["New error"]) + assert result is not existing diff --git a/tests/test_versioning.py b/tests/test_versioning.py new file mode 100644 index 0000000..1a443c3 --- /dev/null +++ b/tests/test_versioning.py @@ -0,0 +1,124 @@ +"""Tests for versioning module. Written FIRST (TDD RED phase).""" + +import pytest + +from release_agent.versioning import ( + calculate_next_version, + format_version, + parse_version, +) + + +class TestParseVersion: + """Tests for parse_version function.""" + + def test_parse_with_v_prefix(self) -> None: + assert parse_version("v1.2.3") == (1, 2, 3) + + def test_parse_without_v_prefix(self) -> None: + assert parse_version("1.2.3") == (1, 2, 3) + + def test_parse_zeros(self) -> None: + assert parse_version("v0.0.0") == (0, 0, 0) + + def test_parse_large_numbers(self) -> None: + assert parse_version("v10.20.300") == (10, 20, 300) + + def test_parse_returns_tuple_of_ints(self) -> None: + result = parse_version("v1.2.3") + assert isinstance(result, tuple) + assert len(result) == 3 + assert all(isinstance(x, int) for x in result) + + def test_parse_invalid_raises_value_error(self) -> None: + with pytest.raises(ValueError): + parse_version("invalid") + + def test_parse_partial_version_raises_value_error(self) -> None: + with pytest.raises(ValueError): + parse_version("v1.2") + + def test_parse_non_numeric_raises_value_error(self) -> None: + with pytest.raises(ValueError): + parse_version("va.b.c") + + +class TestFormatVersion: + """Tests for format_version function.""" + + def test_format_basic(self) -> None: + assert format_version(1, 0, 3) == "v1.0.3" + + def test_format_zeros(self) -> None: + assert format_version(0, 0, 0) == "v0.0.0" + + def test_format_large_numbers(self) -> None: + assert format_version(10, 20, 300) == "v10.20.300" + + def test_format_returns_string(self) -> None: + result = format_version(1, 2, 3) + assert isinstance(result, str) + + def test_format_starts_with_v(self) -> None: + result = format_version(1, 2, 3) + assert result.startswith("v") + + +class TestCalculateNextVersion: + """Tests for calculate_next_version function.""" + + def test_empty_list_returns_v1_0_0(self) -> None: + assert calculate_next_version("my-repo", []) == "v1.0.0" + + def test_single_version_increments_patch(self) -> None: + assert calculate_next_version("my-repo", ["v1.0.0"]) == "v1.0.1" + + def test_multiple_versions_uses_highest(self) -> None: + assert calculate_next_version("my-repo", ["v1.0.3", "v1.0.1"]) == "v1.0.4" + + def test_different_major_versions(self) -> None: + assert calculate_next_version("my-repo", ["v2.1.0", "v1.9.9"]) == "v2.1.1" + + def test_skips_malformed_versions(self) -> None: + assert calculate_next_version("my-repo", ["invalid", "v1.0.0"]) == "v1.0.1" + + def test_all_malformed_versions_returns_v1_0_0(self) -> None: + assert calculate_next_version("my-repo", ["invalid", "bad", "nope"]) == "v1.0.0" + + def test_repo_name_does_not_affect_result(self) -> None: + result_a = calculate_next_version("repo-a", ["v1.0.0"]) + result_b = calculate_next_version("repo-b", ["v1.0.0"]) + assert result_a == result_b + + def test_versions_out_of_order(self) -> None: + assert calculate_next_version("my-repo", ["v1.0.1", "v1.0.3", "v1.0.2"]) == "v1.0.4" + + def test_patch_overflow_does_not_occur(self) -> None: + # Just increments patch - no overflow logic required + result = calculate_next_version("my-repo", ["v1.0.99"]) + assert result == "v1.0.100" + + def test_versions_without_v_prefix_skipped(self) -> None: + # Versions without 'v' prefix are treated as malformed per spec + result = calculate_next_version("my-repo", ["1.0.0", "v2.0.0"]) + assert result == "v2.0.1" + + def test_result_format_starts_with_v(self) -> None: + result = calculate_next_version("my-repo", ["v1.0.0"]) + assert result.startswith("v") + + def test_result_has_three_parts(self) -> None: + result = calculate_next_version("my-repo", ["v1.0.0"]) + parts = result[1:].split(".") + assert len(parts) == 3 + assert all(p.isdigit() for p in parts) + + def test_v_prefix_with_nonnumeric_parts_skipped(self) -> None: + # Starts with 'v' but is malformed - should be skipped gracefully + result = calculate_next_version("my-repo", ["va.b.c", "v1.0.0"]) + assert result == "v1.0.1" + + def test_v_prefix_partial_version_skipped(self) -> None: + # Starts with 'v' but only has two parts - should be skipped + result = calculate_next_version("my-repo", ["v1.0", "v2.0.0"]) + assert result == "v2.0.1" diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tools/fixtures/azdo_approve_release.json b/tests/tools/fixtures/azdo_approve_release.json new file mode 100644 index 0000000..49b3a2e --- /dev/null +++ b/tests/tools/fixtures/azdo_approve_release.json @@ -0,0 +1,9 @@ +{ + "id": "approval-uuid-123", + "status": "approved", + "approver": { + "id": "user-uuid-456", + "displayName": "Release Bot" + }, + "comments": "Approved via release agent" +} diff --git a/tests/tools/fixtures/azdo_build_status.json b/tests/tools/fixtures/azdo_build_status.json new file mode 100644 index 0000000..430c4ab --- /dev/null +++ b/tests/tools/fixtures/azdo_build_status.json @@ -0,0 +1,9 @@ +{ + "id": 1001, + "buildNumber": "20240115.1", + "status": "completed", + "result": "succeeded", + "queueTime": "2024-01-15T10:00:00Z", + "startTime": "2024-01-15T10:01:00Z", + "finishTime": "2024-01-15T10:10:00Z" +} diff --git a/tests/tools/fixtures/azdo_create_pr.json b/tests/tools/fixtures/azdo_create_pr.json new file mode 100644 index 0000000..e392e5c --- /dev/null +++ b/tests/tools/fixtures/azdo_create_pr.json @@ -0,0 +1,8 @@ +{ + "pullRequestId": 99, + "title": "Release v1.2.0", + "status": "active", + "sourceRefName": "refs/heads/release/v1.2.0", + "targetRefName": "refs/heads/main", + "url": "https://dev.azure.com/my-org/my-project/_apis/git/repositories/my-repo/pullRequests/99" +} diff --git a/tests/tools/fixtures/azdo_merge_pr.json b/tests/tools/fixtures/azdo_merge_pr.json new file mode 100644 index 0000000..c6e2a5e --- /dev/null +++ b/tests/tools/fixtures/azdo_merge_pr.json @@ -0,0 +1,8 @@ +{ + "pullRequestId": 42, + "status": "completed", + "title": "Fix the auth bug", + "completionOptions": { + "mergeStrategy": "squash" + } +} diff --git a/tests/tools/fixtures/azdo_pipelines.json b/tests/tools/fixtures/azdo_pipelines.json new file mode 100644 index 0000000..1a3a71b --- /dev/null +++ b/tests/tools/fixtures/azdo_pipelines.json @@ -0,0 +1,15 @@ +{ + "value": [ + { + "id": 10, + "name": "Release Pipeline", + "folder": "\\" + }, + { + "id": 20, + "name": "Build Pipeline", + "folder": "\\" + } + ], + "count": 2 +} diff --git a/tests/tools/fixtures/azdo_pr.json b/tests/tools/fixtures/azdo_pr.json new file mode 100644 index 0000000..c0e97ec --- /dev/null +++ b/tests/tools/fixtures/azdo_pr.json @@ -0,0 +1,16 @@ +{ + "pullRequestId": 42, + "title": "Fix the auth bug", + "status": "active", + "sourceRefName": "refs/heads/bug/ALLPOST-999_fix-auth", + "targetRefName": "refs/heads/main", + "url": "https://dev.azure.com/my-org/my-project/_apis/git/repositories/my-repo/pullRequests/42", + "repository": { + "id": "repo-uuid-123", + "name": "my-repo", + "remoteUrl": "https://dev.azure.com/my-org/my-project/_git/my-repo" + }, + "lastMergeSourceCommit": { + "commitId": "abc123def456" + } +} diff --git a/tests/tools/fixtures/azdo_pr_diff.txt b/tests/tools/fixtures/azdo_pr_diff.txt new file mode 100644 index 0000000..2ef8b50 --- /dev/null +++ b/tests/tools/fixtures/azdo_pr_diff.txt @@ -0,0 +1,11 @@ +diff --git a/src/auth.py b/src/auth.py +index 1234567..abcdefg 100644 +--- a/src/auth.py ++++ b/src/auth.py +@@ -10,6 +10,10 @@ class AuthService: + def authenticate(self, token: str) -> bool: +- return token == "hardcoded" ++ return self._validate_token(token) ++ ++ def _validate_token(self, token: str) -> bool: ++ return len(token) > 0 and token.startswith("Bearer ") diff --git a/tests/tools/fixtures/azdo_trigger_pipeline.json b/tests/tools/fixtures/azdo_trigger_pipeline.json new file mode 100644 index 0000000..ed9dfa3 --- /dev/null +++ b/tests/tools/fixtures/azdo_trigger_pipeline.json @@ -0,0 +1,11 @@ +{ + "id": 1001, + "buildNumber": "20240115.1", + "status": "notStarted", + "queueTime": "2024-01-15T10:00:00Z", + "definition": { + "id": 10, + "name": "Release Pipeline" + }, + "sourceBranch": "refs/heads/main" +} diff --git a/tests/tools/fixtures/jira_issue.json b/tests/tools/fixtures/jira_issue.json new file mode 100644 index 0000000..0e60a18 --- /dev/null +++ b/tests/tools/fixtures/jira_issue.json @@ -0,0 +1,10 @@ +{ + "id": "12345", + "key": "ALLPOST-100", + "fields": { + "summary": "Fix the authentication bug", + "status": { + "name": "In Progress" + } + } +} diff --git a/tests/tools/fixtures/jira_transitions.json b/tests/tools/fixtures/jira_transitions.json new file mode 100644 index 0000000..d75cc9d --- /dev/null +++ b/tests/tools/fixtures/jira_transitions.json @@ -0,0 +1,20 @@ +{ + "transitions": [ + { + "id": "11", + "name": "To Do" + }, + { + "id": "21", + "name": "In Progress" + }, + { + "id": "31", + "name": "Done" + }, + { + "id": "41", + "name": "Released" + } + ] +} diff --git a/tests/tools/test_azdo.py b/tests/tools/test_azdo.py new file mode 100644 index 0000000..4b914da --- /dev/null +++ b/tests/tools/test_azdo.py @@ -0,0 +1,819 @@ +"""Tests for AzDoClient. Written FIRST (TDD RED phase).""" + +import json +from pathlib import Path + +import httpx +import pytest + +from release_agent.exceptions import AuthenticationError, NotFoundError, ServiceError +from release_agent.models.build import ApprovalRecord, BuildStatus +from release_agent.models.pipeline import PipelineInfo +from release_agent.models.pr import PRInfo +from release_agent.tools.azdo import AzDoClient + +# --------------------------------------------------------------------------- +# Fixture helpers +# --------------------------------------------------------------------------- + +FIXTURES = Path(__file__).parent / "fixtures" + + +def _load_json(name: str) -> dict: + return json.loads((FIXTURES / name).read_text()) + + +def _load_text(name: str) -> str: + return (FIXTURES / name).read_text() + + +def _make_transport(routes: dict[tuple[str, str], tuple[int, bytes | str]]) -> httpx.MockTransport: + """Build a MockTransport that dispatches based on (method, url_substring).""" + + def handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + method = request.method + for (m, url_fragment), (status, body) in routes.items(): + if m == method and url_fragment in url: + content = body if isinstance(body, bytes) else body.encode() + return httpx.Response(status_code=status, content=content) + return httpx.Response(status_code=404, content=b'{"message": "Not found"}') + + return httpx.MockTransport(handler) + + +def _make_client(routes: dict) -> AzDoClient: + """Create an AzDoClient with mocked HTTP transport.""" + transport = _make_transport(routes) + http_client = httpx.AsyncClient(transport=transport) + vsrm_client = httpx.AsyncClient(transport=transport) + return AzDoClient( + base_url="https://dev.azure.com/my-org/my-project/_apis", + vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis", + pat="test-pat", + http_client=http_client, + vsrm_http_client=vsrm_client, + ) + + +# --------------------------------------------------------------------------- +# AzDoClient construction tests +# --------------------------------------------------------------------------- + +class TestAzDoClientConstruction: + """Tests for AzDoClient initialization.""" + + def test_can_be_instantiated_with_injected_clients(self) -> None: + transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}")) + http_client = httpx.AsyncClient(transport=transport) + vsrm_client = httpx.AsyncClient(transport=transport) + client = AzDoClient( + base_url="https://dev.azure.com/org/proj/_apis", + vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis", + pat="my-pat", + http_client=http_client, + vsrm_http_client=vsrm_client, + ) + assert client is not None + + async def test_context_manager_closes_clients(self) -> None: + transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}")) + http_client = httpx.AsyncClient(transport=transport) + vsrm_client = httpx.AsyncClient(transport=transport) + async with AzDoClient( + base_url="https://dev.azure.com/org/proj/_apis", + vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis", + pat="my-pat", + http_client=http_client, + vsrm_http_client=vsrm_client, + ) as client: + assert client is not None + # After context manager exits, clients should be closed + assert http_client.is_closed + assert vsrm_client.is_closed + + +# --------------------------------------------------------------------------- +# get_pr tests +# --------------------------------------------------------------------------- + +class TestGetPr: + """Tests for AzDoClient.get_pr.""" + + async def test_returns_pr_info(self) -> None: + pr_data = _load_json("azdo_pr.json") + routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))} + client = _make_client(routes) + + result = await client.get_pr(42) + + assert isinstance(result, PRInfo) + assert result.pr_id == "42" + + async def test_pr_title_extracted(self) -> None: + pr_data = _load_json("azdo_pr.json") + routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))} + client = _make_client(routes) + + result = await client.get_pr(42) + assert result.pr_title == "Fix the auth bug" + + async def test_pr_branch_extracted(self) -> None: + pr_data = _load_json("azdo_pr.json") + routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))} + client = _make_client(routes) + + result = await client.get_pr(42) + assert "ALLPOST-999" in result.branch or "bug" in result.branch + + async def test_pr_status_extracted(self) -> None: + pr_data = _load_json("azdo_pr.json") + routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))} + client = _make_client(routes) + + result = await client.get_pr(42) + assert result.pr_status == "active" + + async def test_404_raises_not_found(self) -> None: + routes = {("GET", "pullRequests/999"): (404, b'{"message": "PR not found"}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.get_pr(999) + + async def test_401_raises_authentication_error(self) -> None: + routes = {("GET", "pullRequests/42"): (401, b'{"message": "Unauthorized"}')} + client = _make_client(routes) + + with pytest.raises(AuthenticationError): + await client.get_pr(42) + + async def test_500_raises_service_error(self) -> None: + routes = {("GET", "pullRequests/42"): (500, b'{"message": "Internal error"}')} + client = _make_client(routes) + + with pytest.raises(ServiceError): + await client.get_pr(42) + + +# --------------------------------------------------------------------------- +# get_pr_diff tests +# --------------------------------------------------------------------------- + +class TestGetPrDiff: + """Tests for AzDoClient.get_pr_diff.""" + + async def test_returns_diff_string(self) -> None: + pr_data = _load_json("azdo_pr.json") + routes = { + ("GET", "pullRequests/42"): (200, json.dumps(pr_data)), + ("GET", "diffs"): (200, json.dumps({ + "changes": [ + { + "item": {"path": "/src/auth.py"}, + "changeType": "edit", + } + ] + })), + } + client = _make_client(routes) + + result = await client.get_pr_diff(42) + assert isinstance(result, str) + + async def test_diff_includes_file_paths(self) -> None: + pr_data = _load_json("azdo_pr.json") + diff_data = { + "changes": [ + {"item": {"path": "/src/auth.py"}, "changeType": "edit"}, + {"item": {"path": "/src/util.py"}, "changeType": "add"}, + ] + } + + def handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + if "diffs" in url: + return httpx.Response(200, content=json.dumps(diff_data).encode()) + if "pullRequests/42" in url: + return httpx.Response(200, content=json.dumps(pr_data).encode()) + return httpx.Response(404, content=b"{}") + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + vsrm_client = httpx.AsyncClient(transport=transport) + client = AzDoClient( + base_url="https://dev.azure.com/my-org/my-project/_apis", + vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis", + pat="test-pat", + http_client=http_client, + vsrm_http_client=vsrm_client, + ) + + result = await client.get_pr_diff(42) + assert "/src/auth.py" in result + assert "/src/util.py" in result + + async def test_empty_changes_returns_empty_string(self) -> None: + pr_data = _load_json("azdo_pr.json") + diff_data: dict = {"changes": []} + + def handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + if "diffs" in url: + return httpx.Response(200, content=json.dumps(diff_data).encode()) + if "pullRequests/42" in url: + return httpx.Response(200, content=json.dumps(pr_data).encode()) + return httpx.Response(404, content=b"{}") + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + vsrm_client = httpx.AsyncClient(transport=transport) + client = AzDoClient( + base_url="https://dev.azure.com/my-org/my-project/_apis", + vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis", + pat="test-pat", + http_client=http_client, + vsrm_http_client=vsrm_client, + ) + + result = await client.get_pr_diff(42) + assert result == "" + + async def test_404_raises_not_found(self) -> None: + routes = {("GET", "pullRequests/999"): (404, b'{"message": "PR not found"}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.get_pr_diff(999) + + async def test_pr_without_url_field_uses_remote_url(self) -> None: + """When the API response omits the 'url' field, fallback URL is built.""" + pr_data = { + "pullRequestId": 42, + "title": "Fix bug", + "status": "active", + "sourceRefName": "refs/heads/fix/ALLPOST-1_fix", + "repository": { + "id": "repo-uuid", + "name": "my-repo", + "remoteUrl": "https://dev.azure.com/org/proj/_git/my-repo", + }, + # NOTE: 'url' field is intentionally omitted + } + routes = { + ("GET", "pullRequests/42"): (200, json.dumps(pr_data)), + ("GET", "diffs"): (200, json.dumps({"changes": []})), + } + client = _make_client(routes) + + result = await client.get_pr(42) + assert "42" in str(result.pr_url) + + +# --------------------------------------------------------------------------- +# merge_pr tests +# --------------------------------------------------------------------------- + +class TestMergePr: + """Tests for AzDoClient.merge_pr.""" + + async def test_returns_true_on_success(self) -> None: + merge_data = _load_json("azdo_merge_pr.json") + routes = {("PATCH", "pullRequests/42"): (200, json.dumps(merge_data))} + client = _make_client(routes) + + result = await client.merge_pr(pr_id=42, last_merge_source_commit="abc123def456") + assert result is True + + async def test_404_raises_not_found(self) -> None: + routes = {("PATCH", "pullRequests/999"): (404, b'{"message": "Not found"}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.merge_pr(pr_id=999, last_merge_source_commit="abc123") + + async def test_409_raises_service_error(self) -> None: + routes = {("PATCH", "pullRequests/42"): (409, b'{"message": "Conflict"}')} + client = _make_client(routes) + + with pytest.raises(ServiceError): + await client.merge_pr(pr_id=42, last_merge_source_commit="abc123") + + +# --------------------------------------------------------------------------- +# create_pr tests +# --------------------------------------------------------------------------- + +class TestCreatePr: + """Tests for AzDoClient.create_pr.""" + + async def test_returns_dict_with_pr_id(self) -> None: + create_data = _load_json("azdo_create_pr.json") + routes = {("POST", "pullRequests"): (201, json.dumps(create_data))} + client = _make_client(routes) + + result = await client.create_pr( + repo="my-repo", + source="refs/heads/release/v1.2.0", + target="refs/heads/main", + title="Release v1.2.0", + description="Release notes", + ) + assert isinstance(result, dict) + assert result["pullRequestId"] == 99 + + async def test_400_raises_service_error(self) -> None: + routes = {("POST", "pullRequests"): (400, b'{"message": "Bad request"}')} + client = _make_client(routes) + + with pytest.raises(ServiceError): + await client.create_pr( + repo="my-repo", + source="refs/heads/release/v1.2.0", + target="refs/heads/main", + title="Release", + description="", + ) + + +# --------------------------------------------------------------------------- +# list_build_pipelines tests +# --------------------------------------------------------------------------- + +class TestListBuildPipelines: + """Tests for AzDoClient.list_build_pipelines.""" + + async def test_returns_list_of_pipeline_info(self) -> None: + pipeline_data = _load_json("azdo_pipelines.json") + routes = {("GET", "pipelines"): (200, json.dumps(pipeline_data))} + client = _make_client(routes) + + result = await client.list_build_pipelines(repo="my-repo") + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(p, PipelineInfo) for p in result) + + async def test_pipeline_ids_extracted(self) -> None: + pipeline_data = _load_json("azdo_pipelines.json") + routes = {("GET", "pipelines"): (200, json.dumps(pipeline_data))} + client = _make_client(routes) + + result = await client.list_build_pipelines(repo="my-repo") + ids = [p.id for p in result] + assert 10 in ids + assert 20 in ids + + async def test_empty_list_on_no_pipelines(self) -> None: + routes = {("GET", "pipelines"): (200, json.dumps({"value": [], "count": 0}))} + client = _make_client(routes) + + result = await client.list_build_pipelines(repo="my-repo") + assert result == [] + + +# --------------------------------------------------------------------------- +# trigger_pipeline tests +# --------------------------------------------------------------------------- + +class TestTriggerPipeline: + """Tests for AzDoClient.trigger_pipeline.""" + + async def test_returns_dict_with_build_id(self) -> None: + trigger_data = _load_json("azdo_trigger_pipeline.json") + routes = {("POST", "pipelines/10/runs"): (200, json.dumps(trigger_data))} + client = _make_client(routes) + + result = await client.trigger_pipeline(pipeline_id=10, branch="refs/heads/main") + assert isinstance(result, dict) + assert result["id"] == 1001 + + async def test_404_raises_not_found(self) -> None: + routes = {("POST", "pipelines/999/runs"): (404, b'{"message": "Pipeline not found"}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.trigger_pipeline(pipeline_id=999, branch="refs/heads/main") + + +# --------------------------------------------------------------------------- +# get_build_status tests +# --------------------------------------------------------------------------- + +class TestGetBuildStatus: + """Tests for AzDoClient.get_build_status.""" + + async def test_returns_build_status_object(self) -> None: + build_data = _load_json("azdo_build_status.json") + routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))} + client = _make_client(routes) + + result = await client.get_build_status(build_id=1001) + assert isinstance(result, BuildStatus) + + async def test_status_field_populated(self) -> None: + build_data = _load_json("azdo_build_status.json") + routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))} + client = _make_client(routes) + + result = await client.get_build_status(build_id=1001) + assert result.status == "completed" + + async def test_result_field_populated(self) -> None: + build_data = _load_json("azdo_build_status.json") + routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))} + client = _make_client(routes) + + result = await client.get_build_status(build_id=1001) + assert result.result == "succeeded" + + async def test_build_url_present(self) -> None: + build_data = _load_json("azdo_build_status.json") + routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))} + client = _make_client(routes) + + result = await client.get_build_status(build_id=1001) + # build_url may be None if not in fixture, but field must exist + assert hasattr(result, "build_url") + + async def test_result_none_when_not_completed(self) -> None: + build_data = {"id": 99, "status": "inProgress", "buildNumber": "20240101.1"} + routes = {("GET", "build/builds/99"): (200, json.dumps(build_data))} + client = _make_client(routes) + + result = await client.get_build_status(build_id=99) + assert result.status == "inProgress" + assert result.result is None + + async def test_404_raises_not_found(self) -> None: + routes = {("GET", "build/builds/9999"): (404, b'{"message": "Build not found"}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.get_build_status(build_id=9999) + + +# --------------------------------------------------------------------------- +# get_release_approvals tests +# --------------------------------------------------------------------------- + +class TestGetReleaseApprovals: + """Tests for AzDoClient.get_release_approvals.""" + + async def test_returns_list_of_approval_records(self) -> None: + approvals_data = { + "value": [ + { + "id": 101, + "status": "pending", + "releaseEnvironment": {"name": "Sandbox", "release": {"id": 55}}, + }, + { + "id": 102, + "status": "approved", + "releaseEnvironment": {"name": "Production", "release": {"id": 55}}, + }, + ], + "count": 2, + } + routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))} + client = _make_client(routes) + + result = await client.get_release_approvals(release_id=55) + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(a, ApprovalRecord) for a in result) + + async def test_approval_id_populated(self) -> None: + approvals_data = { + "value": [ + { + "id": 201, + "status": "pending", + "releaseEnvironment": {"name": "Sandbox", "release": {"id": 10}}, + } + ], + "count": 1, + } + routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))} + client = _make_client(routes) + + result = await client.get_release_approvals(release_id=10) + assert result[0].approval_id == "201" + + async def test_stage_name_populated(self) -> None: + approvals_data = { + "value": [ + { + "id": 300, + "status": "pending", + "releaseEnvironment": {"name": "Production", "release": {"id": 20}}, + } + ], + "count": 1, + } + routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))} + client = _make_client(routes) + + result = await client.get_release_approvals(release_id=20) + assert result[0].stage_name == "Production" + + async def test_release_id_populated(self) -> None: + approvals_data = { + "value": [ + { + "id": 400, + "status": "pending", + "releaseEnvironment": {"name": "Stage", "release": {"id": 99}}, + } + ], + "count": 1, + } + routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))} + client = _make_client(routes) + + result = await client.get_release_approvals(release_id=99) + assert result[0].release_id == 99 + + async def test_empty_list_when_no_approvals(self) -> None: + approvals_data = {"value": [], "count": 0} + routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))} + client = _make_client(routes) + + result = await client.get_release_approvals(release_id=77) + assert result == [] + + async def test_filters_by_release_id_in_query(self) -> None: + captured_urls: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_urls.append(str(request.url)) + return httpx.Response(200, content=b'{"value": [], "count": 0}') + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + vsrm_client = httpx.AsyncClient(transport=transport) + client = AzDoClient( + base_url="https://dev.azure.com/my-org/my-project/_apis", + vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis", + pat="test-pat", + http_client=http_client, + vsrm_http_client=vsrm_client, + ) + + await client.get_release_approvals(release_id=42) + assert any("approvals" in url for url in captured_urls) + + async def test_404_raises_not_found(self) -> None: + routes = {("GET", "release/approvals"): (404, b'{"message": "Not found"}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.get_release_approvals(release_id=999) + + +# --------------------------------------------------------------------------- +# get_latest_release tests +# --------------------------------------------------------------------------- + +class TestGetLatestRelease: + """Tests for AzDoClient.get_latest_release.""" + + async def test_returns_dict(self) -> None: + release_data = { + "value": [{"id": 55, "name": "Release-55", "status": "active"}], + "count": 1, + } + routes = {("GET", "release/releases"): (200, json.dumps(release_data))} + client = _make_client(routes) + + result = await client.get_latest_release(definition_id=7) + assert isinstance(result, dict) + assert result["id"] == 55 + + async def test_returns_empty_dict_when_no_releases(self) -> None: + release_data = {"value": [], "count": 0} + routes = {("GET", "release/releases"): (200, json.dumps(release_data))} + client = _make_client(routes) + + result = await client.get_latest_release(definition_id=99) + assert result == {} + + async def test_404_raises_not_found(self) -> None: + routes = {("GET", "release/releases"): (404, b'{"message": "Not found"}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.get_latest_release(definition_id=999) + + async def test_passes_definition_id_as_filter(self) -> None: + captured_urls: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_urls.append(str(request.url)) + return httpx.Response(200, content=b'{"value": [], "count": 0}') + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + vsrm_client = httpx.AsyncClient(transport=transport) + client = AzDoClient( + base_url="https://dev.azure.com/my-org/my-project/_apis", + vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis", + pat="test-pat", + http_client=http_client, + vsrm_http_client=vsrm_client, + ) + + await client.get_latest_release(definition_id=13) + assert any("releases" in url for url in captured_urls) + + +# --------------------------------------------------------------------------- +# approve_release tests +# --------------------------------------------------------------------------- + +class TestApproveRelease: + """Tests for AzDoClient.approve_release.""" + + async def test_returns_dict_with_status(self) -> None: + approve_data = _load_json("azdo_approve_release.json") + routes = {("PATCH", "release/approvals"): (200, json.dumps(approve_data))} + client = _make_client(routes) + + result = await client.approve_release( + approval_id="approval-uuid-123", comment="Approved" + ) + assert isinstance(result, dict) + assert result["status"] == "approved" + + async def test_404_raises_not_found(self) -> None: + routes = {("PATCH", "release/approvals"): (404, b'{"message": "Approval not found"}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.approve_release(approval_id="bad-id", comment="Approve") + + +# --------------------------------------------------------------------------- +# close() lifecycle tests +# --------------------------------------------------------------------------- + +class TestAzDoClientLifecycle: + """Tests for AzDoClient close() method.""" + + async def test_close_closes_both_clients(self) -> None: + transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}")) + http_client = httpx.AsyncClient(transport=transport) + vsrm_client = httpx.AsyncClient(transport=transport) + client = AzDoClient( + base_url="https://dev.azure.com/org/proj/_apis", + vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis", + pat="my-pat", + http_client=http_client, + vsrm_http_client=vsrm_client, + ) + + await client.close() + + assert http_client.is_closed + assert vsrm_client.is_closed + + +# --------------------------------------------------------------------------- +# list_active_prs tests +# --------------------------------------------------------------------------- + +def _make_pr_list_response(prs: list[dict]) -> str: + return json.dumps({"value": prs, "count": len(prs)}) + + +def _make_active_pr_item( + pr_id: int = 10, + title: str = "Test PR", + branch: str = "refs/heads/feature/ALLPOST-100_fix", + status: str = "active", + repo_name: str = "my-repo", +) -> dict: + return { + "pullRequestId": pr_id, + "title": title, + "status": status, + "sourceRefName": branch, + "targetRefName": "refs/heads/develop", + "url": f"https://dev.azure.com/org/proj/_apis/git/repositories/{repo_name}/pullRequests/{pr_id}", + "repository": { + "id": "repo-uuid", + "name": repo_name, + "remoteUrl": f"https://dev.azure.com/org/proj/_git/{repo_name}", + }, + } + + +class TestListActivePrs: + """Tests for AzDoClient.list_active_prs.""" + + async def test_returns_list_of_pr_info(self) -> None: + pr_item = _make_active_pr_item() + routes = { + ("GET", "git/repositories/my-repo/pullRequests"): ( + 200, + _make_pr_list_response([pr_item]), + ) + } + client = _make_client(routes) + + result = await client.list_active_prs("my-repo", "refs/heads/develop") + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], PRInfo) + + async def test_pr_id_extracted(self) -> None: + pr_item = _make_active_pr_item(pr_id=55) + routes = { + ("GET", "git/repositories/my-repo/pullRequests"): ( + 200, + _make_pr_list_response([pr_item]), + ) + } + client = _make_client(routes) + + result = await client.list_active_prs("my-repo", "refs/heads/develop") + assert result[0].pr_id == "55" + + async def test_pr_title_extracted(self) -> None: + pr_item = _make_active_pr_item(title="My Feature") + routes = { + ("GET", "git/repositories/my-repo/pullRequests"): ( + 200, + _make_pr_list_response([pr_item]), + ) + } + client = _make_client(routes) + + result = await client.list_active_prs("my-repo", "refs/heads/develop") + assert result[0].pr_title == "My Feature" + + async def test_empty_list_when_no_prs(self) -> None: + routes = { + ("GET", "git/repositories/my-repo/pullRequests"): ( + 200, + _make_pr_list_response([]), + ) + } + client = _make_client(routes) + + result = await client.list_active_prs("my-repo", "refs/heads/develop") + assert result == [] + + async def test_multiple_prs_returned(self) -> None: + prs = [ + _make_active_pr_item(pr_id=10, title="PR 10"), + _make_active_pr_item(pr_id=20, title="PR 20"), + ] + routes = { + ("GET", "git/repositories/my-repo/pullRequests"): ( + 200, + _make_pr_list_response(prs), + ) + } + client = _make_client(routes) + + result = await client.list_active_prs("my-repo", "refs/heads/develop") + assert len(result) == 2 + assert {r.pr_id for r in result} == {"10", "20"} + + async def test_404_raises_not_found(self) -> None: + routes = { + ("GET", "git/repositories/missing-repo/pullRequests"): ( + 404, + b'{"message": "Repo not found"}', + ) + } + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.list_active_prs("missing-repo", "refs/heads/develop") + + async def test_401_raises_authentication_error(self) -> None: + routes = { + ("GET", "git/repositories/my-repo/pullRequests"): ( + 401, + b'{"message": "Unauthorized"}', + ) + } + client = _make_client(routes) + + with pytest.raises(AuthenticationError): + await client.list_active_prs("my-repo", "refs/heads/develop") + + async def test_500_raises_service_error(self) -> None: + routes = { + ("GET", "git/repositories/my-repo/pullRequests"): ( + 500, + b'{"message": "Internal error"}', + ) + } + client = _make_client(routes) + + with pytest.raises(ServiceError): + await client.list_active_prs("my-repo", "refs/heads/develop") diff --git a/tests/tools/test_claude_review.py b/tests/tools/test_claude_review.py new file mode 100644 index 0000000..bc9af8d --- /dev/null +++ b/tests/tools/test_claude_review.py @@ -0,0 +1,454 @@ +"""Tests for ClaudeReviewer using Claude Code CLI subprocess.""" + +import json + +import pytest + +from release_agent.models.review import ReviewResult +from release_agent.tools.claude_review import ( + ClaudeReviewer, + _build_prompt, + _parse_cli_output, + _truncate_diff, +) + +MAX_DIFF_CHARS = 100_000 + + +# --------------------------------------------------------------------------- +# Helpers — fake subprocess runner +# --------------------------------------------------------------------------- + +def _make_cli_output( + verdict: str = "approve", + summary: str = "LGTM", + issues: list | None = None, +) -> str: + """Build a JSON string mimicking Claude Code CLI --output-format json.""" + structured = { + "verdict": verdict, + "summary": summary, + "issues": issues or [], + } + return json.dumps({"result": "", "structured_output": structured}) + + +def _make_subprocess_runner( + stdout: str = "", + stderr: str = "", + returncode: int = 0, +): + """Return a fake run_subprocess callable that records calls.""" + calls: list[dict] = [] + + async def fake_run(*, cmd, cwd, timeout): + calls.append({"cmd": cmd, "cwd": cwd, "timeout": timeout}) + return (stdout, stderr, returncode) + + return fake_run, calls + + +# --------------------------------------------------------------------------- +# _truncate_diff tests +# --------------------------------------------------------------------------- + +class TestTruncateDiff: + def test_short_diff_not_truncated(self) -> None: + diff = "short diff" + assert _truncate_diff(diff) == diff + + def test_exact_limit_not_truncated(self) -> None: + diff = "x" * MAX_DIFF_CHARS + assert _truncate_diff(diff) == diff + + def test_over_limit_truncated(self) -> None: + diff = "x" * (MAX_DIFF_CHARS + 1000) + result = _truncate_diff(diff) + assert len(result) < len(diff) + assert "TRUNCATED" in result + + +# --------------------------------------------------------------------------- +# _build_prompt tests +# --------------------------------------------------------------------------- + +class TestBuildPrompt: + def test_contains_pr_title(self) -> None: + prompt = _build_prompt(diff="d", pr_title="My Title", repo_name="repo") + assert "My Title" in prompt + + def test_contains_repo_name(self) -> None: + prompt = _build_prompt(diff="d", pr_title="t", repo_name="my-repo") + assert "my-repo" in prompt + + def test_contains_diff(self) -> None: + prompt = _build_prompt(diff="UNIQUE_DIFF", pr_title="t", repo_name="r") + assert "UNIQUE_DIFF" in prompt + + +# --------------------------------------------------------------------------- +# _parse_cli_output tests +# --------------------------------------------------------------------------- + +class TestParseCliOutput: + def test_parses_structured_output(self) -> None: + stdout = _make_cli_output(verdict="approve", summary="Good") + result = _parse_cli_output(stdout) + assert isinstance(result, ReviewResult) + assert result.verdict == "approve" + assert result.summary == "Good" + + def test_parses_request_changes(self) -> None: + stdout = _make_cli_output( + verdict="request_changes", + summary="Has issues", + issues=[{"severity": "blocker", "description": "SQL injection"}], + ) + result = _parse_cli_output(stdout) + assert result.verdict == "request_changes" + assert len(result.issues) == 1 + assert result.has_blockers is True + + def test_parses_issues_with_optional_fields(self) -> None: + stdout = _make_cli_output( + verdict="request_changes", + summary="Issues found", + issues=[{ + "severity": "warning", + "description": "Style issue", + "file_path": "src/foo.py", + "suggestion": "Fix it", + }], + ) + result = _parse_cli_output(stdout) + assert result.issues[0].file_path == "src/foo.py" + assert result.issues[0].suggestion == "Fix it" + + def test_empty_issues_no_blockers(self) -> None: + stdout = _make_cli_output(verdict="approve", summary="Clean", issues=[]) + result = _parse_cli_output(stdout) + assert result.has_blockers is False + assert len(result.issues) == 0 + + def test_result_field_as_json_string(self) -> None: + """When structured_output is absent, falls back to parsing result as JSON.""" + inner = {"verdict": "approve", "summary": "OK", "issues": []} + stdout = json.dumps({"result": json.dumps(inner)}) + result = _parse_cli_output(stdout) + assert result.verdict == "approve" + + def test_invalid_json_raises(self) -> None: + with pytest.raises(ValueError, match="Failed to parse"): + _parse_cli_output("not json at all") + + def test_missing_structured_output_and_result_raises(self) -> None: + with pytest.raises(ValueError, match="No structured_output"): + _parse_cli_output(json.dumps({"other": "data"})) + + def test_non_dict_structured_output_raises(self) -> None: + stdout = json.dumps({"structured_output": ["not", "a", "dict"]}) + with pytest.raises(ValueError, match="Expected dict"): + _parse_cli_output(stdout) + + def test_result_is_non_json_string_raises(self) -> None: + stdout = json.dumps({"result": "just plain text, not json"}) + with pytest.raises(ValueError, match="not valid JSON"): + _parse_cli_output(stdout) + + +# --------------------------------------------------------------------------- +# ClaudeReviewer construction tests +# --------------------------------------------------------------------------- + +class TestClaudeReviewerConstruction: + def test_can_be_instantiated(self) -> None: + reviewer = ClaudeReviewer() + assert reviewer is not None + + def test_custom_claude_cmd(self) -> None: + reviewer = ClaudeReviewer(claude_cmd="/usr/local/bin/claude") + assert reviewer._claude_cmd == "/usr/local/bin/claude" + + def test_custom_timeout(self) -> None: + reviewer = ClaudeReviewer(timeout=60) + assert reviewer._timeout == 60 + + +# --------------------------------------------------------------------------- +# review_pr tests +# --------------------------------------------------------------------------- + +class TestReviewPr: + async def test_returns_review_result(self) -> None: + stdout = _make_cli_output(verdict="approve", summary="Looks good") + runner, _ = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + result = await reviewer.review_pr( + diff="diff --git a/foo.py ...", + pr_title="Fix bug", + repo_name="my-repo", + ) + + assert isinstance(result, ReviewResult) + assert result.verdict == "approve" + + async def test_passes_cwd_to_subprocess(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.review_pr( + diff="diff", + pr_title="PR", + repo_name="repo", + cwd="/path/to/worktree", + ) + + assert calls[0]["cwd"] == "/path/to/worktree" + + async def test_cmd_includes_claude_p(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.review_pr(diff="d", pr_title="t", repo_name="r") + + cmd = calls[0]["cmd"] + assert cmd[0] == "claude" + assert "-p" in cmd + + async def test_cmd_includes_output_format_json(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.review_pr(diff="d", pr_title="t", repo_name="r") + + cmd = calls[0]["cmd"] + idx = cmd.index("--output-format") + assert cmd[idx + 1] == "json" + + async def test_cmd_includes_json_schema(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.review_pr(diff="d", pr_title="t", repo_name="r") + + cmd = calls[0]["cmd"] + assert "--json-schema" in cmd + + async def test_cmd_includes_allowed_tools(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.review_pr(diff="d", pr_title="t", repo_name="r") + + cmd = calls[0]["cmd"] + idx = cmd.index("--allowedTools") + assert "Read" in cmd[idx + 1] + + async def test_cmd_includes_system_prompt(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.review_pr(diff="d", pr_title="t", repo_name="r") + + cmd = calls[0]["cmd"] + assert "--system-prompt" in cmd + + async def test_nonzero_exit_raises(self) -> None: + runner, _ = _make_subprocess_runner( + stdout="", stderr="error occurred", returncode=1 + ) + reviewer = ClaudeReviewer(run_subprocess=runner) + + with pytest.raises(RuntimeError, match="exited with code 1"): + await reviewer.review_pr(diff="d", pr_title="t", repo_name="r") + + async def test_timeout_passed_to_subprocess(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner, timeout=120) + + await reviewer.review_pr(diff="d", pr_title="t", repo_name="r") + + assert calls[0]["timeout"] == 120 + + async def test_pr_title_in_prompt(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.review_pr( + diff="d", pr_title="Specific Title", repo_name="r" + ) + + cmd = calls[0]["cmd"] + prompt = cmd[cmd.index("-p") + 1] + assert "Specific Title" in prompt + + async def test_repo_name_in_prompt(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.review_pr( + diff="d", pr_title="t", repo_name="special-repo" + ) + + cmd = calls[0]["cmd"] + prompt = cmd[cmd.index("-p") + 1] + assert "special-repo" in prompt + + async def test_cwd_none_when_not_provided(self) -> None: + stdout = _make_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.review_pr(diff="d", pr_title="t", repo_name="r") + + assert calls[0]["cwd"] is None + + async def test_request_changes_with_issues(self) -> None: + stdout = _make_cli_output( + verdict="request_changes", + summary="Problems found", + issues=[ + {"severity": "blocker", "description": "Security flaw"}, + {"severity": "warning", "description": "Missing docs"}, + ], + ) + runner, _ = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + result = await reviewer.review_pr(diff="d", pr_title="t", repo_name="r") + + assert result.verdict == "request_changes" + assert len(result.issues) == 2 + assert result.has_blockers is True + + +# --------------------------------------------------------------------------- +# ClaudeReviewer.generate_ticket_content tests +# --------------------------------------------------------------------------- + +def _make_ticket_cli_output(summary: str = "My summary", description: str = "My desc") -> str: + """Build a JSON string mimicking Claude Code CLI output for ticket generation.""" + structured = {"summary": summary, "description": description} + return json.dumps({"result": "", "structured_output": structured}) + + +class TestGenerateTicketContent: + """Tests for ClaudeReviewer.generate_ticket_content.""" + + async def test_returns_tuple_of_summary_and_description(self) -> None: + stdout = _make_ticket_cli_output(summary="Fix login bug", description="Detailed desc") + runner, _ = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + result = await reviewer.generate_ticket_content( + diff="edit: main.py", pr_title="Fix login", repo_name="backend" + ) + + assert isinstance(result, tuple) + assert len(result) == 2 + + async def test_returns_correct_summary(self) -> None: + stdout = _make_ticket_cli_output(summary="Implement OAuth2 login") + runner, _ = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + summary, _ = await reviewer.generate_ticket_content( + diff="d", pr_title="Add OAuth", repo_name="auth-service" + ) + + assert summary == "Implement OAuth2 login" + + async def test_returns_correct_description(self) -> None: + stdout = _make_ticket_cli_output(description="This adds OAuth2 support for the login flow") + runner, _ = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + _, description = await reviewer.generate_ticket_content( + diff="d", pr_title="Add OAuth", repo_name="auth-service" + ) + + assert description == "This adds OAuth2 support for the login flow" + + async def test_uses_json_schema_with_summary_and_description_fields(self) -> None: + stdout = _make_ticket_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r") + + cmd = calls[0]["cmd"] + # Verify --json-schema flag was used + assert "--json-schema" in cmd + schema_idx = cmd.index("--json-schema") + schema_json = cmd[schema_idx + 1] + schema = json.loads(schema_json) + assert "summary" in schema["properties"] + assert "description" in schema["properties"] + + async def test_passes_pr_title_in_prompt(self) -> None: + stdout = _make_ticket_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.generate_ticket_content( + diff="d", pr_title="My Unique PR Title", repo_name="r" + ) + + cmd_str = " ".join(calls[0]["cmd"]) + assert "My Unique PR Title" in cmd_str + + async def test_passes_repo_name_in_prompt(self) -> None: + stdout = _make_ticket_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.generate_ticket_content( + diff="d", pr_title="t", repo_name="my-special-repo" + ) + + cmd_str = " ".join(calls[0]["cmd"]) + assert "my-special-repo" in cmd_str + + async def test_passes_cwd_to_subprocess(self) -> None: + stdout = _make_ticket_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.generate_ticket_content( + diff="d", pr_title="t", repo_name="r", cwd="/some/path" + ) + + assert calls[0]["cwd"] == "/some/path" + + async def test_cwd_none_by_default(self) -> None: + stdout = _make_ticket_cli_output() + runner, calls = _make_subprocess_runner(stdout=stdout) + reviewer = ClaudeReviewer(run_subprocess=runner) + + await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r") + + assert calls[0]["cwd"] is None + + async def test_raises_on_nonzero_exit_code(self) -> None: + runner, _ = _make_subprocess_runner(stdout="", stderr="Error", returncode=1) + reviewer = ClaudeReviewer(run_subprocess=runner) + + with pytest.raises(RuntimeError, match="Claude CLI"): + await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r") + + async def test_raises_on_invalid_json_output(self) -> None: + runner, _ = _make_subprocess_runner(stdout="not json at all") + reviewer = ClaudeReviewer(run_subprocess=runner) + + with pytest.raises((ValueError, Exception)): + await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r") diff --git a/tests/tools/test_http.py b/tests/tools/test_http.py new file mode 100644 index 0000000..ebdace1 --- /dev/null +++ b/tests/tools/test_http.py @@ -0,0 +1,205 @@ +"""Tests for shared HTTP helpers. Written FIRST (TDD RED phase).""" + +import base64 + +import httpx +import pytest + +from release_agent.exceptions import ( + AuthenticationError, + NotFoundError, + RateLimitError, + ServiceError, + ServiceUnavailableError, +) +from release_agent.tools._http import build_auth_header, raise_for_status + + +# --------------------------------------------------------------------------- +# raise_for_status tests +# --------------------------------------------------------------------------- + +def _make_response(status_code: int, headers: dict | None = None) -> httpx.Response: + """Build a minimal httpx.Response with the given status code.""" + return httpx.Response( + status_code=status_code, + headers=headers or {}, + content=b"{}", + request=httpx.Request("GET", "https://example.com"), + ) + + +class TestRaiseForStatus: + """Tests for raise_for_status helper.""" + + def test_2xx_does_not_raise(self) -> None: + for code in [200, 201, 204]: + response = _make_response(code) + # Should not raise anything + raise_for_status(response, service="test") + + def test_401_raises_authentication_error(self) -> None: + response = _make_response(401) + with pytest.raises(AuthenticationError) as exc_info: + raise_for_status(response, service="azdo") + assert exc_info.value.service == "azdo" + assert exc_info.value.status_code == 401 + + def test_403_raises_authentication_error(self) -> None: + response = _make_response(403) + with pytest.raises(AuthenticationError) as exc_info: + raise_for_status(response, service="jira") + assert exc_info.value.status_code == 403 + + def test_404_raises_not_found_error(self) -> None: + response = _make_response(404) + with pytest.raises(NotFoundError) as exc_info: + raise_for_status(response, service="azdo") + assert exc_info.value.service == "azdo" + assert exc_info.value.status_code == 404 + + def test_429_raises_rate_limit_error(self) -> None: + response = _make_response(429) + with pytest.raises(RateLimitError) as exc_info: + raise_for_status(response, service="jira") + assert exc_info.value.status_code == 429 + + def test_429_with_retry_after_header_populates_retry_after(self) -> None: + response = _make_response(429, headers={"Retry-After": "60"}) + with pytest.raises(RateLimitError) as exc_info: + raise_for_status(response, service="jira") + assert exc_info.value.retry_after == 60 + + def test_429_without_retry_after_header_retry_after_is_none(self) -> None: + response = _make_response(429) + with pytest.raises(RateLimitError) as exc_info: + raise_for_status(response, service="jira") + assert exc_info.value.retry_after is None + + def test_503_raises_service_unavailable(self) -> None: + response = _make_response(503) + with pytest.raises(ServiceUnavailableError) as exc_info: + raise_for_status(response, service="slack") + assert exc_info.value.status_code == 503 + + def test_500_raises_service_error(self) -> None: + response = _make_response(500) + with pytest.raises(ServiceError) as exc_info: + raise_for_status(response, service="azdo") + assert exc_info.value.status_code == 500 + assert exc_info.value.service == "azdo" + + def test_400_raises_service_error(self) -> None: + response = _make_response(400) + with pytest.raises(ServiceError) as exc_info: + raise_for_status(response, service="jira") + assert exc_info.value.status_code == 400 + + def test_422_raises_service_error(self) -> None: + response = _make_response(422) + with pytest.raises(ServiceError): + raise_for_status(response, service="azdo") + + def test_service_name_propagated_in_all_errors(self) -> None: + """Each error type must carry the service name.""" + cases = [ + (401, AuthenticationError), + (404, NotFoundError), + (429, RateLimitError), + (503, ServiceUnavailableError), + (500, ServiceError), + ] + for code, exc_type in cases: + response = _make_response(code) + with pytest.raises(exc_type) as exc_info: + raise_for_status(response, service="my-service") + assert exc_info.value.service == "my-service" + + def test_3xx_does_not_raise(self) -> None: + """Redirects are not errors (httpx follows them).""" + response = _make_response(301) + raise_for_status(response, service="test") + + +# --------------------------------------------------------------------------- +# build_auth_header tests +# --------------------------------------------------------------------------- + +class TestBuildAuthHeader: + """Tests for build_auth_header helper.""" + + def test_returns_authorization_key(self) -> None: + header = build_auth_header("user", "pass") + assert "Authorization" in header + + def test_returns_basic_scheme(self) -> None: + header = build_auth_header("user", "pass") + assert header["Authorization"].startswith("Basic ") + + def test_value_is_base64_encoded(self) -> None: + header = build_auth_header("user", "pass") + encoded_part = header["Authorization"].removeprefix("Basic ") + decoded = base64.b64decode(encoded_part).decode() + assert decoded == "user:pass" + + def test_empty_username(self) -> None: + # PAT auth uses empty username with token as password + header = build_auth_header("", "my-token") + encoded_part = header["Authorization"].removeprefix("Basic ") + decoded = base64.b64decode(encoded_part).decode() + assert decoded == ":my-token" + + def test_special_characters_in_password(self) -> None: + header = build_auth_header("user", "p@ss!#$%") + encoded_part = header["Authorization"].removeprefix("Basic ") + decoded = base64.b64decode(encoded_part).decode() + assert decoded == "user:p@ss!#$%" + + def test_returns_dict(self) -> None: + result = build_auth_header("u", "p") + assert isinstance(result, dict) + + def test_result_is_immutable_dict(self) -> None: + result = build_auth_header("u", "p") + # Ensure only the Authorization key is present + assert list(result.keys()) == ["Authorization"] + + +# --------------------------------------------------------------------------- +# Edge case coverage for _extract_detail and _parse_retry_after +# --------------------------------------------------------------------------- + +class TestExtractDetailEdgeCases: + """Tests for the private _extract_detail helper via raise_for_status.""" + + def test_non_dict_body_still_raises_service_error(self) -> None: + """A JSON array body (non-dict) should still raise ServiceError.""" + response = httpx.Response( + status_code=500, + content=b'["error", "list"]', + request=httpx.Request("GET", "https://example.com"), + ) + with pytest.raises(ServiceError): + raise_for_status(response, service="test") + + def test_invalid_json_body_still_raises(self) -> None: + """A non-JSON response body should still raise ServiceError.""" + response = httpx.Response( + status_code=500, + content=b"Internal Server Error (plain text)", + request=httpx.Request("GET", "https://example.com"), + ) + with pytest.raises(ServiceError): + raise_for_status(response, service="test") + + def test_429_with_non_integer_retry_after_retry_after_is_none(self) -> None: + """A non-integer Retry-After value should result in retry_after=None.""" + response = httpx.Response( + status_code=429, + headers={"Retry-After": "Wed, 21 Oct 2015 07:28:00 GMT"}, + content=b"{}", + request=httpx.Request("GET", "https://example.com"), + ) + with pytest.raises(RateLimitError) as exc_info: + raise_for_status(response, service="test") + assert exc_info.value.retry_after is None diff --git a/tests/tools/test_jira.py b/tests/tools/test_jira.py new file mode 100644 index 0000000..604ef0a --- /dev/null +++ b/tests/tools/test_jira.py @@ -0,0 +1,572 @@ +"""Tests for JiraClient. Written FIRST (TDD RED phase).""" + +import json +from pathlib import Path + +import httpx +import pytest + +from release_agent.exceptions import AuthenticationError, NotFoundError, ServiceError +from release_agent.models.jira import JiraIssue, JiraTransition +from release_agent.tools.jira import JiraClient + +# --------------------------------------------------------------------------- +# Fixture helpers +# --------------------------------------------------------------------------- + +FIXTURES = Path(__file__).parent / "fixtures" + + +def _load_json(name: str) -> dict: + return json.loads((FIXTURES / name).read_text()) + + +def _make_transport(routes: dict[tuple[str, str], tuple[int, bytes | str]]) -> httpx.MockTransport: + """Build a MockTransport dispatching by (method, url_substring).""" + + def handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + method = request.method + for (m, url_fragment), (status, body) in routes.items(): + if m == method and url_fragment in url: + content = body if isinstance(body, bytes) else body.encode() + return httpx.Response(status_code=status, content=content) + return httpx.Response(status_code=404, content=b'{"errorMessages": ["Not found"]}') + + return httpx.MockTransport(handler) + + +def _make_client(routes: dict) -> JiraClient: + transport = _make_transport(routes) + http_client = httpx.AsyncClient(transport=transport) + return JiraClient( + base_url="https://billolife.atlassian.net", + email="user@example.com", + api_token="test-token", + http_client=http_client, + ) + + +# --------------------------------------------------------------------------- +# Construction tests +# --------------------------------------------------------------------------- + +class TestJiraClientConstruction: + """Tests for JiraClient initialization.""" + + def test_can_be_instantiated(self) -> None: + transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}")) + http_client = httpx.AsyncClient(transport=transport) + client = JiraClient( + base_url="https://billolife.atlassian.net", + email="u@example.com", + api_token="token", + http_client=http_client, + ) + assert client is not None + + async def test_context_manager_closes_client(self) -> None: + transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}")) + http_client = httpx.AsyncClient(transport=transport) + async with JiraClient( + base_url="https://billolife.atlassian.net", + email="u@example.com", + api_token="token", + http_client=http_client, + ) as client: + assert client is not None + assert http_client.is_closed + + +# --------------------------------------------------------------------------- +# get_issue tests +# --------------------------------------------------------------------------- + +class TestGetIssue: + """Tests for JiraClient.get_issue.""" + + async def test_returns_jira_issue(self) -> None: + issue_data = _load_json("jira_issue.json") + routes = {("GET", "ALLPOST-100"): (200, json.dumps(issue_data))} + client = _make_client(routes) + + result = await client.get_issue("ALLPOST-100") + assert isinstance(result, JiraIssue) + + async def test_key_extracted(self) -> None: + issue_data = _load_json("jira_issue.json") + routes = {("GET", "ALLPOST-100"): (200, json.dumps(issue_data))} + client = _make_client(routes) + + result = await client.get_issue("ALLPOST-100") + assert result.key == "ALLPOST-100" + + async def test_summary_extracted(self) -> None: + issue_data = _load_json("jira_issue.json") + routes = {("GET", "ALLPOST-100"): (200, json.dumps(issue_data))} + client = _make_client(routes) + + result = await client.get_issue("ALLPOST-100") + assert result.summary == "Fix the authentication bug" + + async def test_status_extracted(self) -> None: + issue_data = _load_json("jira_issue.json") + routes = {("GET", "ALLPOST-100"): (200, json.dumps(issue_data))} + client = _make_client(routes) + + result = await client.get_issue("ALLPOST-100") + assert result.status == "In Progress" + + async def test_404_raises_not_found(self) -> None: + routes = {("GET", "ALLPOST-999"): (404, b'{"errorMessages": ["Issue not found"]}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.get_issue("ALLPOST-999") + + async def test_401_raises_authentication_error(self) -> None: + routes = {("GET", "ALLPOST-100"): (401, b'{"errorMessages": ["Unauthorized"]}')} + client = _make_client(routes) + + with pytest.raises(AuthenticationError): + await client.get_issue("ALLPOST-100") + + +# --------------------------------------------------------------------------- +# get_transitions tests +# --------------------------------------------------------------------------- + +class TestGetTransitions: + """Tests for JiraClient.get_transitions.""" + + async def test_returns_list_of_transitions(self) -> None: + transition_data = _load_json("jira_transitions.json") + routes = {("GET", "transitions"): (200, json.dumps(transition_data))} + client = _make_client(routes) + + result = await client.get_transitions("ALLPOST-100") + assert isinstance(result, list) + assert all(isinstance(t, JiraTransition) for t in result) + + async def test_transition_names_extracted(self) -> None: + transition_data = _load_json("jira_transitions.json") + routes = {("GET", "transitions"): (200, json.dumps(transition_data))} + client = _make_client(routes) + + result = await client.get_transitions("ALLPOST-100") + names = [t.name for t in result] + assert "Released" in names + assert "In Progress" in names + + async def test_transition_ids_extracted(self) -> None: + transition_data = _load_json("jira_transitions.json") + routes = {("GET", "transitions"): (200, json.dumps(transition_data))} + client = _make_client(routes) + + result = await client.get_transitions("ALLPOST-100") + ids = [t.id for t in result] + assert "11" in ids + + async def test_empty_transitions_returned(self) -> None: + routes = {("GET", "transitions"): (200, json.dumps({"transitions": []}))} + client = _make_client(routes) + + result = await client.get_transitions("ALLPOST-100") + assert result == [] + + async def test_404_raises_not_found(self) -> None: + routes = {("GET", "transitions"): (404, b'{"errorMessages": ["Not found"]}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.get_transitions("ALLPOST-999") + + +# --------------------------------------------------------------------------- +# transition_issue tests +# --------------------------------------------------------------------------- + +class TestTransitionIssue: + """Tests for JiraClient.transition_issue.""" + + async def test_returns_true_on_success(self) -> None: + transition_data = _load_json("jira_transitions.json") + routes = { + ("GET", "transitions"): (200, json.dumps(transition_data)), + ("POST", "transitions"): (204, b""), + } + client = _make_client(routes) + + result = await client.transition_issue("ALLPOST-100", "Released") + assert result is True + + async def test_returns_false_when_transition_not_found(self) -> None: + transition_data = _load_json("jira_transitions.json") + routes = { + ("GET", "transitions"): (200, json.dumps(transition_data)), + ("POST", "transitions"): (204, b""), + } + client = _make_client(routes) + + # "QA Review" is not in the fixture transitions + result = await client.transition_issue("ALLPOST-100", "QA Review") + assert result is False + + async def test_fallback_to_dev_in_progress_then_retries(self) -> None: + """Two-step fallback: if target unavailable, try 'Dev in Progress' first.""" + get_call_count = {"n": 0} + # First GET: only "Dev in Progress" available (no "Released" yet) + transition_data_first = { + "transitions": [{"id": "21", "name": "Dev in Progress"}] + } + # Second GET (after Dev in Progress transition): "Released" now available + transition_data_second = { + "transitions": [ + {"id": "21", "name": "Dev in Progress"}, + {"id": "41", "name": "Released"}, + ] + } + + def handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + method = request.method + if method == "GET" and "transitions" in url: + get_call_count["n"] += 1 + if get_call_count["n"] <= 1: + return httpx.Response(200, content=json.dumps(transition_data_first).encode()) + return httpx.Response(200, content=json.dumps(transition_data_second).encode()) + if method == "POST" and "transitions" in url: + return httpx.Response(204, content=b"") + return httpx.Response(404, content=b'{"errorMessages": ["Not found"]}') + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = JiraClient( + base_url="https://billolife.atlassian.net", + email="user@example.com", + api_token="test-token", + http_client=http_client, + ) + + result = await client.transition_issue("ALLPOST-100", "Released") + assert result is True + + async def test_returns_false_when_still_unavailable_after_fallback(self) -> None: + """Return False when target transition is unavailable even after fallback.""" + get_call_count = {"n": 0} + # First GET: only "Dev in Progress" available + transition_data_first = { + "transitions": [{"id": "21", "name": "Dev in Progress"}] + } + # Second GET (after fallback): target STILL not available + transition_data_second = { + "transitions": [{"id": "21", "name": "Dev in Progress"}] + } + + def handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + method = request.method + if method == "GET" and "transitions" in url: + get_call_count["n"] += 1 + if get_call_count["n"] <= 1: + return httpx.Response(200, content=json.dumps(transition_data_first).encode()) + return httpx.Response(200, content=json.dumps(transition_data_second).encode()) + if method == "POST" and "transitions" in url: + return httpx.Response(204, content=b"") + return httpx.Response(404, content=b'{"errorMessages": ["Not found"]}') + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = JiraClient( + base_url="https://billolife.atlassian.net", + email="user@example.com", + api_token="test-token", + http_client=http_client, + ) + + result = await client.transition_issue("ALLPOST-100", "Released") + assert result is False + + async def test_404_on_transition_raises_not_found(self) -> None: + transition_data = _load_json("jira_transitions.json") + routes = { + ("GET", "transitions"): (200, json.dumps(transition_data)), + ("POST", "transitions"): (404, b'{"errorMessages": ["Not found"]}'), + } + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.transition_issue("ALLPOST-100", "Released") + + +# --------------------------------------------------------------------------- +# add_remote_link tests +# --------------------------------------------------------------------------- + +class TestAddRemoteLink: + """Tests for JiraClient.add_remote_link.""" + + async def test_returns_true_on_success(self) -> None: + routes = { + ("POST", "remotelink"): (200, json.dumps({"id": 1, "self": "..."})), + } + client = _make_client(routes) + + result = await client.add_remote_link( + ticket_id="ALLPOST-100", + url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42", + title="PR #42: Fix auth", + ) + assert result is True + + async def test_404_raises_not_found(self) -> None: + routes = {("POST", "remotelink"): (404, b'{"errorMessages": ["Not found"]}')} + client = _make_client(routes) + + with pytest.raises(NotFoundError): + await client.add_remote_link( + ticket_id="ALLPOST-999", + url="https://example.com", + title="Some PR", + ) + + async def test_400_raises_service_error(self) -> None: + routes = {("POST", "remotelink"): (400, b'{"errorMessages": ["Bad request"]}')} + client = _make_client(routes) + + with pytest.raises(ServiceError): + await client.add_remote_link( + ticket_id="ALLPOST-100", + url="not-a-url", + title="Bad link", + ) + + +# --------------------------------------------------------------------------- +# Lifecycle tests +# --------------------------------------------------------------------------- + +class TestJiraClientLifecycle: + """Tests for JiraClient close() method.""" + + async def test_close_closes_http_client(self) -> None: + transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}")) + http_client = httpx.AsyncClient(transport=transport) + client = JiraClient( + base_url="https://billolife.atlassian.net", + email="u@example.com", + api_token="token", + http_client=http_client, + ) + + await client.close() + + assert http_client.is_closed + + +# --------------------------------------------------------------------------- +# _text_to_adf tests +# --------------------------------------------------------------------------- + +class TestTextToAdf: + """Tests for the _text_to_adf helper.""" + + def test_returns_dict(self) -> None: + from release_agent.tools.jira import _text_to_adf + result = _text_to_adf("Hello world") + assert isinstance(result, dict) + + def test_version_is_1(self) -> None: + from release_agent.tools.jira import _text_to_adf + result = _text_to_adf("Hello world") + assert result["version"] == 1 + + def test_type_is_doc(self) -> None: + from release_agent.tools.jira import _text_to_adf + result = _text_to_adf("Hello world") + assert result["type"] == "doc" + + def test_content_is_list(self) -> None: + from release_agent.tools.jira import _text_to_adf + result = _text_to_adf("Hello world") + assert isinstance(result["content"], list) + + def test_single_line_produces_one_paragraph(self) -> None: + from release_agent.tools.jira import _text_to_adf + result = _text_to_adf("Hello world") + assert len(result["content"]) == 1 + assert result["content"][0]["type"] == "paragraph" + + def test_multiline_produces_multiple_paragraphs(self) -> None: + from release_agent.tools.jira import _text_to_adf + result = _text_to_adf("Line one\n\nLine two") + # Each non-empty line becomes a paragraph + paragraphs = [c for c in result["content"] if c["type"] == "paragraph"] + assert len(paragraphs) == 2 + + def test_paragraph_contains_text_node(self) -> None: + from release_agent.tools.jira import _text_to_adf + result = _text_to_adf("Hello") + paragraph = result["content"][0] + assert "content" in paragraph + text_node = paragraph["content"][0] + assert text_node["type"] == "text" + assert text_node["text"] == "Hello" + + def test_empty_string_produces_empty_doc(self) -> None: + from release_agent.tools.jira import _text_to_adf + result = _text_to_adf("") + assert result["content"] == [] + + +# --------------------------------------------------------------------------- +# JiraClient.create_issue tests +# --------------------------------------------------------------------------- + +class TestCreateIssue: + """Tests for JiraClient.create_issue.""" + + async def test_returns_ticket_key(self) -> None: + response_body = json.dumps({"id": "10001", "key": "ALLPOST-42", "self": "https://..."}) + routes = {("POST", "/rest/api/3/issue"): (201, response_body)} + client = _make_client(routes) + + result = await client.create_issue( + project="ALLPOST", + summary="New feature", + description="Some description", + ) + assert result == "ALLPOST-42" + + async def test_default_issue_type_is_story(self) -> None: + captured_bodies: list[dict] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_bodies.append(json.loads(request.content)) + return httpx.Response( + 201, + content=json.dumps({"key": "ALLPOST-1"}).encode(), + ) + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = JiraClient( + base_url="https://billolife.atlassian.net", + email="u@example.com", + api_token="token", + http_client=http_client, + ) + + await client.create_issue(project="ALLPOST", summary="S", description="D") + assert captured_bodies[0]["fields"]["issuetype"]["name"] == "Story" + + async def test_custom_issue_type_sent(self) -> None: + captured_bodies: list[dict] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_bodies.append(json.loads(request.content)) + return httpx.Response( + 201, + content=json.dumps({"key": "ALLPOST-2"}).encode(), + ) + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = JiraClient( + base_url="https://billolife.atlassian.net", + email="u@example.com", + api_token="token", + http_client=http_client, + ) + + await client.create_issue( + project="ALLPOST", summary="S", description="D", issue_type="Bug" + ) + assert captured_bodies[0]["fields"]["issuetype"]["name"] == "Bug" + + async def test_project_key_sent_in_body(self) -> None: + captured_bodies: list[dict] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_bodies.append(json.loads(request.content)) + return httpx.Response( + 201, + content=json.dumps({"key": "MYPROJ-3"}).encode(), + ) + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = JiraClient( + base_url="https://billolife.atlassian.net", + email="u@example.com", + api_token="token", + http_client=http_client, + ) + + await client.create_issue(project="MYPROJ", summary="S", description="D") + assert captured_bodies[0]["fields"]["project"]["key"] == "MYPROJ" + + async def test_summary_sent_in_body(self) -> None: + captured_bodies: list[dict] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_bodies.append(json.loads(request.content)) + return httpx.Response( + 201, + content=json.dumps({"key": "ALLPOST-5"}).encode(), + ) + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = JiraClient( + base_url="https://billolife.atlassian.net", + email="u@example.com", + api_token="token", + http_client=http_client, + ) + + await client.create_issue(project="ALLPOST", summary="My Summary", description="D") + assert captured_bodies[0]["fields"]["summary"] == "My Summary" + + async def test_description_is_adf_format(self) -> None: + captured_bodies: list[dict] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_bodies.append(json.loads(request.content)) + return httpx.Response( + 201, + content=json.dumps({"key": "ALLPOST-6"}).encode(), + ) + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = JiraClient( + base_url="https://billolife.atlassian.net", + email="u@example.com", + api_token="token", + http_client=http_client, + ) + + await client.create_issue(project="ALLPOST", summary="S", description="My desc") + desc = captured_bodies[0]["fields"]["description"] + assert desc["type"] == "doc" + assert desc["version"] == 1 + + async def test_401_raises_authentication_error(self) -> None: + routes = {("POST", "/rest/api/3/issue"): (401, b'{"errorMessages": ["Unauthorized"]}')} + client = _make_client(routes) + + with pytest.raises(AuthenticationError): + await client.create_issue(project="ALLPOST", summary="S", description="D") + + async def test_400_raises_service_error(self) -> None: + routes = { + ("POST", "/rest/api/3/issue"): ( + 400, + json.dumps({"errorMessages": ["Bad request"], "errors": {}}).encode(), + ) + } + client = _make_client(routes) + + with pytest.raises(ServiceError): + await client.create_issue(project="ALLPOST", summary="S", description="D") diff --git a/tests/tools/test_retry.py b/tests/tools/test_retry.py new file mode 100644 index 0000000..c7cc8ce --- /dev/null +++ b/tests/tools/test_retry.py @@ -0,0 +1,198 @@ +"""Tests for async retry decorator. Written FIRST (TDD RED phase).""" + +import asyncio + +import pytest + +from release_agent.exceptions import ( + NotFoundError, + RateLimitError, + ServiceError, + ServiceUnavailableError, +) +from release_agent.tools._retry import with_retry + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_failing_then_succeeding(failures: int, exc_factory, result="ok"): + """Return an async callable that fails `failures` times then returns `result`.""" + call_count = {"n": 0} + + async def fn(): + call_count["n"] += 1 + if call_count["n"] <= failures: + raise exc_factory() + return result + + return fn + + +# --------------------------------------------------------------------------- +# with_retry tests +# --------------------------------------------------------------------------- + +class TestWithRetry: + """Tests for the with_retry decorator.""" + + async def test_success_on_first_attempt(self) -> None: + call_count = {"n": 0} + + @with_retry(max_attempts=3) + async def fn(): + call_count["n"] += 1 + return "done" + + result = await fn() + assert result == "done" + assert call_count["n"] == 1 + + async def test_retries_on_rate_limit_error(self) -> None: + call_count = {"n": 0} + + @with_retry(max_attempts=3, base_delay=0.0) + async def fn(): + call_count["n"] += 1 + if call_count["n"] < 3: + raise RateLimitError(service="jira", retry_after=None) + return "ok" + + result = await fn() + assert result == "ok" + assert call_count["n"] == 3 + + async def test_retries_on_service_unavailable_error(self) -> None: + call_count = {"n": 0} + + @with_retry(max_attempts=3, base_delay=0.0) + async def fn(): + call_count["n"] += 1 + if call_count["n"] < 2: + raise ServiceUnavailableError(service="azdo") + return "ok" + + result = await fn() + assert result == "ok" + assert call_count["n"] == 2 + + async def test_does_not_retry_on_not_found_error(self) -> None: + call_count = {"n": 0} + + @with_retry(max_attempts=3, base_delay=0.0) + async def fn(): + call_count["n"] += 1 + raise NotFoundError(service="azdo", detail="not found") + + with pytest.raises(NotFoundError): + await fn() + assert call_count["n"] == 1 + + async def test_does_not_retry_on_generic_service_error(self) -> None: + call_count = {"n": 0} + + @with_retry(max_attempts=3, base_delay=0.0) + async def fn(): + call_count["n"] += 1 + raise ServiceError(service="azdo", status_code=400, detail="bad request") + + with pytest.raises(ServiceError): + await fn() + assert call_count["n"] == 1 + + async def test_raises_after_max_attempts_exceeded(self) -> None: + call_count = {"n": 0} + + @with_retry(max_attempts=3, base_delay=0.0) + async def fn(): + call_count["n"] += 1 + raise RateLimitError(service="jira", retry_after=None) + + with pytest.raises(RateLimitError): + await fn() + assert call_count["n"] == 3 + + async def test_max_attempts_one_means_no_retry(self) -> None: + call_count = {"n": 0} + + @with_retry(max_attempts=1, base_delay=0.0) + async def fn(): + call_count["n"] += 1 + raise RateLimitError(service="jira", retry_after=None) + + with pytest.raises(RateLimitError): + await fn() + assert call_count["n"] == 1 + + async def test_does_not_retry_on_non_release_agent_error(self) -> None: + call_count = {"n": 0} + + @with_retry(max_attempts=3, base_delay=0.0) + async def fn(): + call_count["n"] += 1 + raise ValueError("unexpected") + + with pytest.raises(ValueError): + await fn() + assert call_count["n"] == 1 + + async def test_respects_retry_after_from_rate_limit_error(self) -> None: + """When retry_after is set, the decorator must wait at least that long.""" + delays: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + delays.append(seconds) + + call_count = {"n": 0} + + @with_retry(max_attempts=2, base_delay=0.0, sleep_fn=fake_sleep) + async def fn(): + call_count["n"] += 1 + if call_count["n"] < 2: + raise RateLimitError(service="jira", retry_after=5) + return "ok" + + result = await fn() + assert result == "ok" + assert len(delays) == 1 + assert delays[0] >= 5.0 + + async def test_exponential_backoff_grows(self) -> None: + """Verify delays grow between retries (exponential).""" + delays: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + delays.append(seconds) + + call_count = {"n": 0} + + @with_retry(max_attempts=4, base_delay=1.0, sleep_fn=fake_sleep) + async def fn(): + call_count["n"] += 1 + if call_count["n"] < 4: + raise ServiceUnavailableError(service="azdo") + return "ok" + + await fn() + assert len(delays) == 3 + # Each subsequent delay must not be less than the previous + assert delays[1] >= delays[0] + assert delays[2] >= delays[1] + + async def test_preserves_return_value(self) -> None: + @with_retry(max_attempts=2, base_delay=0.0) + async def fn(): + return {"key": "value"} + + result = await fn() + assert result == {"key": "value"} + + async def test_works_without_decorator_args_defaults(self) -> None: + """Decorator used with defaults should still work.""" + @with_retry() + async def fn(): + return 42 + + result = await fn() + assert result == 42 diff --git a/tests/tools/test_slack.py b/tests/tools/test_slack.py new file mode 100644 index 0000000..337e8f2 --- /dev/null +++ b/tests/tools/test_slack.py @@ -0,0 +1,755 @@ +"""Tests for SlackClient and Block Kit builders. Written FIRST (TDD RED phase).""" + +import json +from datetime import date + +import httpx +import pytest + +from release_agent.exceptions import ServiceError +from release_agent.models.ticket import TicketEntry +from release_agent.tools.slack import ( + SlackClient, + _build_approval_blocks, + _build_ci_status_blocks, + _build_interactive_approval_blocks, + _build_release_blocks, + _build_resolved_approval_blocks, +) + + +# --------------------------------------------------------------------------- +# Fixture helpers +# --------------------------------------------------------------------------- + +def _make_ticket(ticket_id: str = "ALLPOST-100", summary: str = "Fix bug") -> TicketEntry: + return TicketEntry( + id=ticket_id, + summary=summary, + pr_id="PR-42", + pr_url="https://dev.azure.com/org/project/_git/repo/pullrequest/42", + pr_title="Fix bug PR", + branch=f"bug/{ticket_id}_fix-bug", + merged_at=date(2024, 1, 15), + ) + + +def _make_transport(status: int = 200, body: bytes = b'{"ok": true}') -> httpx.MockTransport: + return httpx.MockTransport(lambda r: httpx.Response(status_code=status, content=body)) + + +def _make_client(status: int = 200) -> SlackClient: + transport = _make_transport(status) + http_client = httpx.AsyncClient(transport=transport) + return SlackClient( + webhook_url="https://hooks.slack.com/services/T000/B000/xxxx", + http_client=http_client, + ) + + +def _make_web_api_client( + status: int = 200, + body: bytes = b'{"ok": true, "ts": "1234567890.123456"}', +) -> SlackClient: + transport = httpx.MockTransport(lambda r: httpx.Response(status_code=status, content=body)) + http_client = httpx.AsyncClient(transport=transport) + return SlackClient( + bot_token="xoxb-test-token", + channel_id="C12345678", + http_client=http_client, + ) + + +# --------------------------------------------------------------------------- +# _build_release_blocks tests (pure function) +# --------------------------------------------------------------------------- + +class TestBuildReleaseBlocks: + """Tests for the _build_release_blocks pure function.""" + + def test_returns_list(self) -> None: + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + assert isinstance(blocks, list) + + def test_has_at_least_one_block(self) -> None: + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + assert len(blocks) >= 1 + + def test_repo_name_present_in_blocks(self) -> None: + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + text = json.dumps(blocks) + assert "my-repo" in text + + def test_version_present_in_blocks(self) -> None: + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + text = json.dumps(blocks) + assert "v1.2.0" in text + + def test_release_date_present_in_blocks(self) -> None: + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + text = json.dumps(blocks) + assert "2024" in text + + def test_ticket_ids_present_in_blocks(self) -> None: + tickets = [_make_ticket("ALLPOST-100"), _make_ticket("ALLPOST-200")] + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=tickets, + ) + text = json.dumps(blocks) + assert "ALLPOST-100" in text + assert "ALLPOST-200" in text + + def test_empty_tickets_still_valid(self) -> None: + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + assert len(blocks) >= 1 + + def test_blocks_are_dicts(self) -> None: + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[_make_ticket()], + ) + assert all(isinstance(b, dict) for b in blocks) + + def test_each_block_has_type_key(self) -> None: + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + for block in blocks: + assert "type" in block + + def test_ticket_summaries_included(self) -> None: + tickets = [_make_ticket("ALLPOST-100", "Fix the auth bug")] + blocks = _build_release_blocks( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=tickets, + ) + text = json.dumps(blocks) + assert "Fix the auth bug" in text + + +# --------------------------------------------------------------------------- +# _build_approval_blocks tests (pure function) +# --------------------------------------------------------------------------- + +class TestBuildApprovalBlocks: + """Tests for the _build_approval_blocks pure function.""" + + def test_returns_list(self) -> None: + blocks = _build_approval_blocks( + action="Deploy to Production", + details="v1.2.0 for my-repo", + approval_url="https://dev.azure.com/approve/123", + ) + assert isinstance(blocks, list) + + def test_has_at_least_one_block(self) -> None: + blocks = _build_approval_blocks( + action="Deploy", + details="v1.0.0", + approval_url="https://example.com", + ) + assert len(blocks) >= 1 + + def test_action_present_in_blocks(self) -> None: + blocks = _build_approval_blocks( + action="Deploy to Production", + details="v1.2.0", + approval_url="https://example.com", + ) + text = json.dumps(blocks) + assert "Deploy to Production" in text + + def test_details_present_in_blocks(self) -> None: + blocks = _build_approval_blocks( + action="Deploy", + details="version v1.2.0 of my-repo", + approval_url="https://example.com", + ) + text = json.dumps(blocks) + assert "version v1.2.0 of my-repo" in text + + def test_approval_url_present_in_blocks(self) -> None: + blocks = _build_approval_blocks( + action="Deploy", + details="details", + approval_url="https://dev.azure.com/approve/abc", + ) + text = json.dumps(blocks) + assert "https://dev.azure.com/approve/abc" in text + + def test_blocks_are_dicts(self) -> None: + blocks = _build_approval_blocks( + action="Deploy", + details="details", + approval_url="https://example.com", + ) + assert all(isinstance(b, dict) for b in blocks) + + def test_each_block_has_type_key(self) -> None: + blocks = _build_approval_blocks( + action="Deploy", + details="details", + approval_url="https://example.com", + ) + for block in blocks: + assert "type" in block + + +# --------------------------------------------------------------------------- +# SlackClient.send_release_notification tests +# --------------------------------------------------------------------------- + +class TestSendReleaseNotification: + """Tests for SlackClient.send_release_notification.""" + + async def test_returns_true_on_success(self) -> None: + client = _make_client(status=200) + + result = await client.send_release_notification( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[_make_ticket()], + ) + assert result is True + + async def test_returns_true_with_empty_tickets(self) -> None: + client = _make_client(status=200) + + result = await client.send_release_notification( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + assert result is True + + async def test_500_raises_service_error(self) -> None: + client = _make_client(status=500) + + with pytest.raises(ServiceError): + await client.send_release_notification( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + + async def test_sends_post_request(self) -> None: + requests_captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_captured.append(request) + return httpx.Response(200, content=b'{"ok": true}') + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = SlackClient( + webhook_url="https://hooks.slack.com/services/T000/B000/xxxx", + http_client=http_client, + ) + + await client.send_release_notification( + repo="my-repo", + version="v1.2.0", + release_date=date(2024, 1, 15), + tickets=[], + ) + + assert len(requests_captured) == 1 + assert requests_captured[0].method == "POST" + + +# --------------------------------------------------------------------------- +# SlackClient.send_approval_request tests +# --------------------------------------------------------------------------- + +class TestSendApprovalRequest: + """Tests for SlackClient.send_approval_request.""" + + async def test_returns_true_on_success(self) -> None: + client = _make_client(status=200) + + result = await client.send_approval_request( + action="Deploy to Production", + details="v1.2.0 for my-repo", + approval_url="https://dev.azure.com/approve/123", + ) + assert result is True + + async def test_500_raises_service_error(self) -> None: + client = _make_client(status=500) + + with pytest.raises(ServiceError): + await client.send_approval_request( + action="Deploy", + details="v1.0.0", + approval_url="https://example.com", + ) + + +# --------------------------------------------------------------------------- +# SlackClient lifecycle tests +# --------------------------------------------------------------------------- + +class TestSlackClientLifecycle: + """Tests for SlackClient close() and context manager.""" + + async def test_close_closes_http_client(self) -> None: + transport = _make_transport() + http_client = httpx.AsyncClient(transport=transport) + client = SlackClient( + webhook_url="https://hooks.slack.com/services/T000/B000/xxxx", + http_client=http_client, + ) + + await client.close() + + assert http_client.is_closed + + async def test_context_manager_closes_client(self) -> None: + transport = _make_transport() + http_client = httpx.AsyncClient(transport=transport) + async with SlackClient( + webhook_url="https://hooks.slack.com/services/T000/B000/xxxx", + http_client=http_client, + ) as client: + assert client is not None + assert http_client.is_closed + + +# --------------------------------------------------------------------------- +# SlackClient dual-mode construction tests +# --------------------------------------------------------------------------- + +class TestSlackClientDualMode: + """Tests for dual-mode SlackClient (webhook vs Web API).""" + + def test_can_be_created_with_webhook_only(self) -> None: + transport = _make_transport() + http_client = httpx.AsyncClient(transport=transport) + client = SlackClient( + webhook_url="https://hooks.slack.com/services/T000/B000/xxxx", + http_client=http_client, + ) + assert client is not None + + def test_can_be_created_with_bot_token_and_channel(self) -> None: + transport = _make_transport() + http_client = httpx.AsyncClient(transport=transport) + client = SlackClient( + bot_token="xoxb-test", + channel_id="C12345", + http_client=http_client, + ) + assert client is not None + + def test_can_be_created_with_all_params(self) -> None: + transport = _make_transport() + http_client = httpx.AsyncClient(transport=transport) + client = SlackClient( + webhook_url="https://hooks.slack.com/services/T000/B000/xxxx", + bot_token="xoxb-test", + channel_id="C12345", + http_client=http_client, + ) + assert client is not None + + def test_can_be_created_with_no_url_params(self) -> None: + transport = _make_transport() + http_client = httpx.AsyncClient(transport=transport) + client = SlackClient(http_client=http_client) + assert client is not None + + +# --------------------------------------------------------------------------- +# SlackClient.send_interactive_approval tests +# --------------------------------------------------------------------------- + +class TestSendInteractiveApproval: + """Tests for SlackClient.send_interactive_approval.""" + + async def test_returns_message_ts_on_success(self) -> None: + client = _make_web_api_client() + + result = await client.send_interactive_approval( + thread_id="thread-abc", + action="Deploy to Sandbox", + details="Release v1.0.0 of my-repo", + buttons=[{"text": "Approve", "value": "approve"}, {"text": "Reject", "value": "reject"}], + ) + assert isinstance(result, str) + assert result == "1234567890.123456" + + async def test_returns_empty_string_on_api_error(self) -> None: + client = _make_web_api_client(status=200, body=b'{"ok": false, "error": "channel_not_found"}') + + result = await client.send_interactive_approval( + thread_id="thread-abc", + action="Deploy", + details="v1.0.0", + buttons=[], + ) + assert result == "" + + async def test_posts_to_chat_postmessage(self) -> None: + captured_urls: list[str] = [] + captured_bodies: list[dict] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_urls.append(str(request.url)) + captured_bodies.append(json.loads(request.content)) + return httpx.Response(200, content=b'{"ok": true, "ts": "111.222"}') + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = SlackClient( + bot_token="xoxb-test", + channel_id="C99999", + http_client=http_client, + ) + + await client.send_interactive_approval( + thread_id="thread-xyz", + action="Deploy", + details="details", + buttons=[], + ) + + assert any("chat.postMessage" in url for url in captured_urls) + assert captured_bodies[0]["channel"] == "C99999" + + async def test_includes_thread_id_in_blocks(self) -> None: + captured_bodies: list[dict] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_bodies.append(json.loads(request.content)) + return httpx.Response(200, content=b'{"ok": true, "ts": "111.222"}') + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = SlackClient(bot_token="xoxb-test", channel_id="C1", http_client=http_client) + + await client.send_interactive_approval( + thread_id="my-thread-id", + action="Deploy", + details="v1.0.0", + buttons=[{"text": "Approve", "value": "approve"}], + ) + + body_str = json.dumps(captured_bodies) + assert "my-thread-id" in body_str + + +# --------------------------------------------------------------------------- +# SlackClient.update_message tests +# --------------------------------------------------------------------------- + +class TestUpdateMessage: + """Tests for SlackClient.update_message.""" + + async def test_returns_true_on_success(self) -> None: + client = _make_web_api_client(body=b'{"ok": true}') + + result = await client.update_message( + message_ts="1234567890.123456", + text="Updated message", + blocks=[], + ) + assert result is True + + async def test_returns_false_on_api_error(self) -> None: + client = _make_web_api_client(body=b'{"ok": false, "error": "message_not_found"}') + + result = await client.update_message( + message_ts="bad-ts", + text="Update", + blocks=[], + ) + assert result is False + + async def test_posts_to_chat_update(self) -> None: + captured_urls: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_urls.append(str(request.url)) + return httpx.Response(200, content=b'{"ok": true}') + + transport = httpx.MockTransport(handler) + http_client = httpx.AsyncClient(transport=transport) + client = SlackClient(bot_token="xoxb-test", channel_id="C1", http_client=http_client) + + await client.update_message(message_ts="ts-abc", text="Hello", blocks=[]) + + assert any("chat.update" in url for url in captured_urls) + + +# --------------------------------------------------------------------------- +# SlackClient.send_notification tests +# --------------------------------------------------------------------------- + +class TestSendNotification: + """Tests for SlackClient.send_notification.""" + + async def test_returns_true_via_web_api(self) -> None: + client = _make_web_api_client(body=b'{"ok": true, "ts": "111.222"}') + + result = await client.send_notification(text="Build passed", blocks=[]) + assert result is True + + async def test_returns_true_via_webhook(self) -> None: + client = _make_client(status=200) + + result = await client.send_notification(text="Build passed", blocks=[]) + assert result is True + + async def test_returns_false_on_web_api_error(self) -> None: + client = _make_web_api_client(body=b'{"ok": false, "error": "invalid_auth"}') + + result = await client.send_notification(text="Build passed", blocks=[]) + assert result is False + + +# --------------------------------------------------------------------------- +# _build_interactive_approval_blocks tests (pure function) +# --------------------------------------------------------------------------- + +class TestBuildInteractiveApprovalBlocks: + """Tests for _build_interactive_approval_blocks pure function.""" + + def test_returns_list(self) -> None: + blocks = _build_interactive_approval_blocks( + thread_id="t1", + action="Deploy to Sandbox", + details="v1.0.0", + buttons=[{"text": "Approve", "value": "approve"}], + ) + assert isinstance(blocks, list) + + def test_has_at_least_one_block(self) -> None: + blocks = _build_interactive_approval_blocks( + thread_id="t1", + action="Deploy", + details="details", + buttons=[], + ) + assert len(blocks) >= 1 + + def test_action_present_in_blocks(self) -> None: + blocks = _build_interactive_approval_blocks( + thread_id="t1", + action="Deploy to Production", + details="v1.2.0", + buttons=[], + ) + text = json.dumps(blocks) + assert "Deploy to Production" in text + + def test_thread_id_in_button_value(self) -> None: + blocks = _build_interactive_approval_blocks( + thread_id="my-unique-thread", + action="Deploy", + details="details", + buttons=[{"text": "Approve", "value": "approve"}], + ) + text = json.dumps(blocks) + assert "my-unique-thread" in text + + def test_buttons_render_as_actions_block(self) -> None: + blocks = _build_interactive_approval_blocks( + thread_id="t1", + action="Deploy", + details="details", + buttons=[ + {"text": "Approve", "value": "approve"}, + {"text": "Reject", "value": "reject"}, + ], + ) + block_types = [b["type"] for b in blocks] + assert "actions" in block_types + + def test_empty_buttons_still_valid(self) -> None: + blocks = _build_interactive_approval_blocks( + thread_id="t1", + action="Deploy", + details="details", + buttons=[], + ) + assert isinstance(blocks, list) + + def test_details_present_in_blocks(self) -> None: + blocks = _build_interactive_approval_blocks( + thread_id="t1", + action="Deploy", + details="Release v2.0.0 of my-service", + buttons=[], + ) + text = json.dumps(blocks) + assert "Release v2.0.0 of my-service" in text + + +# --------------------------------------------------------------------------- +# _build_ci_status_blocks tests (pure function) +# --------------------------------------------------------------------------- + +class TestBuildCiStatusBlocks: + """Tests for _build_ci_status_blocks pure function.""" + + def test_returns_list(self) -> None: + blocks = _build_ci_status_blocks( + repo="my-repo", + branch="main", + status="succeeded", + build_url="https://dev.azure.com/org/proj/_build/results?buildId=42", + ) + assert isinstance(blocks, list) + + def test_repo_present(self) -> None: + blocks = _build_ci_status_blocks( + repo="my-service", + branch="main", + status="succeeded", + build_url=None, + ) + text = json.dumps(blocks) + assert "my-service" in text + + def test_branch_present(self) -> None: + blocks = _build_ci_status_blocks( + repo="my-repo", + branch="release/v1.0.0", + status="succeeded", + build_url=None, + ) + text = json.dumps(blocks) + assert "release/v1.0.0" in text + + def test_status_present(self) -> None: + blocks = _build_ci_status_blocks( + repo="my-repo", + branch="main", + status="failed", + build_url=None, + ) + text = json.dumps(blocks) + assert "failed" in text + + def test_build_url_present_when_provided(self) -> None: + url = "https://dev.azure.com/org/proj/_build/results?buildId=99" + blocks = _build_ci_status_blocks( + repo="my-repo", + branch="main", + status="succeeded", + build_url=url, + ) + text = json.dumps(blocks) + assert url in text + + def test_build_url_none_does_not_crash(self) -> None: + blocks = _build_ci_status_blocks( + repo="my-repo", + branch="main", + status="succeeded", + build_url=None, + ) + assert isinstance(blocks, list) + + def test_all_blocks_are_dicts(self) -> None: + blocks = _build_ci_status_blocks( + repo="my-repo", + branch="main", + status="succeeded", + build_url=None, + ) + assert all(isinstance(b, dict) for b in blocks) + + +# --------------------------------------------------------------------------- +# _build_resolved_approval_blocks tests (pure function) +# --------------------------------------------------------------------------- + +class TestBuildResolvedApprovalBlocks: + """Tests for _build_resolved_approval_blocks pure function.""" + + def test_returns_list(self) -> None: + blocks = _build_resolved_approval_blocks( + action="Deploy to Sandbox", + outcome="approved", + user="alice", + ) + assert isinstance(blocks, list) + + def test_action_present(self) -> None: + blocks = _build_resolved_approval_blocks( + action="Deploy to Production", + outcome="approved", + user="alice", + ) + text = json.dumps(blocks) + assert "Deploy to Production" in text + + def test_outcome_present(self) -> None: + blocks = _build_resolved_approval_blocks( + action="Deploy", + outcome="rejected", + user="bob", + ) + text = json.dumps(blocks) + assert "rejected" in text + + def test_user_present(self) -> None: + blocks = _build_resolved_approval_blocks( + action="Deploy", + outcome="approved", + user="charlie", + ) + text = json.dumps(blocks) + assert "charlie" in text + + def test_all_blocks_are_dicts(self) -> None: + blocks = _build_resolved_approval_blocks( + action="Deploy", + outcome="approved", + user="dave", + ) + assert all(isinstance(b, dict) for b in blocks)