feat: initial commit — Billo Release Agent (LangGraph)
LangGraph-based release automation agent with: - PR discovery (webhook + polling) - AI code review via Claude Code CLI (subscription-based) - Auto-create Jira tickets for PRs without ticket ID - Jira ticket lifecycle management (code review -> staging -> done) - CI/CD pipeline trigger, polling, and approval gates - Slack interactive messages with approval buttons - Per-repo semantic versioning - PostgreSQL persistence (threads, staging, releases) - FastAPI API (webhooks, approvals, status, manual triggers) - Docker Compose deployment 1069 tests, 95%+ coverage.
This commit is contained in:
55
.env.example
Normal file
55
.env.example
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# Billo Release Agent - Environment Variables
|
||||||
|
# Copy this file to .env and fill in your values.
|
||||||
|
# Never commit .env to source control.
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# REQUIRED
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
# --- Azure DevOps ----------------------------------------------------------
|
||||||
|
AZDO_ORGANIZATION=billodev
|
||||||
|
AZDO_PROJECT=Billo App Platform
|
||||||
|
AZDO_PAT= # Personal access token (Code read/write, Build read/execute)
|
||||||
|
|
||||||
|
# --- PostgreSQL ------------------------------------------------------------
|
||||||
|
POSTGRES_PASSWORD= # Used by docker-compose db service
|
||||||
|
POSTGRES_DSN=postgresql://agent:${POSTGRES_PASSWORD}@localhost:5432/agent
|
||||||
|
|
||||||
|
# --- Jira ------------------------------------------------------------------
|
||||||
|
JIRA_EMAIL= # Jira account email
|
||||||
|
JIRA_API_TOKEN= # Jira API token
|
||||||
|
|
||||||
|
# --- Slack -----------------------------------------------------------------
|
||||||
|
SLACK_WEBHOOK_URL= # Incoming webhook URL (for release notifications)
|
||||||
|
|
||||||
|
# --- Webhook Security ------------------------------------------------------
|
||||||
|
WEBHOOK_SECRET= # HMAC secret for Azure DevOps webhook validation
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# OPTIONAL (have defaults)
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
# --- Slack App (for interactive buttons) -----------------------------------
|
||||||
|
SLACK_BOT_TOKEN= # xoxb-... Bot token (needed for Slack buttons)
|
||||||
|
SLACK_SIGNING_SECRET= # Slack app signing secret (needed for /slack/interactions)
|
||||||
|
SLACK_CHANNEL_ID= # Channel ID to post messages (e.g., C0123456789)
|
||||||
|
|
||||||
|
# --- CI/CD Polling ---------------------------------------------------------
|
||||||
|
CI_POLL_INTERVAL_SECONDS=30 # Seconds between CI status polls
|
||||||
|
CI_POLL_MAX_WAIT_SECONDS=1800 # Max wait for CI completion (30 min)
|
||||||
|
|
||||||
|
# --- Claude Code Review ----------------------------------------------------
|
||||||
|
# ANTHROPIC_API_KEY= # Not needed — uses Claude Code CLI subscription
|
||||||
|
# CLAUDE_REVIEW_MODEL=claude-sonnet-4-20250514
|
||||||
|
|
||||||
|
# --- Local Repo Path -------------------------------------------------------
|
||||||
|
REPOS_BASE_DIR= # Base dir containing Billo repos (e.g., /c/Users/yaoji/git/Billo)
|
||||||
|
|
||||||
|
# --- Jira ------------------------------------------------------------------
|
||||||
|
JIRA_BASE_URL=https://billolife.atlassian.net
|
||||||
|
|
||||||
|
# --- Operator API ----------------------------------------------------------
|
||||||
|
OPERATOR_TOKEN= # When set, protects POST /approvals/* and POST /manual/*
|
||||||
|
|
||||||
|
# --- Server ----------------------------------------------------------------
|
||||||
|
PORT=8000
|
||||||
36
.gitignore
vendored
Normal file
36
.gitignore
vendored
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.eggs/
|
||||||
|
|
||||||
|
# Virtual environment
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
|
||||||
|
# uv
|
||||||
|
uv.lock
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Data
|
||||||
|
data/staging/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
32
Dockerfile
Normal file
32
Dockerfile
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies required for psycopg binary
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
libpq5 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy dependency files first for layer caching
|
||||||
|
COPY pyproject.toml uv.lock ./
|
||||||
|
|
||||||
|
# Install uv for fast dependency management
|
||||||
|
RUN pip install --no-cache-dir uv
|
||||||
|
|
||||||
|
# Install runtime dependencies (no dev extras)
|
||||||
|
RUN uv pip install --system --no-cache-dir -e .
|
||||||
|
|
||||||
|
# Copy application source
|
||||||
|
COPY src/ ./src/
|
||||||
|
|
||||||
|
# Create data directory for staging store
|
||||||
|
RUN mkdir -p data/staging
|
||||||
|
|
||||||
|
# Non-root user for security
|
||||||
|
RUN adduser --disabled-password --gecos "" appuser && \
|
||||||
|
chown -R appuser:appuser /app
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "release_agent.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
341
README.md
Normal file
341
README.md
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
# Billo Release Agent
|
||||||
|
|
||||||
|
A LangGraph-based release automation agent for Billo. Automates the full release
|
||||||
|
pipeline: PR discovery, code review (via Claude Code CLI), Jira ticket management,
|
||||||
|
staging release tracking, CI/CD pipeline triggering with approval gates, and Slack
|
||||||
|
interactive notifications.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
+--- Azure DevOps Webhook ---+ +--- PR Poller (every 5 min) ---+
|
||||||
|
| POST /webhooks/azdo | | Scans WATCHED_REPOS for |
|
||||||
|
| (push-based) | | active PRs (pull-based) |
|
||||||
|
+------------+---------------+ +------------+------------------+
|
||||||
|
| |
|
||||||
|
v v
|
||||||
|
+---------------------------------------------+
|
||||||
|
| FastAPI Application |
|
||||||
|
| /webhooks/azdo /slack/interactions |
|
||||||
|
| /approvals/* /manual/* /status |
|
||||||
|
+---------------------+------------------------+
|
||||||
|
|
|
||||||
|
v
|
||||||
|
+---------------------------------------------+
|
||||||
|
| LangGraph Graphs |
|
||||||
|
| |
|
||||||
|
| pr_completed: |
|
||||||
|
| parse -> fetch -> [has ticket?] |
|
||||||
|
| no -> Claude generates ticket |
|
||||||
|
| yes -> Claude code review |
|
||||||
|
| -> merge -> Jira -> staging -> CI build |
|
||||||
|
| |
|
||||||
|
| release: |
|
||||||
|
| create release PR -> merge -> CI build |
|
||||||
|
| -> CD release -> [Sandbox approve] |
|
||||||
|
| -> [Production approve] -> Slack notify |
|
||||||
|
+---------------------+------------------------+
|
||||||
|
|
|
||||||
|
+----------------+----------------+
|
||||||
|
| | |
|
||||||
|
v v v
|
||||||
|
PostgreSQL Azure DevOps Slack (buttons)
|
||||||
|
- threads - PRs/Pipelines - Notifications
|
||||||
|
- staging - Builds/Releases - Approvals
|
||||||
|
- releases
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
| Feature | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| **PR Discovery** | Webhook-based (push) or polling-based (pull) — or both |
|
||||||
|
| **Auto-Create Jira Ticket** | When PR branch has no ticket ID, Claude generates summary + description and creates a Jira Story |
|
||||||
|
| **AI Code Review** | Claude Code CLI reviews PRs with full repo context (Read/Glob/Grep), using your subscription |
|
||||||
|
| **CI/CD Integration** | Triggers CI builds after merge, polls for completion, handles CD release approval gates |
|
||||||
|
| **Slack Interactive** | Approval requests with [Approve]/[Cancel] buttons, CI/CD status notifications |
|
||||||
|
| **Human-in-the-loop** | 5 interrupt points where operator confirmation is required before destructive actions |
|
||||||
|
| **Per-repo Versioning** | Independent semantic versioning per repository (patch auto-increment) |
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Python 3.12+
|
||||||
|
- PostgreSQL 16+
|
||||||
|
- [uv](https://github.com/astral-sh/uv) (recommended) or pip
|
||||||
|
- Claude Code CLI installed and authenticated (`claude` in PATH)
|
||||||
|
- Slack App (for interactive buttons) or Slack Incoming Webhook (for notifications only)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Local Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
uv sync --all-extras
|
||||||
|
|
||||||
|
# Copy and configure environment
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env -- fill in all REQUIRED variables
|
||||||
|
|
||||||
|
# Start PostgreSQL
|
||||||
|
docker compose up -d db
|
||||||
|
|
||||||
|
# Run the server
|
||||||
|
uv run uvicorn release_agent.main:app --reload --port 8000
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
curl http://localhost:8000/status
|
||||||
|
```
|
||||||
|
|
||||||
|
### Docker Compose (Production)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env -- POSTGRES_PASSWORD, WEBHOOK_SECRET, etc. are required
|
||||||
|
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
Tables are created automatically on first startup.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
All configuration is via environment variables. See `.env.example` for the full list.
|
||||||
|
|
||||||
|
### Required Variables
|
||||||
|
|
||||||
|
| Variable | Description |
|
||||||
|
|----------|-------------|
|
||||||
|
| `AZDO_ORGANIZATION` | Azure DevOps organization name |
|
||||||
|
| `AZDO_PROJECT` | Azure DevOps project name |
|
||||||
|
| `AZDO_PAT` | Azure DevOps personal access token |
|
||||||
|
| `POSTGRES_DSN` | PostgreSQL connection string |
|
||||||
|
| `POSTGRES_PASSWORD` | PostgreSQL password (used by docker-compose) |
|
||||||
|
| `JIRA_EMAIL` | Jira account email |
|
||||||
|
| `JIRA_API_TOKEN` | Jira API token |
|
||||||
|
| `SLACK_WEBHOOK_URL` | Slack incoming webhook URL |
|
||||||
|
| `WEBHOOK_SECRET` | Shared secret for validating AzDo webhooks (must be non-empty) |
|
||||||
|
|
||||||
|
### Optional Variables
|
||||||
|
|
||||||
|
| Variable | Default | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| `REPOS_BASE_DIR` | `""` | Base dir with Billo repos (e.g., `/c/Users/yaoji/git/Billo`) |
|
||||||
|
| `WATCHED_REPOS` | `""` | Comma-separated repos to poll (e.g., `Billo.Platform.Payment,Billo.Platform.Document.DocumentAnalyser`) |
|
||||||
|
| `PR_POLL_ENABLED` | `False` | Enable periodic PR polling |
|
||||||
|
| `PR_POLL_INTERVAL_SECONDS` | `300` | Polling interval (5 min) |
|
||||||
|
| `PR_POLL_TARGET_BRANCH` | `refs/heads/develop` | Target branch filter |
|
||||||
|
| `DEFAULT_JIRA_PROJECT` | `ALLPOST` | Jira project key for auto-created tickets |
|
||||||
|
| `AUTO_CREATE_TICKET_ENABLED` | `True` | Auto-create Jira ticket when branch has no ticket ID |
|
||||||
|
| `SLACK_BOT_TOKEN` | `""` | Slack App bot token (for interactive buttons) |
|
||||||
|
| `SLACK_SIGNING_SECRET` | `""` | Slack App signing secret (required for /slack/interactions) |
|
||||||
|
| `SLACK_CHANNEL_ID` | `""` | Channel for interactive messages |
|
||||||
|
| `CI_POLL_INTERVAL_SECONDS` | `30` | CI build status poll interval |
|
||||||
|
| `CI_POLL_MAX_WAIT_SECONDS` | `1800` | Max wait for CI completion (30 min) |
|
||||||
|
| `OPERATOR_TOKEN` | `""` | Token for operator endpoints (empty = no auth) |
|
||||||
|
| `JIRA_BASE_URL` | `https://billolife.atlassian.net` | Jira instance URL |
|
||||||
|
| `PORT` | `8000` | HTTP server port |
|
||||||
|
|
||||||
|
### Security Notes
|
||||||
|
|
||||||
|
- `WEBHOOK_SECRET` must be non-empty; empty secret rejects all webhooks
|
||||||
|
- `POSTGRES_PASSWORD` has no default in docker-compose; fails if unset
|
||||||
|
- `SLACK_SIGNING_SECRET` must be set for `/slack/interactions` to accept requests (returns 503 if empty)
|
||||||
|
- Slack signature verification includes 5-minute replay attack prevention
|
||||||
|
- All secrets use `SecretStr` and are never logged or included in error responses
|
||||||
|
- Set `OPERATOR_TOKEN` in production to protect approval and manual trigger endpoints
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### Webhooks
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|--------|------|------|-------------|
|
||||||
|
| POST | `/webhooks/azdo` | `X-Webhook-Secret` | Receive Azure DevOps PR webhook events |
|
||||||
|
| POST | `/slack/interactions` | Slack Signing Secret | Receive Slack button click callbacks |
|
||||||
|
|
||||||
|
### Approvals (requires `X-Operator-Token` when configured)
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
|--------|------|-------------|
|
||||||
|
| GET | `/approvals/pending` | List threads awaiting operator approval |
|
||||||
|
| POST | `/approvals/{thread_id}` | Submit approval decision (merge/cancel/approve/skip) |
|
||||||
|
|
||||||
|
### Status and Manual Triggers
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|--------|------|------|-------------|
|
||||||
|
| GET | `/status` | None | Health check |
|
||||||
|
| GET | `/releases/{repo}` | None | List versions for a repo |
|
||||||
|
| GET | `/staging?repo={repo}` | None | Current staging release |
|
||||||
|
| POST | `/manual/pr/{pr_id}` | `X-Operator-Token` | Manually trigger PR processing |
|
||||||
|
| POST | `/manual/release` | `X-Operator-Token` | Manually trigger a release |
|
||||||
|
|
||||||
|
## Graph Workflows
|
||||||
|
|
||||||
|
### PR Completed
|
||||||
|
|
||||||
|
```
|
||||||
|
parse_webhook -> fetch_pr_details -> route_after_fetch
|
||||||
|
|-- merged -----------------> calculate_version -> update_staging -> CI build -> END
|
||||||
|
|-- active_with_ticket -----> move_jira_code_review -+
|
||||||
|
|-- active_no_ticket -------> auto_create_ticket ----+
|
||||||
|
|
|
||||||
|
run_code_review -> evaluate_review
|
||||||
|
|-- approve -> [Slack: Merge?] -> merge_pr
|
||||||
|
|-- request_changes -> notify -> END
|
||||||
|
-> Jira transitions -> calculate_version
|
||||||
|
-> update_staging -> CI build -> notify -> END
|
||||||
|
```
|
||||||
|
|
||||||
|
### Release
|
||||||
|
|
||||||
|
```
|
||||||
|
load_staging -> [Slack: Create release?] -> create_release_pr
|
||||||
|
-> [Slack: Merge release?] -> merge_release_pr
|
||||||
|
-> CI build on main -> poll until complete
|
||||||
|
|-- ci_failed -> notify failure -> END
|
||||||
|
|-- ci_passed -> wait for CD release -> approval loop:
|
||||||
|
|-- [Slack: Approve Sandbox?] -> approve -> poll again
|
||||||
|
|-- [Slack: Approve Production?] -> approve -> poll again
|
||||||
|
|-- all_deployed -> move tickets to Done
|
||||||
|
-> Slack release notification -> archive -> END
|
||||||
|
```
|
||||||
|
|
||||||
|
### Interrupt Points (Slack buttons)
|
||||||
|
|
||||||
|
| # | When | Slack Message | Buttons |
|
||||||
|
|---|------|--------------|---------|
|
||||||
|
| 1 | After code review approves | PR title + review summary | [Merge] [Cancel] |
|
||||||
|
| 2 | Before creating release PR | Version + ticket list | [Create] [Cancel] |
|
||||||
|
| 3 | Before merging release PR | Release PR link | [Merge] [Cancel] |
|
||||||
|
| 4 | Before triggering pipelines | Pipeline list | [Trigger] [Skip] |
|
||||||
|
| 5 | Before approving release stage | Stage name + status | [Approve] [Skip] |
|
||||||
|
|
||||||
|
## PR Polling (Alternative to Webhooks)
|
||||||
|
|
||||||
|
When `PR_POLL_ENABLED=True`, the agent periodically scans all `WATCHED_REPOS` for
|
||||||
|
active PRs targeting the configured branch. New PRs not yet tracked in `agent_threads`
|
||||||
|
are automatically processed through the `pr_completed` graph.
|
||||||
|
|
||||||
|
This eliminates the need for Azure DevOps webhook configuration and works behind
|
||||||
|
firewalls without public endpoint exposure.
|
||||||
|
|
||||||
|
## Auto-Create Jira Ticket
|
||||||
|
|
||||||
|
When a PR branch has no ticket ID (e.g., `chore/update-dependencies` instead of
|
||||||
|
`feature/ALLPOST-4028_login-page`), the agent automatically:
|
||||||
|
|
||||||
|
1. Sends the PR diff to Claude Code CLI
|
||||||
|
2. Claude generates a concise ticket summary and description
|
||||||
|
3. Creates a Jira Story in the `DEFAULT_JIRA_PROJECT`
|
||||||
|
4. Continues the normal workflow with the created ticket
|
||||||
|
|
||||||
|
## Database Schema
|
||||||
|
|
||||||
|
Tables are created automatically on startup:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Thread tracking for LangGraph interrupts and PR dedup
|
||||||
|
agent_threads (thread_id, graph_name, repo_name, pr_id, status, state JSONB,
|
||||||
|
slack_message_ts, created_at, updated_at)
|
||||||
|
|
||||||
|
-- Current in-progress releases (one per repo)
|
||||||
|
staging_releases (repo, version, started_at, tickets JSONB, updated_at)
|
||||||
|
|
||||||
|
-- Completed releases (immutable history)
|
||||||
|
archived_releases (repo, version, started_at, tickets JSONB, released_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migrating Existing JSON Files
|
||||||
|
|
||||||
|
If you have existing release data from the Claude Code skill:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Dry run
|
||||||
|
uv run python scripts/migrate_json_to_db.py \
|
||||||
|
--source ../release-workflow/releases --dry-run
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
uv run python scripts/migrate_json_to_db.py \
|
||||||
|
--source ../release-workflow/releases \
|
||||||
|
--dsn "postgresql://agent:password@localhost/agent"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests with coverage (1061 tests, 96%+ coverage)
|
||||||
|
uv run pytest
|
||||||
|
|
||||||
|
# Run without coverage (faster)
|
||||||
|
uv run pytest --no-cov
|
||||||
|
|
||||||
|
# Run specific module
|
||||||
|
uv run pytest tests/graph/test_pr_completed.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
src/release_agent/
|
||||||
|
main.py # FastAPI app, lifespan, task management
|
||||||
|
config.py # pydantic-settings (all env vars)
|
||||||
|
state.py # LangGraph ReleaseState TypedDict
|
||||||
|
exceptions.py # Exception hierarchy
|
||||||
|
branch_parser.py # Extract ticket ID from branch name
|
||||||
|
versioning.py # Per-repo version calculation
|
||||||
|
api/
|
||||||
|
models.py # HTTP request/response Pydantic models
|
||||||
|
dependencies.py # FastAPI Depends() + operator auth
|
||||||
|
webhooks.py # POST /webhooks/azdo
|
||||||
|
approvals.py # Approval endpoints
|
||||||
|
status.py # Status, releases, manual triggers
|
||||||
|
slack_interactions.py # POST /slack/interactions (button callbacks)
|
||||||
|
graph/
|
||||||
|
dependencies.py # ToolClients, StagingStore Protocol
|
||||||
|
postgres_staging_store.py # PostgreSQL-backed store
|
||||||
|
routing.py # Pure routing functions (route_after_fetch, etc.)
|
||||||
|
pr_completed.py # PR graph nodes + auto_create_ticket
|
||||||
|
release.py # Release graph nodes + CI/CD approval loop
|
||||||
|
full_cycle.py # Subgraph composition
|
||||||
|
ci_nodes.py # CI trigger, poll, notify nodes
|
||||||
|
polling.py # Reusable async poll_until utility
|
||||||
|
models/
|
||||||
|
pr.py, ticket.py, release.py, pipeline.py
|
||||||
|
webhook.py, review.py, jira.py, build.py
|
||||||
|
tools/
|
||||||
|
azdo.py # Azure DevOps REST client
|
||||||
|
jira.py # Jira REST client (transitions + create_issue)
|
||||||
|
slack.py # Slack dual-mode (webhook + Web API)
|
||||||
|
claude_review.py # Claude Code CLI (review + ticket generation)
|
||||||
|
_http.py, _retry.py # Shared helpers
|
||||||
|
services/
|
||||||
|
pr_poller.py # Background PR polling loop
|
||||||
|
pr_dedup.py # PR deduplication via agent_threads
|
||||||
|
scripts/
|
||||||
|
migrate_json_to_db.py # One-time JSON -> PostgreSQL migration
|
||||||
|
tests/ # 1061 tests, 96%+ coverage
|
||||||
|
```
|
||||||
|
|
||||||
|
## Docker
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Required: set POSTGRES_PASSWORD and WEBHOOK_SECRET in .env
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
The agent service includes a health check at `/status`. PostgreSQL uses
|
||||||
|
`pg_isready` with `service_healthy` dependency.
|
||||||
|
|
||||||
|
## Slack App Setup
|
||||||
|
|
||||||
|
To use interactive buttons (optional — REST API approvals still work without it):
|
||||||
|
|
||||||
|
1. Create a Slack App at https://api.slack.com/apps
|
||||||
|
2. Enable **Interactivity** with Request URL: `https://<your-domain>/slack/interactions`
|
||||||
|
3. Add Bot Token Scopes: `chat:write`, `chat:update`
|
||||||
|
4. Install to workspace, get Bot Token (`xoxb-...`)
|
||||||
|
5. Set `SLACK_BOT_TOKEN`, `SLACK_SIGNING_SECRET`, `SLACK_CHANNEL_ID` in `.env`
|
||||||
48
docker-compose.yml
Normal file
48
docker-compose.yml
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
services:
|
||||||
|
agent:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
environment:
|
||||||
|
AZDO_ORGANIZATION: ${AZDO_ORGANIZATION}
|
||||||
|
AZDO_PROJECT: ${AZDO_PROJECT}
|
||||||
|
AZDO_PAT: ${AZDO_PAT}
|
||||||
|
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-unused}
|
||||||
|
POSTGRES_DSN: postgresql://agent:${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}@db:5432/agent
|
||||||
|
JIRA_EMAIL: ${JIRA_EMAIL}
|
||||||
|
JIRA_API_TOKEN: ${JIRA_API_TOKEN}
|
||||||
|
SLACK_WEBHOOK_URL: ${SLACK_WEBHOOK_URL}
|
||||||
|
WEBHOOK_SECRET: ${WEBHOOK_SECRET:?WEBHOOK_SECRET must be set}
|
||||||
|
JIRA_BASE_URL: ${JIRA_BASE_URL:-https://billolife.atlassian.net}
|
||||||
|
CLAUDE_REVIEW_MODEL: ${CLAUDE_REVIEW_MODEL:-claude-sonnet-4-20250514}
|
||||||
|
REPOS_BASE_DIR: ${REPOS_BASE_DIR:-}
|
||||||
|
OPERATOR_TOKEN: ${OPERATOR_TOKEN:-}
|
||||||
|
PORT: 8000
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "curl -sf http://localhost:8000/status || exit 1"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 3
|
||||||
|
start_period: 15s
|
||||||
|
|
||||||
|
db:
|
||||||
|
image: postgres:16-alpine
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: agent
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
||||||
|
POSTGRES_DB: agent
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U agent -d agent"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
58
pyproject.toml
Normal file
58
pyproject.toml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "billo-release-agent"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "LangGraph-based release automation agent for Billo"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"langgraph>=0.3.30",
|
||||||
|
"langgraph-checkpoint-postgres>=2.0.15",
|
||||||
|
"fastapi>=0.115.0",
|
||||||
|
"uvicorn[standard]>=0.34.0",
|
||||||
|
"httpx>=0.28.0",
|
||||||
|
"anthropic>=0.52.0",
|
||||||
|
"pydantic>=2.11.0",
|
||||||
|
"pydantic-settings>=2.8.0",
|
||||||
|
"psycopg[binary]>=3.2.0",
|
||||||
|
"psycopg-pool>=3.2.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.3.0",
|
||||||
|
"pytest-asyncio>=0.25.0",
|
||||||
|
"pytest-cov>=6.1.0",
|
||||||
|
"ruff>=0.11.0",
|
||||||
|
"psycopg[binary]>=3.2.0",
|
||||||
|
"psycopg-pool>=3.2.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/release_agent"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
pythonpath = ["."]
|
||||||
|
addopts = "--cov=src/release_agent --cov-report=term-missing --cov-fail-under=80"
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
source = ["src/release_agent"]
|
||||||
|
|
||||||
|
[tool.coverage.report]
|
||||||
|
exclude_lines = [
|
||||||
|
"pragma: no cover",
|
||||||
|
"if TYPE_CHECKING:",
|
||||||
|
"raise NotImplementedError",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py312"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I", "N", "UP", "B", "SIM"]
|
||||||
|
ignore = ["E501"]
|
||||||
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
284
scripts/migrate_json_to_db.py
Normal file
284
scripts/migrate_json_to_db.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""Migration script: JSON files -> PostgreSQL.
|
||||||
|
|
||||||
|
Reads staging and archived release JSON files from a directory tree and
|
||||||
|
inserts them into the staging_releases and archived_releases tables.
|
||||||
|
|
||||||
|
All business logic is implemented as pure functions so it can be tested
|
||||||
|
without a real database. The main() entry point wires together the pure
|
||||||
|
functions with actual I/O.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/migrate_json_to_db.py --source /path/to/releases \\
|
||||||
|
--dsn "postgresql://user:pass@localhost/db" [--dry-run]
|
||||||
|
|
||||||
|
Pure functions (testable without DB):
|
||||||
|
collect_json_files(directory) -> list[Path]
|
||||||
|
is_staging_filename(name) -> bool
|
||||||
|
is_archived_filename(name) -> bool
|
||||||
|
parse_staging_json(data) -> MigrationRecord
|
||||||
|
parse_archived_json(data) -> MigrationRecord
|
||||||
|
build_staging_insert_sql(record) -> tuple[str, tuple]
|
||||||
|
build_archived_insert_sql(record) -> tuple[str, tuple]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import date
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MigrationRecord:
|
||||||
|
"""Parsed record ready for database insertion.
|
||||||
|
|
||||||
|
released_at is None for staging records, set for archived records.
|
||||||
|
"""
|
||||||
|
|
||||||
|
repo: str
|
||||||
|
version: str
|
||||||
|
started_at: date
|
||||||
|
tickets: list[dict]
|
||||||
|
released_at: date | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# File classification
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Archived filenames match: <repo>_<version>_<date>.json
|
||||||
|
# e.g. Billo.Platform.Payment_v1.0.1_2026-03-23.json
|
||||||
|
_ARCHIVED_PATTERN = re.compile(
|
||||||
|
r"^.+_v\d+\.\d+\.\d+_\d{4}-\d{2}-\d{2}\.json$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_staging_filename(name: str) -> bool:
|
||||||
|
"""Return True if the filename looks like a staging JSON file.
|
||||||
|
|
||||||
|
Staging files end in .json and do not match the archived pattern.
|
||||||
|
"""
|
||||||
|
if not name.endswith(".json"):
|
||||||
|
return False
|
||||||
|
return not _ARCHIVED_PATTERN.match(name)
|
||||||
|
|
||||||
|
|
||||||
|
def is_archived_filename(name: str) -> bool:
|
||||||
|
"""Return True if the filename looks like an archived release JSON file."""
|
||||||
|
return bool(_ARCHIVED_PATTERN.match(name))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# File collection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def collect_json_files(directory: Path) -> list[Path]:
|
||||||
|
"""Recursively collect all .json files under directory.
|
||||||
|
|
||||||
|
Returns a sorted list of Path objects.
|
||||||
|
"""
|
||||||
|
return sorted(directory.rglob("*.json"))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Parsing pure functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def parse_staging_json(data: dict) -> MigrationRecord:
|
||||||
|
"""Parse a staging release JSON dict into a MigrationRecord.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Parsed JSON dict with keys: version, repo, started_at, tickets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MigrationRecord with released_at=None.
|
||||||
|
"""
|
||||||
|
return MigrationRecord(
|
||||||
|
repo=data["repo"],
|
||||||
|
version=data["version"],
|
||||||
|
started_at=date.fromisoformat(data["started_at"]),
|
||||||
|
tickets=list(data.get("tickets") or []),
|
||||||
|
released_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_archived_json(data: dict) -> MigrationRecord:
|
||||||
|
"""Parse an archived release JSON dict into a MigrationRecord.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Parsed JSON dict with keys: version, repo, started_at, tickets,
|
||||||
|
released_at.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MigrationRecord with released_at set.
|
||||||
|
"""
|
||||||
|
return MigrationRecord(
|
||||||
|
repo=data["repo"],
|
||||||
|
version=data["version"],
|
||||||
|
started_at=date.fromisoformat(data["started_at"]),
|
||||||
|
tickets=list(data.get("tickets") or []),
|
||||||
|
released_at=date.fromisoformat(data["released_at"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SQL builder pure functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_STAGING_INSERT_SQL = """
|
||||||
|
INSERT INTO staging_releases (repo, version, started_at, tickets)
|
||||||
|
VALUES (%s, %s, %s, %s)
|
||||||
|
ON CONFLICT (repo) DO NOTHING
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
_ARCHIVED_INSERT_SQL = """
|
||||||
|
INSERT INTO archived_releases (repo, version, started_at, tickets, released_at)
|
||||||
|
VALUES (%s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT (repo, version) DO NOTHING
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
def build_staging_insert_sql(record: MigrationRecord) -> tuple[str, tuple]:
|
||||||
|
"""Build the INSERT SQL and parameters for a staging release record.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(sql_string, params_tuple) ready for cursor.execute().
|
||||||
|
"""
|
||||||
|
tickets_json = json.dumps(record.tickets)
|
||||||
|
params = (
|
||||||
|
record.repo,
|
||||||
|
record.version,
|
||||||
|
record.started_at.isoformat(),
|
||||||
|
tickets_json,
|
||||||
|
)
|
||||||
|
return _STAGING_INSERT_SQL, params
|
||||||
|
|
||||||
|
|
||||||
|
def build_archived_insert_sql(record: MigrationRecord) -> tuple[str, tuple]:
|
||||||
|
"""Build the INSERT SQL and parameters for an archived release record.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(sql_string, params_tuple) ready for cursor.execute().
|
||||||
|
"""
|
||||||
|
tickets_json = json.dumps(record.tickets)
|
||||||
|
params = (
|
||||||
|
record.repo,
|
||||||
|
record.version,
|
||||||
|
record.started_at.isoformat(),
|
||||||
|
tickets_json,
|
||||||
|
record.released_at.isoformat() if record.released_at else None,
|
||||||
|
)
|
||||||
|
return _ARCHIVED_INSERT_SQL, params
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Migrate JSON release files to PostgreSQL"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--source",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Root directory containing release JSON files",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dsn",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="PostgreSQL DSN (e.g. postgresql://user:pass@localhost/db)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
help="Print SQL statements without executing them",
|
||||||
|
)
|
||||||
|
return parser.parse_args(argv)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv: list[str] | None = None) -> int:
|
||||||
|
"""Entry point for the migration script.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
0 on success, 1 on error.
|
||||||
|
"""
|
||||||
|
args = _parse_args(argv)
|
||||||
|
|
||||||
|
if not args.source.exists():
|
||||||
|
print(f"ERROR: source directory does not exist: {args.source}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
files = collect_json_files(args.source)
|
||||||
|
print(f"Found {len(files)} JSON file(s) under {args.source}")
|
||||||
|
|
||||||
|
statements: list[tuple[str, tuple]] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for path in files:
|
||||||
|
try:
|
||||||
|
data = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
except (json.JSONDecodeError, OSError) as exc:
|
||||||
|
errors.append(f"Failed to read {path}: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine file type from filename
|
||||||
|
name = path.name
|
||||||
|
try:
|
||||||
|
if is_archived_filename(name) or "released_at" in data:
|
||||||
|
record = parse_archived_json(data)
|
||||||
|
sql, params = build_archived_insert_sql(record)
|
||||||
|
else:
|
||||||
|
record = parse_staging_json(data)
|
||||||
|
sql, params = build_staging_insert_sql(record)
|
||||||
|
except (KeyError, ValueError) as exc:
|
||||||
|
errors.append(f"Failed to parse {path}: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
statements.append((sql, params))
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
for err in errors:
|
||||||
|
print(f"WARNING: {err}", file=sys.stderr)
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print(f"\nDry run: {len(statements)} statement(s) would be executed:")
|
||||||
|
for sql, params in statements:
|
||||||
|
print(f" SQL: {sql!r}")
|
||||||
|
print(f" Params: {params}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if not args.dsn:
|
||||||
|
print("ERROR: --dsn is required when not using --dry-run", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
import psycopg # noqa: PLC0415
|
||||||
|
except ImportError:
|
||||||
|
print("ERROR: psycopg not installed. Run: pip install psycopg[binary]", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
inserted = 0
|
||||||
|
with psycopg.connect(args.dsn) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
for sql, params in statements:
|
||||||
|
cur.execute(sql, params)
|
||||||
|
inserted += 1
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
print(f"Migration complete: {inserted} record(s) inserted.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
1
src/release_agent/__init__.py
Normal file
1
src/release_agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Billo Release Agent - LangGraph-based release automation."""
|
||||||
0
src/release_agent/api/__init__.py
Normal file
0
src/release_agent/api/__init__.py
Normal file
166
src/release_agent/api/approvals.py
Normal file
166
src/release_agent/api/approvals.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""Approvals endpoint for resuming interrupted graph threads.
|
||||||
|
|
||||||
|
POST /approvals/{thread_id} — resume an interrupted graph with a decision.
|
||||||
|
GET /approvals/pending — list threads with status="interrupted".
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
|
||||||
|
from release_agent.api.dependencies import get_db_pool, get_graphs, get_tool_clients, require_operator_token
|
||||||
|
from release_agent.api.models import (
|
||||||
|
ApprovalDecision,
|
||||||
|
ApprovalResponse,
|
||||||
|
PendingApproval,
|
||||||
|
PendingApprovalsResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _resume_graph(
|
||||||
|
*,
|
||||||
|
graph,
|
||||||
|
thread_id: str,
|
||||||
|
decision: str,
|
||||||
|
tool_clients,
|
||||||
|
db_pool,
|
||||||
|
) -> dict:
|
||||||
|
"""Resume an interrupted graph thread with the given decision.
|
||||||
|
|
||||||
|
The decision string is passed as the resume value to the graph.
|
||||||
|
Thread status is updated in the database.
|
||||||
|
"""
|
||||||
|
from release_agent.api.webhooks import _upsert_thread
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"clients": tool_clients,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
result = await graph.ainvoke({"decision": decision}, config=config)
|
||||||
|
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="completed", state=result or {})
|
||||||
|
return result or {}
|
||||||
|
except Exception as exc:
|
||||||
|
await _upsert_thread(
|
||||||
|
db_pool,
|
||||||
|
thread_id=thread_id,
|
||||||
|
thread_status="failed",
|
||||||
|
state={"errors": [str(exc)]},
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_thread_graph_name(db_pool, thread_id: str) -> str | None:
|
||||||
|
"""Look up the graph_name for a thread_id from the database."""
|
||||||
|
sql = "SELECT graph_name FROM agent_threads WHERE thread_id = %s"
|
||||||
|
async with db_pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(sql, (thread_id,))
|
||||||
|
row = await cur.fetchone()
|
||||||
|
return row[0] if row else None
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_pending_threads(db_pool) -> list[tuple]:
|
||||||
|
"""Query agent_threads for rows with status='interrupted'."""
|
||||||
|
sql = """
|
||||||
|
SELECT
|
||||||
|
thread_id,
|
||||||
|
graph_name,
|
||||||
|
interrupt_value,
|
||||||
|
created_at,
|
||||||
|
repo_name,
|
||||||
|
pr_id,
|
||||||
|
version
|
||||||
|
FROM agent_threads
|
||||||
|
WHERE status = 'interrupted'
|
||||||
|
ORDER BY created_at ASC
|
||||||
|
"""
|
||||||
|
async with db_pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(sql)
|
||||||
|
return await cur.fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /approvals/pending
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/approvals/pending")
|
||||||
|
async def list_pending_approvals(
|
||||||
|
request: Request, # noqa: ARG001
|
||||||
|
db_pool=Depends(get_db_pool),
|
||||||
|
_auth: None = Depends(require_operator_token),
|
||||||
|
) -> PendingApprovalsResponse:
|
||||||
|
"""Return all graph threads currently awaiting operator approval."""
|
||||||
|
rows = await _fetch_pending_threads(db_pool)
|
||||||
|
items = [
|
||||||
|
PendingApproval(
|
||||||
|
thread_id=row[0],
|
||||||
|
graph_name=row[1],
|
||||||
|
interrupt_value=row[2],
|
||||||
|
created_at=row[3] if isinstance(row[3], datetime) else datetime.fromisoformat(str(row[3])),
|
||||||
|
repo_name=row[4],
|
||||||
|
pr_id=row[5],
|
||||||
|
version=row[6],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
return PendingApprovalsResponse(items=items, count=len(items))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /approvals/{thread_id}
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/approvals/{thread_id}")
|
||||||
|
async def submit_approval(
|
||||||
|
thread_id: str,
|
||||||
|
body: ApprovalDecision,
|
||||||
|
request: Request, # noqa: ARG001
|
||||||
|
graphs=Depends(get_graphs),
|
||||||
|
tool_clients=Depends(get_tool_clients),
|
||||||
|
db_pool=Depends(get_db_pool),
|
||||||
|
_auth: None = Depends(require_operator_token),
|
||||||
|
) -> ApprovalResponse:
|
||||||
|
"""Submit an approval decision to resume an interrupted graph thread.
|
||||||
|
|
||||||
|
The decision is forwarded as the interrupt resume value to the graph.
|
||||||
|
Thread status is updated to 'completed' or 'failed' based on result.
|
||||||
|
"""
|
||||||
|
# Look up graph_name from DB for the correct graph
|
||||||
|
graph_name = await _get_thread_graph_name(db_pool, thread_id)
|
||||||
|
if graph_name is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
|
graph = graphs.get(graph_name)
|
||||||
|
if graph is None:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Unknown graph: {graph_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _resume_graph(
|
||||||
|
graph=graph,
|
||||||
|
thread_id=thread_id,
|
||||||
|
decision=body.decision,
|
||||||
|
tool_clients=tool_clients,
|
||||||
|
db_pool=db_pool,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return ApprovalResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
status="failed",
|
||||||
|
message=f"Thread {thread_id} failed during resume",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ApprovalResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
status="resumed",
|
||||||
|
message=f"Thread {thread_id} resumed with decision '{body.decision}'",
|
||||||
|
)
|
||||||
66
src/release_agent/api/dependencies.py
Normal file
66
src/release_agent/api/dependencies.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""FastAPI dependency callables that extract objects from app.state.
|
||||||
|
|
||||||
|
Each function accepts a Request and returns the corresponding object stored
|
||||||
|
on app.state during lifespan startup. Use with Depends() in route handlers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
from release_agent.config import Settings
|
||||||
|
from release_agent.graph.dependencies import StagingStore, ToolClients
|
||||||
|
|
||||||
|
_operator_token_header = APIKeyHeader(name="X-Operator-Token", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings(request: Request) -> Settings:
|
||||||
|
"""Return the Settings instance stored on app.state."""
|
||||||
|
return request.app.state.settings
|
||||||
|
|
||||||
|
|
||||||
|
def get_graphs(request: Request) -> dict:
|
||||||
|
"""Return the compiled graphs dict stored on app.state."""
|
||||||
|
return request.app.state.graphs
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_clients(request: Request) -> ToolClients:
|
||||||
|
"""Return the ToolClients instance stored on app.state."""
|
||||||
|
return request.app.state.tool_clients
|
||||||
|
|
||||||
|
|
||||||
|
def get_staging_store(request: Request) -> StagingStore:
|
||||||
|
"""Return the StagingStore instance stored on app.state."""
|
||||||
|
return request.app.state.staging_store
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_pool(request: Request):
|
||||||
|
"""Return the database connection pool stored on app.state."""
|
||||||
|
return request.app.state.db_pool
|
||||||
|
|
||||||
|
|
||||||
|
async def require_operator_token(
|
||||||
|
request: Request,
|
||||||
|
token: str | None = Depends(_operator_token_header),
|
||||||
|
) -> None:
|
||||||
|
"""Validate the X-Operator-Token header against the configured operator token.
|
||||||
|
|
||||||
|
If operator_token is not configured (empty string), authentication is skipped
|
||||||
|
and all requests are allowed. When configured, uses hmac.compare_digest
|
||||||
|
for constant-time comparison to prevent timing attacks.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 401 if a token is configured but the provided token is
|
||||||
|
missing or does not match.
|
||||||
|
"""
|
||||||
|
configured = request.app.state.settings.operator_token.get_secret_value()
|
||||||
|
if not configured:
|
||||||
|
# No token configured — auth disabled
|
||||||
|
return
|
||||||
|
|
||||||
|
if not token or not hmac.compare_digest(token, configured):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or missing operator token",
|
||||||
|
)
|
||||||
137
src/release_agent/api/models.py
Normal file
137
src/release_agent/api/models.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""API request/response Pydantic models for the HTTP layer.
|
||||||
|
|
||||||
|
All models are frozen (immutable) following the project immutability policy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Webhook models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class WebhookResponse(BaseModel):
|
||||||
|
"""Response returned after a webhook event is accepted."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Approval models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ApprovalDecision(BaseModel):
|
||||||
|
"""Body for POST /approvals/{thread_id}."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
decision: Literal["merge", "cancel", "approve", "skip", "trigger"]
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalResponse(BaseModel):
|
||||||
|
"""Response returned after an approval decision is processed."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
status: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class PendingApproval(BaseModel):
|
||||||
|
"""A single interrupted graph thread awaiting operator input."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
graph_name: str
|
||||||
|
interrupt_value: str
|
||||||
|
created_at: datetime
|
||||||
|
repo_name: str | None = None
|
||||||
|
pr_id: str | None = None
|
||||||
|
version: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PendingApprovalsResponse(BaseModel):
|
||||||
|
"""Response for GET /approvals/pending."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
items: list[PendingApproval]
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Health / status models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class HealthResponse(BaseModel):
|
||||||
|
"""Response for GET /status."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
status: Literal["ok", "degraded"]
|
||||||
|
version: str
|
||||||
|
uptime_seconds: float
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Release / staging models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ReleaseVersionListResponse(BaseModel):
|
||||||
|
"""Response for GET /releases/{repo}."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
repo: str
|
||||||
|
versions: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class StagingResponse(BaseModel):
|
||||||
|
"""Response for GET /staging."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
repo: str
|
||||||
|
staging: dict | None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Manual trigger models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ManualTriggerResponse(BaseModel):
|
||||||
|
"""Response for manual trigger endpoints."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class ManualReleaseRequest(BaseModel):
|
||||||
|
"""Request body for POST /manual/release."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
repo: str
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Error model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
"""Standard error response envelope."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
error: str
|
||||||
|
detail: str | None = None
|
||||||
264
src/release_agent/api/slack_interactions.py
Normal file
264
src/release_agent/api/slack_interactions.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
"""Slack interactive messages endpoint.
|
||||||
|
|
||||||
|
POST /slack/interactions — receives Slack button click payloads, verifies
|
||||||
|
the signing secret, extracts thread_id + decision, and resumes the
|
||||||
|
appropriate graph thread in the background.
|
||||||
|
|
||||||
|
Slack requires a 200 response within 3 seconds. The graph resume is
|
||||||
|
therefore scheduled as a background task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from release_agent.api.dependencies import get_db_pool, get_graphs, get_settings, get_tool_clients
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pure helper: signature verification
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_MAX_TIMESTAMP_AGE_SECONDS = 300 # 5 minutes — Slack's recommendation
|
||||||
|
_ALLOWED_DECISIONS = frozenset({"merge", "cancel", "approve", "skip", "trigger", "yes", "no"})
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_slack_signature(
|
||||||
|
*,
|
||||||
|
signing_secret: str,
|
||||||
|
timestamp: str,
|
||||||
|
body: str,
|
||||||
|
signature: str,
|
||||||
|
current_time: float | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Verify a Slack request signature using HMAC-SHA256.
|
||||||
|
|
||||||
|
Also rejects requests with timestamps older than 5 minutes
|
||||||
|
to prevent replay attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signing_secret: The app's Slack signing secret.
|
||||||
|
timestamp: Value of the X-Slack-Request-Timestamp header.
|
||||||
|
body: Raw request body as a string.
|
||||||
|
signature: Value of the X-Slack-Signature header.
|
||||||
|
current_time: Current unix timestamp (injectable for testing).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the signature is valid and fresh, False otherwise.
|
||||||
|
"""
|
||||||
|
if not signature or not signature.startswith("v0="):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Replay attack prevention
|
||||||
|
try:
|
||||||
|
request_ts = int(timestamp)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
now = current_time if current_time is not None else time.time()
|
||||||
|
if abs(now - request_ts) > _MAX_TIMESTAMP_AGE_SECONDS:
|
||||||
|
return False
|
||||||
|
|
||||||
|
base_string = f"v0:{timestamp}:{body}"
|
||||||
|
computed_hash = hmac.new(
|
||||||
|
signing_secret.encode(),
|
||||||
|
base_string.encode(),
|
||||||
|
hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
expected_sig = f"v0={computed_hash}"
|
||||||
|
|
||||||
|
return hmac.compare_digest(expected_sig, signature)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal: parse button payload
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_button_payload(form_body: str) -> dict[str, Any] | None:
|
||||||
|
"""Parse a URL-encoded Slack button action payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
form_body: Raw URL-encoded form body from Slack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with "thread_id" and "decision" keys, or None if parsing fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = urllib.parse.parse_qs(form_body)
|
||||||
|
payload_json = parsed.get("payload", [None])[0]
|
||||||
|
if not payload_json:
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload = json.loads(payload_json)
|
||||||
|
actions = payload.get("actions", [])
|
||||||
|
if not actions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
action = actions[0]
|
||||||
|
value = action.get("value", "")
|
||||||
|
# value format: "{thread_id}:{decision}"
|
||||||
|
if ":" not in value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts = value.split(":", 1)
|
||||||
|
thread_id = parts[0]
|
||||||
|
decision = parts[1]
|
||||||
|
|
||||||
|
user = payload.get("user") or {}
|
||||||
|
user_name = user.get("name", user.get("id", "unknown"))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"decision": decision,
|
||||||
|
"user_name": user_name,
|
||||||
|
}
|
||||||
|
except (json.JSONDecodeError, KeyError, IndexError, ValueError) as exc:
|
||||||
|
logger.warning("Failed to parse Slack button payload: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal: background graph resume
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _background_resume(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
decision: str,
|
||||||
|
graphs: dict,
|
||||||
|
tool_clients: Any,
|
||||||
|
db_pool: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Fetch the graph for thread_id from DB and resume it."""
|
||||||
|
from release_agent.api.approvals import _get_thread_graph_name, _resume_graph
|
||||||
|
|
||||||
|
graph_name = await _get_thread_graph_name(db_pool, thread_id)
|
||||||
|
if graph_name is None:
|
||||||
|
logger.warning("slack_interactions: thread %s not found in DB", thread_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
graph = graphs.get(graph_name)
|
||||||
|
if graph is None:
|
||||||
|
logger.warning(
|
||||||
|
"slack_interactions: unknown graph '%s' for thread %s",
|
||||||
|
graph_name,
|
||||||
|
thread_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _resume_graph(
|
||||||
|
graph=graph,
|
||||||
|
thread_id=thread_id,
|
||||||
|
decision=decision,
|
||||||
|
tool_clients=tool_clients,
|
||||||
|
db_pool=db_pool,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"slack_interactions: failed to resume thread %s: %s",
|
||||||
|
thread_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoint: POST /slack/interactions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/slack/interactions")
|
||||||
|
async def slack_interactions(
|
||||||
|
request: Request,
|
||||||
|
graphs: dict = Depends(get_graphs),
|
||||||
|
tool_clients: Any = Depends(get_tool_clients),
|
||||||
|
db_pool: Any = Depends(get_db_pool),
|
||||||
|
settings: Any = Depends(get_settings),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Receive and process Slack interactive button clicks.
|
||||||
|
|
||||||
|
1. Verify the Slack signing secret.
|
||||||
|
2. Parse the payload to extract thread_id and decision.
|
||||||
|
3. Schedule the graph resume as a background task.
|
||||||
|
4. Return 200 immediately (Slack requires a fast response).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200 {"ok": True} on success.
|
||||||
|
400 if required headers are missing.
|
||||||
|
403 if signature verification fails.
|
||||||
|
"""
|
||||||
|
raw_body = await request.body()
|
||||||
|
body_str = raw_body.decode("utf-8")
|
||||||
|
|
||||||
|
timestamp = request.headers.get("X-Slack-Request-Timestamp")
|
||||||
|
signature = request.headers.get("X-Slack-Signature")
|
||||||
|
|
||||||
|
if not timestamp:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Missing X-Slack-Request-Timestamp header",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not signature:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Missing X-Slack-Signature header",
|
||||||
|
)
|
||||||
|
|
||||||
|
signing_secret = settings.slack_signing_secret.get_secret_value()
|
||||||
|
if not signing_secret:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Slack signing secret not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
valid = _verify_slack_signature(
|
||||||
|
signing_secret=signing_secret,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body_str,
|
||||||
|
signature=signature,
|
||||||
|
)
|
||||||
|
if not valid:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid Slack signature",
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed = _parse_button_payload(body_str)
|
||||||
|
if not parsed:
|
||||||
|
logger.warning("slack_interactions: could not parse payload")
|
||||||
|
return JSONResponse({"ok": True})
|
||||||
|
|
||||||
|
thread_id = parsed["thread_id"]
|
||||||
|
decision = parsed["decision"]
|
||||||
|
|
||||||
|
if decision not in _ALLOWED_DECISIONS:
|
||||||
|
logger.warning("slack_interactions: unexpected decision '%s', ignoring", decision)
|
||||||
|
return JSONResponse({"ok": True})
|
||||||
|
|
||||||
|
# Schedule the graph resume as a background task (Slack needs 200 fast)
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_background_resume(
|
||||||
|
thread_id=thread_id,
|
||||||
|
decision=decision,
|
||||||
|
graphs=graphs,
|
||||||
|
tool_clients=tool_clients,
|
||||||
|
db_pool=db_pool,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if hasattr(request.app.state, "background_tasks"):
|
||||||
|
request.app.state.background_tasks.add(task)
|
||||||
|
task.add_done_callback(request.app.state.background_tasks.discard)
|
||||||
|
|
||||||
|
return JSONResponse({"ok": True})
|
||||||
153
src/release_agent/api/status.py
Normal file
153
src/release_agent/api/status.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Status, releases, staging, and manual trigger endpoints.
|
||||||
|
|
||||||
|
GET /status — health check
|
||||||
|
GET /releases/{repo} — list release versions for a repo
|
||||||
|
GET /staging — current staging release for a repo
|
||||||
|
POST /manual/pr/{pr_id} — manually trigger PR processing
|
||||||
|
POST /manual/release — manually trigger a release run
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Request, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from release_agent.api.dependencies import (
|
||||||
|
get_db_pool,
|
||||||
|
get_graphs,
|
||||||
|
get_staging_store,
|
||||||
|
get_tool_clients,
|
||||||
|
require_operator_token,
|
||||||
|
)
|
||||||
|
from release_agent.api.models import (
|
||||||
|
HealthResponse,
|
||||||
|
ManualReleaseRequest,
|
||||||
|
ManualTriggerResponse,
|
||||||
|
ReleaseVersionListResponse,
|
||||||
|
StagingResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
_VERSION = "0.1.0"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /status
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/status")
|
||||||
|
async def get_status(request: Request) -> HealthResponse:
|
||||||
|
"""Return the health status of the agent service."""
|
||||||
|
started_at: datetime = request.app.state.started_at
|
||||||
|
uptime = (datetime.now(tz=timezone.utc) - started_at).total_seconds()
|
||||||
|
return HealthResponse(
|
||||||
|
status="ok",
|
||||||
|
version=_VERSION,
|
||||||
|
uptime_seconds=uptime,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /releases/{repo}
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/releases/{repo}")
|
||||||
|
async def list_release_versions(
|
||||||
|
repo: str,
|
||||||
|
staging_store=Depends(get_staging_store),
|
||||||
|
) -> ReleaseVersionListResponse:
|
||||||
|
"""List all known release versions for the given repository."""
|
||||||
|
versions = await staging_store.list_versions(repo)
|
||||||
|
return ReleaseVersionListResponse(repo=repo, versions=versions)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /staging
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/staging")
|
||||||
|
async def get_staging(
|
||||||
|
repo: str,
|
||||||
|
staging_store=Depends(get_staging_store),
|
||||||
|
) -> StagingResponse:
|
||||||
|
"""Return the current staging release for the given repository."""
|
||||||
|
staging_obj = await staging_store.load(repo)
|
||||||
|
staging_dict = staging_obj.model_dump(mode="json") if staging_obj is not None else None
|
||||||
|
return StagingResponse(repo=repo, staging=staging_dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /manual/pr/{pr_id}
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/manual/pr/{pr_id}", status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
async def manual_pr_trigger(
|
||||||
|
pr_id: str,
|
||||||
|
request: Request,
|
||||||
|
graphs=Depends(get_graphs),
|
||||||
|
tool_clients=Depends(get_tool_clients),
|
||||||
|
db_pool=Depends(get_db_pool),
|
||||||
|
_auth: None = Depends(require_operator_token),
|
||||||
|
) -> ManualTriggerResponse:
|
||||||
|
"""Manually trigger PR processing for the given PR ID."""
|
||||||
|
from release_agent.api.webhooks import _run_graph
|
||||||
|
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
initial_state = {"pr_id": pr_id}
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_run_graph(
|
||||||
|
graph=graphs["pr_completed"],
|
||||||
|
initial_state=initial_state,
|
||||||
|
thread_id=thread_id,
|
||||||
|
tool_clients=tool_clients,
|
||||||
|
db_pool=db_pool,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
request.app.state.background_tasks.add(task)
|
||||||
|
task.add_done_callback(request.app.state.background_tasks.discard)
|
||||||
|
|
||||||
|
return ManualTriggerResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
message=f"PR {pr_id} processing scheduled as thread {thread_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /manual/release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/manual/release", status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
async def manual_release_trigger(
|
||||||
|
body: ManualReleaseRequest,
|
||||||
|
request: Request,
|
||||||
|
graphs=Depends(get_graphs),
|
||||||
|
tool_clients=Depends(get_tool_clients),
|
||||||
|
db_pool=Depends(get_db_pool),
|
||||||
|
_auth: None = Depends(require_operator_token),
|
||||||
|
) -> ManualTriggerResponse:
|
||||||
|
"""Manually trigger a release run for the given repository."""
|
||||||
|
from release_agent.api.webhooks import _run_graph
|
||||||
|
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
initial_state = {"repo_name": body.repo}
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_run_graph(
|
||||||
|
graph=graphs["release"],
|
||||||
|
initial_state=initial_state,
|
||||||
|
thread_id=thread_id,
|
||||||
|
tool_clients=tool_clients,
|
||||||
|
db_pool=db_pool,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
request.app.state.background_tasks.add(task)
|
||||||
|
task.add_done_callback(request.app.state.background_tasks.discard)
|
||||||
|
|
||||||
|
return ManualTriggerResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
message=f"Release for repo '{body.repo}' scheduled as thread {thread_id}",
|
||||||
|
)
|
||||||
195
src/release_agent/api/webhooks.py
Normal file
195
src/release_agent/api/webhooks.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Webhook endpoint for Azure DevOps events.
|
||||||
|
|
||||||
|
POST /webhooks/azdo — validates secret, parses payload, filters events,
|
||||||
|
schedules graph execution as a background task, returns 202.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from release_agent.api.dependencies import get_db_pool, get_graphs, get_settings, get_tool_clients
|
||||||
|
from release_agent.api.models import WebhookResponse
|
||||||
|
from release_agent.models.webhook import WebhookPayload
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Secret validation helper (pure function — easily unit-tested)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _validate_webhook_secret(header: str | None, expected: str) -> bool:
|
||||||
|
"""Return True if header matches expected using constant-time comparison.
|
||||||
|
|
||||||
|
Returns False if:
|
||||||
|
- expected is empty (auth misconfigured, reject all)
|
||||||
|
- header is None (no header sent)
|
||||||
|
- header does not match expected
|
||||||
|
"""
|
||||||
|
if not expected:
|
||||||
|
return False
|
||||||
|
if header is None:
|
||||||
|
return False
|
||||||
|
return hmac.compare_digest(header, expected)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /webhooks/azdo
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/webhooks/azdo")
|
||||||
|
async def receive_azdo_webhook(
|
||||||
|
request: Request,
|
||||||
|
x_webhook_secret: Annotated[str | None, Header()] = None,
|
||||||
|
settings=Depends(get_settings),
|
||||||
|
graphs=Depends(get_graphs),
|
||||||
|
tool_clients=Depends(get_tool_clients),
|
||||||
|
db_pool=Depends(get_db_pool),
|
||||||
|
):
|
||||||
|
"""Receive an Azure DevOps webhook event.
|
||||||
|
|
||||||
|
Validates the X-Webhook-Secret header, parses the payload,
|
||||||
|
filters to completed PR events only, then schedules graph execution.
|
||||||
|
|
||||||
|
Returns 202 with thread_id for accepted events.
|
||||||
|
Returns 200 with "event ignored" for non-completed PR events.
|
||||||
|
Returns 401 for missing or invalid secret.
|
||||||
|
Returns 422 for invalid payload structure.
|
||||||
|
"""
|
||||||
|
expected_secret = settings.webhook_secret.get_secret_value()
|
||||||
|
if not _validate_webhook_secret(x_webhook_secret, expected_secret):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or missing webhook secret",
|
||||||
|
)
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
try:
|
||||||
|
payload = WebhookPayload.model_validate(body)
|
||||||
|
except (ValidationError, Exception) as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=f"Invalid payload: {exc}",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Only process completed PRs
|
||||||
|
if payload.resource.status != "completed":
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"message": "event ignored: PR not completed", "status": payload.resource.status},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate a unique thread ID for this execution
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Build initial state from webhook payload
|
||||||
|
initial_state = {
|
||||||
|
"webhook_payload": body,
|
||||||
|
"pr_id": str(payload.resource.pull_request_id),
|
||||||
|
"repo_name": payload.resource.repository.name,
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = graphs["pr_completed"]
|
||||||
|
|
||||||
|
# Schedule as a background task (non-blocking)
|
||||||
|
settings = request.app.state.settings
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_run_graph(
|
||||||
|
graph=graph,
|
||||||
|
initial_state=initial_state,
|
||||||
|
thread_id=thread_id,
|
||||||
|
tool_clients=tool_clients,
|
||||||
|
db_pool=db_pool,
|
||||||
|
repos_base_dir=settings.repos_base_dir,
|
||||||
|
graph_name="pr_completed",
|
||||||
|
default_jira_project=settings.default_jira_project,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
request.app.state.background_tasks.add(task)
|
||||||
|
task.add_done_callback(request.app.state.background_tasks.discard)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
content=WebhookResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
message=f"Webhook accepted, scheduled as thread {thread_id}",
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal graph runner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _run_graph(
|
||||||
|
*,
|
||||||
|
graph,
|
||||||
|
initial_state: dict,
|
||||||
|
thread_id: str,
|
||||||
|
tool_clients,
|
||||||
|
db_pool,
|
||||||
|
repos_base_dir: str = "",
|
||||||
|
graph_name: str = "",
|
||||||
|
default_jira_project: str = "ALLPOST",
|
||||||
|
) -> None:
|
||||||
|
"""Invoke the graph and persist thread status to the database."""
|
||||||
|
repo_name = initial_state.get("repo_name", "")
|
||||||
|
pr_id = initial_state.get("pr_id", "")
|
||||||
|
config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"clients": tool_clients,
|
||||||
|
"repos_base_dir": repos_base_dir,
|
||||||
|
"default_jira_project": default_jira_project,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await _upsert_thread(
|
||||||
|
db_pool, thread_id=thread_id, thread_status="running",
|
||||||
|
state=initial_state, graph_name=graph_name,
|
||||||
|
repo_name=repo_name, pr_id=pr_id,
|
||||||
|
)
|
||||||
|
result = await graph.ainvoke(initial_state, config=config)
|
||||||
|
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="completed", state=result or {})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Graph execution failed for thread %s", thread_id)
|
||||||
|
error_state = {**initial_state, "errors": [str(exc)]}
|
||||||
|
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="failed", state=error_state)
|
||||||
|
|
||||||
|
|
||||||
|
async def _upsert_thread(
|
||||||
|
pool,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
thread_status: str,
|
||||||
|
state: dict,
|
||||||
|
graph_name: str = "",
|
||||||
|
repo_name: str = "",
|
||||||
|
pr_id: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Insert or update a thread record in agent_threads."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
INSERT INTO agent_threads (thread_id, graph_name, repo_name, pr_id, status, state, updated_at)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, NOW())
|
||||||
|
ON CONFLICT (thread_id) DO UPDATE
|
||||||
|
SET status = EXCLUDED.status,
|
||||||
|
state = EXCLUDED.state,
|
||||||
|
updated_at = EXCLUDED.updated_at
|
||||||
|
"""
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(sql, (
|
||||||
|
thread_id, graph_name, repo_name, pr_id,
|
||||||
|
thread_status, json.dumps(state),
|
||||||
|
))
|
||||||
46
src/release_agent/branch_parser.py
Normal file
46
src/release_agent/branch_parser.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Branch name parsing utilities for extracting Jira ticket IDs.
|
||||||
|
|
||||||
|
Pure functions only - no side effects, no mutation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Matches Jira-style ticket IDs: one or more uppercase letters/digits (starting
|
||||||
|
# with a letter) followed by a dash and one or more digits.
|
||||||
|
# Examples: ALLPOST-4229, BILL-42, MY-1, AB2-100
|
||||||
|
_TICKET_PATTERN = re.compile(r"(?<![A-Z0-9])([A-Z][A-Z0-9]+-\d+)(?![A-Z0-9-])")
|
||||||
|
|
||||||
|
# The refs/heads/ prefix that Azure DevOps sometimes includes in branch names.
|
||||||
|
_REFS_HEADS_PREFIX = "refs/heads/"
|
||||||
|
|
||||||
|
|
||||||
|
def strip_refs_prefix(branch: str) -> str:
|
||||||
|
"""Remove the 'refs/heads/' prefix from a branch name if present.
|
||||||
|
|
||||||
|
Returns the branch name unchanged if the prefix is not present.
|
||||||
|
Does not strip 'refs/tags/' or any other prefix.
|
||||||
|
"""
|
||||||
|
if branch.startswith(_REFS_HEADS_PREFIX):
|
||||||
|
return branch[len(_REFS_HEADS_PREFIX):]
|
||||||
|
return branch
|
||||||
|
|
||||||
|
|
||||||
|
def parse_branch(branch: str) -> tuple[str | None, bool]:
|
||||||
|
"""Parse a branch name and extract a Jira ticket ID if present.
|
||||||
|
|
||||||
|
Handles the 'refs/heads/' prefix automatically.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (ticket_id, has_ticket) where ticket_id is the Jira ID
|
||||||
|
string (e.g. "ALLPOST-4229") or None, and has_ticket is a bool.
|
||||||
|
"""
|
||||||
|
if not branch:
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
normalized = strip_refs_prefix(branch)
|
||||||
|
|
||||||
|
match = _TICKET_PATTERN.search(normalized)
|
||||||
|
if match:
|
||||||
|
return match.group(1), True
|
||||||
|
|
||||||
|
return None, False
|
||||||
110
src/release_agent/config.py
Normal file
110
src/release_agent/config.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Application configuration using pydantic-settings.
|
||||||
|
|
||||||
|
All secrets are stored as SecretStr to prevent accidental logging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import Field, SecretStr, computed_field
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment variables.
|
||||||
|
|
||||||
|
Required variables:
|
||||||
|
AZDO_ORGANIZATION - Azure DevOps organization name
|
||||||
|
AZDO_PROJECT - Azure DevOps project name
|
||||||
|
AZDO_PAT - Azure DevOps personal access token
|
||||||
|
ANTHROPIC_API_KEY - Anthropic API key for Claude
|
||||||
|
POSTGRES_DSN - PostgreSQL connection string
|
||||||
|
JIRA_EMAIL - Jira account email address
|
||||||
|
JIRA_API_TOKEN - Jira API token
|
||||||
|
SLACK_WEBHOOK_URL - Slack incoming webhook URL
|
||||||
|
|
||||||
|
Optional variables:
|
||||||
|
PORT - HTTP server port (default: 8000, range: 1-65535)
|
||||||
|
JIRA_BASE_URL - Jira base URL (default: https://billolife.atlassian.net)
|
||||||
|
CLAUDE_REVIEW_MODEL - Anthropic model for PR review (default: claude-sonnet-4-20250514)
|
||||||
|
WATCHED_REPOS - Comma-separated repo names to poll for PRs
|
||||||
|
PR_POLL_INTERVAL_SECONDS - Seconds between PR polling runs (default: 300)
|
||||||
|
PR_POLL_TARGET_BRANCH - Target branch ref to filter PRs (default: refs/heads/develop)
|
||||||
|
PR_POLL_ENABLED - Enable PR polling background task (default: False)
|
||||||
|
DEFAULT_JIRA_PROJECT - Default Jira project key for auto-created tickets (default: ALLPOST)
|
||||||
|
AUTO_CREATE_TICKET_ENABLED - Auto-create Jira ticket when PR has no ticket (default: True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
azdo_organization: str
|
||||||
|
azdo_project: str
|
||||||
|
azdo_pat: SecretStr
|
||||||
|
anthropic_api_key: SecretStr = Field(default=SecretStr("")) # Optional: only needed if using API directly
|
||||||
|
postgres_dsn: SecretStr
|
||||||
|
port: int = Field(default=8000, ge=1, le=65535)
|
||||||
|
|
||||||
|
# Jira settings
|
||||||
|
jira_base_url: str = "https://billolife.atlassian.net"
|
||||||
|
jira_email: str
|
||||||
|
jira_api_token: SecretStr
|
||||||
|
|
||||||
|
# Slack settings
|
||||||
|
slack_webhook_url: SecretStr = Field(default=SecretStr(""))
|
||||||
|
slack_bot_token: SecretStr = Field(default=SecretStr(""))
|
||||||
|
slack_signing_secret: SecretStr = Field(default=SecretStr(""))
|
||||||
|
slack_channel_id: str = ""
|
||||||
|
|
||||||
|
# CI polling settings
|
||||||
|
ci_poll_interval_seconds: int = 30
|
||||||
|
ci_poll_max_wait_seconds: int = 1800
|
||||||
|
|
||||||
|
# Claude settings
|
||||||
|
claude_review_model: str = "claude-sonnet-4-20250514"
|
||||||
|
|
||||||
|
# Local repo settings
|
||||||
|
repos_base_dir: str = "" # Base directory containing Billo repos (e.g., /c/Users/yaoji/git/Billo)
|
||||||
|
|
||||||
|
# Webhook settings
|
||||||
|
webhook_secret: SecretStr
|
||||||
|
|
||||||
|
# Operator API settings
|
||||||
|
operator_token: SecretStr = Field(default=SecretStr(""))
|
||||||
|
|
||||||
|
# PR polling settings
|
||||||
|
watched_repos: str = ""
|
||||||
|
pr_poll_interval_seconds: int = 300
|
||||||
|
pr_poll_target_branch: str = "refs/heads/develop"
|
||||||
|
pr_poll_enabled: bool = False
|
||||||
|
|
||||||
|
# Auto-create Jira ticket settings
|
||||||
|
default_jira_project: str = "ALLPOST"
|
||||||
|
auto_create_ticket_enabled: bool = True
|
||||||
|
|
||||||
|
@computed_field # type: ignore[misc]
|
||||||
|
@property
|
||||||
|
def watched_repos_list(self) -> list[str]:
|
||||||
|
"""Return watched_repos as a list, splitting on commas and stripping whitespace."""
|
||||||
|
if not self.watched_repos:
|
||||||
|
return []
|
||||||
|
return [r.strip() for r in self.watched_repos.split(",") if r.strip()]
|
||||||
|
|
||||||
|
@computed_field # type: ignore[misc]
|
||||||
|
@property
|
||||||
|
def azdo_base_url(self) -> str:
|
||||||
|
"""Base URL for the Azure DevOps organization."""
|
||||||
|
return f"https://dev.azure.com/{self.azdo_organization}"
|
||||||
|
|
||||||
|
@computed_field # type: ignore[misc]
|
||||||
|
@property
|
||||||
|
def azdo_api_url(self) -> str:
|
||||||
|
"""Base API URL for the Azure DevOps project."""
|
||||||
|
return f"https://dev.azure.com/{self.azdo_organization}/{self.azdo_project}/_apis"
|
||||||
|
|
||||||
|
@computed_field # type: ignore[misc]
|
||||||
|
@property
|
||||||
|
def azdo_vsrm_api_url(self) -> str:
|
||||||
|
"""Base API URL for the Azure DevOps VSRM release management."""
|
||||||
|
return f"https://vsrm.dev.azure.com/{self.azdo_organization}/{self.azdo_project}/_apis"
|
||||||
59
src/release_agent/exceptions.py
Normal file
59
src/release_agent/exceptions.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Custom exception hierarchy for the release agent.
|
||||||
|
|
||||||
|
All exceptions inherit from ReleaseAgentError so callers can catch
|
||||||
|
the entire hierarchy with a single except clause when needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ReleaseAgentError(Exception):
|
||||||
|
"""Base exception for all release agent errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceError(ReleaseAgentError):
|
||||||
|
"""An HTTP error response from an external service.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
service: Name of the service that returned the error (e.g. "jira").
|
||||||
|
status_code: HTTP status code from the response.
|
||||||
|
detail: Optional human-readable detail message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, service: str, status_code: int, detail: str | None) -> None:
|
||||||
|
self.service = service
|
||||||
|
self.status_code = status_code
|
||||||
|
self.detail = detail
|
||||||
|
super().__init__(f"[{service}] HTTP {status_code}: {detail}")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(ServiceError):
|
||||||
|
"""Authentication or authorisation failure (401 / 403)."""
|
||||||
|
|
||||||
|
def __init__(self, *, service: str, status_code: int = 401, detail: str | None = None) -> None:
|
||||||
|
super().__init__(service=service, status_code=status_code, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(ServiceError):
|
||||||
|
"""Requested resource was not found (404)."""
|
||||||
|
|
||||||
|
def __init__(self, *, service: str, detail: str | None = None) -> None:
|
||||||
|
super().__init__(service=service, status_code=404, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitError(ServiceError):
|
||||||
|
"""Rate limit exceeded (429).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
retry_after: Seconds to wait before retrying, or None if not provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, service: str, retry_after: int | None) -> None:
|
||||||
|
self.retry_after = retry_after
|
||||||
|
detail = f"retry after {retry_after}s" if retry_after is not None else None
|
||||||
|
super().__init__(service=service, status_code=429, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceUnavailableError(ServiceError):
|
||||||
|
"""Service temporarily unavailable (503)."""
|
||||||
|
|
||||||
|
def __init__(self, *, service: str, detail: str | None = None) -> None:
|
||||||
|
super().__init__(service=service, status_code=503, detail=detail)
|
||||||
0
src/release_agent/graph/__init__.py
Normal file
0
src/release_agent/graph/__init__.py
Normal file
180
src/release_agent/graph/ci_nodes.py
Normal file
180
src/release_agent/graph/ci_nodes.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""CI/CD graph nodes for triggering, polling, and notifying CI builds.
|
||||||
|
|
||||||
|
Each node is an async function (state, config) -> dict.
|
||||||
|
Nodes never mutate state — they return new dicts.
|
||||||
|
External clients are accessed via config["configurable"]["clients"].
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from release_agent.exceptions import ReleaseAgentError
|
||||||
|
from release_agent.graph.dependencies import ToolClients
|
||||||
|
from release_agent.graph.polling import poll_until
|
||||||
|
from release_agent.models.build import BuildStatus
|
||||||
|
from release_agent.tools.slack import _build_ci_status_blocks
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_BRANCH = "refs/heads/main"
|
||||||
|
_CI_POLL_INTERVAL = 30
|
||||||
|
_CI_POLL_MAX_WAIT = 1800
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clients(config: dict) -> ToolClients:
|
||||||
|
return config["configurable"]["clients"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_settings(config: dict):
|
||||||
|
return config["configurable"].get("settings")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: trigger_ci_build
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def trigger_ci_build(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Trigger the CI build pipeline for the repository.
|
||||||
|
|
||||||
|
Finds the first available pipeline for the repo and triggers it on
|
||||||
|
the appropriate branch (main for post-merge, release branch otherwise).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Graph state dict.
|
||||||
|
config: LangGraph config dict with "configurable" key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with ci_build_id on success, or errors on failure.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
version = state.get("version", "")
|
||||||
|
|
||||||
|
# Determine branch: main if version present (release), develop otherwise (PR merge)
|
||||||
|
if version:
|
||||||
|
branch = "refs/heads/main"
|
||||||
|
else:
|
||||||
|
branch = "refs/heads/develop"
|
||||||
|
|
||||||
|
try:
|
||||||
|
pipelines = await clients.azdo.list_build_pipelines(repo=repo_name)
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"trigger_ci_build: failed to list pipelines: {exc}"]}
|
||||||
|
|
||||||
|
if not pipelines:
|
||||||
|
return {"errors": [f"trigger_ci_build: no pipelines found for repo '{repo_name}'"]}
|
||||||
|
|
||||||
|
pipeline = pipelines[0]
|
||||||
|
try:
|
||||||
|
result = await clients.azdo.trigger_pipeline(
|
||||||
|
pipeline_id=pipeline.id,
|
||||||
|
branch=branch,
|
||||||
|
)
|
||||||
|
raw_id = result.get("id")
|
||||||
|
if not isinstance(raw_id, int):
|
||||||
|
return {"errors": [f"trigger_ci_build: unexpected build id type: {raw_id!r}"]}
|
||||||
|
return {
|
||||||
|
"ci_build_id": raw_id,
|
||||||
|
"messages": [f"CI build triggered: pipeline {pipeline.id}, build {raw_id}"],
|
||||||
|
}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"trigger_ci_build: failed to trigger pipeline {pipeline.id}: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: poll_ci_build
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def poll_ci_build(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Poll the CI build until completion or timeout.
|
||||||
|
|
||||||
|
Uses the polling utility with configurable interval and max wait.
|
||||||
|
On timeout, appends an error and returns the last known status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Graph state dict containing ci_build_id.
|
||||||
|
config: LangGraph config dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with ci_build_status, ci_build_result, and ci_build_url on
|
||||||
|
completion, or errors on timeout/failure.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
build_id = state.get("ci_build_id")
|
||||||
|
|
||||||
|
if not build_id:
|
||||||
|
return {"errors": ["poll_ci_build: ci_build_id not set in state"]}
|
||||||
|
|
||||||
|
settings = _get_settings(config)
|
||||||
|
interval = getattr(settings, "ci_poll_interval_seconds", _CI_POLL_INTERVAL) if settings else _CI_POLL_INTERVAL
|
||||||
|
max_wait = getattr(settings, "ci_poll_max_wait_seconds", _CI_POLL_MAX_WAIT) if settings else _CI_POLL_MAX_WAIT
|
||||||
|
|
||||||
|
async def poll_fn() -> BuildStatus:
|
||||||
|
return await clients.azdo.get_build_status(build_id=build_id)
|
||||||
|
|
||||||
|
def is_done(bs: BuildStatus) -> bool:
|
||||||
|
return bs is not None and bs.status == "completed"
|
||||||
|
|
||||||
|
last_status, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=is_done,
|
||||||
|
interval_seconds=interval,
|
||||||
|
max_wait_seconds=max_wait,
|
||||||
|
)
|
||||||
|
|
||||||
|
if last_status is None:
|
||||||
|
return {"errors": [f"poll_ci_build: build {build_id} polling failed (no status returned)"]}
|
||||||
|
|
||||||
|
result: dict = {
|
||||||
|
"ci_build_status": last_status.status,
|
||||||
|
"ci_build_result": last_status.result,
|
||||||
|
"ci_build_url": last_status.build_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not completed:
|
||||||
|
result["errors"] = [
|
||||||
|
f"poll_ci_build: build {build_id} did not complete within timeout "
|
||||||
|
f"(last status: {last_status.status})"
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: notify_ci_result
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def notify_ci_result(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Send a Slack notification with the CI build result.
|
||||||
|
|
||||||
|
Non-critical: errors are appended rather than re-raised.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Graph state dict.
|
||||||
|
config: LangGraph config dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with messages on success, or errors on failure.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
repo_name = state.get("repo_name", "unknown")
|
||||||
|
build_result = state.get("ci_build_result", "unknown")
|
||||||
|
build_url = state.get("ci_build_url")
|
||||||
|
branch = state.get("version", "main")
|
||||||
|
|
||||||
|
blocks = _build_ci_status_blocks(
|
||||||
|
repo=repo_name,
|
||||||
|
branch=branch,
|
||||||
|
status=build_result,
|
||||||
|
build_url=build_url,
|
||||||
|
)
|
||||||
|
status_label = "passed" if build_result == "succeeded" else "failed"
|
||||||
|
text = f"CI build {status_label} for {repo_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await clients.slack.send_notification(text=text, blocks=blocks)
|
||||||
|
return {"messages": [text]}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"notify_ci_result: {exc}"]}
|
||||||
|
except Exception as exc:
|
||||||
|
return {"errors": [f"notify_ci_result: unexpected error: {exc}"]}
|
||||||
171
src/release_agent/graph/dependencies.py
Normal file
171
src/release_agent/graph/dependencies.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Graph dependency types: ToolClients dataclass and StagingStore protocol.
|
||||||
|
|
||||||
|
ToolClients is a frozen dataclass injected via config["configurable"]["clients"].
|
||||||
|
StagingStore is a Protocol with file-based implementation JsonFileStagingStore.
|
||||||
|
All StagingStore methods are async to support both file-based and DB backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from release_agent.models.release import ArchivedRelease, StagingRelease
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ToolClients
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ToolClients:
|
||||||
|
"""Frozen container for all external service clients.
|
||||||
|
|
||||||
|
Injected into graph nodes via config["configurable"]["clients"].
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
azdo: AzDoClient instance.
|
||||||
|
jira: JiraClient instance.
|
||||||
|
slack: SlackClient instance.
|
||||||
|
reviewer: ClaudeReviewer instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
azdo: Any
|
||||||
|
jira: Any
|
||||||
|
slack: Any
|
||||||
|
reviewer: Any
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# StagingStore Protocol
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class StagingStore(Protocol):
|
||||||
|
"""Protocol for persistent staging release storage.
|
||||||
|
|
||||||
|
Implementations handle reading and writing StagingRelease and
|
||||||
|
ArchivedRelease objects. All methods are async to support both
|
||||||
|
file-based and PostgreSQL backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def load(self, repo: str) -> StagingRelease | None:
|
||||||
|
"""Load the current staging release for the given repository.
|
||||||
|
|
||||||
|
Returns None if no staging release exists.
|
||||||
|
"""
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
async def save(self, release: StagingRelease) -> None:
|
||||||
|
"""Persist a staging release, overwriting any existing entry."""
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
async def archive(self, release: StagingRelease, released_at: date) -> None:
|
||||||
|
"""Archive the staging release with the given release date.
|
||||||
|
|
||||||
|
Creates an archive entry and removes the staging entry.
|
||||||
|
"""
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
async def list_versions(self, repo: str) -> list[str]:
|
||||||
|
"""Return all version strings known for the given repository.
|
||||||
|
|
||||||
|
Includes both current staging version (if any) and all archived versions.
|
||||||
|
"""
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JsonFileStagingStore
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class JsonFileStagingStore:
|
||||||
|
"""File-based StagingStore implementation using JSON files.
|
||||||
|
|
||||||
|
Files are stored in a single directory:
|
||||||
|
- Staging: <directory>/<repo>.json
|
||||||
|
- Archive: <directory>/<repo>_<version>_<date>.json
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Path to the directory for storing JSON files.
|
||||||
|
Created automatically if it does not exist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, directory: Path) -> None:
|
||||||
|
self._dir = directory
|
||||||
|
self._dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public interface (async)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def load(self, repo: str) -> StagingRelease | None:
|
||||||
|
"""Load the current staging release for the given repository."""
|
||||||
|
path = self._staging_path(repo)
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
data = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
return StagingRelease.model_validate(data)
|
||||||
|
|
||||||
|
async def save(self, release: StagingRelease) -> None:
|
||||||
|
"""Persist a staging release to disk, overwriting any existing file."""
|
||||||
|
path = self._staging_path(release.repo)
|
||||||
|
path.write_text(
|
||||||
|
release.model_dump_json(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def archive(self, release: StagingRelease, released_at: date) -> None:
|
||||||
|
"""Archive the staging release and remove the staging file."""
|
||||||
|
archived = ArchivedRelease(
|
||||||
|
version=release.version,
|
||||||
|
repo=release.repo,
|
||||||
|
started_at=release.started_at,
|
||||||
|
tickets=list(release.tickets),
|
||||||
|
released_at=released_at,
|
||||||
|
)
|
||||||
|
archive_path = self._archive_path(release.repo, release.version, released_at)
|
||||||
|
archive_path.write_text(
|
||||||
|
archived.model_dump_json(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
# Remove staging file if it exists
|
||||||
|
staging_path = self._staging_path(release.repo)
|
||||||
|
if staging_path.exists():
|
||||||
|
staging_path.unlink()
|
||||||
|
|
||||||
|
async def list_versions(self, repo: str) -> list[str]:
|
||||||
|
"""Return all version strings known for the given repository."""
|
||||||
|
versions: list[str] = []
|
||||||
|
|
||||||
|
# Check current staging file
|
||||||
|
staging = await self.load(repo)
|
||||||
|
if staging is not None:
|
||||||
|
versions.append(staging.version)
|
||||||
|
|
||||||
|
# Check archived files
|
||||||
|
prefix = f"{repo}_"
|
||||||
|
for path in self._dir.iterdir():
|
||||||
|
if path.name.startswith(prefix) and path.name.endswith(".json"):
|
||||||
|
try:
|
||||||
|
data = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
v = data.get("version")
|
||||||
|
if v and v not in versions:
|
||||||
|
versions.append(v)
|
||||||
|
except (json.JSONDecodeError, KeyError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return versions
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Private helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _staging_path(self, repo: str) -> Path:
|
||||||
|
"""Return the path to the staging file for the given repo."""
|
||||||
|
return self._dir / f"{repo}.json"
|
||||||
|
|
||||||
|
def _archive_path(self, repo: str, version: str, released_at: date) -> Path:
|
||||||
|
"""Return the path to an archive file."""
|
||||||
|
date_str = released_at.isoformat()
|
||||||
|
return self._dir / f"{repo}_{version}_{date_str}.json"
|
||||||
44
src/release_agent/graph/full_cycle.py
Normal file
44
src/release_agent/graph/full_cycle.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Full Cycle graph: composes pr_completed and release as subgraph nodes.
|
||||||
|
|
||||||
|
The full cycle graph handles the complete lifecycle from a PR merge webhook
|
||||||
|
through to a released version. After the PR completed subgraph, a conditional
|
||||||
|
edge routes to the release subgraph if continue_to_release is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
|
||||||
|
from release_agent.graph.pr_completed import build_pr_completed_graph
|
||||||
|
from release_agent.graph.release import build_release_graph
|
||||||
|
from release_agent.graph.routing import should_continue_to_release
|
||||||
|
from release_agent.state import ReleaseState
|
||||||
|
|
||||||
|
|
||||||
|
def build_full_cycle_graph():
|
||||||
|
"""Assemble and compile the Full Cycle StateGraph.
|
||||||
|
|
||||||
|
The graph consists of:
|
||||||
|
- pr_completed subgraph: handles webhook parsing, code review, and PR merge
|
||||||
|
- A conditional edge routing to the release subgraph or END
|
||||||
|
- release subgraph: handles staging confirmation, release PR, and pipeline triggers
|
||||||
|
|
||||||
|
Returns the compiled graph ready for invocation.
|
||||||
|
"""
|
||||||
|
pr_completed = build_pr_completed_graph()
|
||||||
|
release = build_release_graph()
|
||||||
|
|
||||||
|
graph = StateGraph(ReleaseState)
|
||||||
|
|
||||||
|
graph.add_node("pr_completed", pr_completed)
|
||||||
|
graph.add_node("release", release)
|
||||||
|
|
||||||
|
graph.add_edge(START, "pr_completed")
|
||||||
|
|
||||||
|
graph.add_conditional_edges(
|
||||||
|
"pr_completed",
|
||||||
|
should_continue_to_release,
|
||||||
|
{"yes": "release", "no": END},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph.add_edge("release", END)
|
||||||
|
|
||||||
|
return graph.compile()
|
||||||
91
src/release_agent/graph/polling.py
Normal file
91
src/release_agent/graph/polling.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""Reusable async polling utility for CI/CD status checks.
|
||||||
|
|
||||||
|
poll_until runs a poll_fn repeatedly until is_done returns True,
|
||||||
|
a timeout is reached, or too many consecutive failures occur.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, TypeVar
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
_MAX_CONSECUTIVE_FAILURES = 3
|
||||||
|
|
||||||
|
|
||||||
|
async def poll_until(
|
||||||
|
*,
|
||||||
|
poll_fn: Callable[[], Any],
|
||||||
|
is_done: Callable[[Any], bool],
|
||||||
|
interval_seconds: float = 30,
|
||||||
|
max_wait_seconds: float = 1800,
|
||||||
|
sleep_fn: Callable[[float], Any] | None = None,
|
||||||
|
) -> tuple[Any, bool]:
|
||||||
|
"""Poll poll_fn until is_done returns True, timeout, or too many failures.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
poll_fn: Async callable that fetches the latest status.
|
||||||
|
is_done: Pure function that returns True when polling should stop.
|
||||||
|
interval_seconds: Seconds to wait between polls (default: 30).
|
||||||
|
max_wait_seconds: Maximum total seconds before giving up (default: 1800).
|
||||||
|
sleep_fn: Async sleep function to inject (default: asyncio.sleep).
|
||||||
|
Inject a no-op in tests to avoid real waits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (last_result, completed_within_timeout) where:
|
||||||
|
- last_result is the last value returned by poll_fn (or None if every
|
||||||
|
call raised an exception).
|
||||||
|
- completed_within_timeout is True if is_done returned True before the
|
||||||
|
timeout and before consecutive failure abort.
|
||||||
|
"""
|
||||||
|
if sleep_fn is None:
|
||||||
|
sleep_fn = asyncio.sleep
|
||||||
|
|
||||||
|
elapsed: float = 0.0
|
||||||
|
consecutive_failures = 0
|
||||||
|
last_result: Any = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
last_result = await poll_fn()
|
||||||
|
consecutive_failures = 0
|
||||||
|
except Exception as exc:
|
||||||
|
consecutive_failures += 1
|
||||||
|
logger.warning(
|
||||||
|
"poll_until: poll_fn raised exception (consecutive=%d): %s",
|
||||||
|
consecutive_failures,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
if consecutive_failures >= _MAX_CONSECUTIVE_FAILURES:
|
||||||
|
logger.error(
|
||||||
|
"poll_until: aborting after %d consecutive failures",
|
||||||
|
_MAX_CONSECUTIVE_FAILURES,
|
||||||
|
)
|
||||||
|
return None, False
|
||||||
|
# Sleep before retry
|
||||||
|
await sleep_fn(interval_seconds)
|
||||||
|
elapsed += interval_seconds
|
||||||
|
if elapsed >= max_wait_seconds:
|
||||||
|
return last_result, False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_done(last_result):
|
||||||
|
return last_result, True
|
||||||
|
|
||||||
|
# Check if we've already used the full budget
|
||||||
|
if elapsed >= max_wait_seconds:
|
||||||
|
return last_result, False
|
||||||
|
|
||||||
|
# Sleep and advance elapsed time
|
||||||
|
remaining = max_wait_seconds - elapsed
|
||||||
|
sleep_time = min(interval_seconds, remaining)
|
||||||
|
if sleep_time <= 0:
|
||||||
|
return last_result, False
|
||||||
|
|
||||||
|
await sleep_fn(sleep_time)
|
||||||
|
elapsed += sleep_time
|
||||||
|
|
||||||
|
if elapsed >= max_wait_seconds:
|
||||||
|
return last_result, False
|
||||||
148
src/release_agent/graph/postgres_staging_store.py
Normal file
148
src/release_agent/graph/postgres_staging_store.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""PostgreSQL-backed StagingStore implementation.
|
||||||
|
|
||||||
|
Uses psycopg async connection pool. Tables:
|
||||||
|
- staging_releases: current in-progress releases per repo
|
||||||
|
- archived_releases: completed releases (immutable history)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
from release_agent.models.release import ArchivedRelease, StagingRelease
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
|
||||||
|
_INSERT_STAGING_SQL = """
|
||||||
|
INSERT INTO staging_releases (repo, version, started_at, tickets)
|
||||||
|
VALUES (%s, %s, %s, %s)
|
||||||
|
ON CONFLICT (repo) DO UPDATE
|
||||||
|
SET version = EXCLUDED.version,
|
||||||
|
started_at = EXCLUDED.started_at,
|
||||||
|
tickets = EXCLUDED.tickets,
|
||||||
|
updated_at = NOW()
|
||||||
|
"""
|
||||||
|
|
||||||
|
_SELECT_STAGING_SQL = """
|
||||||
|
SELECT repo, version, started_at, tickets
|
||||||
|
FROM staging_releases
|
||||||
|
WHERE repo = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
_INSERT_ARCHIVED_SQL = """
|
||||||
|
INSERT INTO archived_releases (repo, version, started_at, tickets, released_at)
|
||||||
|
VALUES (%s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT (repo, version) DO NOTHING
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DELETE_STAGING_SQL = """
|
||||||
|
DELETE FROM staging_releases WHERE repo = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
_SELECT_ARCHIVED_SQL = """
|
||||||
|
SELECT repo, version, started_at, tickets, released_at
|
||||||
|
FROM archived_releases
|
||||||
|
WHERE repo = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresStagingStore:
|
||||||
|
"""Async PostgreSQL-backed StagingStore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool: A psycopg_pool.AsyncConnectionPool instance (or compatible fake).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, pool) -> None:
|
||||||
|
self._pool = pool
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public async interface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def load(self, repo: str) -> StagingRelease | None:
|
||||||
|
"""Load the current staging release for the given repository."""
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(_SELECT_STAGING_SQL, (repo,))
|
||||||
|
row = await cur.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._row_to_staging(row)
|
||||||
|
|
||||||
|
async def save(self, release: StagingRelease) -> None:
|
||||||
|
"""Upsert the staging release into the database."""
|
||||||
|
tickets_json = json.dumps(
|
||||||
|
[t.model_dump(mode="json") for t in release.tickets]
|
||||||
|
)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
_INSERT_STAGING_SQL,
|
||||||
|
(
|
||||||
|
release.repo,
|
||||||
|
release.version,
|
||||||
|
release.started_at.isoformat(),
|
||||||
|
tickets_json,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def archive(self, release: StagingRelease, released_at: date) -> None:
|
||||||
|
"""Archive the staging release and delete it from staging_releases."""
|
||||||
|
tickets_json = json.dumps(
|
||||||
|
[t.model_dump(mode="json") for t in release.tickets]
|
||||||
|
)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
async with conn.transaction():
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
_INSERT_ARCHIVED_SQL,
|
||||||
|
(
|
||||||
|
release.repo,
|
||||||
|
release.version,
|
||||||
|
release.started_at.isoformat(),
|
||||||
|
tickets_json,
|
||||||
|
released_at.isoformat(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await cur.execute(_DELETE_STAGING_SQL, (release.repo,))
|
||||||
|
|
||||||
|
async def list_versions(self, repo: str) -> list[str]:
|
||||||
|
"""Return all version strings for the given repository."""
|
||||||
|
versions: list[str] = []
|
||||||
|
|
||||||
|
# Check current staging
|
||||||
|
staging = await self.load(repo)
|
||||||
|
if staging is not None and staging.version not in versions:
|
||||||
|
versions.append(staging.version)
|
||||||
|
|
||||||
|
# Check archived
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(_SELECT_ARCHIVED_SQL, (repo,))
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
version = row[1] # version is column index 1
|
||||||
|
if version and version not in versions:
|
||||||
|
versions.append(version)
|
||||||
|
|
||||||
|
return versions
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Private helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _row_to_staging(self, row: tuple) -> StagingRelease:
|
||||||
|
"""Convert a DB row (repo, version, started_at, tickets) to StagingRelease."""
|
||||||
|
repo, version, started_at_str, tickets_raw = row[0], row[1], row[2], row[3]
|
||||||
|
|
||||||
|
tickets_data = json.loads(tickets_raw) if isinstance(tickets_raw, str) else tickets_raw
|
||||||
|
tickets = [TicketEntry.model_validate(t) for t in (tickets_data or [])]
|
||||||
|
|
||||||
|
return StagingRelease(
|
||||||
|
repo=repo,
|
||||||
|
version=version,
|
||||||
|
started_at=date.fromisoformat(str(started_at_str)),
|
||||||
|
tickets=tickets,
|
||||||
|
)
|
||||||
573
src/release_agent/graph/pr_completed.py
Normal file
573
src/release_agent/graph/pr_completed.py
Normal file
@@ -0,0 +1,573 @@
|
|||||||
|
"""Node functions for the PR Completed subgraph.
|
||||||
|
|
||||||
|
Each node is an async function (state, config) -> dict.
|
||||||
|
Nodes never mutate state — they return new dicts with updated fields.
|
||||||
|
External clients are accessed via config["configurable"]["clients"].
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from langgraph.types import interrupt
|
||||||
|
|
||||||
|
from release_agent.branch_parser import parse_branch
|
||||||
|
from release_agent.exceptions import ReleaseAgentError
|
||||||
|
from release_agent.graph.ci_nodes import notify_ci_result, poll_ci_build, trigger_ci_build
|
||||||
|
from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients
|
||||||
|
from release_agent.graph.routing import has_ticket, is_pr_already_merged, is_review_approved, route_after_fetch
|
||||||
|
from release_agent.models.release import StagingRelease
|
||||||
|
from release_agent.models.review import ReviewResult
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
from release_agent.models.webhook import WebhookPayload
|
||||||
|
from release_agent.state import ReleaseState
|
||||||
|
from release_agent.versioning import calculate_next_version
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _get_clients(config: dict) -> ToolClients:
|
||||||
|
return config["configurable"]["clients"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_staging_store(config: dict):
|
||||||
|
return config["configurable"].get("staging_store")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: parse_webhook
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def parse_webhook(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Parse the webhook_payload field and extract PR info and ticket ID.
|
||||||
|
|
||||||
|
Returns a dict with pr_info, ticket_id, has_ticket, repo_name, pr_id.
|
||||||
|
On validation error, appends to errors and returns partial state.
|
||||||
|
"""
|
||||||
|
payload_raw = state.get("webhook_payload", {})
|
||||||
|
try:
|
||||||
|
payload = WebhookPayload.model_validate(payload_raw)
|
||||||
|
resource = payload.resource
|
||||||
|
pr_id = str(resource.pull_request_id)
|
||||||
|
repo_name = resource.repository.name
|
||||||
|
branch = resource.source_ref_name
|
||||||
|
ticket_id, ticket_found = parse_branch(branch)
|
||||||
|
|
||||||
|
pr_info = {
|
||||||
|
"pr_id": pr_id,
|
||||||
|
"repo_name": repo_name,
|
||||||
|
"branch": branch,
|
||||||
|
"pr_title": resource.title,
|
||||||
|
"pr_status": resource.status,
|
||||||
|
"pr_url": str(resource.repository.web_url) + f"/pullrequest/{pr_id}",
|
||||||
|
"ticket_id": ticket_id,
|
||||||
|
"has_ticket": ticket_found,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"pr_info": pr_info,
|
||||||
|
"pr_id": pr_id,
|
||||||
|
"repo_name": repo_name,
|
||||||
|
"ticket_id": ticket_id,
|
||||||
|
"has_ticket": ticket_found,
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
return {"errors": [f"parse_webhook failed: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: fetch_pr_details
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def fetch_pr_details(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Fetch full PR details from AzDo and check if already merged.
|
||||||
|
|
||||||
|
Sets pr_already_merged, pr_diff, and last_merge_source_commit.
|
||||||
|
On ReleaseAgentError, appends to errors (non-critical).
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
pr_id_raw = state.get("pr_id") or (state.get("pr_info") or {}).get("pr_id", "")
|
||||||
|
try:
|
||||||
|
pr_id = int(pr_id_raw)
|
||||||
|
pr = await clients.azdo.get_pr(pr_id)
|
||||||
|
already_merged = pr.pr_status == "completed"
|
||||||
|
diff = ""
|
||||||
|
if not already_merged:
|
||||||
|
diff = await clients.azdo.get_pr_diff(pr_id)
|
||||||
|
|
||||||
|
# last_merge_source_commit is not directly on PRInfo; pass None if unavailable
|
||||||
|
return {
|
||||||
|
"pr_already_merged": already_merged,
|
||||||
|
"pr_diff": diff,
|
||||||
|
"last_merge_source_commit": None,
|
||||||
|
}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"fetch_pr_details failed: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: auto_create_ticket
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def auto_create_ticket(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Auto-create a Jira ticket for a PR that has no existing ticket.
|
||||||
|
|
||||||
|
Uses ClaudeReviewer to generate ticket content from the PR diff,
|
||||||
|
then creates a new Jira Story in the configured default project.
|
||||||
|
|
||||||
|
Sets ticket_id and has_ticket=True on success.
|
||||||
|
On error, appends to errors (non-critical).
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
pr_info = state.get("pr_info") or {}
|
||||||
|
pr_title = pr_info.get("pr_title", "")
|
||||||
|
repo_name = pr_info.get("repo_name", "")
|
||||||
|
pr_diff = state.get("pr_diff", "")
|
||||||
|
default_jira_project = config.get("configurable", {}).get("default_jira_project", "ALLPOST")
|
||||||
|
|
||||||
|
repos_base_dir = config.get("configurable", {}).get("repos_base_dir", "")
|
||||||
|
cwd = None
|
||||||
|
if repos_base_dir and repo_name:
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
repo_path = _Path(repos_base_dir) / repo_name
|
||||||
|
if repo_path.is_dir():
|
||||||
|
cwd = str(repo_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary, description = await clients.reviewer.generate_ticket_content(
|
||||||
|
diff=pr_diff,
|
||||||
|
pr_title=pr_title,
|
||||||
|
repo_name=repo_name,
|
||||||
|
cwd=cwd,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return {"errors": [f"auto_create_ticket generate_ticket_content failed: {exc}"]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
ticket_id = await clients.jira.create_issue(
|
||||||
|
project=default_jira_project,
|
||||||
|
summary=summary,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"ticket_id": ticket_id,
|
||||||
|
"has_ticket": True,
|
||||||
|
"messages": [f"Auto-created Jira ticket {ticket_id} for PR in {repo_name}"],
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
return {"errors": [f"auto_create_ticket create_issue failed: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: move_jira_code_review
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def move_jira_code_review(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Transition the Jira ticket to Code Review status.
|
||||||
|
|
||||||
|
Skipped if has_ticket is False. Non-critical: errors are appended.
|
||||||
|
"""
|
||||||
|
if not state.get("has_ticket"):
|
||||||
|
return {}
|
||||||
|
clients = _get_clients(config)
|
||||||
|
ticket_id = state.get("ticket_id", "")
|
||||||
|
try:
|
||||||
|
await clients.jira.transition_issue(ticket_id, "code review")
|
||||||
|
return {"messages": [f"Jira {ticket_id} moved to Code Review"]}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"move_jira_code_review failed: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: post review comments to Azure DevOps PR
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _post_review_to_pr(
|
||||||
|
clients: ToolClients,
|
||||||
|
repo_name: str,
|
||||||
|
pr_id: int,
|
||||||
|
review: ReviewResult,
|
||||||
|
) -> None:
|
||||||
|
"""Post review results as comments on the Azure DevOps PR.
|
||||||
|
|
||||||
|
Posts inline comments for issues with file_path + line_start,
|
||||||
|
and a summary comment for the overall review.
|
||||||
|
Non-critical: errors are logged but not raised.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Post inline comments for issues with file path and line numbers
|
||||||
|
for issue in review.issues:
|
||||||
|
if issue.file_path and issue.line_start:
|
||||||
|
severity_icon = {"blocker": "BLOCKER", "error": "ERROR", "warning": "WARNING", "info": "INFO"}
|
||||||
|
label = severity_icon.get(issue.severity, issue.severity.upper())
|
||||||
|
content = f"**[{label}]** {issue.description}"
|
||||||
|
if issue.suggestion:
|
||||||
|
content += f"\n\n**Suggestion:** {issue.suggestion}"
|
||||||
|
try:
|
||||||
|
await clients.azdo.add_pr_inline_comment(
|
||||||
|
repo=repo_name,
|
||||||
|
pr_id=pr_id,
|
||||||
|
content=content,
|
||||||
|
file_path=issue.file_path,
|
||||||
|
line_start=issue.line_start,
|
||||||
|
line_end=issue.line_end,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to post inline comment on %s:%d: %s", issue.file_path, issue.line_start, exc)
|
||||||
|
|
||||||
|
# Post summary comment
|
||||||
|
verdict_text = "APPROVE" if review.verdict == "approve" else "REQUEST CHANGES"
|
||||||
|
issue_count = len(review.issues)
|
||||||
|
summary = f"## Claude Code Review: {verdict_text}\n\n{review.summary}\n\n"
|
||||||
|
if issue_count > 0:
|
||||||
|
summary += f"**{issue_count} issue(s) found:**\n"
|
||||||
|
for issue in review.issues:
|
||||||
|
loc = f" ({issue.file_path}:{issue.line_start})" if issue.file_path and issue.line_start else ""
|
||||||
|
summary += f"- **{issue.severity.upper()}**{loc}: {issue.description}\n"
|
||||||
|
else:
|
||||||
|
summary += "No issues found."
|
||||||
|
|
||||||
|
try:
|
||||||
|
await clients.azdo.add_pr_comment(repo=repo_name, pr_id=pr_id, content=summary)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to post review summary comment: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: run_code_review
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def run_code_review(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Run Claude code review on the PR diff.
|
||||||
|
|
||||||
|
Returns review_result as a serialisable dict.
|
||||||
|
On error, appends to errors.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
diff = state.get("pr_diff", "")
|
||||||
|
pr_info = state.get("pr_info") or {}
|
||||||
|
pr_title = pr_info.get("pr_title", "")
|
||||||
|
repo_name = pr_info.get("repo_name", "")
|
||||||
|
|
||||||
|
# Build cwd from repos_base_dir + repo_name so Claude Code can read the codebase
|
||||||
|
repos_base_dir = config.get("configurable", {}).get("repos_base_dir", "")
|
||||||
|
cwd = None
|
||||||
|
if repos_base_dir and repo_name:
|
||||||
|
from pathlib import Path
|
||||||
|
repo_path = Path(repos_base_dir) / repo_name
|
||||||
|
if repo_path.is_dir():
|
||||||
|
cwd = str(repo_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
review: ReviewResult = await clients.reviewer.review_pr(
|
||||||
|
diff=diff, pr_title=pr_title, repo_name=repo_name, cwd=cwd
|
||||||
|
)
|
||||||
|
# Post review comments to Azure DevOps PR
|
||||||
|
pr_id = int(pr_info.get("pr_id", 0))
|
||||||
|
if pr_id and repo_name:
|
||||||
|
await _post_review_to_pr(clients, repo_name, pr_id, review)
|
||||||
|
return {"review_result": review.model_dump(mode="json")}
|
||||||
|
except Exception as exc:
|
||||||
|
return {"errors": [f"run_code_review failed: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: evaluate_review
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def evaluate_review(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Evaluate the review result and set review_approved.
|
||||||
|
|
||||||
|
Approved only when verdict=="approve" and no blockers present.
|
||||||
|
"""
|
||||||
|
review = state.get("review_result") or {}
|
||||||
|
verdict = review.get("verdict", "request_changes")
|
||||||
|
has_blockers = review.get("has_blockers", False)
|
||||||
|
# Also check issues list for blocker severity
|
||||||
|
if not has_blockers:
|
||||||
|
issues = review.get("issues") or []
|
||||||
|
has_blockers = any(i.get("severity") == "blocker" for i in issues)
|
||||||
|
approved = (verdict == "approve") and not has_blockers
|
||||||
|
return {"review_approved": approved}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: interrupt_confirm_merge
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def interrupt_confirm_merge(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Interrupt to ask the operator to confirm the PR merge.
|
||||||
|
|
||||||
|
Passes a human-readable summary string to interrupt().
|
||||||
|
"""
|
||||||
|
pr_info = state.get("pr_info") or {}
|
||||||
|
review = state.get("review_result") or {}
|
||||||
|
summary = (
|
||||||
|
f"PR #{pr_info.get('pr_id')} - {pr_info.get('pr_title')} "
|
||||||
|
f"[repo: {pr_info.get('repo_name')}]\n"
|
||||||
|
f"Review: {review.get('summary', 'No review summary')}\n"
|
||||||
|
f"Confirm merge? (yes/no)"
|
||||||
|
)
|
||||||
|
interrupt(summary)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: merge_pr_node
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def merge_pr_node(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Merge the PR via AzDo API.
|
||||||
|
|
||||||
|
Critical node — re-raises ReleaseAgentError.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
pr_info = state.get("pr_info") or {}
|
||||||
|
pr_id = int(pr_info.get("pr_id", state.get("pr_id", 0)))
|
||||||
|
commit = state.get("last_merge_source_commit", "")
|
||||||
|
await clients.azdo.merge_pr(pr_id=pr_id, last_merge_source_commit=commit)
|
||||||
|
return {"messages": [f"PR #{pr_id} merged"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: move_jira_ready_for_stage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def move_jira_ready_for_stage(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Transition the Jira ticket to Ready for stage (2).
|
||||||
|
|
||||||
|
Skipped if has_ticket is False. Non-critical: errors appended.
|
||||||
|
"""
|
||||||
|
if not state.get("has_ticket"):
|
||||||
|
return {}
|
||||||
|
clients = _get_clients(config)
|
||||||
|
ticket_id = state.get("ticket_id", "")
|
||||||
|
try:
|
||||||
|
await clients.jira.transition_issue(ticket_id, "Ready for stage (2)")
|
||||||
|
return {"messages": [f"Jira {ticket_id} moved to Ready for stage (2)"]}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"move_jira_ready_for_stage failed: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: add_jira_pr_link
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def add_jira_pr_link(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Add the PR URL as a remote link on the Jira ticket.
|
||||||
|
|
||||||
|
Skipped if has_ticket is False. Non-critical: errors appended.
|
||||||
|
"""
|
||||||
|
if not state.get("has_ticket"):
|
||||||
|
return {}
|
||||||
|
clients = _get_clients(config)
|
||||||
|
ticket_id = state.get("ticket_id", "")
|
||||||
|
pr_info = state.get("pr_info") or {}
|
||||||
|
pr_url = pr_info.get("pr_url", "")
|
||||||
|
pr_title = pr_info.get("pr_title", "PR")
|
||||||
|
pr_id = pr_info.get("pr_id", "")
|
||||||
|
try:
|
||||||
|
await clients.jira.add_remote_link(
|
||||||
|
ticket_id=ticket_id,
|
||||||
|
url=str(pr_url),
|
||||||
|
title=f"PR #{pr_id}: {pr_title}",
|
||||||
|
)
|
||||||
|
return {"messages": [f"PR link added to Jira {ticket_id}"]}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"add_jira_pr_link failed: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: calculate_version
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def calculate_version(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Calculate the next version for the repository using the staging store.
|
||||||
|
|
||||||
|
Uses calculate_next_version from versioning module.
|
||||||
|
"""
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
staging_store = _get_staging_store(config)
|
||||||
|
existing_versions: list[str] = []
|
||||||
|
if staging_store is not None:
|
||||||
|
existing_versions = await staging_store.list_versions(repo_name)
|
||||||
|
version = calculate_next_version(repo_name, existing_versions)
|
||||||
|
return {"version": version}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: update_staging
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def update_staging(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Add the PR's ticket to the staging release.
|
||||||
|
|
||||||
|
If no staging exists, creates a new one. If has_ticket is False, skips
|
||||||
|
the ticket addition but still ensures a staging record exists.
|
||||||
|
"""
|
||||||
|
if not state.get("has_ticket"):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
clients = _get_clients(config)
|
||||||
|
staging_store = _get_staging_store(config)
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
version = state.get("version", "v1.0.0")
|
||||||
|
ticket_id = state.get("ticket_id", "")
|
||||||
|
pr_info = state.get("pr_info") or {}
|
||||||
|
|
||||||
|
# Fetch ticket summary from Jira
|
||||||
|
try:
|
||||||
|
issue = await clients.jira.get_issue(ticket_id)
|
||||||
|
summary = issue.summary
|
||||||
|
except Exception:
|
||||||
|
summary = ticket_id
|
||||||
|
|
||||||
|
pr_url = pr_info.get("pr_url", "")
|
||||||
|
pr_id = pr_info.get("pr_id", "")
|
||||||
|
pr_title = pr_info.get("pr_title", "")
|
||||||
|
branch = pr_info.get("branch", "")
|
||||||
|
|
||||||
|
ticket_entry = TicketEntry(
|
||||||
|
id=ticket_id,
|
||||||
|
summary=summary,
|
||||||
|
pr_id=str(pr_id),
|
||||||
|
pr_url=pr_url if pr_url.startswith("http") else f"https://example.com/{pr_id}",
|
||||||
|
pr_title=pr_title,
|
||||||
|
branch=branch,
|
||||||
|
merged_at=date.today(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if staging_store is not None:
|
||||||
|
existing = await staging_store.load(repo_name)
|
||||||
|
if existing is None:
|
||||||
|
staging = StagingRelease(
|
||||||
|
version=version,
|
||||||
|
repo=repo_name,
|
||||||
|
started_at=date.today(),
|
||||||
|
tickets=[ticket_entry],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
staging = existing.add_ticket(ticket_entry)
|
||||||
|
await staging_store.save(staging)
|
||||||
|
return {"staging": staging.model_dump(mode="json")}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: notify_request_changes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def notify_request_changes(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Send a Slack notification when the review requests changes.
|
||||||
|
|
||||||
|
Non-critical: errors appended.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
pr_info = state.get("pr_info") or {}
|
||||||
|
review = state.get("review_result") or {}
|
||||||
|
pr_id = pr_info.get("pr_id", "?")
|
||||||
|
pr_title = pr_info.get("pr_title", "")
|
||||||
|
repo = pr_info.get("repo_name", "")
|
||||||
|
summary = review.get("summary", "Review requested changes")
|
||||||
|
issues = review.get("issues") or []
|
||||||
|
issues_text = "\n".join(
|
||||||
|
f"- [{i.get('severity', 'info')}] {i.get('description', '')}" for i in issues
|
||||||
|
)
|
||||||
|
details = f"PR #{pr_id}: {pr_title} [{repo}]\n{summary}\n{issues_text}"
|
||||||
|
try:
|
||||||
|
await clients.slack.send_approval_request(
|
||||||
|
action="Request Changes",
|
||||||
|
details=details,
|
||||||
|
approval_url="",
|
||||||
|
)
|
||||||
|
return {"messages": [f"Slack notified: request changes for PR #{pr_id}"]}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"notify_request_changes failed: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Graph builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def build_pr_completed_graph():
|
||||||
|
"""Assemble and compile the PR Completed StateGraph.
|
||||||
|
|
||||||
|
Returns the compiled graph ready for invocation.
|
||||||
|
|
||||||
|
Routing after fetch_pr_details:
|
||||||
|
- "merged" -> calculate_version (PR already merged)
|
||||||
|
- "active_with_ticket" -> move_jira_code_review (PR active, has Jira ticket)
|
||||||
|
- "active_no_ticket" -> auto_create_ticket -> move_jira_code_review (no ticket)
|
||||||
|
"""
|
||||||
|
graph = StateGraph(ReleaseState)
|
||||||
|
|
||||||
|
# Add all nodes
|
||||||
|
graph.add_node("parse_webhook", parse_webhook)
|
||||||
|
graph.add_node("fetch_pr_details", fetch_pr_details)
|
||||||
|
graph.add_node("auto_create_ticket", auto_create_ticket)
|
||||||
|
graph.add_node("move_jira_code_review", move_jira_code_review)
|
||||||
|
graph.add_node("run_code_review", run_code_review)
|
||||||
|
graph.add_node("evaluate_review", evaluate_review)
|
||||||
|
graph.add_node("interrupt_confirm_merge", interrupt_confirm_merge)
|
||||||
|
graph.add_node("merge_pr_node", merge_pr_node)
|
||||||
|
graph.add_node("move_jira_ready_for_stage", move_jira_ready_for_stage)
|
||||||
|
graph.add_node("add_jira_pr_link", add_jira_pr_link)
|
||||||
|
graph.add_node("calculate_version", calculate_version)
|
||||||
|
graph.add_node("update_staging", update_staging)
|
||||||
|
graph.add_node("notify_request_changes", notify_request_changes)
|
||||||
|
graph.add_node("trigger_ci_build", trigger_ci_build)
|
||||||
|
graph.add_node("poll_ci_build", poll_ci_build)
|
||||||
|
graph.add_node("notify_ci_result", notify_ci_result)
|
||||||
|
|
||||||
|
# Entry
|
||||||
|
graph.add_edge(START, "parse_webhook")
|
||||||
|
graph.add_edge("parse_webhook", "fetch_pr_details")
|
||||||
|
|
||||||
|
# Three-way route: merged, active_with_ticket, active_no_ticket
|
||||||
|
graph.add_conditional_edges(
|
||||||
|
"fetch_pr_details",
|
||||||
|
route_after_fetch,
|
||||||
|
{
|
||||||
|
"merged": "calculate_version",
|
||||||
|
"active_with_ticket": "move_jira_code_review",
|
||||||
|
"active_no_ticket": "auto_create_ticket",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# auto_create_ticket -> move_jira_code_review (ticket now exists after creation)
|
||||||
|
graph.add_edge("auto_create_ticket", "move_jira_code_review")
|
||||||
|
|
||||||
|
# Active PR path: code review
|
||||||
|
graph.add_edge("move_jira_code_review", "run_code_review")
|
||||||
|
graph.add_edge("run_code_review", "evaluate_review")
|
||||||
|
|
||||||
|
# Route based on review outcome
|
||||||
|
graph.add_conditional_edges(
|
||||||
|
"evaluate_review",
|
||||||
|
is_review_approved,
|
||||||
|
{
|
||||||
|
"approve": "interrupt_confirm_merge",
|
||||||
|
"request_changes": "notify_request_changes",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph.add_edge("interrupt_confirm_merge", "merge_pr_node")
|
||||||
|
graph.add_edge("merge_pr_node", "move_jira_ready_for_stage")
|
||||||
|
graph.add_edge("move_jira_ready_for_stage", "add_jira_pr_link")
|
||||||
|
graph.add_edge("add_jira_pr_link", "calculate_version")
|
||||||
|
|
||||||
|
# calculate_version -> update_staging -> trigger_ci_build -> poll_ci_build -> notify_ci_result -> END
|
||||||
|
graph.add_edge("calculate_version", "update_staging")
|
||||||
|
graph.add_edge("update_staging", "trigger_ci_build")
|
||||||
|
graph.add_edge("trigger_ci_build", "poll_ci_build")
|
||||||
|
graph.add_edge("poll_ci_build", "notify_ci_result")
|
||||||
|
graph.add_edge("notify_ci_result", END)
|
||||||
|
|
||||||
|
# notify_request_changes -> END
|
||||||
|
graph.add_edge("notify_request_changes", END)
|
||||||
|
|
||||||
|
return graph.compile()
|
||||||
627
src/release_agent/graph/release.py
Normal file
627
src/release_agent/graph/release.py
Normal file
@@ -0,0 +1,627 @@
|
|||||||
|
"""Node functions for the Release subgraph.
|
||||||
|
|
||||||
|
Each node is an async function (state, config) -> dict.
|
||||||
|
Nodes never mutate state — they return new dicts.
|
||||||
|
External clients are accessed via config["configurable"]["clients"].
|
||||||
|
|
||||||
|
Release flow (Phase 5):
|
||||||
|
merge_release_pr
|
||||||
|
-> trigger_ci_build_main
|
||||||
|
-> poll_ci_build_main
|
||||||
|
-> route_ci_result
|
||||||
|
-> ci_passed: wait_for_cd_release
|
||||||
|
-> poll_release_approvals
|
||||||
|
-> route_approval_stage
|
||||||
|
-> sandbox_pending: interrupt_sandbox_approval
|
||||||
|
-> execute_sandbox_approval
|
||||||
|
-> poll_release_approvals (loop)
|
||||||
|
-> prod_pending: interrupt_prod_approval
|
||||||
|
-> execute_prod_approval
|
||||||
|
-> poll_release_approvals (loop)
|
||||||
|
-> all_deployed: move_tickets_to_done
|
||||||
|
-> send_release_notification
|
||||||
|
-> archive_release
|
||||||
|
-> END
|
||||||
|
-> ci_failed: notify_ci_failure -> END
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from langgraph.types import interrupt
|
||||||
|
|
||||||
|
from release_agent.exceptions import ReleaseAgentError
|
||||||
|
from release_agent.graph.ci_nodes import poll_ci_build, trigger_ci_build
|
||||||
|
from release_agent.graph.dependencies import ToolClients
|
||||||
|
from release_agent.graph.routing import route_approval_stage, route_ci_result
|
||||||
|
from release_agent.models.release import StagingRelease
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
from release_agent.state import ReleaseState
|
||||||
|
from release_agent.tools.slack import _build_ci_status_blocks
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _get_clients(config: dict) -> ToolClients:
|
||||||
|
return config["configurable"]["clients"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_staging_store(config: dict):
|
||||||
|
return config["configurable"].get("staging_store")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: load_staging
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def load_staging(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Load the current staging release from the store."""
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
staging_store = _get_staging_store(config)
|
||||||
|
if staging_store is None:
|
||||||
|
return {"staging": None}
|
||||||
|
staging = await staging_store.load(repo_name)
|
||||||
|
if staging is None:
|
||||||
|
return {"staging": None}
|
||||||
|
return {"staging": staging.model_dump(mode="json")}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: interrupt_confirm_release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def interrupt_confirm_release(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Interrupt to ask the operator to confirm starting the release."""
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
staging_dict = state.get("staging") or {}
|
||||||
|
version = staging_dict.get("version", "unknown")
|
||||||
|
tickets = staging_dict.get("tickets") or []
|
||||||
|
ticket_lines = "\n".join(f" - {t.get('id')}: {t.get('summary', '')}" for t in tickets)
|
||||||
|
summary = (
|
||||||
|
f"Release {version} for repo '{repo_name}'\n"
|
||||||
|
f"Tickets ({len(tickets)}):\n{ticket_lines}\n"
|
||||||
|
f"Confirm release? (yes/no)"
|
||||||
|
)
|
||||||
|
interrupt(summary)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: create_release_pr
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def create_release_pr(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Create a release PR in AzDo from a release branch to main."""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
version = state.get("version", "")
|
||||||
|
staging_dict = state.get("staging") or {}
|
||||||
|
tickets = staging_dict.get("tickets") or []
|
||||||
|
ticket_ids = ", ".join(t.get("id", "") for t in tickets)
|
||||||
|
|
||||||
|
source_branch = f"refs/heads/release/{version}"
|
||||||
|
target_branch = "refs/heads/main"
|
||||||
|
title = f"Release {version}"
|
||||||
|
description = f"Release {version} for {repo_name}\n\nTickets: {ticket_ids}"
|
||||||
|
|
||||||
|
data = await clients.azdo.create_pr(
|
||||||
|
repo=repo_name,
|
||||||
|
source=source_branch,
|
||||||
|
target=target_branch,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
|
||||||
|
pr_id = str(data.get("pullRequestId", ""))
|
||||||
|
commit = (data.get("lastMergeSourceCommit") or {}).get("commitId", "")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"release_pr_id": pr_id,
|
||||||
|
"release_pr_commit": commit,
|
||||||
|
"messages": [f"Release PR #{pr_id} created for {version}"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: interrupt_confirm_merge_release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def interrupt_confirm_merge_release(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Interrupt to ask the operator to confirm merging the release PR."""
|
||||||
|
release_pr_id = state.get("release_pr_id", "?")
|
||||||
|
version = state.get("version", "")
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
summary = (
|
||||||
|
f"Merge release PR #{release_pr_id} ({version}) into main for '{repo_name}'?\n"
|
||||||
|
f"Confirm merge? (yes/no)"
|
||||||
|
)
|
||||||
|
interrupt(summary)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: merge_release_pr
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def merge_release_pr(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Merge the release PR via AzDo API."""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
pr_id = int(state.get("release_pr_id", 0))
|
||||||
|
commit = state.get("release_pr_commit", "")
|
||||||
|
await clients.azdo.merge_pr(pr_id=pr_id, last_merge_source_commit=commit)
|
||||||
|
return {"messages": [f"Release PR #{pr_id} merged"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: trigger_ci_build_main (delegates to ci_nodes.trigger_ci_build)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def trigger_ci_build_main(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Trigger CI build on main after the release PR is merged."""
|
||||||
|
return await trigger_ci_build(state, config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: poll_ci_build_main (delegates to ci_nodes.poll_ci_build)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def poll_ci_build_main(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Poll the main branch CI build until completion."""
|
||||||
|
return await poll_ci_build(state, config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: notify_ci_failure
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def notify_ci_failure(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Send a Slack notification that the CI build failed on main.
|
||||||
|
|
||||||
|
Non-critical: errors appended.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
repo_name = state.get("repo_name", "unknown")
|
||||||
|
build_result = state.get("ci_build_result", "failed")
|
||||||
|
build_url = state.get("ci_build_url")
|
||||||
|
version = state.get("version", "")
|
||||||
|
|
||||||
|
blocks = _build_ci_status_blocks(
|
||||||
|
repo=repo_name,
|
||||||
|
branch=f"main (release {version})" if version else "main",
|
||||||
|
status=build_result,
|
||||||
|
build_url=build_url,
|
||||||
|
)
|
||||||
|
text = f"CI build failed on main for {repo_name} release {version}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await clients.slack.send_notification(text=text, blocks=blocks)
|
||||||
|
return {"messages": [text]}
|
||||||
|
except Exception as exc:
|
||||||
|
return {"errors": [f"notify_ci_failure: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: wait_for_cd_release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def wait_for_cd_release(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Wait for the CD pipeline to create a release after CI passes.
|
||||||
|
|
||||||
|
Fetches the latest release for the configured release definition.
|
||||||
|
Non-critical: errors appended if not found.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
definition_id = state.get("release_definition_id")
|
||||||
|
|
||||||
|
if not definition_id:
|
||||||
|
return {"errors": ["wait_for_cd_release: release_definition_id not set in state"]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
release = await clients.azdo.get_latest_release(definition_id=definition_id)
|
||||||
|
if not release:
|
||||||
|
return {"errors": ["wait_for_cd_release: no release found for definition"]}
|
||||||
|
return {
|
||||||
|
"release_id": release["id"],
|
||||||
|
"messages": [f"CD release found: {release.get('name', release['id'])}"],
|
||||||
|
}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"wait_for_cd_release: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: poll_release_approvals
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def poll_release_approvals(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Poll AzDo for pending release environment approvals.
|
||||||
|
|
||||||
|
Non-critical: errors appended on failure.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
release_id = state.get("release_id")
|
||||||
|
|
||||||
|
if not release_id:
|
||||||
|
return {"pending_approvals": [], "errors": ["poll_release_approvals: release_id not set"]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
approvals = await clients.azdo.get_release_approvals(release_id=release_id)
|
||||||
|
pending = [
|
||||||
|
{
|
||||||
|
"approval_id": a.approval_id,
|
||||||
|
"stage_name": a.stage_name,
|
||||||
|
"status": a.status,
|
||||||
|
"release_id": a.release_id,
|
||||||
|
}
|
||||||
|
for a in approvals
|
||||||
|
if a.status == "pending"
|
||||||
|
]
|
||||||
|
return {"pending_approvals": pending}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"poll_release_approvals: {exc}"], "pending_approvals": []}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: interrupt_sandbox_approval
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def interrupt_sandbox_approval(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Interrupt to ask the operator to approve the sandbox deployment."""
|
||||||
|
approvals = state.get("pending_approvals") or []
|
||||||
|
version = state.get("version", "")
|
||||||
|
stage_names = ", ".join(a.get("stage_name", "?") for a in approvals)
|
||||||
|
summary = (
|
||||||
|
f"Approve sandbox deployment for release {version}?\n"
|
||||||
|
f"Stages: {stage_names}\n"
|
||||||
|
f"Approve? (yes/no)"
|
||||||
|
)
|
||||||
|
interrupt(summary)
|
||||||
|
return {"current_stage": "sandbox_pending"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: execute_sandbox_approval
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def execute_sandbox_approval(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Approve all pending sandbox stage approvals via AzDo VSRM.
|
||||||
|
|
||||||
|
Non-critical per approval: errors appended on individual failures.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
approvals = state.get("pending_approvals") or []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for approval in approvals:
|
||||||
|
approval_id = approval.get("approval_id", "")
|
||||||
|
try:
|
||||||
|
await clients.azdo.approve_release(
|
||||||
|
approval_id=approval_id,
|
||||||
|
comment="Sandbox approved by release agent via Slack",
|
||||||
|
)
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
errors.append(f"execute_sandbox_approval failed for {approval_id}: {exc}")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
return {"errors": errors}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: interrupt_prod_approval
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def interrupt_prod_approval(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Interrupt to ask the operator to approve the production deployment."""
|
||||||
|
approvals = state.get("pending_approvals") or []
|
||||||
|
version = state.get("version", "")
|
||||||
|
stage_names = ", ".join(a.get("stage_name", "?") for a in approvals)
|
||||||
|
summary = (
|
||||||
|
f"Approve production deployment for release {version}?\n"
|
||||||
|
f"Stages: {stage_names}\n"
|
||||||
|
f"Approve? (yes/no)"
|
||||||
|
)
|
||||||
|
interrupt(summary)
|
||||||
|
return {"current_stage": "prod_pending"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: execute_prod_approval
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def execute_prod_approval(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Approve all pending production stage approvals via AzDo VSRM.
|
||||||
|
|
||||||
|
Non-critical per approval: errors appended on individual failures.
|
||||||
|
"""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
approvals = state.get("pending_approvals") or []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for approval in approvals:
|
||||||
|
approval_id = approval.get("approval_id", "")
|
||||||
|
try:
|
||||||
|
await clients.azdo.approve_release(
|
||||||
|
approval_id=approval_id,
|
||||||
|
comment="Production approved by release agent via Slack",
|
||||||
|
)
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
errors.append(f"execute_prod_approval failed for {approval_id}: {exc}")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
return {"errors": errors}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: move_tickets_to_done
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def move_tickets_to_done(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Transition all staging tickets to Done/Released in Jira."""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
staging_dict = state.get("staging") or {}
|
||||||
|
tickets = staging_dict.get("tickets") or []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for ticket_raw in tickets:
|
||||||
|
ticket_id = ticket_raw.get("id", "")
|
||||||
|
try:
|
||||||
|
await clients.jira.transition_issue(ticket_id, "Done")
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
errors.append(f"move_tickets_to_done failed for {ticket_id}: {exc}")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
return {"errors": errors}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: send_slack_notification (send_release_notification)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def send_slack_notification(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Send a release notification to Slack."""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
version = state.get("version", "")
|
||||||
|
staging_dict = state.get("staging") or {}
|
||||||
|
tickets_raw = staging_dict.get("tickets") or []
|
||||||
|
|
||||||
|
try:
|
||||||
|
ticket_entries = [TicketEntry.model_validate(t) for t in tickets_raw]
|
||||||
|
await clients.slack.send_release_notification(
|
||||||
|
repo=repo_name,
|
||||||
|
version=version,
|
||||||
|
release_date=date.today(),
|
||||||
|
tickets=ticket_entries,
|
||||||
|
)
|
||||||
|
return {"messages": [f"Slack notified: release {version} for {repo_name}"]}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"send_slack_notification failed: {exc}"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node: archive_release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def archive_release(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Archive the staging release in the store."""
|
||||||
|
staging_store = _get_staging_store(config)
|
||||||
|
staging_dict = state.get("staging") or {}
|
||||||
|
if staging_store is None or not staging_dict:
|
||||||
|
return {}
|
||||||
|
staging = StagingRelease.model_validate(staging_dict)
|
||||||
|
await staging_store.archive(staging, date.today())
|
||||||
|
return {"messages": [f"Release {staging.version} archived"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Legacy nodes kept for backward compatibility
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def list_pipelines(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""List build pipelines for the repository via AzDo."""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
try:
|
||||||
|
pipelines = await clients.azdo.list_build_pipelines(repo=repo_name)
|
||||||
|
return {"pipelines": [p.model_dump() for p in pipelines]}
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
return {"errors": [f"list_pipelines failed: {exc}"], "pipelines": []}
|
||||||
|
|
||||||
|
|
||||||
|
async def interrupt_confirm_trigger(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Interrupt to ask the operator to confirm triggering pipelines."""
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
version = state.get("version", "")
|
||||||
|
pipelines = state.get("pipelines") or []
|
||||||
|
pipeline_names = ", ".join(p.get("name", str(p.get("id", ""))) for p in pipelines)
|
||||||
|
summary = (
|
||||||
|
f"Trigger pipelines for {repo_name} {version}?\n"
|
||||||
|
f"Pipelines: {pipeline_names}\n"
|
||||||
|
f"Confirm? (yes/no)"
|
||||||
|
)
|
||||||
|
interrupt(summary)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
async def trigger_pipelines(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Trigger all listed pipelines for the release branch."""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
repo_name = state.get("repo_name", "")
|
||||||
|
version = state.get("version", "")
|
||||||
|
pipelines = state.get("pipelines") or []
|
||||||
|
branch = f"refs/heads/release/{version}"
|
||||||
|
|
||||||
|
triggered: list[dict] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for pipeline in pipelines:
|
||||||
|
pipeline_id = pipeline.get("id")
|
||||||
|
try:
|
||||||
|
result = await clients.azdo.trigger_pipeline(
|
||||||
|
pipeline_id=pipeline_id, branch=branch
|
||||||
|
)
|
||||||
|
triggered.append(result)
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
errors.append(f"trigger_pipelines failed for pipeline {pipeline_id}: {exc}")
|
||||||
|
|
||||||
|
result_dict: dict = {"triggered_builds": triggered}
|
||||||
|
if errors:
|
||||||
|
result_dict["errors"] = errors
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
|
async def check_release_approvals(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Check triggered builds for pending stage approvals."""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
triggered_builds = state.get("triggered_builds") or []
|
||||||
|
pending: list[dict] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for build in triggered_builds:
|
||||||
|
build_id = build.get("id")
|
||||||
|
try:
|
||||||
|
await clients.azdo.get_build_status(build_id=build_id)
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
errors.append(f"check_release_approvals failed for build {build_id}: {exc}")
|
||||||
|
|
||||||
|
result_dict: dict = {"pending_approvals": pending}
|
||||||
|
if errors:
|
||||||
|
result_dict["errors"] = errors
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
|
async def interrupt_confirm_approve(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Interrupt to ask the operator to confirm approving pipeline stages."""
|
||||||
|
approvals = state.get("pending_approvals") or []
|
||||||
|
version = state.get("version", "")
|
||||||
|
stage_names = ", ".join(a.get("stage_name", a.get("approval_id", "?")) for a in approvals)
|
||||||
|
summary = (
|
||||||
|
f"Approve pipeline stages for release {version}?\n"
|
||||||
|
f"Stages: {stage_names}\n"
|
||||||
|
f"Confirm? (yes/no)"
|
||||||
|
)
|
||||||
|
interrupt(summary)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
async def approve_stage(state: dict[str, Any], config: dict) -> dict:
|
||||||
|
"""Approve all pending pipeline stage approvals via AzDo VSRM."""
|
||||||
|
clients = _get_clients(config)
|
||||||
|
approvals = state.get("pending_approvals") or []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for approval in approvals:
|
||||||
|
approval_id = approval.get("approval_id", "")
|
||||||
|
try:
|
||||||
|
await clients.azdo.approve_release(
|
||||||
|
approval_id=approval_id,
|
||||||
|
comment="Approved by release agent",
|
||||||
|
)
|
||||||
|
except ReleaseAgentError as exc:
|
||||||
|
errors.append(f"approve_stage failed for {approval_id}: {exc}")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
return {"errors": errors}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Graph builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def build_release_graph():
|
||||||
|
"""Assemble and compile the Release StateGraph (Phase 5 architecture).
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
load_staging -> interrupt_confirm_release -> create_release_pr
|
||||||
|
-> interrupt_confirm_merge_release -> merge_release_pr
|
||||||
|
-> trigger_ci_build_main -> poll_ci_build_main
|
||||||
|
-> route_ci_result:
|
||||||
|
ci_passed: wait_for_cd_release -> poll_release_approvals
|
||||||
|
-> route_approval_stage:
|
||||||
|
sandbox_pending: interrupt_sandbox_approval
|
||||||
|
-> execute_sandbox_approval
|
||||||
|
-> poll_release_approvals
|
||||||
|
prod_pending: interrupt_prod_approval
|
||||||
|
-> execute_prod_approval
|
||||||
|
-> poll_release_approvals
|
||||||
|
all_deployed: move_tickets_to_done
|
||||||
|
-> send_slack_notification
|
||||||
|
-> archive_release -> END
|
||||||
|
ci_failed: notify_ci_failure -> END
|
||||||
|
"""
|
||||||
|
graph = StateGraph(ReleaseState)
|
||||||
|
|
||||||
|
# Core release PR nodes
|
||||||
|
graph.add_node("load_staging", load_staging)
|
||||||
|
graph.add_node("interrupt_confirm_release", interrupt_confirm_release)
|
||||||
|
graph.add_node("create_release_pr", create_release_pr)
|
||||||
|
graph.add_node("interrupt_confirm_merge_release", interrupt_confirm_merge_release)
|
||||||
|
graph.add_node("merge_release_pr", merge_release_pr)
|
||||||
|
|
||||||
|
# CI nodes
|
||||||
|
graph.add_node("trigger_ci_build_main", trigger_ci_build_main)
|
||||||
|
graph.add_node("poll_ci_build_main", poll_ci_build_main)
|
||||||
|
graph.add_node("notify_ci_failure", notify_ci_failure)
|
||||||
|
|
||||||
|
# CD approval nodes
|
||||||
|
graph.add_node("wait_for_cd_release", wait_for_cd_release)
|
||||||
|
graph.add_node("poll_release_approvals", poll_release_approvals)
|
||||||
|
graph.add_node("interrupt_sandbox_approval", interrupt_sandbox_approval)
|
||||||
|
graph.add_node("execute_sandbox_approval", execute_sandbox_approval)
|
||||||
|
graph.add_node("interrupt_prod_approval", interrupt_prod_approval)
|
||||||
|
graph.add_node("execute_prod_approval", execute_prod_approval)
|
||||||
|
|
||||||
|
# Completion nodes
|
||||||
|
graph.add_node("move_tickets_to_done", move_tickets_to_done)
|
||||||
|
graph.add_node("send_release_notification", send_slack_notification)
|
||||||
|
graph.add_node("archive_release", archive_release)
|
||||||
|
|
||||||
|
# Main release flow up to merge
|
||||||
|
graph.add_edge(START, "load_staging")
|
||||||
|
graph.add_edge("load_staging", "interrupt_confirm_release")
|
||||||
|
graph.add_edge("interrupt_confirm_release", "create_release_pr")
|
||||||
|
graph.add_edge("create_release_pr", "interrupt_confirm_merge_release")
|
||||||
|
graph.add_edge("interrupt_confirm_merge_release", "merge_release_pr")
|
||||||
|
|
||||||
|
# CI pipeline after merge
|
||||||
|
graph.add_edge("merge_release_pr", "trigger_ci_build_main")
|
||||||
|
graph.add_edge("trigger_ci_build_main", "poll_ci_build_main")
|
||||||
|
graph.add_conditional_edges(
|
||||||
|
"poll_ci_build_main",
|
||||||
|
route_ci_result,
|
||||||
|
{"ci_passed": "wait_for_cd_release", "ci_failed": "notify_ci_failure"},
|
||||||
|
)
|
||||||
|
graph.add_edge("notify_ci_failure", END)
|
||||||
|
|
||||||
|
# CD approval flow
|
||||||
|
graph.add_edge("wait_for_cd_release", "poll_release_approvals")
|
||||||
|
graph.add_conditional_edges(
|
||||||
|
"poll_release_approvals",
|
||||||
|
route_approval_stage,
|
||||||
|
{
|
||||||
|
"sandbox_pending": "interrupt_sandbox_approval",
|
||||||
|
"prod_pending": "interrupt_prod_approval",
|
||||||
|
"all_deployed": "move_tickets_to_done",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sandbox approval loop
|
||||||
|
graph.add_edge("interrupt_sandbox_approval", "execute_sandbox_approval")
|
||||||
|
graph.add_edge("execute_sandbox_approval", "poll_release_approvals")
|
||||||
|
|
||||||
|
# Prod approval loop
|
||||||
|
graph.add_edge("interrupt_prod_approval", "execute_prod_approval")
|
||||||
|
graph.add_edge("execute_prod_approval", "poll_release_approvals")
|
||||||
|
|
||||||
|
# Completion
|
||||||
|
graph.add_edge("move_tickets_to_done", "send_release_notification")
|
||||||
|
graph.add_edge("send_release_notification", "archive_release")
|
||||||
|
graph.add_edge("archive_release", END)
|
||||||
|
|
||||||
|
return graph.compile()
|
||||||
131
src/release_agent/graph/routing.py
Normal file
131
src/release_agent/graph/routing.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Pure routing functions for LangGraph conditional edges.
|
||||||
|
|
||||||
|
Each function takes a state dict and returns a routing string.
|
||||||
|
No side effects, no mutation, no I/O.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def is_pr_already_merged(state: dict[str, Any]) -> str:
|
||||||
|
"""Route based on whether the PR was already merged before processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"merged" if state["pr_already_merged"] is True.
|
||||||
|
"active" otherwise (including missing or None).
|
||||||
|
"""
|
||||||
|
if state.get("pr_already_merged"):
|
||||||
|
return "merged"
|
||||||
|
return "active"
|
||||||
|
|
||||||
|
|
||||||
|
def route_after_fetch(state: dict[str, Any]) -> str:
|
||||||
|
"""Three-way route after fetch_pr_details based on merge status and ticket presence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"merged" if state["pr_already_merged"] is True.
|
||||||
|
"active_with_ticket" if PR is active and state["has_ticket"] is True.
|
||||||
|
"active_no_ticket" if PR is active and state["has_ticket"] is falsy.
|
||||||
|
"""
|
||||||
|
if state.get("pr_already_merged"):
|
||||||
|
return "merged"
|
||||||
|
if state.get("has_ticket"):
|
||||||
|
return "active_with_ticket"
|
||||||
|
return "active_no_ticket"
|
||||||
|
|
||||||
|
|
||||||
|
def is_review_approved(state: dict[str, Any]) -> str:
|
||||||
|
"""Route based on the review approval result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"approve" if state["review_approved"] is True.
|
||||||
|
"request_changes" otherwise (including missing or None).
|
||||||
|
"""
|
||||||
|
if state.get("review_approved"):
|
||||||
|
return "approve"
|
||||||
|
return "request_changes"
|
||||||
|
|
||||||
|
|
||||||
|
def has_ticket(state: dict[str, Any]) -> str:
|
||||||
|
"""Route based on whether the PR branch contains a Jira ticket ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"yes" if state["has_ticket"] is True.
|
||||||
|
"no" otherwise (including missing or None).
|
||||||
|
"""
|
||||||
|
if state.get("has_ticket"):
|
||||||
|
return "yes"
|
||||||
|
return "no"
|
||||||
|
|
||||||
|
|
||||||
|
def should_continue_to_release(state: dict[str, Any]) -> str:
|
||||||
|
"""Route based on whether the operator chose to proceed to the release flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"yes" if state["continue_to_release"] is True and no errors accumulated.
|
||||||
|
"no" otherwise (including missing, None, or if errors are present).
|
||||||
|
"""
|
||||||
|
if state.get("continue_to_release") and not state.get("errors"):
|
||||||
|
return "yes"
|
||||||
|
return "no"
|
||||||
|
|
||||||
|
|
||||||
|
def has_pipelines(state: dict[str, Any]) -> str:
|
||||||
|
"""Route based on whether any pipelines are defined for the repository.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"yes" if state["pipelines"] is a non-empty list.
|
||||||
|
"no" otherwise (including missing, None, or empty list).
|
||||||
|
"""
|
||||||
|
pipelines = state.get("pipelines")
|
||||||
|
if pipelines:
|
||||||
|
return "yes"
|
||||||
|
return "no"
|
||||||
|
|
||||||
|
|
||||||
|
def has_pending_approvals(state: dict[str, Any]) -> str:
|
||||||
|
"""Route based on whether there are pipeline stage approvals pending.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"yes" if state["pending_approvals"] is a non-empty list.
|
||||||
|
"no" otherwise (including missing, None, or empty list).
|
||||||
|
"""
|
||||||
|
approvals = state.get("pending_approvals")
|
||||||
|
if approvals:
|
||||||
|
return "yes"
|
||||||
|
return "no"
|
||||||
|
|
||||||
|
|
||||||
|
def route_ci_result(state: dict[str, Any]) -> str:
|
||||||
|
"""Route based on the CI build result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"ci_passed" if state["ci_build_result"] == "succeeded".
|
||||||
|
"ci_failed" otherwise (including missing, None, or any other value).
|
||||||
|
"""
|
||||||
|
if state.get("ci_build_result") == "succeeded":
|
||||||
|
return "ci_passed"
|
||||||
|
return "ci_failed"
|
||||||
|
|
||||||
|
|
||||||
|
def route_approval_stage(state: dict[str, Any]) -> str:
|
||||||
|
"""Route based on the current CD approval stage.
|
||||||
|
|
||||||
|
Uses state["current_stage"] if present. Falls back to "sandbox_pending"
|
||||||
|
when approvals exist but no stage is set (first stage assumption).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"sandbox_pending" if state["current_stage"] == "sandbox_pending" or
|
||||||
|
approvals exist with no explicit stage (default first stage).
|
||||||
|
"prod_pending" if state["current_stage"] == "prod_pending".
|
||||||
|
"all_deployed" if state["pending_approvals"] is empty or missing.
|
||||||
|
"""
|
||||||
|
approvals = state.get("pending_approvals")
|
||||||
|
if not approvals:
|
||||||
|
return "all_deployed"
|
||||||
|
|
||||||
|
current_stage = state.get("current_stage", "")
|
||||||
|
if current_stage == "prod_pending":
|
||||||
|
return "prod_pending"
|
||||||
|
# Default: sandbox_pending (covers both explicit sandbox_pending and unknown)
|
||||||
|
return "sandbox_pending"
|
||||||
361
src/release_agent/main.py
Normal file
361
src/release_agent/main.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
"""FastAPI application entry point.
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
- Create the FastAPI app with lifespan handler
|
||||||
|
- Startup: compile graphs, open DB pool, build tool clients, ensure schema
|
||||||
|
- Shutdown: wait for background tasks (30s timeout)
|
||||||
|
- Register routers and global exception handlers
|
||||||
|
- Expose schedule_graph and run_graph_in_background helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from release_agent.api.approvals import router as approvals_router
|
||||||
|
from release_agent.api.models import ErrorResponse
|
||||||
|
from release_agent.api.slack_interactions import router as slack_interactions_router
|
||||||
|
from release_agent.api.status import router as status_router
|
||||||
|
from release_agent.api.webhooks import router as webhook_router
|
||||||
|
from release_agent.config import Settings
|
||||||
|
from release_agent.exceptions import ReleaseAgentError
|
||||||
|
from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients
|
||||||
|
from release_agent.graph.postgres_staging_store import PostgresStagingStore
|
||||||
|
from release_agent.graph.pr_completed import build_pr_completed_graph
|
||||||
|
from release_agent.graph.release import build_release_graph
|
||||||
|
from release_agent.services.pr_poller import run_pr_poll_loop
|
||||||
|
from release_agent.tools.azdo import AzDoClient
|
||||||
|
from release_agent.tools.claude_review import ClaudeReviewer
|
||||||
|
from release_agent.tools.jira import JiraClient
|
||||||
|
from release_agent.tools.slack import SlackClient
|
||||||
|
|
||||||
|
try:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
AsyncConnectionPool = None # type: ignore[assignment,misc]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SHUTDOWN_TIMEOUT_SECONDS = 30
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Startup helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _create_tool_clients(settings: Settings) -> tuple[ToolClients, list]:
|
||||||
|
"""Build and return a ToolClients instance and closeable HTTP clients."""
|
||||||
|
http_client = httpx.AsyncClient(timeout=30.0)
|
||||||
|
vsrm_http_client = httpx.AsyncClient(timeout=30.0)
|
||||||
|
azdo = AzDoClient(
|
||||||
|
http_client=http_client,
|
||||||
|
vsrm_http_client=vsrm_http_client,
|
||||||
|
base_url=settings.azdo_api_url,
|
||||||
|
vsrm_base_url=settings.azdo_vsrm_api_url,
|
||||||
|
pat=settings.azdo_pat.get_secret_value(),
|
||||||
|
)
|
||||||
|
jira = JiraClient(
|
||||||
|
http_client=http_client,
|
||||||
|
base_url=settings.jira_base_url,
|
||||||
|
email=settings.jira_email,
|
||||||
|
api_token=settings.jira_api_token.get_secret_value(),
|
||||||
|
)
|
||||||
|
slack = SlackClient(
|
||||||
|
http_client=http_client,
|
||||||
|
webhook_url=settings.slack_webhook_url.get_secret_value(),
|
||||||
|
bot_token=settings.slack_bot_token.get_secret_value(),
|
||||||
|
channel_id=settings.slack_channel_id,
|
||||||
|
)
|
||||||
|
reviewer = ClaudeReviewer()
|
||||||
|
clients = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
|
||||||
|
return clients, [http_client, vsrm_http_client]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_staging_store(pool=None) -> PostgresStagingStore | JsonFileStagingStore:
|
||||||
|
"""Return a StagingStore instance.
|
||||||
|
|
||||||
|
When a PostgreSQL pool is provided, returns a PostgresStagingStore.
|
||||||
|
Falls back to JsonFileStagingStore for local development without a DB.
|
||||||
|
"""
|
||||||
|
if pool is not None:
|
||||||
|
return PostgresStagingStore(pool=pool)
|
||||||
|
return JsonFileStagingStore(directory=Path("data/staging"))
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_db_schema(pool) -> None:
|
||||||
|
"""Create all required tables if they do not already exist."""
|
||||||
|
statements = [
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS agent_threads (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
graph_name TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'running',
|
||||||
|
interrupt_value TEXT,
|
||||||
|
state JSONB,
|
||||||
|
repo_name TEXT,
|
||||||
|
pr_id TEXT,
|
||||||
|
version TEXT,
|
||||||
|
slack_message_ts TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
ALTER TABLE agent_threads
|
||||||
|
ADD COLUMN IF NOT EXISTS slack_message_ts TEXT
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS staging_releases (
|
||||||
|
repo TEXT PRIMARY KEY,
|
||||||
|
version TEXT NOT NULL,
|
||||||
|
started_at DATE NOT NULL,
|
||||||
|
tickets JSONB NOT NULL DEFAULT '[]',
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS archived_releases (
|
||||||
|
repo TEXT NOT NULL,
|
||||||
|
version TEXT NOT NULL,
|
||||||
|
started_at DATE NOT NULL,
|
||||||
|
tickets JSONB NOT NULL DEFAULT '[]',
|
||||||
|
released_at DATE NOT NULL,
|
||||||
|
PRIMARY KEY (repo, version)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
for sql in statements:
|
||||||
|
await cur.execute(sql)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Background task helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def schedule_graph(
|
||||||
|
*,
|
||||||
|
app: FastAPI,
|
||||||
|
graph,
|
||||||
|
initial_state: dict,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
tool_clients: ToolClients | None = None,
|
||||||
|
db_pool=None,
|
||||||
|
) -> str:
|
||||||
|
"""Schedule a graph run as a background asyncio task.
|
||||||
|
|
||||||
|
Creates a new thread_id if not provided. Adds the task to
|
||||||
|
app.state.background_tasks for graceful shutdown tracking.
|
||||||
|
|
||||||
|
Returns the thread_id used for this execution.
|
||||||
|
"""
|
||||||
|
if thread_id is None:
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
if tool_clients is None and hasattr(app.state, "tool_clients"):
|
||||||
|
tool_clients = app.state.tool_clients
|
||||||
|
|
||||||
|
if db_pool is None and hasattr(app.state, "db_pool"):
|
||||||
|
db_pool = app.state.db_pool
|
||||||
|
|
||||||
|
# Extract settings for config propagation
|
||||||
|
settings = getattr(app.state, "settings", None)
|
||||||
|
repos_base_dir = getattr(settings, "repos_base_dir", "") if settings else ""
|
||||||
|
default_jira_project = getattr(settings, "default_jira_project", "ALLPOST") if settings else "ALLPOST"
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
run_graph_in_background(
|
||||||
|
graph=graph,
|
||||||
|
initial_state=initial_state,
|
||||||
|
thread_id=thread_id,
|
||||||
|
tool_clients=tool_clients,
|
||||||
|
db_pool=db_pool,
|
||||||
|
repos_base_dir=repos_base_dir,
|
||||||
|
default_jira_project=default_jira_project,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
app.state.background_tasks.add(task)
|
||||||
|
task.add_done_callback(app.state.background_tasks.discard)
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
|
||||||
|
async def run_graph_in_background(
|
||||||
|
*,
|
||||||
|
graph,
|
||||||
|
initial_state: dict,
|
||||||
|
thread_id: str,
|
||||||
|
tool_clients: ToolClients | None = None,
|
||||||
|
db_pool=None,
|
||||||
|
repos_base_dir: str = "",
|
||||||
|
default_jira_project: str = "ALLPOST",
|
||||||
|
) -> None:
|
||||||
|
"""Execute the graph and update thread status in the database."""
|
||||||
|
from release_agent.api.webhooks import _upsert_thread
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"clients": tool_clients,
|
||||||
|
"repos_base_dir": repos_base_dir,
|
||||||
|
"default_jira_project": default_jira_project,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
if db_pool is not None:
|
||||||
|
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="running", state=initial_state)
|
||||||
|
result = await graph.ainvoke(initial_state, config=config)
|
||||||
|
if db_pool is not None:
|
||||||
|
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="completed", state=result or {})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Graph execution failed for thread %s: %s", thread_id, exc)
|
||||||
|
if db_pool is not None:
|
||||||
|
await _upsert_thread(
|
||||||
|
db_pool,
|
||||||
|
thread_id=thread_id,
|
||||||
|
thread_status="failed",
|
||||||
|
state={"errors": [str(exc)]},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Lifespan
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Manage startup and graceful shutdown of shared resources."""
|
||||||
|
settings = Settings()
|
||||||
|
app.state.settings = settings
|
||||||
|
app.state.background_tasks = set()
|
||||||
|
app.state.started_at = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
# Compile graphs once at startup
|
||||||
|
logger.info("Compiling graphs...")
|
||||||
|
app.state.graphs = {
|
||||||
|
"pr_completed": build_pr_completed_graph(),
|
||||||
|
"release": build_release_graph(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build tool clients
|
||||||
|
logger.info("Building tool clients...")
|
||||||
|
app.state.tool_clients, app.state._http_clients = _create_tool_clients(settings)
|
||||||
|
|
||||||
|
# Open PostgreSQL connection pool
|
||||||
|
logger.info("Opening DB connection pool...")
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
conninfo=settings.postgres_dsn.get_secret_value(),
|
||||||
|
open=False,
|
||||||
|
)
|
||||||
|
await pool.open()
|
||||||
|
app.state.db_pool = pool
|
||||||
|
|
||||||
|
# Ensure schema exists
|
||||||
|
await _ensure_db_schema(pool)
|
||||||
|
|
||||||
|
# Build staging store (backed by PostgreSQL)
|
||||||
|
app.state.staging_store = _create_staging_store(pool=pool)
|
||||||
|
|
||||||
|
# Start PR polling background task if enabled
|
||||||
|
if settings.pr_poll_enabled:
|
||||||
|
logger.info("Starting PR polling background task...")
|
||||||
|
poll_task = asyncio.create_task(
|
||||||
|
run_pr_poll_loop(
|
||||||
|
azdo_client=app.state.tool_clients.azdo,
|
||||||
|
db_pool=pool,
|
||||||
|
watched_repos=settings.watched_repos_list,
|
||||||
|
target_branch=settings.pr_poll_target_branch,
|
||||||
|
interval_seconds=settings.pr_poll_interval_seconds,
|
||||||
|
schedule_fn=lambda *, initial_state: schedule_graph(
|
||||||
|
app=app,
|
||||||
|
graph=app.state.graphs["pr_completed"],
|
||||||
|
initial_state=initial_state,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
app.state.background_tasks.add(poll_task)
|
||||||
|
poll_task.add_done_callback(app.state.background_tasks.discard)
|
||||||
|
|
||||||
|
logger.info("Startup complete.")
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Graceful shutdown: wait for background tasks
|
||||||
|
tasks = list(app.state.background_tasks)
|
||||||
|
if tasks:
|
||||||
|
logger.info("Waiting for %d background task(s) to complete...", len(tasks))
|
||||||
|
done, pending = await asyncio.wait(tasks, timeout=_SHUTDOWN_TIMEOUT_SECONDS)
|
||||||
|
for task in pending:
|
||||||
|
logger.warning("Cancelling timed-out background task %s", task)
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Close HTTP clients
|
||||||
|
for client in getattr(app.state, "_http_clients", []):
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
await pool.close()
|
||||||
|
logger.info("Shutdown complete.")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Exception handlers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _release_agent_error_handler(request: Request, exc: ReleaseAgentError) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=ErrorResponse(
|
||||||
|
error=type(exc).__name__,
|
||||||
|
detail=str(exc),
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generic_error_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||||
|
logger.exception("Unhandled exception: %s", exc)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=ErrorResponse(
|
||||||
|
error="InternalServerError",
|
||||||
|
detail="An unexpected error occurred",
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# App factory
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create and configure the FastAPI application."""
|
||||||
|
app = FastAPI(
|
||||||
|
title="Billo Release Agent",
|
||||||
|
version="0.1.0",
|
||||||
|
description="LangGraph-based release automation agent",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register routers
|
||||||
|
app.include_router(webhook_router)
|
||||||
|
app.include_router(approvals_router)
|
||||||
|
app.include_router(status_router)
|
||||||
|
app.include_router(slack_interactions_router)
|
||||||
|
|
||||||
|
# Register exception handlers
|
||||||
|
app.add_exception_handler(ReleaseAgentError, _release_agent_error_handler) # type: ignore[arg-type]
|
||||||
|
app.add_exception_handler(Exception, _generic_error_handler)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ASGI entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
1
src/release_agent/models/__init__.py
Normal file
1
src/release_agent/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Pydantic models for the release agent."""
|
||||||
43
src/release_agent/models/build.py
Normal file
43
src/release_agent/models/build.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Build and approval dataclasses for CI/CD tracking.
|
||||||
|
|
||||||
|
Immutable (frozen) dataclasses used across graph nodes, routing functions,
|
||||||
|
and API endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BuildStatus:
|
||||||
|
"""Immutable snapshot of an Azure DevOps build's current state.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
status: AzDo build status string (e.g. "notStarted", "inProgress",
|
||||||
|
"completed", "cancelling").
|
||||||
|
result: AzDo build result string (e.g. "succeeded", "failed",
|
||||||
|
"canceled", "partiallySucceeded"). None when not completed.
|
||||||
|
build_url: Direct URL to the build results page, or None if unknown.
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
result: str | None
|
||||||
|
build_url: str | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ApprovalRecord:
|
||||||
|
"""Immutable record representing a single release pipeline approval gate.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
approval_id: Unique identifier of the approval record in AzDo.
|
||||||
|
stage_name: Human-readable name of the deployment stage (e.g.
|
||||||
|
"Sandbox", "Production").
|
||||||
|
status: Current approval status (e.g. "pending", "approved",
|
||||||
|
"rejected").
|
||||||
|
release_id: Numeric ID of the AzDo release containing this approval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
approval_id: str
|
||||||
|
stage_name: str
|
||||||
|
status: str
|
||||||
|
release_id: int
|
||||||
33
src/release_agent/models/jira.py
Normal file
33
src/release_agent/models/jira.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Jira domain models."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class JiraTransition(BaseModel):
|
||||||
|
"""A Jira workflow transition.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Transition identifier as returned by the Jira API.
|
||||||
|
name: Human-readable transition name (e.g. "Done", "Released").
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class JiraIssue(BaseModel):
|
||||||
|
"""A Jira issue.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
key: Issue key (e.g. "ALLPOST-100").
|
||||||
|
summary: Issue summary / title.
|
||||||
|
status: Current workflow status name (e.g. "In Progress").
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
key: str
|
||||||
|
summary: str
|
||||||
|
status: str
|
||||||
31
src/release_agent/models/pipeline.py
Normal file
31
src/release_agent/models/pipeline.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Pipeline-related models for Azure DevOps pipeline tracking."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineInfo(BaseModel):
|
||||||
|
"""An immutable record identifying an Azure DevOps pipeline."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
repo: str
|
||||||
|
|
||||||
|
|
||||||
|
class ReleasePipelineStage(BaseModel):
|
||||||
|
"""An immutable record representing a single stage in a release pipeline."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
name: str
|
||||||
|
rank: int = Field(ge=0)
|
||||||
|
requires_approval: bool
|
||||||
|
approval_id: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_approval_consistency(self) -> "ReleasePipelineStage":
|
||||||
|
"""Ensure approval_id is consistent with requires_approval."""
|
||||||
|
if not self.requires_approval and self.approval_id is not None:
|
||||||
|
raise ValueError("approval_id must be None when requires_approval is False")
|
||||||
|
return self
|
||||||
38
src/release_agent/models/pr.py
Normal file
38
src/release_agent/models/pr.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""PRInfo model representing a pull request with ticket extraction."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, HttpUrl, model_validator
|
||||||
|
|
||||||
|
from release_agent.branch_parser import parse_branch
|
||||||
|
|
||||||
|
|
||||||
|
class PRInfo(BaseModel):
|
||||||
|
"""An immutable representation of a pull request.
|
||||||
|
|
||||||
|
The ticket_id and has_ticket fields are automatically derived
|
||||||
|
from the branch name during model construction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
pr_id: str
|
||||||
|
pr_url: HttpUrl
|
||||||
|
repo_name: str
|
||||||
|
branch: str
|
||||||
|
pr_title: str
|
||||||
|
pr_status: Literal["active", "completed", "abandoned"]
|
||||||
|
ticket_id: str | None = None
|
||||||
|
has_ticket: bool = False
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def extract_ticket_from_branch(cls, values: dict) -> dict:
|
||||||
|
"""Derive ticket_id and has_ticket from the branch name."""
|
||||||
|
branch = values.get("branch", "")
|
||||||
|
ticket_id, has_ticket = parse_branch(branch)
|
||||||
|
# ticket_id and has_ticket are always derived from branch parsing.
|
||||||
|
# Any caller-provided values are overridden to keep the model consistent.
|
||||||
|
values["ticket_id"] = ticket_id
|
||||||
|
values["has_ticket"] = has_ticket
|
||||||
|
return values
|
||||||
65
src/release_agent/models/release.py
Normal file
65
src/release_agent/models/release.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""StagingRelease and ArchivedRelease models for tracking release state."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||||
|
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
|
||||||
|
_VERSION_PATTERN = re.compile(r"^v\d+\.\d+\.\d+$")
|
||||||
|
|
||||||
|
|
||||||
|
class StagingRelease(BaseModel):
|
||||||
|
"""An immutable model representing a release currently in staging.
|
||||||
|
|
||||||
|
All mutation operations return new instances rather than modifying
|
||||||
|
the existing one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
version: str
|
||||||
|
repo: str
|
||||||
|
started_at: date
|
||||||
|
tickets: list[TicketEntry]
|
||||||
|
|
||||||
|
@field_validator("version")
|
||||||
|
@classmethod
|
||||||
|
def validate_version_format(cls, value: str) -> str:
|
||||||
|
"""Validate version string matches vMAJOR.MINOR.PATCH."""
|
||||||
|
if not _VERSION_PATTERN.match(value):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid version format: {value!r}. Must match ^v\\d+\\.\\d+\\.\\d+$"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def add_ticket(self, ticket: TicketEntry) -> "StagingRelease":
|
||||||
|
"""Return a new StagingRelease with the given ticket appended.
|
||||||
|
|
||||||
|
Does not mutate this instance.
|
||||||
|
"""
|
||||||
|
return self.model_copy(update={"tickets": [*self.tickets, ticket]})
|
||||||
|
|
||||||
|
def has_ticket(self, ticket_id: str) -> bool:
|
||||||
|
"""Return True if a ticket with the given ID exists in this release."""
|
||||||
|
return any(t.id == ticket_id for t in self.tickets)
|
||||||
|
|
||||||
|
|
||||||
|
class ArchivedRelease(StagingRelease):
|
||||||
|
"""An immutable model representing a completed/released release.
|
||||||
|
|
||||||
|
Extends StagingRelease with a released_at date that must be
|
||||||
|
on or after started_at.
|
||||||
|
"""
|
||||||
|
|
||||||
|
released_at: date
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_released_after_started(self) -> "ArchivedRelease":
|
||||||
|
"""Validate that released_at is not before started_at."""
|
||||||
|
if self.released_at < self.started_at:
|
||||||
|
raise ValueError(
|
||||||
|
f"released_at ({self.released_at}) must be >= started_at ({self.started_at})"
|
||||||
|
)
|
||||||
|
return self
|
||||||
48
src/release_agent/models/review.py
Normal file
48
src/release_agent/models/review.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Review models for Claude PR review structured output."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, computed_field
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewIssue(BaseModel):
|
||||||
|
"""A single issue identified during PR review.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
severity: One of "blocker", "error", "warning", or "info".
|
||||||
|
description: Human-readable description of the issue.
|
||||||
|
file_path: Optional path to the affected file.
|
||||||
|
suggestion: Optional remediation suggestion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
severity: Literal["blocker", "error", "warning", "info"]
|
||||||
|
description: str
|
||||||
|
file_path: str | None = None
|
||||||
|
line_start: int | None = None
|
||||||
|
line_end: int | None = None
|
||||||
|
suggestion: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewResult(BaseModel):
|
||||||
|
"""Structured output from a Claude PR review.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
verdict: Either "approve" or "request_changes".
|
||||||
|
summary: A human-readable summary of the review.
|
||||||
|
issues: List of issues identified during review.
|
||||||
|
has_blockers: Computed field - True if any issue has severity "blocker".
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
verdict: Literal["approve", "request_changes"]
|
||||||
|
summary: str
|
||||||
|
issues: tuple[ReviewIssue, ...] = ()
|
||||||
|
|
||||||
|
@computed_field # type: ignore[misc]
|
||||||
|
@property
|
||||||
|
def has_blockers(self) -> bool:
|
||||||
|
"""Return True if any issue has blocker severity."""
|
||||||
|
return any(issue.severity == "blocker" for issue in self.issues)
|
||||||
33
src/release_agent/models/ticket.py
Normal file
33
src/release_agent/models/ticket.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""TicketEntry model representing a Jira ticket linked to a merged PR."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
|
||||||
|
|
||||||
|
_JIRA_ID_PATTERN = re.compile(r"^[A-Z][A-Z0-9]*-\d+$")
|
||||||
|
|
||||||
|
|
||||||
|
class TicketEntry(BaseModel):
|
||||||
|
"""An immutable record of a Jira ticket merged into a release."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
summary: str
|
||||||
|
pr_id: str
|
||||||
|
pr_url: HttpUrl
|
||||||
|
pr_title: str
|
||||||
|
branch: str
|
||||||
|
merged_at: date
|
||||||
|
|
||||||
|
@field_validator("id")
|
||||||
|
@classmethod
|
||||||
|
def validate_jira_id(cls, value: str) -> str:
|
||||||
|
"""Validate that the ID matches the Jira ticket format."""
|
||||||
|
if not _JIRA_ID_PATTERN.match(value):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid Jira ticket ID: {value!r}. "
|
||||||
|
"Must match pattern ^[A-Z][A-Z0-9]*-\\d+$"
|
||||||
|
)
|
||||||
|
return value
|
||||||
40
src/release_agent/models/webhook.py
Normal file
40
src/release_agent/models/webhook.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""WebhookPayload models for Azure DevOps webhook events."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookRepository(BaseModel):
|
||||||
|
"""An immutable record of a repository as reported in a webhook event."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
web_url: HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookResource(BaseModel):
|
||||||
|
"""An immutable record of the resource object in a webhook payload."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
repository: WebhookRepository
|
||||||
|
pull_request_id: int
|
||||||
|
title: str
|
||||||
|
source_ref_name: str
|
||||||
|
target_ref_name: str
|
||||||
|
status: Literal["active", "completed", "abandoned"]
|
||||||
|
closed_date: datetime | None
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookPayload(BaseModel):
|
||||||
|
"""An immutable model representing an Azure DevOps webhook event payload."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
subscription_id: str
|
||||||
|
event_type: str
|
||||||
|
resource: WebhookResource
|
||||||
0
src/release_agent/services/__init__.py
Normal file
0
src/release_agent/services/__init__.py
Normal file
46
src/release_agent/services/pr_dedup.py
Normal file
46
src/release_agent/services/pr_dedup.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""PR deduplication service.
|
||||||
|
|
||||||
|
Queries the agent_threads table to find which PRs from a given list have
|
||||||
|
not yet been processed. This prevents the PR poller from re-triggering
|
||||||
|
graph runs for PRs that already have an existing thread.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from release_agent.models.pr import PRInfo
|
||||||
|
|
||||||
|
_QUERY = """
|
||||||
|
SELECT pr_id, repo_name
|
||||||
|
FROM agent_threads
|
||||||
|
WHERE (pr_id, repo_name) IN (
|
||||||
|
SELECT unnest(%s::text[]), unnest(%s::text[])
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def find_unprocessed_prs(pool, prs: list[PRInfo]) -> list[PRInfo]:
|
||||||
|
"""Return the subset of prs that have no existing agent_threads record.
|
||||||
|
|
||||||
|
A PR is considered already-processed if there exists a row in agent_threads
|
||||||
|
with matching (pr_id, repo_name) pair. The SQL uses unnest to enforce
|
||||||
|
pair-wise matching (not independent ANY on each column).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool: Async psycopg connection pool.
|
||||||
|
prs: List of PRInfo objects to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new list containing only PRs with no existing thread.
|
||||||
|
Original list is not mutated.
|
||||||
|
"""
|
||||||
|
if not prs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
pr_ids = [p.pr_id for p in prs]
|
||||||
|
repo_names = [p.repo_name for p in prs]
|
||||||
|
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(_QUERY, (pr_ids, repo_names))
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
processed = {(str(row[0]), str(row[1])) for row in rows}
|
||||||
|
return [p for p in prs if (p.pr_id, p.repo_name) not in processed]
|
||||||
108
src/release_agent/services/pr_poller.py
Normal file
108
src/release_agent/services/pr_poller.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""PR polling service.
|
||||||
|
|
||||||
|
Polls Azure DevOps for active pull requests targeting a configured branch,
|
||||||
|
deduplicates against already-processed PRs, and schedules graph runs for
|
||||||
|
newly discovered PRs by synthesizing a fake webhook payload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from release_agent.models.pr import PRInfo
|
||||||
|
from release_agent.services.pr_dedup import find_unprocessed_prs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _synthesize_webhook_payload(pr: PRInfo) -> dict:
|
||||||
|
"""Build a fake webhook payload dict from a PRInfo.
|
||||||
|
|
||||||
|
Produces a dict that matches the structure expected by the parse_webhook
|
||||||
|
node so the PR can be processed as if it arrived via the real webhook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pr: The PRInfo object representing the active pull request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict with event_type, subscription_id, and resource fields.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"event_type": "git.pullrequest.updated",
|
||||||
|
"subscription_id": f"polled-{pr.repo_name}-{pr.pr_id}",
|
||||||
|
"resource": {
|
||||||
|
"pull_request_id": int(pr.pr_id),
|
||||||
|
"title": pr.pr_title,
|
||||||
|
"status": pr.pr_status,
|
||||||
|
"source_ref_name": pr.branch,
|
||||||
|
"target_ref_name": "",
|
||||||
|
"repository": {
|
||||||
|
"id": f"{pr.repo_name}-id",
|
||||||
|
"name": pr.repo_name,
|
||||||
|
"web_url": str(pr.pr_url).rsplit("/pullrequest/", 1)[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def run_pr_poll_loop(
|
||||||
|
*,
|
||||||
|
azdo_client,
|
||||||
|
db_pool,
|
||||||
|
watched_repos: list[str],
|
||||||
|
target_branch: str,
|
||||||
|
interval_seconds: int,
|
||||||
|
schedule_fn: Callable[..., Any],
|
||||||
|
sleep_fn: Callable[[float], Coroutine] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Poll Azure DevOps for active PRs and schedule unprocessed ones.
|
||||||
|
|
||||||
|
Runs indefinitely until cancelled. For each polling interval:
|
||||||
|
1. For each watched repo, list active PRs targeting target_branch.
|
||||||
|
2. Find PRs not yet recorded in agent_threads.
|
||||||
|
3. For each unprocessed PR, call schedule_fn with the synthesized payload.
|
||||||
|
4. Sleep for interval_seconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
azdo_client: AzDoClient instance.
|
||||||
|
db_pool: Async psycopg connection pool.
|
||||||
|
watched_repos: List of repository names to poll.
|
||||||
|
target_branch: Target branch ref to filter PRs (e.g. "refs/heads/develop").
|
||||||
|
interval_seconds: Seconds to sleep between polling iterations.
|
||||||
|
schedule_fn: Callable to schedule a graph run; called with keyword arg:
|
||||||
|
initial_state (dict).
|
||||||
|
sleep_fn: Async sleep callable (default: asyncio.sleep). Injectable for testing.
|
||||||
|
"""
|
||||||
|
if sleep_fn is None:
|
||||||
|
sleep_fn = asyncio.sleep # type: ignore[assignment]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
all_prs: list[PRInfo] = []
|
||||||
|
|
||||||
|
for repo in watched_repos:
|
||||||
|
try:
|
||||||
|
prs = await azdo_client.list_active_prs(repo, target_branch)
|
||||||
|
all_prs.extend(prs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("PR polling failed for repo %s: %s", repo, exc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
unprocessed = await find_unprocessed_prs(db_pool, all_prs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("PR dedup query failed: %s", exc)
|
||||||
|
unprocessed = []
|
||||||
|
|
||||||
|
for pr in unprocessed:
|
||||||
|
webhook_payload = _synthesize_webhook_payload(pr)
|
||||||
|
initial_state = {
|
||||||
|
"webhook_payload": webhook_payload,
|
||||||
|
"pr_id": pr.pr_id,
|
||||||
|
"repo_name": pr.repo_name,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
schedule_fn(initial_state=initial_state)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to schedule PR %s/%s: %s", pr.repo_name, pr.pr_id, exc)
|
||||||
|
|
||||||
|
await sleep_fn(interval_seconds) # type: ignore[misc]
|
||||||
91
src/release_agent/state.py
Normal file
91
src/release_agent/state.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""LangGraph state definition for the release agent.
|
||||||
|
|
||||||
|
Uses TypedDict with total=False so all fields are optional, enabling
|
||||||
|
partial state updates throughout the graph execution.
|
||||||
|
|
||||||
|
Reducers follow immutable patterns - they always return new lists.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
def add_messages(existing: list[str], new: list[str]) -> list[str]:
|
||||||
|
"""Accumulate messages without mutating the existing list.
|
||||||
|
|
||||||
|
Returns a new list containing all existing messages followed by new ones.
|
||||||
|
"""
|
||||||
|
return [*existing, *new]
|
||||||
|
|
||||||
|
|
||||||
|
def add_errors(existing: list[str], new: list[str]) -> list[str]:
|
||||||
|
"""Accumulate error messages without mutating the existing list.
|
||||||
|
|
||||||
|
Returns a new list containing all existing errors followed by new ones.
|
||||||
|
"""
|
||||||
|
return [*existing, *new]
|
||||||
|
|
||||||
|
|
||||||
|
class ReleaseState(TypedDict, total=False):
|
||||||
|
"""The mutable state passed between nodes in the release agent graph.
|
||||||
|
|
||||||
|
All fields are optional (total=False). Each field can be updated
|
||||||
|
independently by graph nodes. List fields use reducers that accumulate
|
||||||
|
values rather than replacing them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Core identifiers
|
||||||
|
repo_name: str
|
||||||
|
pr_id: str
|
||||||
|
ticket_id: str
|
||||||
|
version: str
|
||||||
|
|
||||||
|
# Accumulated lists
|
||||||
|
messages: Annotated[list[str], add_messages]
|
||||||
|
errors: Annotated[list[str], add_errors]
|
||||||
|
|
||||||
|
# Phase 3: Webhook / PR fields
|
||||||
|
webhook_payload: dict
|
||||||
|
pr_info: dict
|
||||||
|
pr_diff: str
|
||||||
|
last_merge_source_commit: str
|
||||||
|
|
||||||
|
# Phase 3: Jira / ticket fields
|
||||||
|
ticket_summary: str
|
||||||
|
has_ticket: bool
|
||||||
|
|
||||||
|
# Phase 3: Review fields
|
||||||
|
review_result: dict
|
||||||
|
review_approved: bool
|
||||||
|
|
||||||
|
# Phase 3: Staging fields
|
||||||
|
staging: dict
|
||||||
|
pr_already_merged: bool
|
||||||
|
|
||||||
|
# Phase 3: Release PR fields
|
||||||
|
release_pr_id: str
|
||||||
|
release_pr_commit: str
|
||||||
|
|
||||||
|
# Phase 3: Pipeline / approval fields
|
||||||
|
pipelines: list[dict]
|
||||||
|
triggered_builds: list[dict]
|
||||||
|
pending_approvals: list[dict]
|
||||||
|
|
||||||
|
# Phase 3: Flow control
|
||||||
|
continue_to_release: bool
|
||||||
|
|
||||||
|
# Phase 5: CI build tracking
|
||||||
|
ci_build_id: int
|
||||||
|
ci_build_status: str
|
||||||
|
ci_build_result: str
|
||||||
|
ci_build_url: str
|
||||||
|
|
||||||
|
# Phase 5: CD release tracking
|
||||||
|
release_definition_id: int
|
||||||
|
release_id: int
|
||||||
|
current_stage: str
|
||||||
|
|
||||||
|
# Phase 5: Slack interactive message tracking
|
||||||
|
approval_message_ts: str
|
||||||
|
slack_message_ts: str
|
||||||
0
src/release_agent/tools/__init__.py
Normal file
0
src/release_agent/tools/__init__.py
Normal file
103
src/release_agent/tools/_http.py
Normal file
103
src/release_agent/tools/_http.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""Shared HTTP helpers for service clients.
|
||||||
|
|
||||||
|
Provides a unified response-to-exception mapper and a Basic auth header builder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from release_agent.exceptions import (
|
||||||
|
AuthenticationError,
|
||||||
|
NotFoundError,
|
||||||
|
RateLimitError,
|
||||||
|
ServiceError,
|
||||||
|
ServiceUnavailableError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Status codes that should not raise exceptions
|
||||||
|
_OK_RANGE_START = 200
|
||||||
|
_OK_RANGE_END = 299
|
||||||
|
_REDIRECT_RANGE_END = 399
|
||||||
|
|
||||||
|
|
||||||
|
def raise_for_status(response: httpx.Response, *, service: str) -> None:
|
||||||
|
"""Raise an appropriate exception based on the HTTP status code.
|
||||||
|
|
||||||
|
2xx and 3xx responses do not raise. 4xx/5xx responses raise typed exceptions
|
||||||
|
that carry the service name for debugging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The httpx response to check.
|
||||||
|
service: Name of the service that returned this response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: For 401 or 403 status codes.
|
||||||
|
NotFoundError: For 404 status codes.
|
||||||
|
RateLimitError: For 429 status codes.
|
||||||
|
ServiceUnavailableError: For 503 status codes.
|
||||||
|
ServiceError: For all other 4xx/5xx status codes.
|
||||||
|
"""
|
||||||
|
code = response.status_code
|
||||||
|
|
||||||
|
if code <= _REDIRECT_RANGE_END:
|
||||||
|
return
|
||||||
|
|
||||||
|
if code in (401, 403):
|
||||||
|
raise AuthenticationError(
|
||||||
|
service=service, status_code=code, detail=_extract_detail(response)
|
||||||
|
)
|
||||||
|
|
||||||
|
if code == 404:
|
||||||
|
raise NotFoundError(service=service, detail=_extract_detail(response))
|
||||||
|
|
||||||
|
if code == 429:
|
||||||
|
retry_after = _parse_retry_after(response)
|
||||||
|
raise RateLimitError(service=service, retry_after=retry_after)
|
||||||
|
|
||||||
|
if code == 503:
|
||||||
|
raise ServiceUnavailableError(service=service, detail=_extract_detail(response))
|
||||||
|
|
||||||
|
raise ServiceError(service=service, status_code=code, detail=_extract_detail(response))
|
||||||
|
|
||||||
|
|
||||||
|
def build_auth_header(username: str, password: str) -> dict[str, str]:
|
||||||
|
"""Build a Basic authentication header.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: The username (may be empty string for PAT-only auth).
|
||||||
|
password: The password or token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict with a single "Authorization" key containing the Basic auth value.
|
||||||
|
"""
|
||||||
|
credentials = f"{username}:{password}"
|
||||||
|
encoded = base64.b64encode(credentials.encode()).decode()
|
||||||
|
return {"Authorization": f"Basic {encoded}"}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_detail(response: httpx.Response) -> str | None:
|
||||||
|
"""Try to extract a human-readable detail string from the response body."""
|
||||||
|
try:
|
||||||
|
body = response.json()
|
||||||
|
if isinstance(body, dict):
|
||||||
|
return (
|
||||||
|
body.get("message")
|
||||||
|
or body.get("detail")
|
||||||
|
or ("; ".join(body["errorMessages"]) if body.get("errorMessages") else None)
|
||||||
|
or str(body)
|
||||||
|
)
|
||||||
|
return str(body)
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
return response.text or None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_retry_after(response: httpx.Response) -> int | None:
|
||||||
|
"""Parse the Retry-After header value as an integer number of seconds."""
|
||||||
|
value = response.headers.get("Retry-After")
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
74
src/release_agent/tools/_retry.py
Normal file
74
src/release_agent/tools/_retry.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Async retry decorator with exponential backoff.
|
||||||
|
|
||||||
|
Retries only on RateLimitError and ServiceUnavailableError.
|
||||||
|
Respects the Retry-After header when present in RateLimitError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from release_agent.exceptions import RateLimitError, ServiceUnavailableError
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
_DEFAULT_MAX_ATTEMPTS = 3
|
||||||
|
_DEFAULT_BASE_DELAY = 1.0
|
||||||
|
_BACKOFF_MULTIPLIER = 2.0
|
||||||
|
|
||||||
|
# Types that trigger a retry
|
||||||
|
_RETRYABLE = (RateLimitError, ServiceUnavailableError)
|
||||||
|
|
||||||
|
|
||||||
|
def with_retry(
|
||||||
|
max_attempts: int = _DEFAULT_MAX_ATTEMPTS,
|
||||||
|
base_delay: float = _DEFAULT_BASE_DELAY,
|
||||||
|
sleep_fn: Callable[[float], Awaitable[None]] | None = None,
|
||||||
|
) -> Callable[[Callable[..., Awaitable[_T]]], Callable[..., Awaitable[_T]]]:
|
||||||
|
"""Decorator factory that retries an async function on retryable errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_attempts: Maximum total attempts (including the first call).
|
||||||
|
base_delay: Initial delay in seconds between retries. Doubles each retry.
|
||||||
|
sleep_fn: Async sleep function (defaults to asyncio.sleep). Injected for testing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A decorator that wraps an async function with retry logic.
|
||||||
|
"""
|
||||||
|
if max_attempts < 1:
|
||||||
|
raise ValueError(f"max_attempts must be >= 1, got {max_attempts}")
|
||||||
|
|
||||||
|
_sleep = sleep_fn if sleep_fn is not None else asyncio.sleep
|
||||||
|
|
||||||
|
def decorator(fn: Callable[..., Awaitable[_T]]) -> Callable[..., Awaitable[_T]]:
|
||||||
|
@functools.wraps(fn)
|
||||||
|
async def wrapper(*args: Any, **kwargs: Any) -> _T:
|
||||||
|
delay = base_delay
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for attempt in range(1, max_attempts + 1):
|
||||||
|
try:
|
||||||
|
return await fn(*args, **kwargs)
|
||||||
|
except _RETRYABLE as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if attempt >= max_attempts:
|
||||||
|
raise
|
||||||
|
wait = _resolve_wait(exc, delay)
|
||||||
|
await _sleep(wait)
|
||||||
|
delay *= _BACKOFF_MULTIPLIER
|
||||||
|
raise RuntimeError("Retry loop exited unexpectedly") from last_exc
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_wait(exc: Exception, fallback_delay: float) -> float:
|
||||||
|
"""Return the number of seconds to wait before the next attempt.
|
||||||
|
|
||||||
|
For RateLimitError with a retry_after value, that value takes precedence.
|
||||||
|
Otherwise the exponential backoff delay is used.
|
||||||
|
"""
|
||||||
|
if isinstance(exc, RateLimitError) and exc.retry_after is not None:
|
||||||
|
return float(exc.retry_after)
|
||||||
|
return fallback_delay
|
||||||
513
src/release_agent/tools/azdo.py
Normal file
513
src/release_agent/tools/azdo.py
Normal file
@@ -0,0 +1,513 @@
|
|||||||
|
"""Azure DevOps service client.
|
||||||
|
|
||||||
|
Uses two httpx.AsyncClient instances:
|
||||||
|
- http_client: main AzDo REST API (dev.azure.com)
|
||||||
|
- vsrm_http_client: VSRM release management API (vsrm.dev.azure.com)
|
||||||
|
|
||||||
|
Both clients are injected via constructor for testability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from release_agent.models.build import ApprovalRecord, BuildStatus
|
||||||
|
from release_agent.models.pipeline import PipelineInfo
|
||||||
|
from release_agent.models.pr import PRInfo
|
||||||
|
from release_agent.tools._http import build_auth_header, raise_for_status
|
||||||
|
|
||||||
|
_API_VERSION = "7.1"
|
||||||
|
|
||||||
|
|
||||||
|
class AzDoClient:
|
||||||
|
"""Client for the Azure DevOps REST API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Main AzDo project API base URL.
|
||||||
|
vsrm_base_url: VSRM release management API base URL.
|
||||||
|
pat: Personal Access Token for authentication.
|
||||||
|
http_client: Injected httpx.AsyncClient for the main API.
|
||||||
|
vsrm_http_client: Injected httpx.AsyncClient for the VSRM API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
base_url: str,
|
||||||
|
vsrm_base_url: str,
|
||||||
|
pat: str,
|
||||||
|
http_client: httpx.AsyncClient,
|
||||||
|
vsrm_http_client: httpx.AsyncClient,
|
||||||
|
) -> None:
|
||||||
|
self._base_url = base_url.rstrip("/")
|
||||||
|
self._vsrm_base_url = vsrm_base_url.rstrip("/")
|
||||||
|
self._auth = build_auth_header("", pat)
|
||||||
|
self._http = http_client
|
||||||
|
self._vsrm = vsrm_http_client
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close both underlying HTTP clients."""
|
||||||
|
await self._http.aclose()
|
||||||
|
await self._vsrm.aclose()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Pull requests
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get_pr(self, pr_id: int) -> PRInfo:
|
||||||
|
"""Fetch a pull request by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pr_id: Numeric pull request ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PRInfo model populated from the API response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the PR does not exist.
|
||||||
|
AuthenticationError: If authentication fails.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/git/pullRequests/{pr_id}"
|
||||||
|
response = await self._http.get(url, headers=self._auth, params={"api-version": _API_VERSION})
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
data = response.json()
|
||||||
|
return _parse_pr(data)
|
||||||
|
|
||||||
|
async def list_active_prs(self, repo: str, target_branch: str) -> list[PRInfo]:
|
||||||
|
"""List active pull requests for a repository filtered by target branch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name.
|
||||||
|
target_branch: Target branch ref name (e.g. "refs/heads/develop").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of PRInfo models for active PRs targeting the given branch.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the repository does not exist.
|
||||||
|
AuthenticationError: If authentication fails.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/git/repositories/{repo}/pullRequests"
|
||||||
|
response = await self._http.get(
|
||||||
|
url,
|
||||||
|
headers=self._auth,
|
||||||
|
params={
|
||||||
|
"api-version": _API_VERSION,
|
||||||
|
"status": "active",
|
||||||
|
"targetRefName": target_branch,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
data = response.json()
|
||||||
|
return [_parse_pr(item) for item in data.get("value", [])]
|
||||||
|
|
||||||
|
async def get_pr_diff(self, pr_id: int) -> str:
|
||||||
|
"""Return a diff-like string for the given pull request.
|
||||||
|
|
||||||
|
Fetches the PR to determine the repository, then retrieves the
|
||||||
|
diffs endpoint and formats a summary string suitable for code review.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pr_id: Numeric pull request ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A text string describing changed files.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the PR does not exist.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
pr = await self.get_pr(pr_id)
|
||||||
|
repo = pr.repo_name
|
||||||
|
url = f"{self._base_url}/git/repositories/{repo}/pullRequests/{pr_id}/diffs"
|
||||||
|
response = await self._http.get(
|
||||||
|
url,
|
||||||
|
headers=self._auth,
|
||||||
|
params={"api-version": _API_VERSION, "baseVersionDescriptor.versionType": "commit"},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
data = response.json()
|
||||||
|
return _format_diff(data)
|
||||||
|
|
||||||
|
async def merge_pr(self, *, pr_id: int, last_merge_source_commit: str) -> bool:
|
||||||
|
"""Complete (merge) a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pr_id: Numeric pull request ID.
|
||||||
|
last_merge_source_commit: The commit ID to use as the merge source.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the PR does not exist.
|
||||||
|
ServiceError: For other HTTP errors including merge conflicts.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/git/pullRequests/{pr_id}"
|
||||||
|
payload = {
|
||||||
|
"status": "completed",
|
||||||
|
"lastMergeSourceCommit": {"commitId": last_merge_source_commit},
|
||||||
|
"completionOptions": {"mergeStrategy": "squash"},
|
||||||
|
}
|
||||||
|
response = await self._http.patch(
|
||||||
|
url,
|
||||||
|
headers={**self._auth, "Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
params={"api-version": _API_VERSION},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def create_pr(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
source: str,
|
||||||
|
target: str,
|
||||||
|
title: str,
|
||||||
|
description: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a new pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name.
|
||||||
|
source: Source branch ref name (e.g. "refs/heads/release/v1.2.0").
|
||||||
|
target: Target branch ref name (e.g. "refs/heads/main").
|
||||||
|
title: PR title.
|
||||||
|
description: PR description.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw dict from the API response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ServiceError: For HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/git/repositories/{repo}/pullRequests"
|
||||||
|
payload = {
|
||||||
|
"sourceRefName": source,
|
||||||
|
"targetRefName": target,
|
||||||
|
"title": title,
|
||||||
|
"description": description,
|
||||||
|
}
|
||||||
|
response = await self._http.post(
|
||||||
|
url,
|
||||||
|
headers={**self._auth, "Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
params={"api-version": _API_VERSION},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Pipelines
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def list_build_pipelines(self, *, repo: str) -> list[PipelineInfo]:
|
||||||
|
"""List build pipelines for a repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name (used to filter pipelines).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of PipelineInfo models.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ServiceError: For HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/pipelines"
|
||||||
|
response = await self._http.get(
|
||||||
|
url,
|
||||||
|
headers=self._auth,
|
||||||
|
params={"api-version": _API_VERSION},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
data = response.json()
|
||||||
|
return [
|
||||||
|
PipelineInfo(id=item["id"], name=item["name"], repo=repo)
|
||||||
|
for item in data.get("value", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
async def trigger_pipeline(self, *, pipeline_id: int, branch: str) -> dict:
|
||||||
|
"""Trigger a pipeline run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline_id: Numeric pipeline definition ID.
|
||||||
|
branch: Branch ref to build (e.g. "refs/heads/main").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw dict from the API response containing the build/run ID.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the pipeline does not exist.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/pipelines/{pipeline_id}/runs"
|
||||||
|
payload = {"resources": {"repositories": {"self": {"refName": branch}}}}
|
||||||
|
response = await self._http.post(
|
||||||
|
url,
|
||||||
|
headers={**self._auth, "Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
params={"api-version": _API_VERSION},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def get_build_status(self, *, build_id: int) -> BuildStatus:
|
||||||
|
"""Get the status of a build.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
build_id: Numeric build ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BuildStatus dataclass with status, result, and build_url fields.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the build does not exist.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/build/builds/{build_id}"
|
||||||
|
response = await self._http.get(
|
||||||
|
url,
|
||||||
|
headers=self._auth,
|
||||||
|
params={"api-version": _API_VERSION},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
data = response.json()
|
||||||
|
links = data.get("_links") or {}
|
||||||
|
web_link = links.get("web") or {}
|
||||||
|
build_url = web_link.get("href") or data.get("url")
|
||||||
|
return BuildStatus(
|
||||||
|
status=data["status"],
|
||||||
|
result=data.get("result"),
|
||||||
|
build_url=build_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_release_approvals(self, *, release_id: int) -> list[ApprovalRecord]:
|
||||||
|
"""Get pending approvals for a release.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
release_id: Numeric release ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ApprovalRecord dataclasses for this release.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the release does not exist.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._vsrm_base_url}/release/approvals"
|
||||||
|
response = await self._vsrm.get(
|
||||||
|
url,
|
||||||
|
headers=self._auth,
|
||||||
|
params={"api-version": _API_VERSION, "releaseId": release_id},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
data = response.json()
|
||||||
|
records: list[ApprovalRecord] = []
|
||||||
|
for item in data.get("value", []):
|
||||||
|
env = item.get("releaseEnvironment") or {}
|
||||||
|
rel = env.get("release") or {}
|
||||||
|
records.append(
|
||||||
|
ApprovalRecord(
|
||||||
|
approval_id=str(item["id"]),
|
||||||
|
stage_name=env.get("name", ""),
|
||||||
|
status=item.get("status", "pending"),
|
||||||
|
release_id=rel.get("id", release_id),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return records
|
||||||
|
|
||||||
|
async def get_latest_release(self, *, definition_id: int) -> dict:
|
||||||
|
"""Get the latest release for a release definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
definition_id: Numeric release definition ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw dict from the API response for the latest release,
|
||||||
|
or an empty dict if no releases exist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the definition does not exist.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._vsrm_base_url}/release/releases"
|
||||||
|
response = await self._vsrm.get(
|
||||||
|
url,
|
||||||
|
headers=self._auth,
|
||||||
|
params={
|
||||||
|
"api-version": _API_VERSION,
|
||||||
|
"definitionId": definition_id,
|
||||||
|
"$top": 1,
|
||||||
|
"$orderby": "id desc",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
data = response.json()
|
||||||
|
values = data.get("value", [])
|
||||||
|
return values[0] if values else {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Release approvals (VSRM)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def approve_release(self, *, approval_id: str, comment: str) -> dict:
|
||||||
|
"""Approve a release pipeline stage approval.
|
||||||
|
|
||||||
|
Uses the VSRM API endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
approval_id: The approval record ID.
|
||||||
|
comment: Comment to attach to the approval.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw dict from the API response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the approval does not exist.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._vsrm_base_url}/release/approvals/{approval_id}"
|
||||||
|
payload = {"status": "approved", "comments": comment}
|
||||||
|
response = await self._vsrm.patch(
|
||||||
|
url,
|
||||||
|
headers={**self._auth, "Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
params={"api-version": _API_VERSION},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# PR Comment methods
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def add_pr_comment(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
pr_id: int,
|
||||||
|
content: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Add a general comment thread to a PR (not file-specific).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name.
|
||||||
|
pr_id: Pull request ID.
|
||||||
|
content: Markdown content for the comment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw dict from the API response.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/git/repositories/{repo}/pullRequests/{pr_id}/threads"
|
||||||
|
payload = {
|
||||||
|
"comments": [{"parentCommentId": 0, "content": content, "commentType": 1}],
|
||||||
|
"status": "active",
|
||||||
|
}
|
||||||
|
response = await self._http.post(
|
||||||
|
url,
|
||||||
|
headers={**self._auth, "Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
params={"api-version": _API_VERSION},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def add_pr_inline_comment(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
pr_id: int,
|
||||||
|
content: str,
|
||||||
|
file_path: str,
|
||||||
|
line_start: int,
|
||||||
|
line_end: int | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Add an inline comment thread to a specific file and line in a PR.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name.
|
||||||
|
pr_id: Pull request ID.
|
||||||
|
content: Markdown content for the comment.
|
||||||
|
file_path: Path to the file (e.g., "/src/Handlers/InvoiceHandler.cs").
|
||||||
|
line_start: Starting line number.
|
||||||
|
line_end: Ending line number (defaults to line_start).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw dict from the API response.
|
||||||
|
"""
|
||||||
|
effective_end = line_end if line_end is not None else line_start
|
||||||
|
url = f"{self._base_url}/git/repositories/{repo}/pullRequests/{pr_id}/threads"
|
||||||
|
payload = {
|
||||||
|
"comments": [{"parentCommentId": 0, "content": content, "commentType": 1}],
|
||||||
|
"threadContext": {
|
||||||
|
"filePath": file_path if file_path.startswith("/") else f"/{file_path}",
|
||||||
|
"rightFileStart": {"line": line_start, "offset": 1},
|
||||||
|
"rightFileEnd": {"line": effective_end, "offset": 1},
|
||||||
|
},
|
||||||
|
"status": "active",
|
||||||
|
}
|
||||||
|
response = await self._http.post(
|
||||||
|
url,
|
||||||
|
headers={**self._auth, "Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
params={"api-version": _API_VERSION},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="azdo")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Private parsing helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_pr(data: dict) -> PRInfo:
|
||||||
|
"""Map a raw AzDo PR API response to a PRInfo model."""
|
||||||
|
pr_id = str(data["pullRequestId"])
|
||||||
|
repo = data["repository"]["name"]
|
||||||
|
branch = data.get("sourceRefName", "")
|
||||||
|
url = data.get("url", "")
|
||||||
|
if not url:
|
||||||
|
repo_url = data["repository"].get("remoteUrl", "")
|
||||||
|
url = f"{repo_url}/pullrequest/{pr_id}"
|
||||||
|
return PRInfo(
|
||||||
|
pr_id=pr_id,
|
||||||
|
pr_url=url,
|
||||||
|
repo_name=repo,
|
||||||
|
branch=branch,
|
||||||
|
pr_title=data.get("title", ""),
|
||||||
|
pr_status=_map_pr_status(data.get("status", "active")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _map_pr_status(raw: str) -> str:
|
||||||
|
"""Normalise AzDo PR status to a value accepted by PRInfo."""
|
||||||
|
mapping = {"active": "active", "completed": "completed", "abandoned": "abandoned"}
|
||||||
|
return mapping.get(raw, "active")
|
||||||
|
|
||||||
|
|
||||||
|
def _format_diff(data: dict) -> str:
|
||||||
|
"""Format the diffs API response as a text string."""
|
||||||
|
changes = data.get("changes", [])
|
||||||
|
lines = []
|
||||||
|
for change in changes:
|
||||||
|
item = change.get("item", {})
|
||||||
|
path = item.get("path", "unknown")
|
||||||
|
change_type = change.get("changeType", "edit")
|
||||||
|
lines.append(f"{change_type}: {path}")
|
||||||
|
return "\n".join(lines)
|
||||||
335
src/release_agent/tools/claude_review.py
Normal file
335
src/release_agent/tools/claude_review.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
"""Claude PR reviewer using Claude Code CLI for code review.
|
||||||
|
|
||||||
|
Uses `claude -p` (print mode) with `--output-format json` to leverage the
|
||||||
|
Claude Code subscription instead of API key billing. Claude Code can
|
||||||
|
autonomously read files, grep code, and explore the codebase for deeper reviews.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from release_agent.models.review import ReviewIssue, ReviewResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MAX_DIFF_CHARS = 100_000
|
||||||
|
_TRUNCATION_NOTE = "\n\n[DIFF TRUNCATED: content exceeded 100,000 characters]"
|
||||||
|
_CLI_TIMEOUT_SECONDS = 300
|
||||||
|
|
||||||
|
# JSON schema for structured output from Claude Code CLI
|
||||||
|
_REVIEW_JSON_SCHEMA = json.dumps({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"verdict": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["approve", "request_changes"],
|
||||||
|
},
|
||||||
|
"summary": {
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"issues": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"severity": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["blocker", "error", "warning", "info"],
|
||||||
|
},
|
||||||
|
"description": {"type": "string"},
|
||||||
|
"file_path": {"type": "string"},
|
||||||
|
"line_start": {"type": "integer", "description": "Starting line number in the file"},
|
||||||
|
"line_end": {"type": "integer", "description": "Ending line number in the file"},
|
||||||
|
"suggestion": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["severity", "description"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["verdict", "summary", "issues"],
|
||||||
|
})
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a senior code reviewer for .NET C# projects. "
|
||||||
|
"Review the PR for: backward compatibility (DB migration, serialization, integration events), "
|
||||||
|
"null handling, business logic correctness, test coverage, hardcoded values, and security. "
|
||||||
|
"You have access to the full codebase via Read, Glob, and Grep tools. "
|
||||||
|
"Use them to understand the context around the changed files. "
|
||||||
|
"For each issue, include the file_path and line_start/line_end where the issue occurs."
|
||||||
|
)
|
||||||
|
|
||||||
|
# JSON schema for structured ticket content output
|
||||||
|
_TICKET_JSON_SCHEMA = json.dumps({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"summary": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Short one-line summary of the work done (max 100 chars)",
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Detailed description of the changes and business value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["summary", "description"],
|
||||||
|
})
|
||||||
|
|
||||||
|
_TICKET_SYSTEM_PROMPT = (
|
||||||
|
"You are a Jira ticket writer for a software development team. "
|
||||||
|
"Given a PR diff and title, produce a concise Jira Story ticket. "
|
||||||
|
"The summary should be a clear, action-oriented title (max 100 chars). "
|
||||||
|
"The description should explain what was changed and why, suitable for a product backlog."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeReviewer:
|
||||||
|
"""Reviews pull request diffs using Claude Code CLI.
|
||||||
|
|
||||||
|
Uses the user's Claude Code subscription (not API key) by invoking
|
||||||
|
`claude -p` as a subprocess. This allows Claude to autonomously read
|
||||||
|
files and explore the codebase for more thorough reviews.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
claude_cmd: Path to the claude CLI binary (default: "claude").
|
||||||
|
timeout: Maximum seconds to wait for the review (default: 300).
|
||||||
|
run_subprocess: Async callable for running subprocesses (injected for testing).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
claude_cmd: str = "claude",
|
||||||
|
timeout: int = _CLI_TIMEOUT_SECONDS,
|
||||||
|
run_subprocess: object | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._claude_cmd = claude_cmd
|
||||||
|
self._timeout = timeout
|
||||||
|
self._run_subprocess = run_subprocess or _default_run_subprocess
|
||||||
|
|
||||||
|
async def review_pr(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
diff: str,
|
||||||
|
pr_title: str,
|
||||||
|
repo_name: str,
|
||||||
|
cwd: str | Path | None = None,
|
||||||
|
) -> ReviewResult:
|
||||||
|
"""Review a pull request and return a structured result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
diff: The raw diff text of the pull request.
|
||||||
|
pr_title: Title of the pull request.
|
||||||
|
repo_name: Name of the repository.
|
||||||
|
cwd: Working directory for Claude Code (e.g., git worktree path).
|
||||||
|
If provided, Claude can read files in this directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ReviewResult with verdict, summary, and list of issues.
|
||||||
|
"""
|
||||||
|
truncated_diff = _truncate_diff(diff)
|
||||||
|
prompt = _build_prompt(
|
||||||
|
diff=truncated_diff,
|
||||||
|
pr_title=pr_title,
|
||||||
|
repo_name=repo_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
self._claude_cmd, "-p", prompt,
|
||||||
|
"--output-format", "json",
|
||||||
|
"--json-schema", _REVIEW_JSON_SCHEMA,
|
||||||
|
"--allowedTools", "Read,Glob,Grep",
|
||||||
|
"--system-prompt", _SYSTEM_PROMPT,
|
||||||
|
]
|
||||||
|
|
||||||
|
stdout, stderr, returncode = await self._run_subprocess(
|
||||||
|
cmd=cmd,
|
||||||
|
cwd=str(cwd) if cwd else None,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
logger.error(
|
||||||
|
"Claude CLI failed (exit %d): %s", returncode, stderr[:500]
|
||||||
|
)
|
||||||
|
raise RuntimeError(f"Claude CLI exited with code {returncode}: {stderr[:200]}")
|
||||||
|
|
||||||
|
return _parse_cli_output(stdout)
|
||||||
|
|
||||||
|
async def generate_ticket_content(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
diff: str,
|
||||||
|
pr_title: str,
|
||||||
|
repo_name: str,
|
||||||
|
cwd: str | Path | None = None,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""Generate Jira ticket summary and description from a PR diff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
diff: The raw diff text of the pull request.
|
||||||
|
pr_title: Title of the pull request.
|
||||||
|
repo_name: Name of the repository.
|
||||||
|
cwd: Working directory for Claude Code.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (summary, description) strings suitable for Jira.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the Claude CLI fails or times out.
|
||||||
|
ValueError: If the output cannot be parsed.
|
||||||
|
"""
|
||||||
|
truncated_diff = _truncate_diff(diff)
|
||||||
|
prompt = (
|
||||||
|
f"Generate a Jira Story ticket for this pull request from repository '{repo_name}'.\n\n"
|
||||||
|
f"PR Title: {pr_title}\n\n"
|
||||||
|
f"Diff:\n```\n{truncated_diff}\n```\n\n"
|
||||||
|
"Produce a concise summary and description for a Jira Story ticket."
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
self._claude_cmd, "-p", prompt,
|
||||||
|
"--output-format", "json",
|
||||||
|
"--json-schema", _TICKET_JSON_SCHEMA,
|
||||||
|
"--allowedTools", "Read,Glob,Grep",
|
||||||
|
"--system-prompt", _TICKET_SYSTEM_PROMPT,
|
||||||
|
]
|
||||||
|
|
||||||
|
stdout, stderr, returncode = await self._run_subprocess(
|
||||||
|
cmd=cmd,
|
||||||
|
cwd=str(cwd) if cwd else None,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
logger.error(
|
||||||
|
"Claude CLI failed (exit %d): %s", returncode, stderr[:500]
|
||||||
|
)
|
||||||
|
raise RuntimeError(f"Claude CLI exited with code {returncode}: {stderr[:200]}")
|
||||||
|
|
||||||
|
return _parse_ticket_output(stdout)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Private helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _truncate_diff(diff: str) -> str:
|
||||||
|
"""Truncate diff to at most _MAX_DIFF_CHARS characters."""
|
||||||
|
if len(diff) <= _MAX_DIFF_CHARS:
|
||||||
|
return diff
|
||||||
|
return diff[:_MAX_DIFF_CHARS] + _TRUNCATION_NOTE
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prompt(*, diff: str, pr_title: str, repo_name: str) -> str:
|
||||||
|
"""Build the user prompt for the review request."""
|
||||||
|
return (
|
||||||
|
f"Review this pull request from repository '{repo_name}'.\n\n"
|
||||||
|
f"PR Title: {pr_title}\n\n"
|
||||||
|
f"Diff:\n```\n{diff}\n```\n\n"
|
||||||
|
"Read the related source files to understand context. "
|
||||||
|
"Provide your review as structured JSON output."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_cli_output(stdout: str) -> ReviewResult:
|
||||||
|
"""Parse the JSON output from Claude Code CLI into a ReviewResult."""
|
||||||
|
try:
|
||||||
|
data = json.loads(stdout)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError(f"Failed to parse Claude CLI output as JSON: {exc}") from exc
|
||||||
|
|
||||||
|
# Claude Code --output-format json wraps result in {"result": "...", ...}
|
||||||
|
# The structured_output field contains our schema-conforming data
|
||||||
|
structured = data.get("structured_output") or data.get("result")
|
||||||
|
|
||||||
|
if structured is None:
|
||||||
|
raise ValueError("No structured_output or result in Claude CLI response")
|
||||||
|
|
||||||
|
# If structured is a string (result field), try to parse it as JSON
|
||||||
|
if isinstance(structured, str):
|
||||||
|
try:
|
||||||
|
structured = json.loads(structured)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(
|
||||||
|
"Claude CLI result is not valid JSON for review schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(structured, dict):
|
||||||
|
raise ValueError(f"Expected dict, got {type(structured).__name__}")
|
||||||
|
|
||||||
|
issues = [
|
||||||
|
ReviewIssue(
|
||||||
|
severity=issue["severity"],
|
||||||
|
description=issue["description"],
|
||||||
|
file_path=issue.get("file_path"),
|
||||||
|
line_start=issue.get("line_start"),
|
||||||
|
line_end=issue.get("line_end"),
|
||||||
|
suggestion=issue.get("suggestion"),
|
||||||
|
)
|
||||||
|
for issue in structured.get("issues", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return ReviewResult(
|
||||||
|
verdict=structured["verdict"],
|
||||||
|
summary=structured["summary"],
|
||||||
|
issues=tuple(issues),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ticket_output(stdout: str) -> tuple[str, str]:
|
||||||
|
"""Parse the JSON output from Claude Code CLI into (summary, description)."""
|
||||||
|
try:
|
||||||
|
data = json.loads(stdout)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError(f"Failed to parse Claude CLI ticket output as JSON: {exc}") from exc
|
||||||
|
|
||||||
|
structured = data.get("structured_output") or data.get("result")
|
||||||
|
|
||||||
|
if structured is None:
|
||||||
|
raise ValueError("No structured_output or result in Claude CLI response")
|
||||||
|
|
||||||
|
if isinstance(structured, str):
|
||||||
|
try:
|
||||||
|
structured = json.loads(structured)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError("Claude CLI result is not valid JSON for ticket schema")
|
||||||
|
|
||||||
|
if not isinstance(structured, dict):
|
||||||
|
raise ValueError(f"Expected dict, got {type(structured).__name__}")
|
||||||
|
|
||||||
|
summary = structured.get("summary", "")
|
||||||
|
description = structured.get("description", "")
|
||||||
|
if not summary:
|
||||||
|
raise ValueError("Claude ticket output missing required 'summary' field")
|
||||||
|
return (summary, description)
|
||||||
|
|
||||||
|
|
||||||
|
async def _default_run_subprocess(
|
||||||
|
*,
|
||||||
|
cmd: list[str],
|
||||||
|
cwd: str | None,
|
||||||
|
timeout: int,
|
||||||
|
) -> tuple[str, str, int]:
|
||||||
|
"""Run a subprocess and return (stdout, stderr, returncode)."""
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||||
|
process.communicate(), timeout=timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
process.kill()
|
||||||
|
await process.wait()
|
||||||
|
raise RuntimeError(f"Claude CLI timed out after {timeout} seconds")
|
||||||
|
|
||||||
|
return (
|
||||||
|
stdout_bytes.decode("utf-8", errors="replace"),
|
||||||
|
stderr_bytes.decode("utf-8", errors="replace"),
|
||||||
|
process.returncode or 0,
|
||||||
|
)
|
||||||
269
src/release_agent/tools/jira.py
Normal file
269
src/release_agent/tools/jira.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""Jira service client.
|
||||||
|
|
||||||
|
Uses Basic auth (email:api_token) and the Jira REST API v3.
|
||||||
|
The httpx.AsyncClient is injected via constructor for testability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from release_agent.models.jira import JiraIssue, JiraTransition
|
||||||
|
from release_agent.tools._http import build_auth_header, raise_for_status
|
||||||
|
|
||||||
|
_API_VERSION = "rest/api/3"
|
||||||
|
_DEV_IN_PROGRESS_TRANSITION = "Dev in Progress"
|
||||||
|
|
||||||
|
|
||||||
|
class JiraClient:
|
||||||
|
"""Client for the Jira REST API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Jira instance base URL (e.g. "https://billolife.atlassian.net").
|
||||||
|
email: Account email address for Basic auth.
|
||||||
|
api_token: Jira API token for Basic auth.
|
||||||
|
http_client: Injected httpx.AsyncClient.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
base_url: str,
|
||||||
|
email: str,
|
||||||
|
api_token: str,
|
||||||
|
http_client: httpx.AsyncClient,
|
||||||
|
) -> None:
|
||||||
|
self._base_url = base_url.rstrip("/")
|
||||||
|
self._auth = build_auth_header(email, api_token)
|
||||||
|
self._http = http_client
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the underlying HTTP client."""
|
||||||
|
await self._http.aclose()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Issue operations
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get_issue(self, ticket_id: str) -> JiraIssue:
|
||||||
|
"""Fetch a Jira issue by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticket_id: Jira issue key (e.g. "ALLPOST-100").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JiraIssue model populated from the API response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the issue does not exist.
|
||||||
|
AuthenticationError: If authentication fails.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}"
|
||||||
|
response = await self._http.get(url, headers=self._auth)
|
||||||
|
raise_for_status(response, service="jira")
|
||||||
|
data = response.json()
|
||||||
|
return _parse_issue(data)
|
||||||
|
|
||||||
|
async def get_transitions(self, ticket_id: str) -> list[JiraTransition]:
|
||||||
|
"""Fetch available workflow transitions for a Jira issue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticket_id: Jira issue key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available JiraTransition models.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the issue does not exist.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}/transitions"
|
||||||
|
response = await self._http.get(url, headers=self._auth)
|
||||||
|
raise_for_status(response, service="jira")
|
||||||
|
data = response.json()
|
||||||
|
return [
|
||||||
|
JiraTransition(id=t["id"], name=t["name"])
|
||||||
|
for t in data.get("transitions", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
async def transition_issue(self, ticket_id: str, transition_name: str) -> bool:
|
||||||
|
"""Move a Jira issue through a named workflow transition.
|
||||||
|
|
||||||
|
If the target transition is not currently available, a two-step
|
||||||
|
fallback is attempted: first transition to "Dev in Progress", then
|
||||||
|
retry the target transition. Returns False if the transition is still
|
||||||
|
unavailable after the fallback attempt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticket_id: Jira issue key.
|
||||||
|
transition_name: Name of the target transition (e.g. "Released").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the transition succeeded, False if unavailable.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ServiceError: For HTTP errors during transition execution.
|
||||||
|
"""
|
||||||
|
transitions = await self.get_transitions(ticket_id)
|
||||||
|
transition_id = _find_transition_id(transitions, transition_name)
|
||||||
|
|
||||||
|
if transition_id is None:
|
||||||
|
# Two-step fallback: move to "Dev in Progress" first, then retry.
|
||||||
|
dev_id = _find_transition_id(transitions, _DEV_IN_PROGRESS_TRANSITION)
|
||||||
|
if dev_id is None:
|
||||||
|
return False
|
||||||
|
await self._do_transition(ticket_id, dev_id)
|
||||||
|
|
||||||
|
# Retry target transition after the fallback step.
|
||||||
|
transitions = await self.get_transitions(ticket_id)
|
||||||
|
transition_id = _find_transition_id(transitions, transition_name)
|
||||||
|
if transition_id is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._do_transition(ticket_id, transition_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def create_issue(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
project: str,
|
||||||
|
summary: str,
|
||||||
|
description: str,
|
||||||
|
issue_type: str = "Story",
|
||||||
|
) -> str:
|
||||||
|
"""Create a new Jira issue and return its ticket key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project: Jira project key (e.g. "ALLPOST").
|
||||||
|
summary: Issue summary / title.
|
||||||
|
description: Plain-text description, converted to ADF format.
|
||||||
|
issue_type: Issue type name (default: "Story").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ticket key of the created issue (e.g. "ALLPOST-42").
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If authentication fails.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
url = f"{self._base_url}/{_API_VERSION}/issue"
|
||||||
|
payload = {
|
||||||
|
"fields": {
|
||||||
|
"project": {"key": project},
|
||||||
|
"summary": summary,
|
||||||
|
"description": _text_to_adf(description),
|
||||||
|
"issuetype": {"name": issue_type},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response = await self._http.post(
|
||||||
|
url,
|
||||||
|
headers={**self._auth, "Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="jira")
|
||||||
|
data = response.json()
|
||||||
|
return data["key"]
|
||||||
|
|
||||||
|
async def add_remote_link(self, *, ticket_id: str, url: str, title: str) -> bool:
|
||||||
|
"""Add a remote link (e.g. a PR URL) to a Jira issue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticket_id: Jira issue key.
|
||||||
|
url: URL of the remote resource.
|
||||||
|
title: Display title of the link.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the issue does not exist.
|
||||||
|
ServiceError: For other HTTP errors.
|
||||||
|
"""
|
||||||
|
endpoint = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}/remotelink"
|
||||||
|
payload = {
|
||||||
|
"object": {
|
||||||
|
"url": url,
|
||||||
|
"title": title,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response = await self._http.post(
|
||||||
|
endpoint,
|
||||||
|
headers={**self._auth, "Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="jira")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Private helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _do_transition(self, ticket_id: str, transition_id: str) -> None:
|
||||||
|
"""Execute a transition by ID."""
|
||||||
|
url = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}/transitions"
|
||||||
|
payload = {"transition": {"id": transition_id}}
|
||||||
|
response = await self._http.post(
|
||||||
|
url,
|
||||||
|
headers={**self._auth, "Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="jira")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Private parsing helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _text_to_adf(text: str) -> dict:
|
||||||
|
"""Convert plain text to Atlassian Document Format (ADF).
|
||||||
|
|
||||||
|
Each non-empty line becomes a paragraph node. Empty string produces
|
||||||
|
an empty document (no content nodes).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Plain text to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ADF dict conforming to the Jira REST API v3 description format.
|
||||||
|
"""
|
||||||
|
content = []
|
||||||
|
for line in text.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped:
|
||||||
|
content.append({
|
||||||
|
"type": "paragraph",
|
||||||
|
"content": [{"type": "text", "text": stripped}],
|
||||||
|
})
|
||||||
|
return {"version": 1, "type": "doc", "content": content}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_issue(data: dict) -> JiraIssue:
|
||||||
|
"""Map a raw Jira issue API response to a JiraIssue model."""
|
||||||
|
fields = data.get("fields", {})
|
||||||
|
status = fields.get("status", {}).get("name", "Unknown")
|
||||||
|
return JiraIssue(
|
||||||
|
key=data["key"],
|
||||||
|
summary=fields.get("summary", ""),
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_transition_id(transitions: list[JiraTransition], name: str) -> str | None:
|
||||||
|
"""Return the ID of the first transition matching the given name, or None."""
|
||||||
|
for transition in transitions:
|
||||||
|
if transition.name == name:
|
||||||
|
return transition.id
|
||||||
|
return None
|
||||||
500
src/release_agent/tools/slack.py
Normal file
500
src/release_agent/tools/slack.py
Normal file
@@ -0,0 +1,500 @@
|
|||||||
|
"""Slack service client — dual-mode: webhook fallback + Web API.
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
1. Webhook mode: uses webhook_url for fire-and-forget notifications.
|
||||||
|
2. Web API mode: uses bot_token + channel_id for interactive messages,
|
||||||
|
message updates, and retrieving message timestamps.
|
||||||
|
|
||||||
|
Block Kit payloads are built by pure functions that can be tested independently.
|
||||||
|
The httpx.AsyncClient is injected via constructor for testability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
from release_agent.tools._http import raise_for_status
|
||||||
|
|
||||||
|
_SLACK_API_BASE = "https://slack.com/api"
|
||||||
|
|
||||||
|
|
||||||
|
class SlackClient:
|
||||||
|
"""Client for sending messages to Slack.
|
||||||
|
|
||||||
|
Supports two modes determined by which parameters are provided:
|
||||||
|
- Webhook mode: provide webhook_url.
|
||||||
|
- Web API mode: provide bot_token and channel_id.
|
||||||
|
|
||||||
|
Both modes can be configured simultaneously; Web API takes priority
|
||||||
|
for methods that require it (send_interactive_approval, update_message).
|
||||||
|
Webhook is used as fallback for simple notifications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
webhook_url: The Slack incoming webhook URL (optional).
|
||||||
|
bot_token: Slack bot OAuth token (xoxb-...) for Web API (optional).
|
||||||
|
channel_id: Slack channel ID for Web API messages (optional).
|
||||||
|
http_client: Injected httpx.AsyncClient.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
webhook_url: str = "",
|
||||||
|
bot_token: str = "",
|
||||||
|
channel_id: str = "",
|
||||||
|
http_client: httpx.AsyncClient,
|
||||||
|
) -> None:
|
||||||
|
self._webhook_url = webhook_url
|
||||||
|
self._bot_token = bot_token
|
||||||
|
self._channel_id = channel_id
|
||||||
|
self._http = http_client
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the underlying HTTP client."""
|
||||||
|
await self._http.aclose()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public methods: fire-and-forget notifications
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def send_release_notification(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
version: str,
|
||||||
|
release_date: date,
|
||||||
|
tickets: list[TicketEntry],
|
||||||
|
) -> bool:
|
||||||
|
"""Send a release notification to Slack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name.
|
||||||
|
version: Release version string (e.g. "v1.2.0").
|
||||||
|
release_date: Date of the release.
|
||||||
|
tickets: List of tickets included in the release.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ServiceError: If the Slack webhook returns an error response.
|
||||||
|
"""
|
||||||
|
blocks = _build_release_blocks(
|
||||||
|
repo=repo,
|
||||||
|
version=version,
|
||||||
|
release_date=release_date,
|
||||||
|
tickets=tickets,
|
||||||
|
)
|
||||||
|
return await self._post_blocks(blocks)
|
||||||
|
|
||||||
|
async def send_approval_request(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
action: str,
|
||||||
|
details: str,
|
||||||
|
approval_url: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Send an approval request notification to Slack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: Description of the action requiring approval.
|
||||||
|
details: Additional context for the approval.
|
||||||
|
approval_url: URL where the approver can approve the action.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ServiceError: If the Slack webhook returns an error response.
|
||||||
|
"""
|
||||||
|
blocks = _build_approval_blocks(
|
||||||
|
action=action,
|
||||||
|
details=details,
|
||||||
|
approval_url=approval_url,
|
||||||
|
)
|
||||||
|
return await self._post_blocks(blocks)
|
||||||
|
|
||||||
|
async def send_notification(self, *, text: str, blocks: list[dict]) -> bool:
|
||||||
|
"""Send a plain notification to Slack.
|
||||||
|
|
||||||
|
Uses Web API if bot_token is configured, otherwise falls back to
|
||||||
|
the webhook URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Fallback text for notifications.
|
||||||
|
blocks: Block Kit blocks to include in the message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True on success, False on error.
|
||||||
|
"""
|
||||||
|
if self._bot_token and self._channel_id:
|
||||||
|
ts = await self._post_via_web_api(text=text, blocks=blocks)
|
||||||
|
return ts != ""
|
||||||
|
if self._webhook_url:
|
||||||
|
try:
|
||||||
|
return await self._post_blocks(blocks or [{"type": "section", "text": {"type": "mrkdwn", "text": text}}])
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public methods: interactive messages (Web API only)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def send_interactive_approval(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
action: str,
|
||||||
|
details: str,
|
||||||
|
buttons: list[dict],
|
||||||
|
) -> str:
|
||||||
|
"""Send an interactive approval message with clickable buttons.
|
||||||
|
|
||||||
|
Requires bot_token and channel_id. Returns the message timestamp
|
||||||
|
(ts) which can be used to update the message later.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The graph thread_id, embedded in button payloads
|
||||||
|
so the interactions endpoint can resume the correct thread.
|
||||||
|
action: Human-readable description of the action (e.g. "Deploy to Sandbox").
|
||||||
|
details: Additional context shown in the message body.
|
||||||
|
buttons: List of button dicts with "text" and "value" keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message timestamp string (e.g. "1234567890.123456") on success,
|
||||||
|
or empty string on failure.
|
||||||
|
"""
|
||||||
|
blocks = _build_interactive_approval_blocks(
|
||||||
|
thread_id=thread_id,
|
||||||
|
action=action,
|
||||||
|
details=details,
|
||||||
|
buttons=buttons,
|
||||||
|
)
|
||||||
|
return await self._post_via_web_api(text=action, blocks=blocks)
|
||||||
|
|
||||||
|
async def update_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
message_ts: str,
|
||||||
|
text: str,
|
||||||
|
blocks: list[dict],
|
||||||
|
) -> bool:
|
||||||
|
"""Update an existing Slack message (Web API chat.update).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_ts: Timestamp of the message to update.
|
||||||
|
text: New fallback text.
|
||||||
|
blocks: New Block Kit blocks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True on success, False on failure.
|
||||||
|
"""
|
||||||
|
url = f"{_SLACK_API_BASE}/chat.update"
|
||||||
|
payload: dict = {
|
||||||
|
"channel": self._channel_id,
|
||||||
|
"ts": message_ts,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
if blocks:
|
||||||
|
payload["blocks"] = blocks
|
||||||
|
|
||||||
|
response = await self._http.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers=self._web_api_headers(),
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return False
|
||||||
|
data = response.json()
|
||||||
|
return bool(data.get("ok"))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Private helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _post_blocks(self, blocks: list[dict]) -> bool:
|
||||||
|
"""POST a Block Kit payload to the webhook URL."""
|
||||||
|
response = await self._http.post(
|
||||||
|
self._webhook_url,
|
||||||
|
json={"blocks": blocks},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
raise_for_status(response, service="slack")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _post_via_web_api(self, *, text: str, blocks: list[dict]) -> str:
|
||||||
|
"""POST a message via Slack Web API chat.postMessage.
|
||||||
|
|
||||||
|
Returns the message ts on success, or empty string on failure.
|
||||||
|
"""
|
||||||
|
url = f"{_SLACK_API_BASE}/chat.postMessage"
|
||||||
|
payload: dict = {
|
||||||
|
"channel": self._channel_id,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
if blocks:
|
||||||
|
payload["blocks"] = blocks
|
||||||
|
|
||||||
|
response = await self._http.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers=self._web_api_headers(),
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return ""
|
||||||
|
data = response.json()
|
||||||
|
if not data.get("ok"):
|
||||||
|
return ""
|
||||||
|
return data.get("ts", "")
|
||||||
|
|
||||||
|
def _web_api_headers(self) -> dict:
|
||||||
|
"""Return headers for Slack Web API requests."""
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {self._bot_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pure Block Kit builder functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_release_blocks(
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
version: str,
|
||||||
|
release_date: date,
|
||||||
|
tickets: list[TicketEntry],
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Build Slack Block Kit blocks for a release notification.
|
||||||
|
|
||||||
|
This is a pure function with no side effects, allowing it to be tested
|
||||||
|
independently of HTTP calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name.
|
||||||
|
version: Release version string.
|
||||||
|
release_date: Date of the release.
|
||||||
|
tickets: Tickets included in the release.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Slack Block Kit block dicts.
|
||||||
|
"""
|
||||||
|
header_block = {
|
||||||
|
"type": "header",
|
||||||
|
"text": {
|
||||||
|
"type": "plain_text",
|
||||||
|
"text": f"Release {version} - {repo}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
date_block = {
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": f"*Release date:* {release_date.isoformat()}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
blocks: list[dict] = [header_block, date_block]
|
||||||
|
|
||||||
|
if tickets:
|
||||||
|
ticket_lines = [f"- *{t.id}*: {t.summary}" for t in tickets]
|
||||||
|
ticket_text = "\n".join(ticket_lines)
|
||||||
|
ticket_block = {
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": f"*Tickets:*\n{ticket_text}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
blocks.append(ticket_block)
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def _build_approval_blocks(
|
||||||
|
*,
|
||||||
|
action: str,
|
||||||
|
details: str,
|
||||||
|
approval_url: str,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Build Slack Block Kit blocks for an approval request.
|
||||||
|
|
||||||
|
This is a pure function with no side effects, allowing it to be tested
|
||||||
|
independently of HTTP calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: Description of the action requiring approval.
|
||||||
|
details: Additional context for the approval.
|
||||||
|
approval_url: URL to the approval page.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Slack Block Kit block dicts.
|
||||||
|
"""
|
||||||
|
header_block = {
|
||||||
|
"type": "header",
|
||||||
|
"text": {
|
||||||
|
"type": "plain_text",
|
||||||
|
"text": f"Approval Required: {action}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
details_block = {
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": details,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actions_block = {
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": f"<{approval_url}|Review and Approve>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return [header_block, details_block, actions_block]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_interactive_approval_blocks(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
action: str,
|
||||||
|
details: str,
|
||||||
|
buttons: list[dict],
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Build Slack Block Kit blocks with interactive approval buttons.
|
||||||
|
|
||||||
|
The thread_id is embedded in each button's action_id and value so the
|
||||||
|
/slack/interactions endpoint can resume the correct graph thread when
|
||||||
|
a button is clicked.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Graph thread ID to embed in button payloads.
|
||||||
|
action: Human-readable action description for the header.
|
||||||
|
details: Contextual details shown in the message body.
|
||||||
|
buttons: List of dicts with "text" (str) and "value" (str) keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Slack Block Kit block dicts.
|
||||||
|
"""
|
||||||
|
header_block = {
|
||||||
|
"type": "header",
|
||||||
|
"text": {
|
||||||
|
"type": "plain_text",
|
||||||
|
"text": f"Approval Required: {action}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
details_block = {
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": details,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
blocks: list[dict] = [header_block, details_block]
|
||||||
|
|
||||||
|
if buttons:
|
||||||
|
button_elements = [
|
||||||
|
{
|
||||||
|
"type": "button",
|
||||||
|
"text": {"type": "plain_text", "text": btn["text"]},
|
||||||
|
"value": f"{thread_id}:{btn['value']}",
|
||||||
|
"action_id": f"approval_{btn['value']}_{thread_id}",
|
||||||
|
}
|
||||||
|
for btn in buttons
|
||||||
|
]
|
||||||
|
actions_block = {
|
||||||
|
"type": "actions",
|
||||||
|
"elements": button_elements,
|
||||||
|
}
|
||||||
|
blocks.append(actions_block)
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def _build_ci_status_blocks(
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
branch: str,
|
||||||
|
status: str,
|
||||||
|
build_url: str | None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Build Slack Block Kit blocks for a CI build status notification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name.
|
||||||
|
branch: Branch that was built.
|
||||||
|
status: Build result status (e.g. "succeeded", "failed").
|
||||||
|
build_url: Direct URL to the build results page, or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Slack Block Kit block dicts.
|
||||||
|
"""
|
||||||
|
header_block = {
|
||||||
|
"type": "header",
|
||||||
|
"text": {
|
||||||
|
"type": "plain_text",
|
||||||
|
"text": f"CI Build: {repo}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
detail_text = f"*Branch:* {branch}\n*Status:* {status}"
|
||||||
|
if build_url:
|
||||||
|
detail_text += f"\n<{build_url}|View Build>"
|
||||||
|
|
||||||
|
details_block = {
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": detail_text,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return [header_block, details_block]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_resolved_approval_blocks(
|
||||||
|
*,
|
||||||
|
action: str,
|
||||||
|
outcome: str,
|
||||||
|
user: str,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Build Slack Block Kit blocks for a resolved approval message.
|
||||||
|
|
||||||
|
Replaces the interactive approval message after the decision is made.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The original action that was approved or rejected.
|
||||||
|
outcome: The approval outcome (e.g. "approved", "rejected").
|
||||||
|
user: The Slack user who made the decision.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Slack Block Kit block dicts.
|
||||||
|
"""
|
||||||
|
header_block = {
|
||||||
|
"type": "header",
|
||||||
|
"text": {
|
||||||
|
"type": "plain_text",
|
||||||
|
"text": f"Approval {outcome.capitalize()}: {action}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
details_block = {
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": f"Decision: *{outcome}* by {user}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return [header_block, details_block]
|
||||||
70
src/release_agent/versioning.py
Normal file
70
src/release_agent/versioning.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""Version calculation utilities for release management.
|
||||||
|
|
||||||
|
Pure functions only - no side effects, no mutation.
|
||||||
|
All version strings must include a 'v' prefix (e.g. "v1.2.3").
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
_VERSION_PATTERN = re.compile(r"^v(\d+)\.(\d+)\.(\d+)$")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||||
|
"""Parse a version string into a (major, minor, patch) tuple.
|
||||||
|
|
||||||
|
Accepts strings with or without a leading 'v' prefix.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the string does not match the expected format.
|
||||||
|
"""
|
||||||
|
normalized = version_str if version_str.startswith("v") else f"v{version_str}"
|
||||||
|
match = _VERSION_PATTERN.match(normalized)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Invalid version string: {version_str!r}")
|
||||||
|
return int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||||
|
|
||||||
|
|
||||||
|
def format_version(major: int, minor: int, patch: int) -> str:
|
||||||
|
"""Format a (major, minor, patch) tuple into a version string with 'v' prefix.
|
||||||
|
|
||||||
|
Example: format_version(1, 0, 3) -> "v1.0.3"
|
||||||
|
"""
|
||||||
|
return f"v{major}.{minor}.{patch}"
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_next_version(repo_name: str, existing_versions: list[str]) -> str:
|
||||||
|
"""Calculate the next patch version given a list of existing version strings.
|
||||||
|
|
||||||
|
Filters out any malformed version strings (those that do not match
|
||||||
|
vMAJOR.MINOR.PATCH). Finds the highest valid version and increments
|
||||||
|
the patch component by one. If no valid versions exist, returns "v1.0.0".
|
||||||
|
|
||||||
|
The repo_name parameter is accepted for extensibility but does not
|
||||||
|
affect the calculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_name: The repository name (unused in calculation).
|
||||||
|
existing_versions: A list of version strings, may include malformed ones.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next version string in "vX.Y.Z" format.
|
||||||
|
"""
|
||||||
|
valid_versions: list[tuple[int, int, int]] = []
|
||||||
|
|
||||||
|
for version_str in existing_versions:
|
||||||
|
# Only consider strings that already have the 'v' prefix to match spec:
|
||||||
|
# "Versions without 'v' prefix are treated as malformed."
|
||||||
|
if not version_str.startswith("v"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
parsed = parse_version(version_str)
|
||||||
|
valid_versions.append(parsed)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not valid_versions:
|
||||||
|
return "v1.0.0"
|
||||||
|
|
||||||
|
highest = max(valid_versions)
|
||||||
|
major, minor, patch = highest
|
||||||
|
return format_version(major, minor, patch + 1)
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/api/__init__.py
Normal file
0
tests/api/__init__.py
Normal file
259
tests/api/test_approvals.py
Normal file
259
tests/api/test_approvals.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
"""Tests for approvals endpoint. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from release_agent.api.approvals import router as approvals_router
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_test_app(
|
||||||
|
*,
|
||||||
|
interrupted_threads: list[dict] | None = None,
|
||||||
|
graph_resume_result: dict | None = None,
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Return a FastAPI app with mocked state for approvals tests."""
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(approvals_router)
|
||||||
|
|
||||||
|
if interrupted_threads is None:
|
||||||
|
interrupted_threads = []
|
||||||
|
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.operator_token.get_secret_value.return_value = ""
|
||||||
|
mock_graphs = {
|
||||||
|
"pr_completed": MagicMock(),
|
||||||
|
"release": MagicMock(),
|
||||||
|
}
|
||||||
|
mock_clients = MagicMock()
|
||||||
|
|
||||||
|
# Mock pool that returns interrupted threads from DB
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
(
|
||||||
|
t["thread_id"],
|
||||||
|
t.get("graph_name", "pr_completed"),
|
||||||
|
t.get("interrupt_value", "Confirm?"),
|
||||||
|
t.get("created_at", datetime.now(tz=timezone.utc)),
|
||||||
|
t.get("repo_name"),
|
||||||
|
t.get("pr_id"),
|
||||||
|
t.get("version"),
|
||||||
|
)
|
||||||
|
for t in interrupted_threads
|
||||||
|
]
|
||||||
|
mock_cursor.fetchall = AsyncMock(return_value=rows)
|
||||||
|
mock_cursor.fetchone = AsyncMock(return_value=("pr_completed",))
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
app.state.settings = mock_settings
|
||||||
|
app.state.graphs = mock_graphs
|
||||||
|
app.state.tool_clients = mock_clients
|
||||||
|
app.state.db_pool = mock_pool
|
||||||
|
app.state.background_tasks = set()
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /approvals/{thread_id}
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPostApproval:
|
||||||
|
def test_valid_merge_decision_returns_200(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]})
|
||||||
|
app.state.graphs["pr_completed"] = mock_graph
|
||||||
|
|
||||||
|
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
|
||||||
|
mock_resume.return_value = {"messages": ["resumed"]}
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/thread-123",
|
||||||
|
json={"decision": "merge"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["thread_id"] == "thread-123"
|
||||||
|
assert "status" in data
|
||||||
|
assert "message" in data
|
||||||
|
|
||||||
|
def test_valid_cancel_decision_returns_200(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
|
||||||
|
mock_resume.return_value = {"messages": ["cancelled"]}
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/thread-456",
|
||||||
|
json={"decision": "cancel"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_invalid_decision_returns_422(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/thread-123",
|
||||||
|
json={"decision": "invalid_decision"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_missing_decision_returns_422(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/thread-123",
|
||||||
|
json={},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_response_contains_thread_id(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
|
||||||
|
mock_resume.return_value = {}
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/my-thread-id",
|
||||||
|
json={"decision": "approve"},
|
||||||
|
)
|
||||||
|
assert response.json()["thread_id"] == "my-thread-id"
|
||||||
|
|
||||||
|
def test_approve_decision_returns_200(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
|
||||||
|
mock_resume.return_value = {}
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/t1",
|
||||||
|
json={"decision": "approve"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_skip_decision_returns_200(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
|
||||||
|
mock_resume.return_value = {}
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/t1",
|
||||||
|
json={"decision": "skip"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_trigger_decision_returns_200(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
|
||||||
|
mock_resume.return_value = {}
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/t1",
|
||||||
|
json={"decision": "trigger"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /approvals/pending
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetPendingApprovals:
|
||||||
|
def test_empty_pending_returns_200(self) -> None:
|
||||||
|
app = _make_test_app(interrupted_threads=[])
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/approvals/pending")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["count"] == 0
|
||||||
|
assert data["items"] == []
|
||||||
|
|
||||||
|
def test_pending_approvals_list_structure(self) -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
threads = [
|
||||||
|
{
|
||||||
|
"thread_id": "t1",
|
||||||
|
"graph_name": "pr_completed",
|
||||||
|
"interrupt_value": "Confirm merge?",
|
||||||
|
"created_at": now,
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"pr_id": "42",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
app = _make_test_app(interrupted_threads=threads)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/approvals/pending")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["count"] == 1
|
||||||
|
assert data["items"][0]["thread_id"] == "t1"
|
||||||
|
assert data["items"][0]["graph_name"] == "pr_completed"
|
||||||
|
|
||||||
|
def test_multiple_pending_approvals(self) -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
threads = [
|
||||||
|
{
|
||||||
|
"thread_id": f"t{i}",
|
||||||
|
"graph_name": "pr_completed",
|
||||||
|
"interrupt_value": "Confirm?",
|
||||||
|
"created_at": now,
|
||||||
|
"repo_name": None,
|
||||||
|
"pr_id": None,
|
||||||
|
"version": None,
|
||||||
|
}
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
app = _make_test_app(interrupted_threads=threads)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/approvals/pending")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["count"] == 3
|
||||||
|
assert len(data["items"]) == 3
|
||||||
|
|
||||||
|
def test_pending_approval_optional_fields_nullable(self) -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
threads = [
|
||||||
|
{
|
||||||
|
"thread_id": "t1",
|
||||||
|
"graph_name": "release",
|
||||||
|
"interrupt_value": "Run release?",
|
||||||
|
"created_at": now,
|
||||||
|
"repo_name": None,
|
||||||
|
"pr_id": None,
|
||||||
|
"version": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
app = _make_test_app(interrupted_threads=threads)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/approvals/pending")
|
||||||
|
item = response.json()["items"][0]
|
||||||
|
assert item["repo_name"] is None
|
||||||
|
assert item["pr_id"] is None
|
||||||
|
assert item["version"] is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _resume_graph helper function tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestResumeGraph:
|
||||||
|
def test_resume_graph_callable(self) -> None:
|
||||||
|
from release_agent.api.approvals import _resume_graph
|
||||||
|
import inspect
|
||||||
|
assert inspect.iscoroutinefunction(_resume_graph)
|
||||||
139
tests/api/test_approvals_with_auth.py
Normal file
139
tests/api/test_approvals_with_auth.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Tests for approvals endpoints with operator token authentication.
|
||||||
|
|
||||||
|
Phase 5 - Step 3: Verifies that POST /approvals/{thread_id} and
|
||||||
|
GET /approvals/pending require operator token when configured.
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from release_agent.api.approvals import router as approvals_router
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_test_app(operator_token: str = "") -> FastAPI:
|
||||||
|
"""Return a FastAPI app with approvals router and configurable operator token."""
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(approvals_router)
|
||||||
|
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.operator_token.get_secret_value.return_value = operator_token
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.fetchall = AsyncMock(return_value=[])
|
||||||
|
mock_cursor.fetchone = AsyncMock(return_value=("pr_completed",))
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
app.state.settings = mock_settings
|
||||||
|
app.state.graphs = {
|
||||||
|
"pr_completed": MagicMock(),
|
||||||
|
"release": MagicMock(),
|
||||||
|
}
|
||||||
|
app.state.tool_clients = MagicMock()
|
||||||
|
app.state.db_pool = mock_pool
|
||||||
|
app.state.background_tasks = set()
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /approvals/{thread_id} with auth
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPostApprovalWithAuth:
|
||||||
|
def test_valid_token_allows_post_approval(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secret-token")
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.approvals._resume_graph", new_callable=AsyncMock
|
||||||
|
) as mock_resume:
|
||||||
|
mock_resume.return_value = {}
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/thread-123",
|
||||||
|
json={"decision": "merge"},
|
||||||
|
headers={"X-Operator-Token": "secret-token"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_missing_token_rejects_post_approval(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secret-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/thread-123",
|
||||||
|
json={"decision": "merge"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_wrong_token_rejects_post_approval(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secret-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/thread-123",
|
||||||
|
json={"decision": "merge"},
|
||||||
|
headers={"X-Operator-Token": "wrong-token"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_no_auth_required_when_token_not_configured(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="")
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.approvals._resume_graph", new_callable=AsyncMock
|
||||||
|
) as mock_resume:
|
||||||
|
mock_resume.return_value = {}
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/approvals/thread-123",
|
||||||
|
json={"decision": "merge"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /approvals/pending with auth
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetPendingApprovalsWithAuth:
|
||||||
|
def test_valid_token_allows_get_pending(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secret-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get(
|
||||||
|
"/approvals/pending",
|
||||||
|
headers={"X-Operator-Token": "secret-token"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_missing_token_rejects_get_pending(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secret-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/approvals/pending")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_wrong_token_rejects_get_pending(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secret-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get(
|
||||||
|
"/approvals/pending",
|
||||||
|
headers={"X-Operator-Token": "wrong"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_no_auth_required_when_token_not_configured(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/approvals/pending")
|
||||||
|
assert response.status_code == 200
|
||||||
149
tests/api/test_dependencies.py
Normal file
149
tests/api/test_dependencies.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""Tests for API FastAPI dependencies. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from release_agent.api.dependencies import (
|
||||||
|
get_db_pool,
|
||||||
|
get_graphs,
|
||||||
|
get_settings,
|
||||||
|
get_staging_store,
|
||||||
|
get_tool_clients,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_app_with_state(**state_kwargs) -> FastAPI:
|
||||||
|
"""Return a minimal FastAPI app with app.state attributes set."""
|
||||||
|
app = FastAPI()
|
||||||
|
for key, value in state_kwargs.items():
|
||||||
|
setattr(app.state, key, value)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_settings
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetSettings:
|
||||||
|
def test_returns_settings_from_state(self) -> None:
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
app = _make_app_with_state(settings=mock_settings)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# We test the dependency directly by simulating a request
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
result = get_settings(request)
|
||||||
|
assert result is mock_settings
|
||||||
|
|
||||||
|
def test_raises_when_settings_missing(self) -> None:
|
||||||
|
app = FastAPI() # no state.settings
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
get_settings(request)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_graphs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetGraphs:
|
||||||
|
def test_returns_graphs_from_state(self) -> None:
|
||||||
|
mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()}
|
||||||
|
app = _make_app_with_state(graphs=mock_graphs)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
result = get_graphs(request)
|
||||||
|
assert result is mock_graphs
|
||||||
|
|
||||||
|
def test_raises_when_graphs_missing(self) -> None:
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
get_graphs(request)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_tool_clients
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetToolClients:
|
||||||
|
def test_returns_tool_clients_from_state(self) -> None:
|
||||||
|
mock_clients = MagicMock()
|
||||||
|
app = _make_app_with_state(tool_clients=mock_clients)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
result = get_tool_clients(request)
|
||||||
|
assert result is mock_clients
|
||||||
|
|
||||||
|
def test_raises_when_tool_clients_missing(self) -> None:
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
get_tool_clients(request)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_staging_store
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetStagingStore:
|
||||||
|
def test_returns_staging_store_from_state(self) -> None:
|
||||||
|
mock_store = MagicMock()
|
||||||
|
app = _make_app_with_state(staging_store=mock_store)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
result = get_staging_store(request)
|
||||||
|
assert result is mock_store
|
||||||
|
|
||||||
|
def test_raises_when_staging_store_missing(self) -> None:
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
get_staging_store(request)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_db_pool
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetDbPool:
|
||||||
|
def test_returns_db_pool_from_state(self) -> None:
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
app = _make_app_with_state(db_pool=mock_pool)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
result = get_db_pool(request)
|
||||||
|
assert result is mock_pool
|
||||||
|
|
||||||
|
def test_raises_when_db_pool_missing(self) -> None:
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app = app
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
get_db_pool(request)
|
||||||
446
tests/api/test_internals.py
Normal file
446
tests/api/test_internals.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
"""Tests for internal async helper functions.
|
||||||
|
|
||||||
|
Tests _run_graph, _upsert_thread, _resume_graph, and exception handlers.
|
||||||
|
Written FIRST then verified (TDD GREEN phase for internal helpers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _upsert_thread tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestUpsertThread:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_thread_executes_sql(self) -> None:
|
||||||
|
from release_agent.api.webhooks import _upsert_thread
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _upsert_thread(
|
||||||
|
mock_pool,
|
||||||
|
thread_id="t1",
|
||||||
|
thread_status="running",
|
||||||
|
state={"repo_name": "my-repo"},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_cursor.execute.assert_called_once()
|
||||||
|
args = mock_cursor.execute.call_args[0]
|
||||||
|
assert "agent_threads" in args[0]
|
||||||
|
assert args[1][0] == "t1"
|
||||||
|
assert args[1][4] == "running"
|
||||||
|
# state is JSON-encoded
|
||||||
|
state_json = json.loads(args[1][5])
|
||||||
|
assert state_json["repo_name"] == "my-repo"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_thread_completed_status(self) -> None:
|
||||||
|
from release_agent.api.webhooks import _upsert_thread
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _upsert_thread(
|
||||||
|
mock_pool,
|
||||||
|
thread_id="t2",
|
||||||
|
thread_status="completed",
|
||||||
|
state={},
|
||||||
|
)
|
||||||
|
|
||||||
|
args = mock_cursor.execute.call_args[0]
|
||||||
|
assert args[1][4] == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_thread_failed_status(self) -> None:
|
||||||
|
from release_agent.api.webhooks import _upsert_thread
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _upsert_thread(
|
||||||
|
mock_pool,
|
||||||
|
thread_id="t3",
|
||||||
|
thread_status="failed",
|
||||||
|
state={"errors": ["something went wrong"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
args = mock_cursor.execute.call_args[0]
|
||||||
|
assert args[1][4] == "failed"
|
||||||
|
state_json = json.loads(args[1][5])
|
||||||
|
assert state_json["errors"] == ["something went wrong"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _run_graph tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRunGraph:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_graph_success_upserts_completed(self) -> None:
|
||||||
|
from release_agent.api.webhooks import _run_graph
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]})
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _run_graph(
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={"repo_name": "test"},
|
||||||
|
thread_id="t1",
|
||||||
|
tool_clients=MagicMock(),
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have been called with "running" then "completed"
|
||||||
|
calls = mock_cursor.execute.call_args_list
|
||||||
|
assert len(calls) == 2
|
||||||
|
# First call: "running", second call: "completed"
|
||||||
|
assert calls[0][0][1][4] == "running"
|
||||||
|
assert calls[1][0][1][4] == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_graph_failure_upserts_failed(self) -> None:
|
||||||
|
from release_agent.api.webhooks import _run_graph
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("graph crashed"))
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _run_graph(
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={"repo_name": "test"},
|
||||||
|
thread_id="t-fail",
|
||||||
|
tool_clients=MagicMock(),
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = mock_cursor.execute.call_args_list
|
||||||
|
# First call: "running", second call: "failed"
|
||||||
|
assert calls[0][0][1][4] == "running"
|
||||||
|
assert calls[1][0][1][4] == "failed"
|
||||||
|
# State should contain errors
|
||||||
|
failed_state = json.loads(calls[1][0][1][5])
|
||||||
|
assert "errors" in failed_state
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_graph_invokes_with_correct_config(self) -> None:
|
||||||
|
from release_agent.api.webhooks import _run_graph
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(return_value={})
|
||||||
|
mock_clients = MagicMock()
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _run_graph(
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={"repo_name": "test"},
|
||||||
|
thread_id="t-config",
|
||||||
|
tool_clients=mock_clients,
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_graph.ainvoke.call_args
|
||||||
|
config = call_args[1]["config"]
|
||||||
|
assert config["configurable"]["thread_id"] == "t-config"
|
||||||
|
assert config["configurable"]["clients"] is mock_clients
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _resume_graph tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestResumeGraphInternal:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_graph_success(self) -> None:
|
||||||
|
from release_agent.api.approvals import _resume_graph
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(return_value={"result": "ok"})
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
result = await _resume_graph(
|
||||||
|
graph=mock_graph,
|
||||||
|
thread_id="t1",
|
||||||
|
decision="merge",
|
||||||
|
tool_clients=MagicMock(),
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"result": "ok"}
|
||||||
|
mock_graph.ainvoke.assert_called_once()
|
||||||
|
# Verify the decision was passed
|
||||||
|
call_args = mock_graph.ainvoke.call_args
|
||||||
|
assert call_args[0][0]["decision"] == "merge"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_graph_failure_re_raises(self) -> None:
|
||||||
|
from release_agent.api.approvals import _resume_graph
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("resume failed"))
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="resume failed"):
|
||||||
|
await _resume_graph(
|
||||||
|
graph=mock_graph,
|
||||||
|
thread_id="t1",
|
||||||
|
decision="cancel",
|
||||||
|
tool_clients=MagicMock(),
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_graph_upserts_completed_on_success(self) -> None:
|
||||||
|
from release_agent.api.approvals import _resume_graph
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]})
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _resume_graph(
|
||||||
|
graph=mock_graph,
|
||||||
|
thread_id="t-success",
|
||||||
|
decision="approve",
|
||||||
|
tool_clients=MagicMock(),
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The last execute call should be "completed"
|
||||||
|
last_call = mock_cursor.execute.call_args_list[-1]
|
||||||
|
assert last_call[0][1][4] == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_graph_upserts_failed_on_exception(self) -> None:
|
||||||
|
from release_agent.api.approvals import _resume_graph
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(side_effect=ValueError("bad"))
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await _resume_graph(
|
||||||
|
graph=mock_graph,
|
||||||
|
thread_id="t-fail",
|
||||||
|
decision="skip",
|
||||||
|
tool_clients=MagicMock(),
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_call = mock_cursor.execute.call_args_list[-1]
|
||||||
|
assert last_call[0][1][4] == "failed"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_graph_in_background tests (main.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRunGraphInBackground:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_with_db_pool(self) -> None:
|
||||||
|
from release_agent.main import run_graph_in_background
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(return_value={"done": True})
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await run_graph_in_background(
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={"repo_name": "test"},
|
||||||
|
thread_id="t-bg",
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = mock_cursor.execute.call_args_list
|
||||||
|
assert calls[0][0][1][4] == "running"
|
||||||
|
assert calls[1][0][1][4] == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failure_with_db_pool(self) -> None:
|
||||||
|
from release_agent.main import run_graph_in_background
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("bg failed"))
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await run_graph_in_background(
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={},
|
||||||
|
thread_id="t-bg-fail",
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_call = mock_cursor.execute.call_args_list[-1]
|
||||||
|
assert last_call[0][1][4] == "failed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_without_db_pool(self) -> None:
|
||||||
|
"""run_graph_in_background works even without a db_pool."""
|
||||||
|
from release_agent.main import run_graph_in_background
|
||||||
|
|
||||||
|
mock_graph = AsyncMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(return_value={})
|
||||||
|
|
||||||
|
# Should not raise even with no db_pool
|
||||||
|
await run_graph_in_background(
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={},
|
||||||
|
thread_id="t-no-pool",
|
||||||
|
db_pool=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_graph.ainvoke.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Exception handler tests (main.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestExceptionHandlerFunctions:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_release_agent_error_handler_returns_500(self) -> None:
|
||||||
|
from release_agent.main import _release_agent_error_handler
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
exc = ServiceError(service="azdo", status_code=503, detail="unavailable")
|
||||||
|
|
||||||
|
response = await _release_agent_error_handler(request, exc)
|
||||||
|
|
||||||
|
assert response.status_code == 500
|
||||||
|
body = json.loads(response.body)
|
||||||
|
assert body["error"] == "ServiceError"
|
||||||
|
assert "unavailable" in body["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generic_error_handler_returns_500(self) -> None:
|
||||||
|
from release_agent.main import _generic_error_handler
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
exc = ValueError("something generic")
|
||||||
|
|
||||||
|
response = await _generic_error_handler(request, exc)
|
||||||
|
|
||||||
|
assert response.status_code == 500
|
||||||
|
body = json.loads(response.body)
|
||||||
|
assert body["error"] == "InternalServerError"
|
||||||
|
assert "An unexpected error occurred" in body["detail"]
|
||||||
294
tests/api/test_models.py
Normal file
294
tests/api/test_models.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
"""Tests for API request/response models. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from release_agent.api.models import (
|
||||||
|
ApprovalDecision,
|
||||||
|
ApprovalResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
HealthResponse,
|
||||||
|
ManualReleaseRequest,
|
||||||
|
ManualTriggerResponse,
|
||||||
|
PendingApproval,
|
||||||
|
PendingApprovalsResponse,
|
||||||
|
ReleaseVersionListResponse,
|
||||||
|
StagingResponse,
|
||||||
|
WebhookResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# WebhookResponse
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWebhookResponse:
|
||||||
|
def test_valid_construction(self) -> None:
|
||||||
|
resp = WebhookResponse(thread_id="thread-123", message="scheduled")
|
||||||
|
assert resp.thread_id == "thread-123"
|
||||||
|
assert resp.message == "scheduled"
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
resp = WebhookResponse(thread_id="t1", message="ok")
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
resp.thread_id = "other" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_missing_thread_id_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WebhookResponse(message="ok") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
def test_missing_message_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WebhookResponse(thread_id="t1") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ApprovalDecision
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestApprovalDecision:
|
||||||
|
def test_merge_decision(self) -> None:
|
||||||
|
d = ApprovalDecision(decision="merge")
|
||||||
|
assert d.decision == "merge"
|
||||||
|
|
||||||
|
def test_cancel_decision(self) -> None:
|
||||||
|
d = ApprovalDecision(decision="cancel")
|
||||||
|
assert d.decision == "cancel"
|
||||||
|
|
||||||
|
def test_approve_decision(self) -> None:
|
||||||
|
d = ApprovalDecision(decision="approve")
|
||||||
|
assert d.decision == "approve"
|
||||||
|
|
||||||
|
def test_skip_decision(self) -> None:
|
||||||
|
d = ApprovalDecision(decision="skip")
|
||||||
|
assert d.decision == "skip"
|
||||||
|
|
||||||
|
def test_trigger_decision(self) -> None:
|
||||||
|
d = ApprovalDecision(decision="trigger")
|
||||||
|
assert d.decision == "trigger"
|
||||||
|
|
||||||
|
def test_invalid_decision_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ApprovalDecision(decision="invalid") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
d = ApprovalDecision(decision="merge")
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
d.decision = "cancel" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ApprovalResponse
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestApprovalResponse:
|
||||||
|
def test_valid_construction(self) -> None:
|
||||||
|
resp = ApprovalResponse(
|
||||||
|
thread_id="t1", status="resumed", message="Graph resumed"
|
||||||
|
)
|
||||||
|
assert resp.thread_id == "t1"
|
||||||
|
assert resp.status == "resumed"
|
||||||
|
assert resp.message == "Graph resumed"
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
resp = ApprovalResponse(thread_id="t1", status="ok", message="m")
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
resp.status = "bad" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PendingApproval
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPendingApproval:
|
||||||
|
def test_full_construction(self) -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
pa = PendingApproval(
|
||||||
|
thread_id="t1",
|
||||||
|
graph_name="pr_completed",
|
||||||
|
interrupt_value="Confirm merge?",
|
||||||
|
created_at=now,
|
||||||
|
repo_name="my-repo",
|
||||||
|
pr_id="42",
|
||||||
|
version="v1.2.3",
|
||||||
|
)
|
||||||
|
assert pa.thread_id == "t1"
|
||||||
|
assert pa.graph_name == "pr_completed"
|
||||||
|
assert pa.repo_name == "my-repo"
|
||||||
|
assert pa.pr_id == "42"
|
||||||
|
assert pa.version == "v1.2.3"
|
||||||
|
|
||||||
|
def test_optional_fields_none(self) -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
pa = PendingApproval(
|
||||||
|
thread_id="t1",
|
||||||
|
graph_name="release",
|
||||||
|
interrupt_value="Confirm?",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
assert pa.repo_name is None
|
||||||
|
assert pa.pr_id is None
|
||||||
|
assert pa.version is None
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
pa = PendingApproval(
|
||||||
|
thread_id="t1",
|
||||||
|
graph_name="g",
|
||||||
|
interrupt_value="v",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
pa.thread_id = "other" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PendingApprovalsResponse
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPendingApprovalsResponse:
|
||||||
|
def test_empty_list(self) -> None:
|
||||||
|
resp = PendingApprovalsResponse(items=[], count=0)
|
||||||
|
assert resp.items == []
|
||||||
|
assert resp.count == 0
|
||||||
|
|
||||||
|
def test_with_items(self) -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
item = PendingApproval(
|
||||||
|
thread_id="t1",
|
||||||
|
graph_name="g",
|
||||||
|
interrupt_value="v",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
resp = PendingApprovalsResponse(items=[item], count=1)
|
||||||
|
assert resp.count == 1
|
||||||
|
assert len(resp.items) == 1
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
resp = PendingApprovalsResponse(items=[], count=0)
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
resp.count = 5 # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HealthResponse
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestHealthResponse:
|
||||||
|
def test_ok_status(self) -> None:
|
||||||
|
resp = HealthResponse(status="ok", version="0.1.0", uptime_seconds=123.4)
|
||||||
|
assert resp.status == "ok"
|
||||||
|
assert resp.version == "0.1.0"
|
||||||
|
assert resp.uptime_seconds == pytest.approx(123.4)
|
||||||
|
|
||||||
|
def test_degraded_status(self) -> None:
|
||||||
|
resp = HealthResponse(status="degraded", version="0.1.0", uptime_seconds=0.0)
|
||||||
|
assert resp.status == "degraded"
|
||||||
|
|
||||||
|
def test_invalid_status_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
HealthResponse(status="unknown", version="0.1.0", uptime_seconds=0.0) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
resp = HealthResponse(status="ok", version="0.1.0", uptime_seconds=1.0)
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
resp.status = "degraded" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ReleaseVersionListResponse
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReleaseVersionListResponse:
|
||||||
|
def test_valid_construction(self) -> None:
|
||||||
|
resp = ReleaseVersionListResponse(repo="my-repo", versions=["v1.0.0", "v1.1.0"])
|
||||||
|
assert resp.repo == "my-repo"
|
||||||
|
assert resp.versions == ["v1.0.0", "v1.1.0"]
|
||||||
|
|
||||||
|
def test_empty_versions(self) -> None:
|
||||||
|
resp = ReleaseVersionListResponse(repo="r", versions=[])
|
||||||
|
assert resp.versions == []
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
resp = ReleaseVersionListResponse(repo="r", versions=[])
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
resp.repo = "other" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# StagingResponse
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestStagingResponse:
|
||||||
|
def test_with_staging(self) -> None:
|
||||||
|
staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []}
|
||||||
|
resp = StagingResponse(repo="my-repo", staging=staging_data)
|
||||||
|
assert resp.repo == "my-repo"
|
||||||
|
assert resp.staging is not None
|
||||||
|
assert resp.staging["version"] == "v1.0.0"
|
||||||
|
|
||||||
|
def test_without_staging(self) -> None:
|
||||||
|
resp = StagingResponse(repo="my-repo", staging=None)
|
||||||
|
assert resp.staging is None
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
resp = StagingResponse(repo="r", staging=None)
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
resp.repo = "other" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ManualTriggerResponse
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestManualTriggerResponse:
|
||||||
|
def test_valid_construction(self) -> None:
|
||||||
|
resp = ManualTriggerResponse(thread_id="t1", message="triggered")
|
||||||
|
assert resp.thread_id == "t1"
|
||||||
|
assert resp.message == "triggered"
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
resp = ManualTriggerResponse(thread_id="t1", message="m")
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
resp.thread_id = "other" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ManualReleaseRequest
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestManualReleaseRequest:
|
||||||
|
def test_valid_construction(self) -> None:
|
||||||
|
req = ManualReleaseRequest(repo="my-repo")
|
||||||
|
assert req.repo == "my-repo"
|
||||||
|
|
||||||
|
def test_missing_repo_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ManualReleaseRequest() # type: ignore[call-arg]
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
req = ManualReleaseRequest(repo="r")
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
req.repo = "other" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ErrorResponse
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestErrorResponse:
|
||||||
|
def test_error_only(self) -> None:
|
||||||
|
resp = ErrorResponse(error="Something went wrong")
|
||||||
|
assert resp.error == "Something went wrong"
|
||||||
|
assert resp.detail is None
|
||||||
|
|
||||||
|
def test_error_with_detail(self) -> None:
|
||||||
|
resp = ErrorResponse(error="Not found", detail="Thread t1 not found")
|
||||||
|
assert resp.detail == "Thread t1 not found"
|
||||||
|
|
||||||
|
def test_frozen_immutable(self) -> None:
|
||||||
|
resp = ErrorResponse(error="e")
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
resp.error = "other" # type: ignore[misc]
|
||||||
111
tests/api/test_operator_auth.py
Normal file
111
tests/api/test_operator_auth.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Tests for operator token authentication dependency.
|
||||||
|
|
||||||
|
Phase 5 - Step 3: require_operator_token FastAPI dependency.
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, Depends, HTTPException
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from release_agent.api.dependencies import require_operator_token
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_app_with_token(operator_token: str = "") -> FastAPI:
|
||||||
|
"""Return a minimal app with a protected route and the given token config."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.operator_token.get_secret_value.return_value = operator_token
|
||||||
|
app.state.settings = mock_settings
|
||||||
|
|
||||||
|
@app.get("/protected")
|
||||||
|
async def protected_route(_: None = Depends(require_operator_token)):
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# require_operator_token tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRequireOperatorToken:
|
||||||
|
def test_valid_token_allows_access(self) -> None:
|
||||||
|
app = _make_app_with_token("super-secret-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get(
|
||||||
|
"/protected",
|
||||||
|
headers={"X-Operator-Token": "super-secret-token"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_missing_token_header_returns_401_when_token_configured(self) -> None:
|
||||||
|
app = _make_app_with_token("super-secret-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/protected")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_wrong_token_returns_401(self) -> None:
|
||||||
|
app = _make_app_with_token("super-secret-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get(
|
||||||
|
"/protected",
|
||||||
|
headers={"X-Operator-Token": "wrong-token"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_empty_operator_token_config_skips_auth(self) -> None:
|
||||||
|
"""When operator_token is not configured (empty), all requests pass."""
|
||||||
|
app = _make_app_with_token("")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/protected")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_empty_operator_token_config_passes_even_without_header(self) -> None:
|
||||||
|
app = _make_app_with_token("")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/protected", headers={})
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_token_comparison_is_constant_time(self) -> None:
|
||||||
|
"""Verify hmac.compare_digest is used (not == operator) — tested by checking
|
||||||
|
that the function still works correctly, not timing (which we can't test here)."""
|
||||||
|
app = _make_app_with_token("my-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get(
|
||||||
|
"/protected",
|
||||||
|
headers={"X-Operator-Token": "my-token"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_empty_string_token_header_rejected_when_token_configured(self) -> None:
|
||||||
|
app = _make_app_with_token("configured-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get(
|
||||||
|
"/protected",
|
||||||
|
headers={"X-Operator-Token": ""},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_401_response_has_detail_field(self) -> None:
|
||||||
|
app = _make_app_with_token("secret")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/protected")
|
||||||
|
data = response.json()
|
||||||
|
assert "detail" in data
|
||||||
|
|
||||||
|
def test_valid_token_returns_correct_response_body(self) -> None:
|
||||||
|
app = _make_app_with_token("token123")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get(
|
||||||
|
"/protected",
|
||||||
|
headers={"X-Operator-Token": "token123"},
|
||||||
|
)
|
||||||
|
assert response.json() == {"ok": True}
|
||||||
473
tests/api/test_slack_interactions.py
Normal file
473
tests/api/test_slack_interactions.py
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
"""Tests for api/slack_interactions.py endpoint.
|
||||||
|
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Signature verification (HMAC-SHA256)
|
||||||
|
- Payload parsing
|
||||||
|
- Button routing
|
||||||
|
- _resume_graph invocation
|
||||||
|
- Error handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from release_agent.api.slack_interactions import router as slack_interactions_router
|
||||||
|
from release_agent.api.slack_interactions import _verify_slack_signature
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_TEST_SIGNING_SECRET = "test-signing-secret-abc"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_slack_signature(*, signing_secret: str, timestamp: str, body: str) -> str:
|
||||||
|
"""Compute a valid Slack signing signature."""
|
||||||
|
base_string = f"v0:{timestamp}:{body}"
|
||||||
|
sig = hmac.new(
|
||||||
|
signing_secret.encode(),
|
||||||
|
base_string.encode(),
|
||||||
|
hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
return f"v0={sig}"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_test_app(
|
||||||
|
*,
|
||||||
|
signing_secret: str = _TEST_SIGNING_SECRET,
|
||||||
|
thread_graph_name: str | None = "release",
|
||||||
|
graph_result: dict | None = None,
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Return a FastAPI test app with mocked state for slack interactions."""
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(slack_interactions_router)
|
||||||
|
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.slack_signing_secret.get_secret_value.return_value = signing_secret
|
||||||
|
mock_settings.operator_token.get_secret_value.return_value = ""
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_graph.ainvoke = AsyncMock(return_value=graph_result or {"messages": ["done"]})
|
||||||
|
|
||||||
|
mock_graphs = {
|
||||||
|
"release": mock_graph,
|
||||||
|
"pr_completed": MagicMock(),
|
||||||
|
}
|
||||||
|
mock_clients = MagicMock()
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.fetchone = AsyncMock(
|
||||||
|
return_value=(thread_graph_name,) if thread_graph_name else None
|
||||||
|
)
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
app.state.settings = mock_settings
|
||||||
|
app.state.graphs = mock_graphs
|
||||||
|
app.state.tool_clients = mock_clients
|
||||||
|
app.state.db_pool = mock_pool
|
||||||
|
app.state.background_tasks = set()
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _make_button_payload(
|
||||||
|
*,
|
||||||
|
thread_id: str = "test-thread-123",
|
||||||
|
value: str = "approve",
|
||||||
|
user_id: str = "U12345",
|
||||||
|
user_name: str = "alice",
|
||||||
|
) -> str:
|
||||||
|
"""Build a URL-encoded Slack button action payload."""
|
||||||
|
payload = {
|
||||||
|
"type": "block_actions",
|
||||||
|
"user": {"id": user_id, "name": user_name},
|
||||||
|
"actions": [
|
||||||
|
{
|
||||||
|
"type": "button",
|
||||||
|
"value": f"{thread_id}:{value}",
|
||||||
|
"action_id": f"approval_{value}_{thread_id}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return urllib.parse.urlencode({"payload": json.dumps(payload)})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _verify_slack_signature pure function tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestVerifySlackSignature:
|
||||||
|
"""Tests for the _verify_slack_signature pure function."""
|
||||||
|
|
||||||
|
def test_returns_true_for_valid_signature(self) -> None:
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = "test=body&data=here"
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
signature=sig,
|
||||||
|
) is True
|
||||||
|
|
||||||
|
def test_returns_false_for_wrong_secret(self) -> None:
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = "test=body"
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret="wrong-secret",
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
signature=sig,
|
||||||
|
) is False
|
||||||
|
|
||||||
|
def test_returns_false_for_tampered_body(self) -> None:
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
original_body = "original=body"
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=original_body,
|
||||||
|
)
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body="tampered=body",
|
||||||
|
signature=sig,
|
||||||
|
) is False
|
||||||
|
|
||||||
|
def test_returns_false_for_wrong_timestamp(self) -> None:
|
||||||
|
body = "test=body"
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp="1000000",
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp="9999999",
|
||||||
|
body=body,
|
||||||
|
signature=sig,
|
||||||
|
) is False
|
||||||
|
|
||||||
|
def test_returns_false_for_malformed_signature(self) -> None:
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body="body",
|
||||||
|
signature="not-a-valid-sig",
|
||||||
|
) is False
|
||||||
|
|
||||||
|
def test_returns_false_for_empty_signature(self) -> None:
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body="body",
|
||||||
|
signature="",
|
||||||
|
) is False
|
||||||
|
|
||||||
|
def test_uses_hmac_sha256(self) -> None:
|
||||||
|
timestamp = "1234567890"
|
||||||
|
body = "payload=data"
|
||||||
|
base = f"v0:{timestamp}:{body}"
|
||||||
|
expected_hash = hmac.new(
|
||||||
|
_TEST_SIGNING_SECRET.encode(),
|
||||||
|
base.encode(),
|
||||||
|
hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
sig = f"v0={expected_hash}"
|
||||||
|
|
||||||
|
# Inject current_time matching timestamp to bypass replay prevention
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
signature=sig,
|
||||||
|
current_time=1234567890.0,
|
||||||
|
) is True
|
||||||
|
|
||||||
|
def test_rejects_stale_timestamp(self) -> None:
|
||||||
|
old_timestamp = "1000000000" # year 2001
|
||||||
|
body = "payload=data"
|
||||||
|
base = f"v0:{old_timestamp}:{body}"
|
||||||
|
expected_hash = hmac.new(
|
||||||
|
_TEST_SIGNING_SECRET.encode(),
|
||||||
|
base.encode(),
|
||||||
|
hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
sig = f"v0={expected_hash}"
|
||||||
|
|
||||||
|
# Valid signature but timestamp too old
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=old_timestamp,
|
||||||
|
body=body,
|
||||||
|
signature=sig,
|
||||||
|
) is False
|
||||||
|
|
||||||
|
def test_rejects_non_integer_timestamp(self) -> None:
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp="not-a-number",
|
||||||
|
body="body",
|
||||||
|
signature="v0=abc",
|
||||||
|
) is False
|
||||||
|
|
||||||
|
def test_signature_prefix_must_be_v0(self) -> None:
|
||||||
|
timestamp = "1234567890"
|
||||||
|
body = "payload=data"
|
||||||
|
base = f"v0:{timestamp}:{body}"
|
||||||
|
hash_val = hmac.new(
|
||||||
|
_TEST_SIGNING_SECRET.encode(),
|
||||||
|
base.encode(),
|
||||||
|
hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
wrong_prefix_sig = f"v1={hash_val}"
|
||||||
|
|
||||||
|
assert _verify_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
signature=wrong_prefix_sig,
|
||||||
|
) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /slack/interactions endpoint tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSlackInteractionsEndpoint:
|
||||||
|
"""Tests for POST /slack/interactions."""
|
||||||
|
|
||||||
|
def test_returns_200_for_valid_request(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = _make_button_payload(thread_id="t-abc", value="approve")
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/slack/interactions",
|
||||||
|
content=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"X-Slack-Request-Timestamp": timestamp,
|
||||||
|
"X-Slack-Signature": sig,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_returns_403_for_invalid_signature(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = _make_button_payload()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/slack/interactions",
|
||||||
|
content=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"X-Slack-Request-Timestamp": timestamp,
|
||||||
|
"X-Slack-Signature": "v0=invalid_signature",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
def test_returns_400_when_missing_timestamp_header(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
body = _make_button_payload()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/slack/interactions",
|
||||||
|
content=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"X-Slack-Signature": "v0=something",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code in (400, 403, 422)
|
||||||
|
|
||||||
|
def test_rejects_when_signing_secret_not_configured(self) -> None:
|
||||||
|
app = _make_test_app(signing_secret="")
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = _make_button_payload(thread_id="t-abc", value="approve")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/slack/interactions",
|
||||||
|
content=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"X-Slack-Request-Timestamp": timestamp,
|
||||||
|
"X-Slack-Signature": "v0=any_sig",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 503
|
||||||
|
|
||||||
|
def test_returns_200_with_approve_action(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = _make_button_payload(thread_id="thread-1", value="approve")
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/slack/interactions",
|
||||||
|
content=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"X-Slack-Request-Timestamp": timestamp,
|
||||||
|
"X-Slack-Signature": sig,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_returns_200_with_reject_action(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = _make_button_payload(thread_id="thread-2", value="reject")
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/slack/interactions",
|
||||||
|
content=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"X-Slack-Request-Timestamp": timestamp,
|
||||||
|
"X-Slack-Signature": sig,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_schedules_graph_resume_in_background(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = _make_button_payload(thread_id="t-bg", value="approve")
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/slack/interactions",
|
||||||
|
content=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"X-Slack-Request-Timestamp": timestamp,
|
||||||
|
"X-Slack-Signature": sig,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_returns_404_for_unknown_thread(self) -> None:
|
||||||
|
app = _make_test_app(thread_graph_name=None)
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = _make_button_payload(thread_id="unknown-thread", value="approve")
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/slack/interactions",
|
||||||
|
content=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"X-Slack-Request-Timestamp": timestamp,
|
||||||
|
"X-Slack-Signature": sig,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 200 immediately (Slack requires immediate 200)
|
||||||
|
# but the background task may log an error
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_response_body_is_empty_or_ok(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
body = _make_button_payload(thread_id="t-ok", value="approve")
|
||||||
|
sig = _make_slack_signature(
|
||||||
|
signing_secret=_TEST_SIGNING_SECRET,
|
||||||
|
timestamp=timestamp,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/slack/interactions",
|
||||||
|
content=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
"X-Slack-Request-Timestamp": timestamp,
|
||||||
|
"X-Slack-Signature": sig,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Body may be empty or a simple JSON with ok=True
|
||||||
|
if response.content:
|
||||||
|
data = response.json()
|
||||||
|
assert data.get("ok") is True or "ok" not in data
|
||||||
270
tests/api/test_status.py
Normal file
270
tests/api/test_status.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
"""Tests for status, releases, staging, and manual trigger endpoints.
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from release_agent.api.status import router as status_router
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_test_app(
|
||||||
|
*,
|
||||||
|
versions: list[str] | None = None,
|
||||||
|
staging_data: dict | None = None,
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Return a FastAPI app with mocked state for status tests."""
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(status_router)
|
||||||
|
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.operator_token.get_secret_value.return_value = ""
|
||||||
|
mock_graphs = {
|
||||||
|
"pr_completed": MagicMock(),
|
||||||
|
"release": MagicMock(),
|
||||||
|
}
|
||||||
|
mock_clients = MagicMock()
|
||||||
|
|
||||||
|
mock_staging_store = MagicMock()
|
||||||
|
mock_staging_store.list_versions = AsyncMock(return_value=versions or [])
|
||||||
|
|
||||||
|
# staging store returns StagingRelease-like or None
|
||||||
|
if staging_data is not None:
|
||||||
|
mock_staging_obj = MagicMock()
|
||||||
|
mock_staging_obj.model_dump = MagicMock(return_value=staging_data)
|
||||||
|
mock_staging_store.load = AsyncMock(return_value=mock_staging_obj)
|
||||||
|
else:
|
||||||
|
mock_staging_store.load = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
|
||||||
|
app.state.settings = mock_settings
|
||||||
|
app.state.graphs = mock_graphs
|
||||||
|
app.state.tool_clients = mock_clients
|
||||||
|
app.state.staging_store = mock_staging_store
|
||||||
|
app.state.db_pool = mock_pool
|
||||||
|
app.state.background_tasks = set()
|
||||||
|
app.state.started_at = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /status
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetStatus:
|
||||||
|
def test_returns_200(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/status")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_response_has_status_field(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/status")
|
||||||
|
data = response.json()
|
||||||
|
assert "status" in data
|
||||||
|
assert data["status"] in ("ok", "degraded")
|
||||||
|
|
||||||
|
def test_response_has_version_field(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/status")
|
||||||
|
data = response.json()
|
||||||
|
assert "version" in data
|
||||||
|
assert isinstance(data["version"], str)
|
||||||
|
|
||||||
|
def test_response_has_uptime_seconds(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/status")
|
||||||
|
data = response.json()
|
||||||
|
assert "uptime_seconds" in data
|
||||||
|
assert data["uptime_seconds"] >= 0.0
|
||||||
|
|
||||||
|
def test_status_is_ok_when_healthy(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/status")
|
||||||
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /releases/{repo}
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetReleaseVersions:
|
||||||
|
def test_returns_200(self) -> None:
|
||||||
|
app = _make_test_app(versions=["v1.0.0", "v1.1.0"])
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/releases/my-repo")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_response_has_repo_and_versions(self) -> None:
|
||||||
|
app = _make_test_app(versions=["v1.0.0", "v1.1.0"])
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/releases/my-repo")
|
||||||
|
data = response.json()
|
||||||
|
assert data["repo"] == "my-repo"
|
||||||
|
assert data["versions"] == ["v1.0.0", "v1.1.0"]
|
||||||
|
|
||||||
|
def test_empty_versions_list(self) -> None:
|
||||||
|
app = _make_test_app(versions=[])
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/releases/unknown-repo")
|
||||||
|
data = response.json()
|
||||||
|
assert data["versions"] == []
|
||||||
|
|
||||||
|
def test_repo_name_in_path_used(self) -> None:
|
||||||
|
mock_staging_store = MagicMock()
|
||||||
|
mock_staging_store.list_versions = AsyncMock(return_value=[])
|
||||||
|
app = _make_test_app()
|
||||||
|
app.state.staging_store = mock_staging_store
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
client.get("/releases/specific-repo")
|
||||||
|
|
||||||
|
mock_staging_store.list_versions.assert_called_once_with("specific-repo")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /staging
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetStaging:
|
||||||
|
def test_returns_200_with_staging(self) -> None:
|
||||||
|
staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []}
|
||||||
|
app = _make_test_app(staging_data=staging_data)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/staging?repo=my-repo")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_response_has_repo_and_staging(self) -> None:
|
||||||
|
staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []}
|
||||||
|
app = _make_test_app(staging_data=staging_data)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/staging?repo=my-repo")
|
||||||
|
data = response.json()
|
||||||
|
assert data["repo"] == "my-repo"
|
||||||
|
assert data["staging"] is not None
|
||||||
|
assert data["staging"]["version"] == "v1.0.0"
|
||||||
|
|
||||||
|
def test_returns_null_staging_when_not_found(self) -> None:
|
||||||
|
app = _make_test_app(staging_data=None)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/staging?repo=no-staging-repo")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["staging"] is None
|
||||||
|
|
||||||
|
def test_missing_repo_query_returns_422(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/staging")
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /manual/pr/{pr_id}
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestManualPrTrigger:
|
||||||
|
def test_returns_202(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/manual/pr/42")
|
||||||
|
assert response.status_code == 202
|
||||||
|
|
||||||
|
def test_response_has_thread_id(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/manual/pr/42")
|
||||||
|
data = response.json()
|
||||||
|
assert "thread_id" in data
|
||||||
|
assert isinstance(data["thread_id"], str)
|
||||||
|
|
||||||
|
def test_response_has_message(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/manual/pr/42")
|
||||||
|
assert "message" in response.json()
|
||||||
|
|
||||||
|
def test_schedules_background_task(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
) as mock_create:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
client.post("/manual/pr/99")
|
||||||
|
mock_create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /manual/release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestManualReleaseTrigger:
|
||||||
|
def test_returns_202(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/manual/release",
|
||||||
|
json={"repo": "my-repo"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 202
|
||||||
|
|
||||||
|
def test_response_has_thread_id(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/manual/release",
|
||||||
|
json={"repo": "my-repo"},
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
assert "thread_id" in data
|
||||||
|
|
||||||
|
def test_missing_repo_returns_422(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/manual/release",
|
||||||
|
json={},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_schedules_background_task(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
) as mock_create:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
client.post(
|
||||||
|
"/manual/release",
|
||||||
|
json={"repo": "my-repo"},
|
||||||
|
)
|
||||||
|
mock_create.assert_called_once()
|
||||||
166
tests/api/test_status_with_auth.py
Normal file
166
tests/api/test_status_with_auth.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""Tests for status/manual endpoints with operator token authentication.
|
||||||
|
|
||||||
|
Phase 5 - Step 3: Verifies that POST /manual/* require operator token
|
||||||
|
when configured. GET endpoints are not protected.
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from release_agent.api.status import router as status_router
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_test_app(operator_token: str = "") -> FastAPI:
|
||||||
|
"""Return a FastAPI app with status router and configurable operator token."""
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(status_router)
|
||||||
|
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.operator_token.get_secret_value.return_value = operator_token
|
||||||
|
|
||||||
|
mock_staging_store = MagicMock()
|
||||||
|
mock_staging_store.list_versions = AsyncMock(return_value=[])
|
||||||
|
mock_staging_store.load = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
|
||||||
|
app.state.settings = mock_settings
|
||||||
|
app.state.graphs = {
|
||||||
|
"pr_completed": MagicMock(),
|
||||||
|
"release": MagicMock(),
|
||||||
|
}
|
||||||
|
app.state.tool_clients = MagicMock()
|
||||||
|
app.state.staging_store = mock_staging_store
|
||||||
|
app.state.db_pool = mock_pool
|
||||||
|
app.state.background_tasks = set()
|
||||||
|
app.state.started_at = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /manual/pr/{pr_id} with auth
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestManualPrTriggerWithAuth:
|
||||||
|
def test_valid_token_allows_manual_pr(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secure-token")
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/manual/pr/42",
|
||||||
|
headers={"X-Operator-Token": "secure-token"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 202
|
||||||
|
|
||||||
|
def test_missing_token_rejects_manual_pr(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secure-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/manual/pr/42")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_wrong_token_rejects_manual_pr(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secure-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/manual/pr/42",
|
||||||
|
headers={"X-Operator-Token": "bad-token"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_no_auth_required_when_token_not_configured(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="")
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/manual/pr/42")
|
||||||
|
assert response.status_code == 202
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /manual/release with auth
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestManualReleaseTriggerWithAuth:
|
||||||
|
def test_valid_token_allows_manual_release(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secure-token")
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/manual/release",
|
||||||
|
json={"repo": "my-repo"},
|
||||||
|
headers={"X-Operator-Token": "secure-token"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 202
|
||||||
|
|
||||||
|
def test_missing_token_rejects_manual_release(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secure-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/manual/release",
|
||||||
|
json={"repo": "my-repo"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_wrong_token_rejects_manual_release(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="secure-token")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/manual/release",
|
||||||
|
json={"repo": "my-repo"},
|
||||||
|
headers={"X-Operator-Token": "wrong"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_no_auth_required_when_token_not_configured(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="")
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/manual/release",
|
||||||
|
json={"repo": "my-repo"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 202
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /status, /releases, /staging do NOT require auth
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReadEndpointsNoAuth:
|
||||||
|
def test_get_status_no_token_needed(self) -> None:
|
||||||
|
"""GET /status should never require auth."""
|
||||||
|
app = _make_test_app(operator_token="super-secret")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/status")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_get_releases_no_token_needed(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="super-secret")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/releases/my-repo")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_get_staging_no_token_needed(self) -> None:
|
||||||
|
app = _make_test_app(operator_token="super-secret")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/staging?repo=my-repo")
|
||||||
|
assert response.status_code == 200
|
||||||
218
tests/api/test_webhooks.py
Normal file
218
tests/api/test_webhooks.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""Tests for webhook endpoint. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from release_agent.api.webhooks import (
|
||||||
|
_validate_webhook_secret,
|
||||||
|
router as webhook_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
VALID_SECRET = "super-secret-webhook-key"
|
||||||
|
|
||||||
|
_COMPLETED_PR_PAYLOAD = {
|
||||||
|
"subscription_id": "sub-1",
|
||||||
|
"event_type": "git.pullrequest.updated",
|
||||||
|
"resource": {
|
||||||
|
"repository": {
|
||||||
|
"id": "repo-1",
|
||||||
|
"name": "my-repo",
|
||||||
|
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
|
||||||
|
},
|
||||||
|
"pull_request_id": 42,
|
||||||
|
"title": "feat: add feature",
|
||||||
|
"source_ref_name": "refs/heads/feature/BILL-123-add-feature",
|
||||||
|
"target_ref_name": "refs/heads/main",
|
||||||
|
"status": "completed",
|
||||||
|
"closed_date": "2024-01-15T10:00:00Z",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ACTIVE_PR_PAYLOAD = {
|
||||||
|
"subscription_id": "sub-2",
|
||||||
|
"event_type": "git.pullrequest.updated",
|
||||||
|
"resource": {
|
||||||
|
"repository": {
|
||||||
|
"id": "repo-1",
|
||||||
|
"name": "my-repo",
|
||||||
|
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
|
||||||
|
},
|
||||||
|
"pull_request_id": 43,
|
||||||
|
"title": "WIP: work in progress",
|
||||||
|
"source_ref_name": "refs/heads/feature/BILL-456",
|
||||||
|
"target_ref_name": "refs/heads/main",
|
||||||
|
"status": "active",
|
||||||
|
"closed_date": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_test_app(webhook_secret: str = VALID_SECRET) -> FastAPI:
|
||||||
|
"""Return a FastAPI app with mocked state for webhook tests."""
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(webhook_router)
|
||||||
|
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.webhook_secret.get_secret_value.return_value = webhook_secret
|
||||||
|
|
||||||
|
mock_graphs = {
|
||||||
|
"pr_completed": MagicMock(),
|
||||||
|
"release": MagicMock(),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_clients = MagicMock()
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
|
||||||
|
# background_tasks set tracked on state
|
||||||
|
app.state.settings = mock_settings
|
||||||
|
app.state.graphs = mock_graphs
|
||||||
|
app.state.tool_clients = mock_clients
|
||||||
|
app.state.db_pool = mock_pool
|
||||||
|
app.state.background_tasks = set()
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _validate_webhook_secret (unit tests, pure function)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestValidateWebhookSecret:
|
||||||
|
def test_valid_secret_returns_true(self) -> None:
|
||||||
|
assert _validate_webhook_secret("mysecret", "mysecret") is True
|
||||||
|
|
||||||
|
def test_wrong_secret_returns_false(self) -> None:
|
||||||
|
assert _validate_webhook_secret("wrong", "mysecret") is False
|
||||||
|
|
||||||
|
def test_empty_header_returns_false(self) -> None:
|
||||||
|
assert _validate_webhook_secret("", "mysecret") is False
|
||||||
|
|
||||||
|
def test_none_header_returns_false(self) -> None:
|
||||||
|
assert _validate_webhook_secret(None, "mysecret") is False # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_uses_constant_time_comparison(self) -> None:
|
||||||
|
# Should not raise even for very different lengths
|
||||||
|
assert _validate_webhook_secret("a", "very-long-secret-value") is False
|
||||||
|
|
||||||
|
def test_empty_expected_rejects_all(self) -> None:
|
||||||
|
# Empty expected secret = auth misconfigured, reject everything
|
||||||
|
assert _validate_webhook_secret("", "") is False
|
||||||
|
assert _validate_webhook_secret("any-value", "") is False
|
||||||
|
assert _validate_webhook_secret(None, "") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /webhooks/azdo — integration tests via TestClient
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWebhookEndpoint:
|
||||||
|
def test_valid_completed_pr_returns_202(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.webhooks.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app, raise_server_exceptions=True) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/webhooks/azdo",
|
||||||
|
json=_COMPLETED_PR_PAYLOAD,
|
||||||
|
headers={"X-Webhook-Secret": VALID_SECRET},
|
||||||
|
)
|
||||||
|
assert response.status_code == 202
|
||||||
|
data = response.json()
|
||||||
|
assert "thread_id" in data
|
||||||
|
assert "message" in data
|
||||||
|
|
||||||
|
def test_missing_secret_header_returns_401(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/webhooks/azdo",
|
||||||
|
json=_COMPLETED_PR_PAYLOAD,
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_wrong_secret_header_returns_401(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/webhooks/azdo",
|
||||||
|
json=_COMPLETED_PR_PAYLOAD,
|
||||||
|
headers={"X-Webhook-Secret": "wrong-secret"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_invalid_payload_returns_422(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/webhooks/azdo",
|
||||||
|
json={"invalid": "payload"},
|
||||||
|
headers={"X-Webhook-Secret": VALID_SECRET},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_active_pr_event_returns_200_ignored(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/webhooks/azdo",
|
||||||
|
json=_ACTIVE_PR_PAYLOAD,
|
||||||
|
headers={"X-Webhook-Secret": VALID_SECRET},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "ignored" in data.get("message", "").lower() or "ignored" in str(data).lower()
|
||||||
|
|
||||||
|
def test_completed_pr_thread_id_is_string(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.webhooks.asyncio.create_task", return_value=MagicMock()
|
||||||
|
):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/webhooks/azdo",
|
||||||
|
json=_COMPLETED_PR_PAYLOAD,
|
||||||
|
headers={"X-Webhook-Secret": VALID_SECRET},
|
||||||
|
)
|
||||||
|
assert response.status_code == 202
|
||||||
|
assert isinstance(response.json()["thread_id"], str)
|
||||||
|
|
||||||
|
def test_completed_pr_schedules_background_task(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
task_mock = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"release_agent.api.webhooks.asyncio.create_task", return_value=task_mock
|
||||||
|
) as mock_create:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
client.post(
|
||||||
|
"/webhooks/azdo",
|
||||||
|
json=_COMPLETED_PR_PAYLOAD,
|
||||||
|
headers={"X-Webhook-Secret": VALID_SECRET},
|
||||||
|
)
|
||||||
|
mock_create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Error response shape
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWebhookErrorShape:
|
||||||
|
def test_401_has_detail_field(self) -> None:
|
||||||
|
app = _make_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/webhooks/azdo",
|
||||||
|
json=_COMPLETED_PR_PAYLOAD,
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
# FastAPI HTTPException returns {"detail": ...}
|
||||||
|
assert "detail" in response.json()
|
||||||
0
tests/graph/__init__.py
Normal file
0
tests/graph/__init__.py
Normal file
44
tests/graph/conftest.py
Normal file
44
tests/graph/conftest.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Shared fixtures for graph tests.
|
||||||
|
|
||||||
|
Provides build_mock_clients() to create ToolClients with AsyncMock fields
|
||||||
|
so individual node functions can be tested without compiling the full graph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.graph.dependencies import ToolClients
|
||||||
|
|
||||||
|
|
||||||
|
def build_mock_clients() -> ToolClients:
|
||||||
|
"""Return a ToolClients instance whose fields are all AsyncMock/MagicMock."""
|
||||||
|
azdo = AsyncMock()
|
||||||
|
jira = AsyncMock()
|
||||||
|
slack = AsyncMock()
|
||||||
|
reviewer = AsyncMock()
|
||||||
|
return ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
|
||||||
|
|
||||||
|
|
||||||
|
def build_config(clients: ToolClients | None = None, staging_store=None) -> dict:
|
||||||
|
"""Return a LangGraph-style config dict with clients and staging_store."""
|
||||||
|
if clients is None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
return {
|
||||||
|
"configurable": {
|
||||||
|
"clients": clients,
|
||||||
|
"staging_store": staging_store,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_clients() -> ToolClients:
|
||||||
|
"""Pytest fixture returning fresh mock ToolClients."""
|
||||||
|
return build_mock_clients()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def config(mock_clients: ToolClients):
|
||||||
|
"""Pytest fixture returning a config dict with mock clients."""
|
||||||
|
return build_config(mock_clients)
|
||||||
294
tests/graph/test_ci_nodes.py
Normal file
294
tests/graph/test_ci_nodes.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
"""Tests for graph/ci_nodes.py.
|
||||||
|
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
All external calls (azdo, slack, poll_until) are mocked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.graph.ci_nodes import notify_ci_result, poll_ci_build, trigger_ci_build
|
||||||
|
from release_agent.models.build import BuildStatus
|
||||||
|
from release_agent.models.pipeline import PipelineInfo
|
||||||
|
from tests.graph.conftest import build_config, build_mock_clients
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_pipeline(pipeline_id: int = 10, name: str = "CI-build") -> dict:
|
||||||
|
return {"id": pipeline_id, "name": name, "repo": "my-repo"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# trigger_ci_build
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTriggerCiBuild:
|
||||||
|
"""Tests for trigger_ci_build node."""
|
||||||
|
|
||||||
|
async def test_triggers_pipeline_on_branch(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.list_build_pipelines.return_value = [
|
||||||
|
PipelineInfo(id=10, name="CI", repo="my-repo")
|
||||||
|
]
|
||||||
|
clients.azdo.trigger_pipeline.return_value = {"id": 555, "state": "inProgress"}
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo", "version": "v1.0.0"}
|
||||||
|
|
||||||
|
result = await trigger_ci_build(state, config)
|
||||||
|
|
||||||
|
clients.azdo.trigger_pipeline.assert_called_once()
|
||||||
|
assert "ci_build_id" in result
|
||||||
|
assert result["ci_build_id"] == 555
|
||||||
|
|
||||||
|
async def test_returns_ci_build_id(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.list_build_pipelines.return_value = [
|
||||||
|
PipelineInfo(id=20, name="build-and-test", repo="my-repo")
|
||||||
|
]
|
||||||
|
clients.azdo.trigger_pipeline.return_value = {"id": 999}
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo", "version": "v2.0.0"}
|
||||||
|
|
||||||
|
result = await trigger_ci_build(state, config)
|
||||||
|
assert result["ci_build_id"] == 999
|
||||||
|
|
||||||
|
async def test_appends_error_when_no_pipelines_found(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.list_build_pipelines.return_value = []
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo", "version": "v1.0.0"}
|
||||||
|
|
||||||
|
result = await trigger_ci_build(state, config)
|
||||||
|
|
||||||
|
assert "errors" in result
|
||||||
|
assert len(result["errors"]) >= 1
|
||||||
|
|
||||||
|
async def test_appends_error_on_trigger_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.list_build_pipelines.return_value = [
|
||||||
|
PipelineInfo(id=10, name="CI", repo="my-repo")
|
||||||
|
]
|
||||||
|
clients.azdo.trigger_pipeline.side_effect = ServiceError(
|
||||||
|
service="azdo", status_code=500, detail="Internal error"
|
||||||
|
)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo", "version": "v1.0.0"}
|
||||||
|
|
||||||
|
result = await trigger_ci_build(state, config)
|
||||||
|
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
async def test_uses_main_branch_when_no_version(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.list_build_pipelines.return_value = [
|
||||||
|
PipelineInfo(id=10, name="CI", repo="my-repo")
|
||||||
|
]
|
||||||
|
clients.azdo.trigger_pipeline.return_value = {"id": 1}
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
|
||||||
|
result = await trigger_ci_build(state, config)
|
||||||
|
|
||||||
|
call_kwargs = clients.azdo.trigger_pipeline.call_args[1]
|
||||||
|
branch = call_kwargs.get("branch", "")
|
||||||
|
assert "main" in branch or "refs/heads" in branch
|
||||||
|
|
||||||
|
async def test_appends_message_on_success(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.list_build_pipelines.return_value = [
|
||||||
|
PipelineInfo(id=10, name="CI", repo="my-repo")
|
||||||
|
]
|
||||||
|
clients.azdo.trigger_pipeline.return_value = {"id": 123}
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo", "version": "v1.0.0"}
|
||||||
|
|
||||||
|
result = await trigger_ci_build(state, config)
|
||||||
|
|
||||||
|
assert "messages" in result
|
||||||
|
assert len(result["messages"]) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# poll_ci_build
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPollCiBuild:
|
||||||
|
"""Tests for poll_ci_build node."""
|
||||||
|
|
||||||
|
async def test_returns_ci_build_status_and_result_on_completion(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
completed_status = BuildStatus(status="completed", result="succeeded", build_url="https://build/1")
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"ci_build_id": 42, "repo_name": "my-repo"}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"release_agent.graph.ci_nodes.poll_until",
|
||||||
|
return_value=(completed_status, True),
|
||||||
|
):
|
||||||
|
result = await poll_ci_build(state, config)
|
||||||
|
|
||||||
|
assert result["ci_build_status"] == "completed"
|
||||||
|
assert result["ci_build_result"] == "succeeded"
|
||||||
|
|
||||||
|
async def test_returns_build_url(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
completed_status = BuildStatus(
|
||||||
|
status="completed",
|
||||||
|
result="succeeded",
|
||||||
|
build_url="https://dev.azure.com/build/42",
|
||||||
|
)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"ci_build_id": 42, "repo_name": "my-repo"}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"release_agent.graph.ci_nodes.poll_until",
|
||||||
|
return_value=(completed_status, True),
|
||||||
|
):
|
||||||
|
result = await poll_ci_build(state, config)
|
||||||
|
|
||||||
|
assert result.get("ci_build_url") == "https://dev.azure.com/build/42"
|
||||||
|
|
||||||
|
async def test_appends_error_on_timeout(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
running_status = BuildStatus(status="inProgress", result=None, build_url=None)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"ci_build_id": 42, "repo_name": "my-repo"}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"release_agent.graph.ci_nodes.poll_until",
|
||||||
|
return_value=(running_status, False),
|
||||||
|
):
|
||||||
|
result = await poll_ci_build(state, config)
|
||||||
|
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
async def test_appends_error_when_build_id_missing(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo"} # no ci_build_id
|
||||||
|
|
||||||
|
result = await poll_ci_build(state, config)
|
||||||
|
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
async def test_passes_correct_build_id_to_poll_fn(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.get_build_status.return_value = BuildStatus(
|
||||||
|
status="completed", result="succeeded", build_url=None
|
||||||
|
)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"ci_build_id": 77, "repo_name": "my-repo"}
|
||||||
|
|
||||||
|
async def fake_poll_until(*, poll_fn, is_done, interval_seconds, max_wait_seconds, sleep_fn=None):
|
||||||
|
result = await poll_fn()
|
||||||
|
return result, True
|
||||||
|
|
||||||
|
with patch("release_agent.graph.ci_nodes.poll_until", side_effect=fake_poll_until):
|
||||||
|
await poll_ci_build(state, config)
|
||||||
|
|
||||||
|
clients.azdo.get_build_status.assert_called_once_with(build_id=77)
|
||||||
|
|
||||||
|
async def test_result_none_when_poll_returns_none(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"ci_build_id": 42, "repo_name": "my-repo"}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"release_agent.graph.ci_nodes.poll_until",
|
||||||
|
return_value=(None, False),
|
||||||
|
):
|
||||||
|
result = await poll_ci_build(state, config)
|
||||||
|
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# notify_ci_result
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestNotifyCiResult:
|
||||||
|
"""Tests for notify_ci_result node."""
|
||||||
|
|
||||||
|
async def test_sends_notification_on_success(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_notification.return_value = True
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"ci_build_status": "completed",
|
||||||
|
"ci_build_result": "succeeded",
|
||||||
|
"ci_build_url": "https://build/99",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await notify_ci_result(state, config)
|
||||||
|
|
||||||
|
clients.slack.send_notification.assert_called_once()
|
||||||
|
assert "messages" in result
|
||||||
|
|
||||||
|
async def test_sends_notification_on_failure(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_notification.return_value = True
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"ci_build_status": "completed",
|
||||||
|
"ci_build_result": "failed",
|
||||||
|
"ci_build_url": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await notify_ci_result(state, config)
|
||||||
|
|
||||||
|
clients.slack.send_notification.assert_called_once()
|
||||||
|
|
||||||
|
async def test_handles_slack_error_gracefully(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_notification.side_effect = ServiceError(
|
||||||
|
service="slack", status_code=500, detail="Slack error"
|
||||||
|
)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"ci_build_result": "succeeded",
|
||||||
|
"ci_build_url": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await notify_ci_result(state, config)
|
||||||
|
|
||||||
|
# Should not re-raise; should append error
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
async def test_includes_repo_name_in_message(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_notification.return_value = True
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"repo_name": "super-service",
|
||||||
|
"ci_build_result": "succeeded",
|
||||||
|
"ci_build_url": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
await notify_ci_result(state, config)
|
||||||
|
|
||||||
|
call_kwargs = clients.slack.send_notification.call_args[1]
|
||||||
|
text_or_blocks = str(call_kwargs)
|
||||||
|
assert "super-service" in text_or_blocks
|
||||||
|
|
||||||
|
async def test_returns_empty_dict_when_state_has_no_data(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_notification.return_value = True
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
result = await notify_ci_result(state, config)
|
||||||
|
|
||||||
|
# Should not crash; may return messages or empty dict
|
||||||
|
assert isinstance(result, dict)
|
||||||
283
tests/graph/test_dependencies.py
Normal file
283
tests/graph/test_dependencies.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""Tests for graph/dependencies.py. Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- ToolClients frozen dataclass
|
||||||
|
- StagingStore Protocol (structural check)
|
||||||
|
- JsonFileStagingStore file I/O operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import date
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients
|
||||||
|
from release_agent.models.release import ArchivedRelease, StagingRelease
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
|
||||||
|
return TicketEntry(
|
||||||
|
id=ticket_id,
|
||||||
|
summary="Fix something",
|
||||||
|
pr_id="42",
|
||||||
|
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
|
||||||
|
pr_title="Fix: something",
|
||||||
|
branch=f"feature/{ticket_id}-fix",
|
||||||
|
merged_at=date(2025, 1, 15),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease:
|
||||||
|
return StagingRelease(
|
||||||
|
version=version,
|
||||||
|
repo=repo,
|
||||||
|
started_at=date(2025, 1, 1),
|
||||||
|
tickets=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ToolClients tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestToolClients:
|
||||||
|
"""Tests for the ToolClients frozen dataclass."""
|
||||||
|
|
||||||
|
def test_can_be_constructed_with_all_fields(self) -> None:
|
||||||
|
azdo = AsyncMock()
|
||||||
|
jira = AsyncMock()
|
||||||
|
slack = AsyncMock()
|
||||||
|
reviewer = AsyncMock()
|
||||||
|
clients = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
|
||||||
|
assert clients.azdo is azdo
|
||||||
|
assert clients.jira is jira
|
||||||
|
assert clients.slack is slack
|
||||||
|
assert clients.reviewer is reviewer
|
||||||
|
|
||||||
|
def test_is_frozen_cannot_reassign_field(self) -> None:
|
||||||
|
clients = ToolClients(
|
||||||
|
azdo=AsyncMock(), jira=AsyncMock(), slack=AsyncMock(), reviewer=AsyncMock()
|
||||||
|
)
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
clients.azdo = AsyncMock() # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_fields_are_accessible_by_name(self) -> None:
|
||||||
|
azdo = object()
|
||||||
|
clients = ToolClients(
|
||||||
|
azdo=azdo, jira=object(), slack=object(), reviewer=object()
|
||||||
|
)
|
||||||
|
assert clients.azdo is azdo
|
||||||
|
|
||||||
|
def test_equality_for_same_instances(self) -> None:
|
||||||
|
azdo = AsyncMock()
|
||||||
|
jira = AsyncMock()
|
||||||
|
slack = AsyncMock()
|
||||||
|
reviewer = AsyncMock()
|
||||||
|
c1 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
|
||||||
|
c2 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
|
||||||
|
assert c1 == c2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# StagingStore Protocol structural tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestStagingStoreProtocol:
|
||||||
|
"""Verify that the Protocol is structurally correct."""
|
||||||
|
|
||||||
|
def test_json_file_store_satisfies_protocol(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
# runtime_checkable would need @runtime_checkable; check duck-typing instead
|
||||||
|
assert hasattr(store, "load")
|
||||||
|
assert hasattr(store, "save")
|
||||||
|
assert hasattr(store, "archive")
|
||||||
|
assert hasattr(store, "list_versions")
|
||||||
|
|
||||||
|
def test_protocol_is_importable(self) -> None:
|
||||||
|
# Just import-level check
|
||||||
|
assert StagingStore is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JsonFileStagingStore tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestJsonFileStagingStore:
|
||||||
|
"""Tests for JsonFileStagingStore using tmp_path for file I/O."""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# load
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_load_returns_none_when_file_missing(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
result = await store.load("nonexistent-repo")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
loaded = await store.load("my-repo")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.version == "v1.0.0"
|
||||||
|
assert loaded.repo == "my-repo"
|
||||||
|
|
||||||
|
async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging().add_ticket(_make_ticket("BILL-10"))
|
||||||
|
await store.save(staging)
|
||||||
|
loaded = await store.load("my-repo")
|
||||||
|
assert loaded is not None
|
||||||
|
assert len(loaded.tickets) == 1
|
||||||
|
assert loaded.tickets[0].id == "BILL-10"
|
||||||
|
|
||||||
|
async def test_load_is_read_only_does_not_mutate_stored(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
loaded1 = await store.load("my-repo")
|
||||||
|
loaded2 = await store.load("my-repo")
|
||||||
|
assert loaded1 is not loaded2 # fresh objects each time
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# save
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_save_creates_file_in_directory(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging(repo="api-service")
|
||||||
|
await store.save(staging)
|
||||||
|
expected_path = tmp_path / "api-service.json"
|
||||||
|
assert expected_path.exists()
|
||||||
|
|
||||||
|
async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging_v1 = _make_staging(version="v1.0.0")
|
||||||
|
staging_v2 = _make_staging(version="v1.0.1")
|
||||||
|
await store.save(staging_v1)
|
||||||
|
await store.save(staging_v2)
|
||||||
|
loaded = await store.load("my-repo")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.version == "v1.0.1"
|
||||||
|
|
||||||
|
async def test_save_writes_valid_json(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
raw = (tmp_path / "my-repo.json").read_text()
|
||||||
|
data = json.loads(raw)
|
||||||
|
assert data["version"] == "v1.0.0"
|
||||||
|
assert data["repo"] == "my-repo"
|
||||||
|
|
||||||
|
async def test_save_does_not_mutate_staging_release(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
original_tickets = list(staging.tickets)
|
||||||
|
await store.save(staging)
|
||||||
|
assert list(staging.tickets) == original_tickets
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# archive
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_archive_removes_staging_file(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
await store.archive(staging, date(2025, 6, 1))
|
||||||
|
assert await store.load("my-repo") is None
|
||||||
|
|
||||||
|
async def test_archive_creates_archive_file(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging(repo="my-repo", version="v1.0.0")
|
||||||
|
await store.save(staging)
|
||||||
|
await store.archive(staging, date(2025, 6, 1))
|
||||||
|
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
|
||||||
|
assert archive_path.exists()
|
||||||
|
|
||||||
|
async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
release_date = date(2025, 6, 1)
|
||||||
|
await store.archive(staging, release_date)
|
||||||
|
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
|
||||||
|
data = json.loads(archive_path.read_text())
|
||||||
|
assert data["released_at"] == "2025-06-01"
|
||||||
|
|
||||||
|
async def test_archive_without_prior_save_creates_archive(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.archive(staging, date(2025, 6, 1))
|
||||||
|
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
|
||||||
|
assert archive_path.exists()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# list_versions
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_list_versions_empty_directory(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
assert versions == []
|
||||||
|
|
||||||
|
async def test_list_versions_returns_version_from_staging_file(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
await store.save(_make_staging(version="v2.1.0"))
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
assert "v2.1.0" in versions
|
||||||
|
|
||||||
|
async def test_list_versions_includes_archived_versions(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging(version="v1.5.0")
|
||||||
|
await store.save(staging)
|
||||||
|
await store.archive(staging, date(2025, 3, 1))
|
||||||
|
# Now save a new staging for the same repo
|
||||||
|
await store.save(_make_staging(version="v1.6.0"))
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
assert "v1.5.0" in versions
|
||||||
|
assert "v1.6.0" in versions
|
||||||
|
|
||||||
|
async def test_list_versions_only_returns_versions_for_given_repo(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
await store.save(_make_staging(repo="repo-a", version="v1.0.0"))
|
||||||
|
await store.save(_make_staging(repo="repo-b", version="v2.0.0"))
|
||||||
|
versions_a = await store.list_versions("repo-a")
|
||||||
|
assert "v1.0.0" in versions_a
|
||||||
|
# repo-b version should not appear in repo-a's list
|
||||||
|
assert "v2.0.0" not in versions_a
|
||||||
|
|
||||||
|
async def test_list_versions_no_duplicates(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
await store.save(_make_staging(version="v1.0.0"))
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
assert len(versions) == len(set(versions))
|
||||||
|
|
||||||
|
async def test_list_versions_multiple_archives(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
for i in range(3):
|
||||||
|
staging = _make_staging(version=f"v1.0.{i}")
|
||||||
|
await store.archive(staging, date(2025, 1, i + 1))
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
assert len(versions) == 3
|
||||||
|
assert "v1.0.0" in versions
|
||||||
|
assert "v1.0.1" in versions
|
||||||
|
assert "v1.0.2" in versions
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# directory creation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_store_creates_directory_if_not_exists(self, tmp_path: Path) -> None:
|
||||||
|
new_dir = tmp_path / "staging_data"
|
||||||
|
assert not new_dir.exists()
|
||||||
|
JsonFileStagingStore(directory=new_dir)
|
||||||
|
assert new_dir.exists()
|
||||||
177
tests/graph/test_dependencies_async.py
Normal file
177
tests/graph/test_dependencies_async.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""Tests for async StagingStore protocol and async JsonFileStagingStore.
|
||||||
|
|
||||||
|
Phase 5 - Step 1: All StagingStore methods become async def.
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import date
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients
|
||||||
|
from release_agent.models.release import StagingRelease
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
|
||||||
|
return TicketEntry(
|
||||||
|
id=ticket_id,
|
||||||
|
summary="Fix something",
|
||||||
|
pr_id="42",
|
||||||
|
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
|
||||||
|
pr_title="Fix: something",
|
||||||
|
branch=f"feature/{ticket_id}-fix",
|
||||||
|
merged_at=date(2025, 1, 15),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease:
|
||||||
|
return StagingRelease(
|
||||||
|
version=version,
|
||||||
|
repo=repo,
|
||||||
|
started_at=date(2025, 1, 1),
|
||||||
|
tickets=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Protocol: all methods must be async
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestStagingStoreProtocolIsAsync:
|
||||||
|
"""Verify that StagingStore protocol methods are async-compatible."""
|
||||||
|
|
||||||
|
def test_protocol_has_load_method(self) -> None:
|
||||||
|
assert hasattr(StagingStore, "load")
|
||||||
|
|
||||||
|
def test_protocol_has_save_method(self) -> None:
|
||||||
|
assert hasattr(StagingStore, "save")
|
||||||
|
|
||||||
|
def test_protocol_has_archive_method(self) -> None:
|
||||||
|
assert hasattr(StagingStore, "archive")
|
||||||
|
|
||||||
|
def test_protocol_has_list_versions_method(self) -> None:
|
||||||
|
assert hasattr(StagingStore, "list_versions")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JsonFileStagingStore async interface
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestJsonFileStagingStoreAsync:
|
||||||
|
"""Verify that JsonFileStagingStore methods are awaitable (async def)."""
|
||||||
|
|
||||||
|
async def test_load_is_awaitable(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
result = await store.load("nonexistent-repo")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
loaded = await store.load("my-repo")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.version == "v1.0.0"
|
||||||
|
assert loaded.repo == "my-repo"
|
||||||
|
|
||||||
|
async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging().add_ticket(_make_ticket("BILL-10"))
|
||||||
|
await store.save(staging)
|
||||||
|
loaded = await store.load("my-repo")
|
||||||
|
assert loaded is not None
|
||||||
|
assert len(loaded.tickets) == 1
|
||||||
|
assert loaded.tickets[0].id == "BILL-10"
|
||||||
|
|
||||||
|
async def test_load_returns_fresh_objects(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
loaded1 = await store.load("my-repo")
|
||||||
|
loaded2 = await store.load("my-repo")
|
||||||
|
assert loaded1 is not loaded2
|
||||||
|
|
||||||
|
async def test_save_is_awaitable(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging(repo="api-service")
|
||||||
|
await store.save(staging)
|
||||||
|
expected_path = tmp_path / "api-service.json"
|
||||||
|
assert expected_path.exists()
|
||||||
|
|
||||||
|
async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
await store.save(_make_staging(version="v1.0.0"))
|
||||||
|
await store.save(_make_staging(version="v1.0.1"))
|
||||||
|
loaded = await store.load("my-repo")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.version == "v1.0.1"
|
||||||
|
|
||||||
|
async def test_save_writes_valid_json(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
raw = (tmp_path / "my-repo.json").read_text()
|
||||||
|
data = json.loads(raw)
|
||||||
|
assert data["version"] == "v1.0.0"
|
||||||
|
assert data["repo"] == "my-repo"
|
||||||
|
|
||||||
|
async def test_archive_is_awaitable(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
await store.archive(staging, date(2025, 6, 1))
|
||||||
|
assert await store.load("my-repo") is None
|
||||||
|
|
||||||
|
async def test_archive_creates_archive_file(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging(repo="my-repo", version="v1.0.0")
|
||||||
|
await store.save(staging)
|
||||||
|
await store.archive(staging, date(2025, 6, 1))
|
||||||
|
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
|
||||||
|
assert archive_path.exists()
|
||||||
|
|
||||||
|
async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await store.save(staging)
|
||||||
|
await store.archive(staging, date(2025, 6, 1))
|
||||||
|
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
|
||||||
|
data = json.loads(archive_path.read_text())
|
||||||
|
assert data["released_at"] == "2025-06-01"
|
||||||
|
|
||||||
|
async def test_list_versions_is_awaitable(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
assert versions == []
|
||||||
|
|
||||||
|
async def test_list_versions_returns_staging_version(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
await store.save(_make_staging(version="v2.1.0"))
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
assert "v2.1.0" in versions
|
||||||
|
|
||||||
|
async def test_list_versions_includes_archived(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging(version="v1.5.0")
|
||||||
|
await store.save(staging)
|
||||||
|
await store.archive(staging, date(2025, 3, 1))
|
||||||
|
await store.save(_make_staging(version="v1.6.0"))
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
assert "v1.5.0" in versions
|
||||||
|
assert "v1.6.0" in versions
|
||||||
|
|
||||||
|
async def test_list_versions_only_for_given_repo(self, tmp_path: Path) -> None:
|
||||||
|
store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
await store.save(_make_staging(repo="repo-a", version="v1.0.0"))
|
||||||
|
await store.save(_make_staging(repo="repo-b", version="v2.0.0"))
|
||||||
|
versions_a = await store.list_versions("repo-a")
|
||||||
|
assert "v1.0.0" in versions_a
|
||||||
|
assert "v2.0.0" not in versions_a
|
||||||
53
tests/graph/test_full_cycle.py
Normal file
53
tests/graph/test_full_cycle.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Tests for graph/full_cycle.py.
|
||||||
|
|
||||||
|
Tests that the full cycle graph composes pr_completed and release subgraphs
|
||||||
|
correctly, and that the routing conditional edge works as expected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from release_agent.graph.full_cycle import build_full_cycle_graph
|
||||||
|
from release_agent.graph.routing import should_continue_to_release
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildFullCycleGraph:
|
||||||
|
def test_returns_compiled_graph(self) -> None:
|
||||||
|
graph = build_full_cycle_graph()
|
||||||
|
assert graph is not None
|
||||||
|
|
||||||
|
def test_graph_can_be_built_multiple_times(self) -> None:
|
||||||
|
graph1 = build_full_cycle_graph()
|
||||||
|
graph2 = build_full_cycle_graph()
|
||||||
|
assert graph1 is not None
|
||||||
|
assert graph2 is not None
|
||||||
|
|
||||||
|
def test_graph_has_get_graph_method(self) -> None:
|
||||||
|
graph = build_full_cycle_graph()
|
||||||
|
assert hasattr(graph, "get_graph") or hasattr(graph, "nodes")
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullCycleRouting:
|
||||||
|
"""Test that the routing function used by full_cycle correctly
|
||||||
|
determines whether to continue to the release subgraph."""
|
||||||
|
|
||||||
|
def test_continue_when_flag_true_and_no_errors(self) -> None:
|
||||||
|
state = {"continue_to_release": True, "errors": []}
|
||||||
|
assert should_continue_to_release(state) == "yes"
|
||||||
|
|
||||||
|
def test_stop_when_flag_false(self) -> None:
|
||||||
|
state = {"continue_to_release": False}
|
||||||
|
assert should_continue_to_release(state) == "no"
|
||||||
|
|
||||||
|
def test_stop_when_flag_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert should_continue_to_release(state) == "no"
|
||||||
|
|
||||||
|
def test_stop_when_errors_present(self) -> None:
|
||||||
|
state = {"continue_to_release": True, "errors": ["some error"]}
|
||||||
|
assert should_continue_to_release(state) == "no"
|
||||||
|
|
||||||
|
def test_stop_when_flag_true_but_errors_present(self) -> None:
|
||||||
|
state = {"continue_to_release": True, "errors": ["critical failure"]}
|
||||||
|
assert should_continue_to_release(state) == "no"
|
||||||
|
|
||||||
|
def test_continue_when_errors_empty_list(self) -> None:
|
||||||
|
state = {"continue_to_release": True, "errors": []}
|
||||||
|
assert should_continue_to_release(state) == "yes"
|
||||||
356
tests/graph/test_polling.py
Normal file
356
tests/graph/test_polling.py
Normal file
@@ -0,0 +1,356 @@
|
|||||||
|
"""Tests for graph/polling.py — poll_until async utility.
|
||||||
|
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
All tests inject a fake_sleep_fn that returns immediately to avoid real waits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, call
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.graph.polling import poll_until
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _immediate_sleep(seconds: float) -> None:
|
||||||
|
"""Drop-in replacement for asyncio.sleep that returns immediately."""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Success path tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPollUntilSuccess:
|
||||||
|
"""Tests for the happy path where poll_fn succeeds before timeout."""
|
||||||
|
|
||||||
|
async def test_returns_tuple_of_result_and_completed_true(self) -> None:
|
||||||
|
calls = iter(["running", "running", "completed"])
|
||||||
|
async def poll_fn():
|
||||||
|
return next(calls)
|
||||||
|
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "completed",
|
||||||
|
interval_seconds=1,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "completed"
|
||||||
|
assert completed is True
|
||||||
|
|
||||||
|
async def test_returns_immediately_when_already_done(self) -> None:
|
||||||
|
async def poll_fn():
|
||||||
|
return "completed"
|
||||||
|
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "completed",
|
||||||
|
interval_seconds=1,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "completed"
|
||||||
|
assert completed is True
|
||||||
|
|
||||||
|
async def test_polls_multiple_times_before_done(self) -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return "done" if call_count >= 3 else "pending"
|
||||||
|
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
interval_seconds=1,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "done"
|
||||||
|
assert completed is True
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
async def test_sleep_called_between_polls(self) -> None:
|
||||||
|
call_count = 0
|
||||||
|
sleep_calls: list[float] = []
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return "done" if call_count >= 2 else "pending"
|
||||||
|
|
||||||
|
async def tracking_sleep(seconds: float) -> None:
|
||||||
|
sleep_calls.append(seconds)
|
||||||
|
|
||||||
|
await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
interval_seconds=15,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=tracking_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(sleep_calls) >= 1
|
||||||
|
assert all(s == 15 for s in sleep_calls)
|
||||||
|
|
||||||
|
async def test_no_sleep_on_first_successful_poll(self) -> None:
|
||||||
|
sleep_calls: list[float] = []
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
async def tracking_sleep(seconds: float) -> None:
|
||||||
|
sleep_calls.append(seconds)
|
||||||
|
|
||||||
|
await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
interval_seconds=10,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=tracking_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sleep_calls == []
|
||||||
|
|
||||||
|
async def test_works_with_dict_results(self) -> None:
|
||||||
|
responses = iter([
|
||||||
|
{"status": "inProgress"},
|
||||||
|
{"status": "completed", "result": "succeeded"},
|
||||||
|
])
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
return next(responses)
|
||||||
|
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r["status"] == "completed",
|
||||||
|
interval_seconds=1,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["result"] == "succeeded"
|
||||||
|
assert completed is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Timeout tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPollUntilTimeout:
|
||||||
|
"""Tests for timeout behavior."""
|
||||||
|
|
||||||
|
async def test_returns_last_result_and_completed_false_on_timeout(self) -> None:
|
||||||
|
async def poll_fn():
|
||||||
|
return "still_running"
|
||||||
|
|
||||||
|
# With interval=10, max_wait=5, it should time out after one poll
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
interval_seconds=10,
|
||||||
|
max_wait_seconds=5,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "still_running"
|
||||||
|
assert completed is False
|
||||||
|
|
||||||
|
async def test_at_least_one_poll_happens_before_timeout(self) -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return "running"
|
||||||
|
|
||||||
|
await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
interval_seconds=100,
|
||||||
|
max_wait_seconds=1,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert call_count >= 1
|
||||||
|
|
||||||
|
async def test_max_polls_bounded_by_max_wait_over_interval(self) -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return "running"
|
||||||
|
|
||||||
|
await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: False,
|
||||||
|
interval_seconds=10,
|
||||||
|
max_wait_seconds=30,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
# With interval=10, max_wait=30: should poll at most ceil(30/10)+1 = 4 times
|
||||||
|
assert call_count <= 5
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Error handling tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPollUntilErrorHandling:
|
||||||
|
"""Tests for error/exception handling in poll_until."""
|
||||||
|
|
||||||
|
async def test_continues_after_transient_exception(self) -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
raise RuntimeError("Transient error")
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
interval_seconds=1,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "done"
|
||||||
|
assert completed is True
|
||||||
|
|
||||||
|
async def test_aborts_after_three_consecutive_failures(self) -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise RuntimeError("Persistent error")
|
||||||
|
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: True,
|
||||||
|
interval_seconds=1,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should abort after 3 consecutive failures
|
||||||
|
assert call_count == 3
|
||||||
|
assert completed is False
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_resets_consecutive_failure_count_on_success(self) -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
# Fail twice, succeed once, fail twice, succeed (done)
|
||||||
|
if call_count in (1, 2):
|
||||||
|
raise RuntimeError("fail")
|
||||||
|
if call_count == 3:
|
||||||
|
return "running"
|
||||||
|
if call_count in (4, 5):
|
||||||
|
raise RuntimeError("fail again")
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
interval_seconds=1,
|
||||||
|
max_wait_seconds=120,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completed is True
|
||||||
|
assert result == "done"
|
||||||
|
|
||||||
|
async def test_single_exception_does_not_abort(self) -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise ValueError("one error")
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
interval_seconds=1,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completed is True
|
||||||
|
assert result == "done"
|
||||||
|
|
||||||
|
async def test_two_consecutive_failures_do_not_abort(self) -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count <= 2:
|
||||||
|
raise ConnectionError("two errors")
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
result, completed = await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
interval_seconds=1,
|
||||||
|
max_wait_seconds=60,
|
||||||
|
sleep_fn=_immediate_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completed is True
|
||||||
|
assert result == "done"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Default parameter tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPollUntilDefaults:
|
||||||
|
"""Tests that default parameters match the spec."""
|
||||||
|
|
||||||
|
async def test_default_interval_is_30_seconds(self) -> None:
|
||||||
|
sleep_calls: list[float] = []
|
||||||
|
|
||||||
|
async def poll_fn():
|
||||||
|
return "done" if len(sleep_calls) >= 1 else "running"
|
||||||
|
|
||||||
|
async def tracking_sleep(seconds: float) -> None:
|
||||||
|
sleep_calls.append(seconds)
|
||||||
|
|
||||||
|
await poll_until(
|
||||||
|
poll_fn=poll_fn,
|
||||||
|
is_done=lambda r: r == "done",
|
||||||
|
sleep_fn=tracking_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
if sleep_calls:
|
||||||
|
assert sleep_calls[0] == 30
|
||||||
|
|
||||||
|
async def test_poll_fn_and_is_done_are_keyword_only(self) -> None:
|
||||||
|
"""poll_fn and is_done must be passed as keyword arguments."""
|
||||||
|
async def poll_fn():
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
await poll_until(poll_fn, lambda r: r == "done") # type: ignore[call-arg]
|
||||||
414
tests/graph/test_postgres_staging_store.py
Normal file
414
tests/graph/test_postgres_staging_store.py
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
"""Tests for PostgresStagingStore.
|
||||||
|
|
||||||
|
Phase 5 - Step 2: PostgreSQL-backed StagingStore using async pool.
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
All tests use FakeAsyncPool — no real PostgreSQL required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import date
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.graph.postgres_staging_store import PostgresStagingStore
|
||||||
|
from release_agent.models.release import ArchivedRelease, StagingRelease
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fake pool infrastructure
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class FakeAsyncCursor:
|
||||||
|
"""Records SQL calls and returns configured results."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.executed: list[tuple[str, tuple]] = []
|
||||||
|
self._fetchone_result: tuple | None = None
|
||||||
|
self._fetchall_result: list[tuple] = []
|
||||||
|
|
||||||
|
def set_fetchone(self, row: tuple | None) -> None:
|
||||||
|
self._fetchone_result = row
|
||||||
|
|
||||||
|
def set_fetchall(self, rows: list[tuple]) -> None:
|
||||||
|
self._fetchall_result = rows
|
||||||
|
|
||||||
|
async def execute(self, sql: str, params: tuple = ()) -> None:
|
||||||
|
self.executed.append((sql, params))
|
||||||
|
|
||||||
|
async def fetchone(self) -> tuple | None:
|
||||||
|
return self._fetchone_result
|
||||||
|
|
||||||
|
async def fetchall(self) -> list[tuple]:
|
||||||
|
return self._fetchall_result
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAsyncTransaction:
|
||||||
|
"""Fake async transaction context manager (no-op)."""
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAsyncConnection:
|
||||||
|
"""Async context manager returning a FakeAsyncCursor."""
|
||||||
|
|
||||||
|
def __init__(self, cursor: FakeAsyncCursor) -> None:
|
||||||
|
self._cursor = cursor
|
||||||
|
|
||||||
|
def cursor(self):
|
||||||
|
return self._cursor
|
||||||
|
|
||||||
|
def transaction(self):
|
||||||
|
return FakeAsyncTransaction()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAsyncPool:
|
||||||
|
"""Records all SQL executed through it."""
|
||||||
|
|
||||||
|
def __init__(self, cursor: FakeAsyncCursor) -> None:
|
||||||
|
self._cursor = cursor
|
||||||
|
self._conn = FakeAsyncConnection(cursor)
|
||||||
|
|
||||||
|
def connection(self):
|
||||||
|
return self._conn
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
|
||||||
|
return TicketEntry(
|
||||||
|
id=ticket_id,
|
||||||
|
summary="Fix something",
|
||||||
|
pr_id="42",
|
||||||
|
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
|
||||||
|
pr_title="Fix: something",
|
||||||
|
branch=f"feature/{ticket_id}-fix",
|
||||||
|
merged_at=date(2025, 1, 15),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_staging(
|
||||||
|
repo: str = "my-repo",
|
||||||
|
version: str = "v1.0.0",
|
||||||
|
tickets: list | None = None,
|
||||||
|
) -> StagingRelease:
|
||||||
|
t = tickets if tickets is not None else []
|
||||||
|
return StagingRelease(
|
||||||
|
version=version,
|
||||||
|
repo=repo,
|
||||||
|
started_at=date(2025, 1, 1),
|
||||||
|
tickets=t,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _staging_row(staging: StagingRelease) -> tuple:
|
||||||
|
"""Return (repo, version, started_at, tickets_json) as DB would store it."""
|
||||||
|
return (
|
||||||
|
staging.repo,
|
||||||
|
staging.version,
|
||||||
|
staging.started_at.isoformat(),
|
||||||
|
json.dumps([t.model_dump(mode="json") for t in staging.tickets]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# load()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPostgresStagingStoreLoad:
|
||||||
|
async def test_load_returns_none_when_no_row(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone(None)
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
result = await store.load("nonexistent-repo")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_load_returns_staging_release_when_row_exists(self) -> None:
|
||||||
|
staging = _make_staging(repo="api-service", version="v2.0.0")
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone((
|
||||||
|
staging.repo,
|
||||||
|
staging.version,
|
||||||
|
staging.started_at.isoformat(),
|
||||||
|
json.dumps([]),
|
||||||
|
))
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
result = await store.load("api-service")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, StagingRelease)
|
||||||
|
assert result.repo == "api-service"
|
||||||
|
assert result.version == "v2.0.0"
|
||||||
|
|
||||||
|
async def test_load_returns_staging_with_tickets(self) -> None:
|
||||||
|
ticket = _make_ticket("BILL-42")
|
||||||
|
staging = _make_staging(tickets=[ticket])
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone((
|
||||||
|
staging.repo,
|
||||||
|
staging.version,
|
||||||
|
staging.started_at.isoformat(),
|
||||||
|
json.dumps([ticket.model_dump(mode="json")]),
|
||||||
|
))
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
result = await store.load("my-repo")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result.tickets) == 1
|
||||||
|
assert result.tickets[0].id == "BILL-42"
|
||||||
|
|
||||||
|
async def test_load_executes_select_with_correct_repo(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone(None)
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
await store.load("target-repo")
|
||||||
|
|
||||||
|
assert len(cursor.executed) >= 1
|
||||||
|
sql, params = cursor.executed[-1]
|
||||||
|
assert "SELECT" in sql.upper()
|
||||||
|
assert "target-repo" in params
|
||||||
|
|
||||||
|
async def test_load_queries_staging_releases_table(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone(None)
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
await store.load("my-repo")
|
||||||
|
|
||||||
|
sql, _ = cursor.executed[-1]
|
||||||
|
assert "staging_releases" in sql
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# save()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPostgresStagingStoreSave:
|
||||||
|
async def test_save_executes_upsert(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
staging = _make_staging()
|
||||||
|
|
||||||
|
await store.save(staging)
|
||||||
|
|
||||||
|
assert len(cursor.executed) >= 1
|
||||||
|
sql, _ = cursor.executed[-1]
|
||||||
|
# Should be an INSERT ... ON CONFLICT ... or UPSERT
|
||||||
|
assert "INSERT" in sql.upper() or "UPSERT" in sql.upper()
|
||||||
|
|
||||||
|
async def test_save_passes_repo_to_upsert(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
staging = _make_staging(repo="payment-service")
|
||||||
|
|
||||||
|
await store.save(staging)
|
||||||
|
|
||||||
|
_, params = cursor.executed[-1]
|
||||||
|
assert "payment-service" in params
|
||||||
|
|
||||||
|
async def test_save_passes_version_to_upsert(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
staging = _make_staging(version="v3.1.0")
|
||||||
|
|
||||||
|
await store.save(staging)
|
||||||
|
|
||||||
|
_, params = cursor.executed[-1]
|
||||||
|
assert "v3.1.0" in params
|
||||||
|
|
||||||
|
async def test_save_targets_staging_releases_table(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
staging = _make_staging()
|
||||||
|
|
||||||
|
await store.save(staging)
|
||||||
|
|
||||||
|
sql, _ = cursor.executed[-1]
|
||||||
|
assert "staging_releases" in sql
|
||||||
|
|
||||||
|
async def test_save_serializes_tickets_as_json(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
staging = _make_staging(tickets=[_make_ticket("ALLPOST-99")])
|
||||||
|
|
||||||
|
await store.save(staging)
|
||||||
|
|
||||||
|
_, params = cursor.executed[-1]
|
||||||
|
# tickets param should be a JSON string containing the ticket id
|
||||||
|
tickets_json = next(p for p in params if isinstance(p, str) and "ALLPOST-99" in p)
|
||||||
|
parsed = json.loads(tickets_json)
|
||||||
|
assert parsed[0]["id"] == "ALLPOST-99"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# archive()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPostgresStagingStoreArchive:
|
||||||
|
async def test_archive_inserts_into_archived_releases(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
staging = _make_staging()
|
||||||
|
|
||||||
|
await store.archive(staging, date(2025, 6, 1))
|
||||||
|
|
||||||
|
sql_statements = [sql for sql, _ in cursor.executed]
|
||||||
|
assert any("archived_releases" in sql for sql in sql_statements)
|
||||||
|
|
||||||
|
async def test_archive_deletes_from_staging_releases(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
staging = _make_staging(repo="my-repo")
|
||||||
|
|
||||||
|
await store.archive(staging, date(2025, 6, 1))
|
||||||
|
|
||||||
|
sql_statements = [sql for sql, _ in cursor.executed]
|
||||||
|
assert any("DELETE" in sql.upper() and "staging_releases" in sql for sql in sql_statements)
|
||||||
|
|
||||||
|
async def test_archive_passes_released_at_date(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
staging = _make_staging()
|
||||||
|
release_date = date(2025, 12, 31)
|
||||||
|
|
||||||
|
await store.archive(staging, release_date)
|
||||||
|
|
||||||
|
all_params = [params for _, params in cursor.executed]
|
||||||
|
all_values = [v for params in all_params for v in params]
|
||||||
|
assert "2025-12-31" in all_values or release_date.isoformat() in all_values
|
||||||
|
|
||||||
|
async def test_archive_passes_repo_to_delete(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
staging = _make_staging(repo="payment-service")
|
||||||
|
|
||||||
|
await store.archive(staging, date(2025, 6, 1))
|
||||||
|
|
||||||
|
delete_calls = [(sql, params) for sql, params in cursor.executed if "DELETE" in sql.upper()]
|
||||||
|
assert len(delete_calls) >= 1
|
||||||
|
_, params = delete_calls[0]
|
||||||
|
assert "payment-service" in params
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# list_versions()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPostgresStagingStoreListVersions:
|
||||||
|
async def test_list_versions_returns_empty_when_no_data(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone(None)
|
||||||
|
cursor.set_fetchall([])
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
|
||||||
|
assert versions == []
|
||||||
|
|
||||||
|
async def test_list_versions_includes_staging_version(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
# fetchone returns staging row
|
||||||
|
cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]"))
|
||||||
|
# fetchall returns archived rows
|
||||||
|
cursor.set_fetchall([])
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
|
||||||
|
assert "v1.0.0" in versions
|
||||||
|
|
||||||
|
async def test_list_versions_includes_archived_versions(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone(None)
|
||||||
|
cursor.set_fetchall([
|
||||||
|
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-06-01"),
|
||||||
|
("my-repo", "v1.1.0", "2025-02-01", "[]", "2025-07-01"),
|
||||||
|
])
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
|
||||||
|
assert "v1.0.0" in versions
|
||||||
|
assert "v1.1.0" in versions
|
||||||
|
|
||||||
|
async def test_list_versions_combines_staging_and_archived(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone(("my-repo", "v2.0.0", "2025-03-01", "[]"))
|
||||||
|
cursor.set_fetchall([
|
||||||
|
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"),
|
||||||
|
])
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
|
||||||
|
assert "v2.0.0" in versions
|
||||||
|
assert "v1.0.0" in versions
|
||||||
|
|
||||||
|
async def test_list_versions_no_duplicates(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]"))
|
||||||
|
cursor.set_fetchall([
|
||||||
|
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"),
|
||||||
|
])
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
versions = await store.list_versions("my-repo")
|
||||||
|
|
||||||
|
assert len(versions) == len(set(versions))
|
||||||
|
|
||||||
|
async def test_list_versions_executes_queries_for_correct_repo(self) -> None:
|
||||||
|
cursor = FakeAsyncCursor()
|
||||||
|
cursor.set_fetchone(None)
|
||||||
|
cursor.set_fetchall([])
|
||||||
|
pool = FakeAsyncPool(cursor)
|
||||||
|
store = PostgresStagingStore(pool=pool)
|
||||||
|
|
||||||
|
await store.list_versions("target-repo")
|
||||||
|
|
||||||
|
all_params = [params for _, params in cursor.executed]
|
||||||
|
all_values = [v for params in all_params for v in params]
|
||||||
|
assert "target-repo" in all_values
|
||||||
956
tests/graph/test_pr_completed.py
Normal file
956
tests/graph/test_pr_completed.py
Normal file
@@ -0,0 +1,956 @@
|
|||||||
|
"""Tests for graph/pr_completed.py node functions. Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
Each node is an async function (state, config) -> dict.
|
||||||
|
Tests call nodes directly with a state dict and config dict — no graph compilation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import date, datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients
|
||||||
|
from release_agent.graph.pr_completed import (
|
||||||
|
_post_review_to_pr,
|
||||||
|
add_jira_pr_link,
|
||||||
|
auto_create_ticket,
|
||||||
|
calculate_version,
|
||||||
|
evaluate_review,
|
||||||
|
fetch_pr_details,
|
||||||
|
interrupt_confirm_merge,
|
||||||
|
merge_pr_node,
|
||||||
|
move_jira_code_review,
|
||||||
|
move_jira_ready_for_stage,
|
||||||
|
notify_request_changes,
|
||||||
|
parse_webhook,
|
||||||
|
run_code_review,
|
||||||
|
update_staging,
|
||||||
|
build_pr_completed_graph,
|
||||||
|
)
|
||||||
|
from release_agent.models.review import ReviewIssue
|
||||||
|
from release_agent.models.jira import JiraIssue
|
||||||
|
from release_agent.models.pr import PRInfo
|
||||||
|
from release_agent.models.review import ReviewResult
|
||||||
|
from tests.graph.conftest import build_config, build_mock_clients
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Webhook payload fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_webhook_payload(
|
||||||
|
*,
|
||||||
|
repo_name: str = "my-repo",
|
||||||
|
pr_id: int = 42,
|
||||||
|
source_ref: str = "refs/heads/feature/ALLPOST-100_fix-bug",
|
||||||
|
target_ref: str = "refs/heads/main",
|
||||||
|
status: str = "completed",
|
||||||
|
title: str = "Fix: bug",
|
||||||
|
closed_date: str | None = "2025-01-15T10:00:00Z",
|
||||||
|
) -> dict:
|
||||||
|
# Uses snake_case keys to match WebhookPayload Pydantic model field names
|
||||||
|
return {
|
||||||
|
"subscription_id": "sub-1",
|
||||||
|
"event_type": "git.pullrequest.merged",
|
||||||
|
"resource": {
|
||||||
|
"repository": {
|
||||||
|
"id": "repo-id-1",
|
||||||
|
"name": repo_name,
|
||||||
|
"web_url": "https://dev.azure.com/org/proj/_git/my-repo",
|
||||||
|
},
|
||||||
|
"pull_request_id": pr_id,
|
||||||
|
"title": title,
|
||||||
|
"source_ref_name": source_ref,
|
||||||
|
"target_ref_name": target_ref,
|
||||||
|
"status": status,
|
||||||
|
"closed_date": closed_date,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pr_info(
|
||||||
|
*,
|
||||||
|
pr_id: str = "42",
|
||||||
|
repo_name: str = "my-repo",
|
||||||
|
branch: str = "refs/heads/feature/ALLPOST-100-fix-bug",
|
||||||
|
status: str = "completed",
|
||||||
|
) -> PRInfo:
|
||||||
|
return PRInfo(
|
||||||
|
pr_id=pr_id,
|
||||||
|
pr_url="https://dev.azure.com/org/proj/_git/my-repo/pullrequest/42",
|
||||||
|
repo_name=repo_name,
|
||||||
|
branch=branch,
|
||||||
|
pr_title="Fix: bug",
|
||||||
|
pr_status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_approve_review() -> dict:
|
||||||
|
return {
|
||||||
|
"verdict": "approve",
|
||||||
|
"summary": "Looks good",
|
||||||
|
"issues": [],
|
||||||
|
"has_blockers": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_request_changes_review() -> dict:
|
||||||
|
return {
|
||||||
|
"verdict": "request_changes",
|
||||||
|
"summary": "Needs work",
|
||||||
|
"issues": [{"severity": "blocker", "description": "Missing tests"}],
|
||||||
|
"has_blockers": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# parse_webhook
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestParseWebhook:
|
||||||
|
async def test_extracts_pr_info_from_payload(self) -> None:
|
||||||
|
state = {"webhook_payload": _make_webhook_payload()}
|
||||||
|
config = build_config()
|
||||||
|
result = await parse_webhook(state, config)
|
||||||
|
assert "pr_info" in result
|
||||||
|
pr = result["pr_info"]
|
||||||
|
assert pr["pr_id"] == "42"
|
||||||
|
assert pr["repo_name"] == "my-repo"
|
||||||
|
|
||||||
|
async def test_extracts_ticket_from_branch(self) -> None:
|
||||||
|
state = {"webhook_payload": _make_webhook_payload(
|
||||||
|
source_ref="refs/heads/feature/ALLPOST-100_fix-bug"
|
||||||
|
)}
|
||||||
|
config = build_config()
|
||||||
|
result = await parse_webhook(state, config)
|
||||||
|
assert result["ticket_id"] == "ALLPOST-100"
|
||||||
|
assert result["has_ticket"] is True
|
||||||
|
|
||||||
|
async def test_no_ticket_when_branch_has_none(self) -> None:
|
||||||
|
state = {"webhook_payload": _make_webhook_payload(
|
||||||
|
source_ref="refs/heads/bugfix/generic_fix"
|
||||||
|
)}
|
||||||
|
config = build_config()
|
||||||
|
result = await parse_webhook(state, config)
|
||||||
|
assert result["has_ticket"] is False
|
||||||
|
assert result["ticket_id"] is None
|
||||||
|
|
||||||
|
async def test_sets_repo_name(self) -> None:
|
||||||
|
state = {"webhook_payload": _make_webhook_payload(repo_name="backend-api")}
|
||||||
|
config = build_config()
|
||||||
|
result = await parse_webhook(state, config)
|
||||||
|
assert result["repo_name"] == "backend-api"
|
||||||
|
|
||||||
|
async def test_sets_pr_id_as_string(self) -> None:
|
||||||
|
state = {"webhook_payload": _make_webhook_payload(pr_id=99)}
|
||||||
|
config = build_config()
|
||||||
|
result = await parse_webhook(state, config)
|
||||||
|
assert result["pr_info"]["pr_id"] == "99"
|
||||||
|
|
||||||
|
async def test_invalid_payload_adds_error(self) -> None:
|
||||||
|
state = {"webhook_payload": {"bad": "data"}}
|
||||||
|
config = build_config()
|
||||||
|
result = await parse_webhook(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
assert len(result["errors"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# fetch_pr_details
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFetchPrDetails:
|
||||||
|
async def test_fetches_pr_and_sets_pr_already_merged_false(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
pr = _make_pr_info(status="active")
|
||||||
|
clients.azdo.get_pr = AsyncMock(return_value=pr)
|
||||||
|
clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py")
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "active"}}
|
||||||
|
result = await fetch_pr_details(state, config)
|
||||||
|
assert result["pr_already_merged"] is False
|
||||||
|
assert result["pr_diff"] == "edit: main.py"
|
||||||
|
|
||||||
|
async def test_sets_pr_already_merged_true_when_completed(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
pr = _make_pr_info(status="completed")
|
||||||
|
clients.azdo.get_pr = AsyncMock(return_value=pr)
|
||||||
|
clients.azdo.get_pr_diff = AsyncMock(return_value="")
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "completed"}}
|
||||||
|
result = await fetch_pr_details(state, config)
|
||||||
|
assert result["pr_already_merged"] is True
|
||||||
|
|
||||||
|
async def test_stores_last_merge_source_commit(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
pr = _make_pr_info(status="active")
|
||||||
|
clients.azdo.get_pr = AsyncMock(return_value=pr)
|
||||||
|
clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py")
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"pr_id": "42", "pr_info": {"pr_id": "42"}}
|
||||||
|
result = await fetch_pr_details(state, config)
|
||||||
|
# last_merge_source_commit may be None if pr doesn't have it, but key must be present
|
||||||
|
assert "last_merge_source_commit" in result
|
||||||
|
|
||||||
|
async def test_adds_error_on_service_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.get_pr = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="azdo", status_code=500, detail="Server error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"pr_id": "42"}
|
||||||
|
result = await fetch_pr_details(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
assert len(result["errors"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# move_jira_code_review
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMoveJiraCodeReview:
|
||||||
|
async def test_transitions_ticket_when_has_ticket(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
|
||||||
|
result = await move_jira_code_review(state, config)
|
||||||
|
clients.jira.transition_issue.assert_called_once_with("ALLPOST-100", "code review")
|
||||||
|
assert result == {} or "messages" in result
|
||||||
|
|
||||||
|
async def test_skips_when_no_ticket(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"has_ticket": False, "ticket_id": None}
|
||||||
|
result = await move_jira_code_review(state, config)
|
||||||
|
clients.jira.transition_issue.assert_not_called()
|
||||||
|
|
||||||
|
async def test_appends_error_on_jira_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="jira", status_code=500, detail="Jira down"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
|
||||||
|
result = await move_jira_code_review(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
assert len(result["errors"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_code_review
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRunCodeReview:
|
||||||
|
async def test_calls_reviewer_with_diff(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
|
||||||
|
clients.reviewer.review_pr = AsyncMock(return_value=review)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pr_diff": "edit: main.py",
|
||||||
|
"pr_info": {"pr_title": "Fix: bug", "repo_name": "my-repo"},
|
||||||
|
}
|
||||||
|
result = await run_code_review(state, config)
|
||||||
|
clients.reviewer.review_pr.assert_called_once()
|
||||||
|
assert "review_result" in result
|
||||||
|
|
||||||
|
async def test_stores_review_result_as_dict(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
review = ReviewResult(verdict="approve", summary="Clean code", issues=())
|
||||||
|
clients.reviewer.review_pr = AsyncMock(return_value=review)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pr_diff": "edit: main.py",
|
||||||
|
"pr_info": {"pr_title": "Fix", "repo_name": "repo"},
|
||||||
|
}
|
||||||
|
result = await run_code_review(state, config)
|
||||||
|
assert result["review_result"]["verdict"] == "approve"
|
||||||
|
|
||||||
|
async def test_adds_error_on_reviewer_failure(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.reviewer.review_pr = AsyncMock(side_effect=Exception("API error"))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pr_diff": "edit: main.py",
|
||||||
|
"pr_info": {"pr_title": "Fix", "repo_name": "repo"},
|
||||||
|
}
|
||||||
|
result = await run_code_review(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _post_review_to_pr
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPostReviewToPr:
|
||||||
|
async def test_posts_summary_comment(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.add_pr_comment = AsyncMock()
|
||||||
|
clients.azdo.add_pr_inline_comment = AsyncMock()
|
||||||
|
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
|
||||||
|
await _post_review_to_pr(clients, "my-repo", 42, review)
|
||||||
|
clients.azdo.add_pr_comment.assert_called_once()
|
||||||
|
call_kwargs = clients.azdo.add_pr_comment.call_args
|
||||||
|
assert "APPROVE" in call_kwargs.kwargs["content"]
|
||||||
|
|
||||||
|
async def test_posts_inline_comment_for_issue_with_file_and_line(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.add_pr_comment = AsyncMock()
|
||||||
|
clients.azdo.add_pr_inline_comment = AsyncMock()
|
||||||
|
issue = ReviewIssue(
|
||||||
|
severity="error", description="Null check missing",
|
||||||
|
file_path="src/Foo.cs", line_start=42, suggestion="Add null guard",
|
||||||
|
)
|
||||||
|
review = ReviewResult(verdict="request_changes", summary="Issues", issues=(issue,))
|
||||||
|
await _post_review_to_pr(clients, "my-repo", 42, review)
|
||||||
|
clients.azdo.add_pr_inline_comment.assert_called_once()
|
||||||
|
call_kwargs = clients.azdo.add_pr_inline_comment.call_args.kwargs
|
||||||
|
assert call_kwargs["file_path"] == "src/Foo.cs"
|
||||||
|
assert call_kwargs["line_start"] == 42
|
||||||
|
assert "Null check missing" in call_kwargs["content"]
|
||||||
|
assert "Add null guard" in call_kwargs["content"]
|
||||||
|
|
||||||
|
async def test_skips_inline_for_issue_without_line(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.add_pr_comment = AsyncMock()
|
||||||
|
clients.azdo.add_pr_inline_comment = AsyncMock()
|
||||||
|
issue = ReviewIssue(severity="warning", description="Style issue", file_path="src/Foo.cs")
|
||||||
|
review = ReviewResult(verdict="approve", summary="OK", issues=(issue,))
|
||||||
|
await _post_review_to_pr(clients, "my-repo", 42, review)
|
||||||
|
clients.azdo.add_pr_inline_comment.assert_not_called()
|
||||||
|
|
||||||
|
async def test_skips_inline_for_issue_without_file(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.add_pr_comment = AsyncMock()
|
||||||
|
clients.azdo.add_pr_inline_comment = AsyncMock()
|
||||||
|
issue = ReviewIssue(severity="info", description="General note", line_start=10)
|
||||||
|
review = ReviewResult(verdict="approve", summary="OK", issues=(issue,))
|
||||||
|
await _post_review_to_pr(clients, "my-repo", 42, review)
|
||||||
|
clients.azdo.add_pr_inline_comment.assert_not_called()
|
||||||
|
|
||||||
|
async def test_inline_failure_does_not_prevent_summary(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.add_pr_comment = AsyncMock()
|
||||||
|
clients.azdo.add_pr_inline_comment = AsyncMock(side_effect=Exception("API error"))
|
||||||
|
issue = ReviewIssue(
|
||||||
|
severity="blocker", description="Critical", file_path="a.cs", line_start=1
|
||||||
|
)
|
||||||
|
review = ReviewResult(verdict="request_changes", summary="Bad", issues=(issue,))
|
||||||
|
await _post_review_to_pr(clients, "my-repo", 42, review)
|
||||||
|
# Summary should still be posted even though inline failed
|
||||||
|
clients.azdo.add_pr_comment.assert_called_once()
|
||||||
|
|
||||||
|
async def test_summary_failure_does_not_raise(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.add_pr_comment = AsyncMock(side_effect=Exception("Network error"))
|
||||||
|
clients.azdo.add_pr_inline_comment = AsyncMock()
|
||||||
|
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
|
||||||
|
# Should not raise
|
||||||
|
await _post_review_to_pr(clients, "my-repo", 42, review)
|
||||||
|
|
||||||
|
async def test_summary_contains_issue_count(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.add_pr_comment = AsyncMock()
|
||||||
|
clients.azdo.add_pr_inline_comment = AsyncMock()
|
||||||
|
issues = (
|
||||||
|
ReviewIssue(severity="warning", description="Issue 1"),
|
||||||
|
ReviewIssue(severity="error", description="Issue 2"),
|
||||||
|
)
|
||||||
|
review = ReviewResult(verdict="request_changes", summary="Problems", issues=issues)
|
||||||
|
await _post_review_to_pr(clients, "my-repo", 42, review)
|
||||||
|
content = clients.azdo.add_pr_comment.call_args.kwargs["content"]
|
||||||
|
assert "2 issue(s)" in content
|
||||||
|
|
||||||
|
async def test_run_code_review_calls_post_review(self) -> None:
|
||||||
|
"""Integration: run_code_review posts comments when pr_id and repo_name present."""
|
||||||
|
clients = build_mock_clients()
|
||||||
|
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
|
||||||
|
clients.reviewer.review_pr = AsyncMock(return_value=review)
|
||||||
|
clients.azdo.add_pr_comment = AsyncMock()
|
||||||
|
clients.azdo.add_pr_inline_comment = AsyncMock()
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pr_diff": "edit: main.py",
|
||||||
|
"pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "my-repo"},
|
||||||
|
}
|
||||||
|
await run_code_review(state, config)
|
||||||
|
clients.azdo.add_pr_comment.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# evaluate_review
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEvaluateReview:
|
||||||
|
async def test_sets_review_approved_true_for_approve_verdict(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
state = {"review_result": _make_approve_review()}
|
||||||
|
result = await evaluate_review(state, config)
|
||||||
|
assert result["review_approved"] is True
|
||||||
|
|
||||||
|
async def test_sets_review_approved_false_for_request_changes(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
state = {"review_result": _make_request_changes_review()}
|
||||||
|
result = await evaluate_review(state, config)
|
||||||
|
assert result["review_approved"] is False
|
||||||
|
|
||||||
|
async def test_sets_false_when_review_result_missing(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
state = {}
|
||||||
|
result = await evaluate_review(state, config)
|
||||||
|
assert result["review_approved"] is False
|
||||||
|
|
||||||
|
async def test_sets_false_when_has_blockers(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"review_result": {
|
||||||
|
"verdict": "approve",
|
||||||
|
"summary": "Approve with blocker?",
|
||||||
|
"issues": [{"severity": "blocker", "description": "Problem"}],
|
||||||
|
"has_blockers": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = await evaluate_review(state, config)
|
||||||
|
assert result["review_approved"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# interrupt_confirm_merge
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestInterruptConfirmMerge:
|
||||||
|
async def test_calls_interrupt_with_summary_string(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"},
|
||||||
|
"review_result": {"summary": "LGTM"},
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = "confirm"
|
||||||
|
await interrupt_confirm_merge(state, config)
|
||||||
|
mock_interrupt.assert_called_once()
|
||||||
|
call_arg = mock_interrupt.call_args[0][0]
|
||||||
|
assert isinstance(call_arg, str)
|
||||||
|
assert len(call_arg) > 0
|
||||||
|
|
||||||
|
async def test_interrupt_value_contains_pr_info(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"pr_info": {"pr_id": "42", "pr_title": "Fix: auth bug", "repo_name": "backend"},
|
||||||
|
"review_result": {"summary": "All good"},
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = "confirm"
|
||||||
|
await interrupt_confirm_merge(state, config)
|
||||||
|
call_arg = mock_interrupt.call_args[0][0]
|
||||||
|
assert "42" in call_arg or "Fix: auth bug" in call_arg or "backend" in call_arg
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# merge_pr_node
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMergePrNode:
|
||||||
|
async def test_calls_azdo_merge_pr(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.merge_pr = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pr_info": {"pr_id": "42"},
|
||||||
|
"last_merge_source_commit": "abc123",
|
||||||
|
}
|
||||||
|
result = await merge_pr_node(state, config)
|
||||||
|
clients.azdo.merge_pr.assert_called_once_with(
|
||||||
|
pr_id=42, last_merge_source_commit="abc123"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_returns_message_on_success(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.merge_pr = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pr_info": {"pr_id": "42"},
|
||||||
|
"last_merge_source_commit": "abc123",
|
||||||
|
}
|
||||||
|
result = await merge_pr_node(state, config)
|
||||||
|
assert "messages" in result
|
||||||
|
|
||||||
|
async def test_re_raises_on_service_error(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="azdo", status_code=409, detail="Conflict"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pr_info": {"pr_id": "42"},
|
||||||
|
"last_merge_source_commit": "abc123",
|
||||||
|
}
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
await merge_pr_node(state, config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# move_jira_ready_for_stage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMoveJiraReadyForStage:
|
||||||
|
async def test_transitions_ticket(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
|
||||||
|
result = await move_jira_ready_for_stage(state, config)
|
||||||
|
clients.jira.transition_issue.assert_called_once_with(
|
||||||
|
"ALLPOST-100", "Ready for stage (2)"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_skips_when_no_ticket(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock()
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"has_ticket": False}
|
||||||
|
await move_jira_ready_for_stage(state, config)
|
||||||
|
clients.jira.transition_issue.assert_not_called()
|
||||||
|
|
||||||
|
async def test_appends_error_on_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="jira", status_code=500, detail="Error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
|
||||||
|
result = await move_jira_ready_for_stage(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# add_jira_pr_link
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAddJiraPrLink:
|
||||||
|
async def test_calls_add_remote_link(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.add_remote_link = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"ticket_id": "ALLPOST-100",
|
||||||
|
"has_ticket": True,
|
||||||
|
"pr_info": {
|
||||||
|
"pr_id": "42",
|
||||||
|
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
|
||||||
|
"pr_title": "Fix: bug",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = await add_jira_pr_link(state, config)
|
||||||
|
clients.jira.add_remote_link.assert_called_once()
|
||||||
|
|
||||||
|
async def test_skips_when_no_ticket(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.add_remote_link = AsyncMock()
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"has_ticket": False}
|
||||||
|
await add_jira_pr_link(state, config)
|
||||||
|
clients.jira.add_remote_link.assert_not_called()
|
||||||
|
|
||||||
|
async def test_appends_error_on_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.add_remote_link = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="jira", status_code=500, detail="Error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"ticket_id": "ALLPOST-100",
|
||||||
|
"has_ticket": True,
|
||||||
|
"pr_info": {
|
||||||
|
"pr_id": "42",
|
||||||
|
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
|
||||||
|
"pr_title": "Fix",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = await add_jira_pr_link(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# calculate_version
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCalculateVersion:
|
||||||
|
async def test_returns_v1_0_0_for_empty_store(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
result = await calculate_version(state, config)
|
||||||
|
assert result["version"] == "v1.0.0"
|
||||||
|
|
||||||
|
async def test_increments_patch_version(self, tmp_path) -> None:
|
||||||
|
from release_agent.models.release import StagingRelease
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
# Pre-populate with an existing version
|
||||||
|
staging = StagingRelease(
|
||||||
|
version="v1.0.5",
|
||||||
|
repo="my-repo",
|
||||||
|
started_at=date(2025, 1, 1),
|
||||||
|
tickets=[],
|
||||||
|
)
|
||||||
|
await staging_store.archive(staging, date(2025, 1, 10))
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
result = await calculate_version(state, config)
|
||||||
|
assert result["version"] == "v1.0.6"
|
||||||
|
|
||||||
|
async def test_sets_version_in_state(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {"repo_name": "new-repo"}
|
||||||
|
result = await calculate_version(state, config)
|
||||||
|
assert "version" in result
|
||||||
|
assert result["version"].startswith("v")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# update_staging
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestUpdateStaging:
|
||||||
|
async def test_creates_new_staging_when_none_exists(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
# Jira get_issue returns a summary
|
||||||
|
from release_agent.models.jira import JiraIssue
|
||||||
|
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
|
||||||
|
key="ALLPOST-100", summary="Fix auth bug", status="Ready for stage (2)"
|
||||||
|
))
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"ticket_id": "ALLPOST-100",
|
||||||
|
"has_ticket": True,
|
||||||
|
"pr_info": {
|
||||||
|
"pr_id": "42",
|
||||||
|
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
|
||||||
|
"pr_title": "Fix: auth bug",
|
||||||
|
"branch": "feature/ALLPOST-100-fix",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = await update_staging(state, config)
|
||||||
|
loaded = await staging_store.load("my-repo")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.has_ticket("ALLPOST-100")
|
||||||
|
|
||||||
|
async def test_appends_ticket_to_existing_staging(self, tmp_path) -> None:
|
||||||
|
from datetime import date
|
||||||
|
from release_agent.models.release import StagingRelease
|
||||||
|
from release_agent.models.jira import JiraIssue
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
existing = StagingRelease(
|
||||||
|
version="v1.0.0", repo="my-repo",
|
||||||
|
started_at=date(2025, 1, 1), tickets=[]
|
||||||
|
)
|
||||||
|
await staging_store.save(existing)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
|
||||||
|
key="BILL-99", summary="New feature", status="Ready"
|
||||||
|
))
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"ticket_id": "BILL-99",
|
||||||
|
"has_ticket": True,
|
||||||
|
"pr_info": {
|
||||||
|
"pr_id": "55",
|
||||||
|
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/55",
|
||||||
|
"pr_title": "Feat: new feature",
|
||||||
|
"branch": "feature/BILL-99-feat",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await update_staging(state, config)
|
||||||
|
loaded = await staging_store.load("my-repo")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.has_ticket("BILL-99")
|
||||||
|
|
||||||
|
async def test_skips_ticket_add_when_no_ticket(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"has_ticket": False,
|
||||||
|
}
|
||||||
|
await update_staging(state, config)
|
||||||
|
# No staging file should be created for ticket-less PR if no existing staging
|
||||||
|
# (or staging exists without new ticket added)
|
||||||
|
clients.jira.get_issue.assert_not_called()
|
||||||
|
|
||||||
|
async def test_returns_empty_dict_when_no_staging_store(self) -> None:
|
||||||
|
from release_agent.models.jira import JiraIssue
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
|
||||||
|
key="ALLPOST-1", summary="Fix", status="Ready"
|
||||||
|
))
|
||||||
|
config = build_config(clients, staging_store=None)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"ticket_id": "ALLPOST-1",
|
||||||
|
"has_ticket": True,
|
||||||
|
"pr_info": {
|
||||||
|
"pr_id": "1",
|
||||||
|
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1",
|
||||||
|
"pr_title": "Fix",
|
||||||
|
"branch": "feature/ALLPOST-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = await update_staging(state, config)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
async def test_uses_ticket_id_as_summary_on_jira_failure(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.get_issue = AsyncMock(side_effect=Exception("Jira unavailable"))
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"ticket_id": "ALLPOST-99",
|
||||||
|
"has_ticket": True,
|
||||||
|
"pr_info": {
|
||||||
|
"pr_id": "5",
|
||||||
|
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/5",
|
||||||
|
"pr_title": "Fix something",
|
||||||
|
"branch": "feature/ALLPOST-99_fix",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = await update_staging(state, config)
|
||||||
|
loaded = await staging_store.load("my-repo")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.tickets[0].id == "ALLPOST-99"
|
||||||
|
assert loaded.tickets[0].summary == "ALLPOST-99"
|
||||||
|
|
||||||
|
async def test_sets_staging_dict_in_result(self, tmp_path) -> None:
|
||||||
|
from release_agent.models.jira import JiraIssue
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
|
||||||
|
key="ALLPOST-1", summary="S", status="Ready"
|
||||||
|
))
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"ticket_id": "ALLPOST-1",
|
||||||
|
"has_ticket": True,
|
||||||
|
"pr_info": {
|
||||||
|
"pr_id": "1",
|
||||||
|
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1",
|
||||||
|
"pr_title": "Fix",
|
||||||
|
"branch": "feature/ALLPOST-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = await update_staging(state, config)
|
||||||
|
assert "staging" in result
|
||||||
|
assert isinstance(result["staging"], dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# notify_request_changes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestNotifyRequestChanges:
|
||||||
|
async def test_calls_slack_send_approval_request(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_approval_request = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"},
|
||||||
|
"review_result": {
|
||||||
|
"verdict": "request_changes",
|
||||||
|
"summary": "Too many issues",
|
||||||
|
"issues": [{"severity": "blocker", "description": "No tests"}],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = await notify_request_changes(state, config)
|
||||||
|
clients.slack.send_approval_request.assert_called_once()
|
||||||
|
|
||||||
|
async def test_appends_error_on_slack_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_approval_request = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="slack", status_code=500, detail="Webhook error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "repo"},
|
||||||
|
"review_result": {"summary": "Issues found", "issues": []},
|
||||||
|
}
|
||||||
|
result = await notify_request_changes(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_pr_completed_graph
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# auto_create_ticket node
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAutoCreateTicket:
|
||||||
|
"""Tests for the auto_create_ticket node."""
|
||||||
|
|
||||||
|
def _make_config_with_jira_project(
|
||||||
|
self, jira_project: str = "ALLPOST"
|
||||||
|
):
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.create_issue = AsyncMock(return_value="ALLPOST-99")
|
||||||
|
clients.reviewer.generate_ticket_content = AsyncMock(
|
||||||
|
return_value=("My summary", "My description")
|
||||||
|
)
|
||||||
|
config = build_config(clients)
|
||||||
|
config["configurable"]["default_jira_project"] = jira_project
|
||||||
|
return config, clients
|
||||||
|
|
||||||
|
async def test_creates_jira_issue_and_returns_ticket_id(self) -> None:
|
||||||
|
config, clients = self._make_config_with_jira_project()
|
||||||
|
state = {
|
||||||
|
"pr_diff": "edit: main.py",
|
||||||
|
"pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await auto_create_ticket(state, config)
|
||||||
|
|
||||||
|
assert result.get("ticket_id") == "ALLPOST-99"
|
||||||
|
|
||||||
|
async def test_sets_has_ticket_true(self) -> None:
|
||||||
|
config, clients = self._make_config_with_jira_project()
|
||||||
|
state = {
|
||||||
|
"pr_diff": "edit: main.py",
|
||||||
|
"pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await auto_create_ticket(state, config)
|
||||||
|
|
||||||
|
assert result.get("has_ticket") is True
|
||||||
|
|
||||||
|
async def test_calls_generate_ticket_content(self) -> None:
|
||||||
|
config, clients = self._make_config_with_jira_project()
|
||||||
|
state = {
|
||||||
|
"pr_diff": "edit: main.py",
|
||||||
|
"pr_info": {"pr_title": "Fix login", "repo_name": "auth-service"},
|
||||||
|
}
|
||||||
|
|
||||||
|
await auto_create_ticket(state, config)
|
||||||
|
|
||||||
|
clients.reviewer.generate_ticket_content.assert_awaited_once()
|
||||||
|
|
||||||
|
async def test_calls_create_issue_with_project_key(self) -> None:
|
||||||
|
config, clients = self._make_config_with_jira_project(jira_project="MYPROJ")
|
||||||
|
clients.jira.create_issue = AsyncMock(return_value="MYPROJ-5")
|
||||||
|
config["configurable"]["default_jira_project"] = "MYPROJ"
|
||||||
|
state = {
|
||||||
|
"pr_diff": "d",
|
||||||
|
"pr_info": {"pr_title": "t", "repo_name": "r"},
|
||||||
|
}
|
||||||
|
|
||||||
|
await auto_create_ticket(state, config)
|
||||||
|
|
||||||
|
call_kwargs = clients.jira.create_issue.call_args.kwargs
|
||||||
|
assert call_kwargs["project"] == "MYPROJ"
|
||||||
|
|
||||||
|
async def test_appends_message_on_success(self) -> None:
|
||||||
|
config, _ = self._make_config_with_jira_project()
|
||||||
|
state = {
|
||||||
|
"pr_diff": "d",
|
||||||
|
"pr_info": {"pr_title": "t", "repo_name": "r"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await auto_create_ticket(state, config)
|
||||||
|
|
||||||
|
assert "messages" in result
|
||||||
|
assert len(result["messages"]) > 0
|
||||||
|
|
||||||
|
async def test_appends_error_on_create_issue_failure(self) -> None:
|
||||||
|
config, clients = self._make_config_with_jira_project()
|
||||||
|
clients.jira.create_issue = AsyncMock(side_effect=Exception("Jira down"))
|
||||||
|
state = {
|
||||||
|
"pr_diff": "d",
|
||||||
|
"pr_info": {"pr_title": "t", "repo_name": "r"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await auto_create_ticket(state, config)
|
||||||
|
|
||||||
|
assert "errors" in result
|
||||||
|
assert len(result["errors"]) > 0
|
||||||
|
|
||||||
|
async def test_appends_error_on_generate_content_failure(self) -> None:
|
||||||
|
config, clients = self._make_config_with_jira_project()
|
||||||
|
clients.reviewer.generate_ticket_content = AsyncMock(side_effect=RuntimeError("CLI fail"))
|
||||||
|
state = {
|
||||||
|
"pr_diff": "d",
|
||||||
|
"pr_info": {"pr_title": "t", "repo_name": "r"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await auto_create_ticket(state, config)
|
||||||
|
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
async def test_uses_default_project_from_config(self) -> None:
|
||||||
|
config, clients = self._make_config_with_jira_project(jira_project="TEAM")
|
||||||
|
clients.jira.create_issue = AsyncMock(return_value="TEAM-1")
|
||||||
|
state = {
|
||||||
|
"pr_diff": "d",
|
||||||
|
"pr_info": {"pr_title": "t", "repo_name": "r"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await auto_create_ticket(state, config)
|
||||||
|
|
||||||
|
assert result["ticket_id"] == "TEAM-1"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_pr_completed_graph
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBuildPrCompletedGraph:
|
||||||
|
def test_returns_compiled_graph(self) -> None:
|
||||||
|
graph = build_pr_completed_graph()
|
||||||
|
assert graph is not None
|
||||||
|
|
||||||
|
def test_graph_has_nodes(self) -> None:
|
||||||
|
graph = build_pr_completed_graph()
|
||||||
|
# The compiled graph object should be truthy
|
||||||
|
assert graph is not None
|
||||||
|
|
||||||
|
def test_graph_includes_trigger_ci_build_node(self) -> None:
|
||||||
|
graph = build_pr_completed_graph()
|
||||||
|
# Graph nodes should include CI pipeline nodes
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "trigger_ci_build" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_poll_ci_build_node(self) -> None:
|
||||||
|
graph = build_pr_completed_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "poll_ci_build" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_notify_ci_result_node(self) -> None:
|
||||||
|
graph = build_pr_completed_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "notify_ci_result" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_auto_create_ticket_node(self) -> None:
|
||||||
|
graph = build_pr_completed_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "auto_create_ticket" in graph_nodes
|
||||||
866
tests/graph/test_release.py
Normal file
866
tests/graph/test_release.py
Normal file
@@ -0,0 +1,866 @@
|
|||||||
|
"""Tests for graph/release.py node functions. Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
Each node is an async function (state, config) -> dict.
|
||||||
|
Tests call nodes directly — no graph compilation required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.graph.dependencies import JsonFileStagingStore
|
||||||
|
from release_agent.graph.release import (
|
||||||
|
approve_stage,
|
||||||
|
archive_release,
|
||||||
|
check_release_approvals,
|
||||||
|
create_release_pr,
|
||||||
|
interrupt_confirm_approve,
|
||||||
|
interrupt_confirm_merge_release,
|
||||||
|
interrupt_confirm_release,
|
||||||
|
interrupt_confirm_trigger,
|
||||||
|
list_pipelines,
|
||||||
|
load_staging,
|
||||||
|
merge_release_pr,
|
||||||
|
move_tickets_to_done,
|
||||||
|
send_slack_notification,
|
||||||
|
trigger_pipelines,
|
||||||
|
build_release_graph,
|
||||||
|
)
|
||||||
|
from release_agent.models.pipeline import PipelineInfo, ReleasePipelineStage
|
||||||
|
from release_agent.models.release import StagingRelease
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
from tests.graph.conftest import build_config, build_mock_clients
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
|
||||||
|
return TicketEntry(
|
||||||
|
id=ticket_id,
|
||||||
|
summary="Fix something",
|
||||||
|
pr_id="42",
|
||||||
|
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
|
||||||
|
pr_title="Fix: something",
|
||||||
|
branch=f"feature/{ticket_id}-fix",
|
||||||
|
merged_at=date(2025, 1, 15),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_staging(
|
||||||
|
*,
|
||||||
|
repo: str = "my-repo",
|
||||||
|
version: str = "v1.0.0",
|
||||||
|
tickets: list | None = None,
|
||||||
|
) -> StagingRelease:
|
||||||
|
t = tickets if tickets is not None else [_make_ticket()]
|
||||||
|
return StagingRelease(
|
||||||
|
version=version,
|
||||||
|
repo=repo,
|
||||||
|
started_at=date(2025, 1, 1),
|
||||||
|
tickets=t,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _staging_dict(staging: StagingRelease) -> dict:
|
||||||
|
return staging.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# load_staging
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestLoadStaging:
|
||||||
|
async def test_loads_staging_from_store(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await staging_store.save(staging)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
result = await load_staging(state, config)
|
||||||
|
assert "staging" in result
|
||||||
|
assert result["staging"]["version"] == "v1.0.0"
|
||||||
|
|
||||||
|
async def test_returns_none_when_no_staging(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {"repo_name": "nonexistent"}
|
||||||
|
result = await load_staging(state, config)
|
||||||
|
assert result.get("staging") is None
|
||||||
|
|
||||||
|
async def test_staging_includes_tickets(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging(tickets=[_make_ticket("BILL-10"), _make_ticket("BILL-11")])
|
||||||
|
await staging_store.save(staging)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
result = await load_staging(state, config)
|
||||||
|
assert len(result["staging"]["tickets"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# interrupt_confirm_release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestInterruptConfirmRelease:
|
||||||
|
async def test_calls_interrupt_with_staging_summary(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
staging = _make_staging()
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = "confirm"
|
||||||
|
await interrupt_confirm_release(state, config)
|
||||||
|
mock_interrupt.assert_called_once()
|
||||||
|
call_arg = mock_interrupt.call_args[0][0]
|
||||||
|
assert isinstance(call_arg, str)
|
||||||
|
|
||||||
|
async def test_interrupt_contains_version_and_repo(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
staging = _make_staging(version="v2.5.0", repo="backend")
|
||||||
|
state = {
|
||||||
|
"repo_name": "backend",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = "confirm"
|
||||||
|
await interrupt_confirm_release(state, config)
|
||||||
|
call_arg = mock_interrupt.call_args[0][0]
|
||||||
|
assert "v2.5.0" in call_arg or "backend" in call_arg
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# create_release_pr
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCreateReleasePr:
|
||||||
|
async def test_calls_azdo_create_pr(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.create_pr = AsyncMock(return_value={
|
||||||
|
"pullRequestId": 99,
|
||||||
|
"lastMergeSourceCommit": {"commitId": "deadbeef"},
|
||||||
|
})
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging(version="v1.2.0")
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.2.0",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
result = await create_release_pr(state, config)
|
||||||
|
clients.azdo.create_pr.assert_called_once()
|
||||||
|
call_kwargs = clients.azdo.create_pr.call_args.kwargs
|
||||||
|
assert call_kwargs["repo"] == "my-repo"
|
||||||
|
|
||||||
|
async def test_sets_release_pr_id(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.create_pr = AsyncMock(return_value={
|
||||||
|
"pullRequestId": 77,
|
||||||
|
"lastMergeSourceCommit": {"commitId": "cafe1234"},
|
||||||
|
})
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging(version="v1.0.3")
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.3",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
result = await create_release_pr(state, config)
|
||||||
|
assert result["release_pr_id"] == "77"
|
||||||
|
|
||||||
|
async def test_sets_release_pr_commit(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.create_pr = AsyncMock(return_value={
|
||||||
|
"pullRequestId": 77,
|
||||||
|
"lastMergeSourceCommit": {"commitId": "cafe1234"},
|
||||||
|
})
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging()
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
result = await create_release_pr(state, config)
|
||||||
|
assert result["release_pr_commit"] == "cafe1234"
|
||||||
|
|
||||||
|
async def test_re_raises_on_service_error(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.create_pr = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="azdo", status_code=422, detail="Invalid branch"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging()
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
await create_release_pr(state, config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# interrupt_confirm_merge_release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestInterruptConfirmMergeRelease:
|
||||||
|
async def test_calls_interrupt_with_pr_info(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"release_pr_id": "99",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = "confirm"
|
||||||
|
await interrupt_confirm_merge_release(state, config)
|
||||||
|
mock_interrupt.assert_called_once()
|
||||||
|
call_arg = mock_interrupt.call_args[0][0]
|
||||||
|
assert isinstance(call_arg, str)
|
||||||
|
assert len(call_arg) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# merge_release_pr
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMergeReleasePr:
|
||||||
|
async def test_calls_azdo_merge_pr(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.merge_pr = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"release_pr_id": "99",
|
||||||
|
"release_pr_commit": "abc123",
|
||||||
|
}
|
||||||
|
await merge_release_pr(state, config)
|
||||||
|
clients.azdo.merge_pr.assert_called_once_with(
|
||||||
|
pr_id=99, last_merge_source_commit="abc123"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_re_raises_on_service_error(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="azdo", status_code=409, detail="Conflict"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"release_pr_id": "99", "release_pr_commit": "abc"}
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
await merge_release_pr(state, config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# move_tickets_to_done
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMoveTicketsToDone:
|
||||||
|
async def test_transitions_all_tickets(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging(tickets=[_make_ticket("BILL-1"), _make_ticket("BILL-2")])
|
||||||
|
state = {"staging": _staging_dict(staging)}
|
||||||
|
await move_tickets_to_done(state, config)
|
||||||
|
assert clients.jira.transition_issue.call_count == 2
|
||||||
|
|
||||||
|
async def test_calls_transition_with_done_name(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging(tickets=[_make_ticket("BILL-1")])
|
||||||
|
state = {"staging": _staging_dict(staging)}
|
||||||
|
await move_tickets_to_done(state, config)
|
||||||
|
call_args = clients.jira.transition_issue.call_args_list[0]
|
||||||
|
ticket_id, transition = call_args[0]
|
||||||
|
assert ticket_id == "BILL-1"
|
||||||
|
assert "done" in transition.lower() or "released" in transition.lower()
|
||||||
|
|
||||||
|
async def test_appends_error_on_jira_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="jira", status_code=500, detail="Error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging(tickets=[_make_ticket()])
|
||||||
|
state = {"staging": _staging_dict(staging)}
|
||||||
|
result = await move_tickets_to_done(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
async def test_empty_tickets_no_calls(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.jira.transition_issue = AsyncMock()
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging(tickets=[])
|
||||||
|
state = {"staging": _staging_dict(staging)}
|
||||||
|
await move_tickets_to_done(state, config)
|
||||||
|
clients.jira.transition_issue.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# send_slack_notification
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSendSlackNotification:
|
||||||
|
async def test_calls_slack_send_release_notification(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_release_notification = AsyncMock(return_value=True)
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging()
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
result = await send_slack_notification(state, config)
|
||||||
|
clients.slack.send_release_notification.assert_called_once()
|
||||||
|
|
||||||
|
async def test_appends_error_on_slack_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_release_notification = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="slack", status_code=500, detail="Webhook error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
staging = _make_staging()
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
result = await send_slack_notification(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# archive_release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestArchiveRelease:
|
||||||
|
async def test_archives_staging_to_store(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging()
|
||||||
|
await staging_store.save(staging)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
await archive_release(state, config)
|
||||||
|
# Staging should be gone now
|
||||||
|
assert await staging_store.load("my-repo") is None
|
||||||
|
|
||||||
|
async def test_archive_file_created_in_store(self, tmp_path) -> None:
|
||||||
|
staging_store = JsonFileStagingStore(directory=tmp_path)
|
||||||
|
staging = _make_staging(version="v3.0.0")
|
||||||
|
await staging_store.save(staging)
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients, staging_store=staging_store)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"staging": _staging_dict(staging),
|
||||||
|
}
|
||||||
|
await archive_release(state, config)
|
||||||
|
versions = await staging_store.list_versions("my-repo")
|
||||||
|
assert "v3.0.0" in versions
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# list_pipelines
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestListPipelines:
|
||||||
|
async def test_fetches_pipelines_from_azdo(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
pipelines = [PipelineInfo(id=1, name="build", repo="my-repo")]
|
||||||
|
clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
result = await list_pipelines(state, config)
|
||||||
|
clients.azdo.list_build_pipelines.assert_called_once_with(repo="my-repo")
|
||||||
|
assert "pipelines" in result
|
||||||
|
assert len(result["pipelines"]) == 1
|
||||||
|
|
||||||
|
async def test_stores_pipelines_as_list_of_dicts(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
pipelines = [
|
||||||
|
PipelineInfo(id=1, name="build", repo="my-repo"),
|
||||||
|
PipelineInfo(id=2, name="deploy", repo="my-repo"),
|
||||||
|
]
|
||||||
|
clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
result = await list_pipelines(state, config)
|
||||||
|
assert len(result["pipelines"]) == 2
|
||||||
|
assert result["pipelines"][0]["id"] == 1
|
||||||
|
|
||||||
|
async def test_empty_pipelines_stored_as_empty_list(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.list_build_pipelines = AsyncMock(return_value=[])
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
result = await list_pipelines(state, config)
|
||||||
|
assert result["pipelines"] == []
|
||||||
|
|
||||||
|
async def test_appends_error_on_service_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.list_build_pipelines = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="azdo", status_code=500, detail="Error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
result = await list_pipelines(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# interrupt_confirm_trigger
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestInterruptConfirmTrigger:
|
||||||
|
async def test_calls_interrupt_with_pipelines_summary(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}],
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = "confirm"
|
||||||
|
await interrupt_confirm_trigger(state, config)
|
||||||
|
mock_interrupt.assert_called_once()
|
||||||
|
call_arg = mock_interrupt.call_args[0][0]
|
||||||
|
assert isinstance(call_arg, str)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# trigger_pipelines
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTriggerPipelines:
|
||||||
|
async def test_triggers_each_pipeline(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.trigger_pipeline = AsyncMock(return_value={"id": 1001})
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"pipelines": [
|
||||||
|
{"id": 1, "name": "build", "repo": "my-repo"},
|
||||||
|
{"id": 2, "name": "deploy", "repo": "my-repo"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
result = await trigger_pipelines(state, config)
|
||||||
|
assert clients.azdo.trigger_pipeline.call_count == 2
|
||||||
|
assert "triggered_builds" in result
|
||||||
|
assert len(result["triggered_builds"]) == 2
|
||||||
|
|
||||||
|
async def test_no_pipelines_no_calls(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.trigger_pipeline = AsyncMock()
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"pipelines": [],
|
||||||
|
}
|
||||||
|
result = await trigger_pipelines(state, config)
|
||||||
|
clients.azdo.trigger_pipeline.assert_not_called()
|
||||||
|
assert result["triggered_builds"] == []
|
||||||
|
|
||||||
|
async def test_appends_error_on_trigger_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.trigger_pipeline = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="azdo", status_code=500, detail="Error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}],
|
||||||
|
}
|
||||||
|
result = await trigger_pipelines(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# check_release_approvals
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCheckReleaseApprovals:
|
||||||
|
async def test_fetches_pending_approvals_from_builds(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.get_build_status = AsyncMock(return_value="completed")
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"triggered_builds": [{"id": 1001}],
|
||||||
|
}
|
||||||
|
result = await check_release_approvals(state, config)
|
||||||
|
assert "pending_approvals" in result
|
||||||
|
|
||||||
|
async def test_empty_builds_means_no_approvals(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"triggered_builds": []}
|
||||||
|
result = await check_release_approvals(state, config)
|
||||||
|
assert result["pending_approvals"] == []
|
||||||
|
|
||||||
|
async def test_appends_error_on_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.get_build_status = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="azdo", status_code=500, detail="Error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"triggered_builds": [{"id": 1001}]}
|
||||||
|
result = await check_release_approvals(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# interrupt_confirm_approve
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestInterruptConfirmApprove:
|
||||||
|
async def test_calls_interrupt_with_approvals_summary(self) -> None:
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"pending_approvals": [{"approval_id": "aaa", "stage_name": "Production"}],
|
||||||
|
"version": "v1.0.0",
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = "confirm"
|
||||||
|
await interrupt_confirm_approve(state, config)
|
||||||
|
mock_interrupt.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# approve_stage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestApproveStage:
|
||||||
|
async def test_approves_each_pending_approval(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.approve_release = AsyncMock(return_value={"status": "approved"})
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pending_approvals": [
|
||||||
|
{"approval_id": "aaa"},
|
||||||
|
{"approval_id": "bbb"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
result = await approve_stage(state, config)
|
||||||
|
assert clients.azdo.approve_release.call_count == 2
|
||||||
|
|
||||||
|
async def test_no_approvals_no_calls(self) -> None:
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.approve_release = AsyncMock()
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"pending_approvals": []}
|
||||||
|
await approve_stage(state, config)
|
||||||
|
clients.azdo.approve_release.assert_not_called()
|
||||||
|
|
||||||
|
async def test_appends_error_on_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.approve_release = AsyncMock(side_effect=ServiceError(
|
||||||
|
service="azdo", status_code=500, detail="Error"
|
||||||
|
))
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"pending_approvals": [{"approval_id": "aaa"}]}
|
||||||
|
result = await approve_stage(state, config)
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_release_graph
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBuildReleaseGraph:
|
||||||
|
def test_returns_compiled_graph(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
assert graph is not None
|
||||||
|
|
||||||
|
def test_graph_includes_trigger_ci_build_main_node(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "trigger_ci_build_main" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_poll_ci_build_main_node(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "poll_ci_build_main" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_wait_for_cd_release_node(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "wait_for_cd_release" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_poll_release_approvals_node(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "poll_release_approvals" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_interrupt_sandbox_approval_node(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "interrupt_sandbox_approval" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_interrupt_prod_approval_node(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "interrupt_prod_approval" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_execute_sandbox_approval_node(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "execute_sandbox_approval" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_execute_prod_approval_node(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "execute_prod_approval" in graph_nodes
|
||||||
|
|
||||||
|
def test_graph_includes_notify_ci_failure_node(self) -> None:
|
||||||
|
graph = build_release_graph()
|
||||||
|
graph_nodes = graph.get_graph().nodes
|
||||||
|
assert "notify_ci_failure" in graph_nodes
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# New release graph node: wait_for_cd_release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWaitForCdRelease:
|
||||||
|
"""Tests for wait_for_cd_release node."""
|
||||||
|
|
||||||
|
async def test_sets_release_id_when_found(self) -> None:
|
||||||
|
from release_agent.graph.release import wait_for_cd_release
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.get_latest_release.return_value = {"id": 100, "name": "Release-100"}
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"release_definition_id": 5, "repo_name": "my-repo"}
|
||||||
|
|
||||||
|
result = await wait_for_cd_release(state, config)
|
||||||
|
|
||||||
|
assert "release_id" in result
|
||||||
|
assert result["release_id"] == 100
|
||||||
|
|
||||||
|
async def test_appends_error_when_no_release(self) -> None:
|
||||||
|
from release_agent.graph.release import wait_for_cd_release
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.get_latest_release.return_value = {}
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"release_definition_id": 5, "repo_name": "my-repo"}
|
||||||
|
|
||||||
|
result = await wait_for_cd_release(state, config)
|
||||||
|
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
async def test_works_without_release_definition_id(self) -> None:
|
||||||
|
from release_agent.graph.release import wait_for_cd_release
|
||||||
|
clients = build_mock_clients()
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo"}
|
||||||
|
|
||||||
|
result = await wait_for_cd_release(state, config)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# New release graph node: poll_release_approvals
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPollReleaseApprovals:
|
||||||
|
"""Tests for poll_release_approvals node."""
|
||||||
|
|
||||||
|
async def test_sets_pending_approvals_from_azdo(self) -> None:
|
||||||
|
from release_agent.graph.release import poll_release_approvals
|
||||||
|
from release_agent.models.build import ApprovalRecord
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.get_release_approvals.return_value = [
|
||||||
|
ApprovalRecord(approval_id="a1", stage_name="Sandbox", status="pending", release_id=10),
|
||||||
|
]
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"release_id": 10}
|
||||||
|
|
||||||
|
result = await poll_release_approvals(state, config)
|
||||||
|
|
||||||
|
assert "pending_approvals" in result
|
||||||
|
assert len(result["pending_approvals"]) == 1
|
||||||
|
|
||||||
|
async def test_returns_empty_list_when_no_approvals(self) -> None:
|
||||||
|
from release_agent.graph.release import poll_release_approvals
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.get_release_approvals.return_value = []
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"release_id": 10}
|
||||||
|
|
||||||
|
result = await poll_release_approvals(state, config)
|
||||||
|
|
||||||
|
assert result.get("pending_approvals") == []
|
||||||
|
|
||||||
|
async def test_appends_error_on_failure(self) -> None:
|
||||||
|
from release_agent.exceptions import ServiceError
|
||||||
|
from release_agent.graph.release import poll_release_approvals
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.get_release_approvals.side_effect = ServiceError(
|
||||||
|
service="azdo", status_code=500, detail="error"
|
||||||
|
)
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"release_id": 10}
|
||||||
|
|
||||||
|
result = await poll_release_approvals(state, config)
|
||||||
|
|
||||||
|
assert "errors" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# New release graph node: interrupt_sandbox_approval
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestInterruptSandboxApproval:
|
||||||
|
async def test_calls_interrupt(self) -> None:
|
||||||
|
from release_agent.graph.release import interrupt_sandbox_approval
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
|
||||||
|
"version": "v1.0.0",
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = "confirm"
|
||||||
|
await interrupt_sandbox_approval(state, config)
|
||||||
|
mock_interrupt.assert_called_once()
|
||||||
|
|
||||||
|
async def test_sets_current_stage_to_sandbox_pending(self) -> None:
|
||||||
|
from release_agent.graph.release import interrupt_sandbox_approval
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.release.interrupt", return_value="yes"):
|
||||||
|
result = await interrupt_sandbox_approval(state, config)
|
||||||
|
assert result.get("current_stage") == "sandbox_pending"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# New release graph node: interrupt_prod_approval
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestInterruptProdApproval:
|
||||||
|
async def test_calls_interrupt(self) -> None:
|
||||||
|
from release_agent.graph.release import interrupt_prod_approval
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
|
||||||
|
"version": "v1.0.0",
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = "confirm"
|
||||||
|
await interrupt_prod_approval(state, config)
|
||||||
|
mock_interrupt.assert_called_once()
|
||||||
|
|
||||||
|
async def test_sets_current_stage_to_prod_pending(self) -> None:
|
||||||
|
from release_agent.graph.release import interrupt_prod_approval
|
||||||
|
config = build_config()
|
||||||
|
state = {
|
||||||
|
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
|
||||||
|
}
|
||||||
|
with patch("release_agent.graph.release.interrupt", return_value="yes"):
|
||||||
|
result = await interrupt_prod_approval(state, config)
|
||||||
|
assert result.get("current_stage") == "prod_pending"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# New release graph node: execute_sandbox_approval
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestExecuteSandboxApproval:
|
||||||
|
async def test_approves_sandbox_approvals(self) -> None:
|
||||||
|
from release_agent.graph.release import execute_sandbox_approval
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.approve_release.return_value = {"status": "approved"}
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pending_approvals": [{"approval_id": "sb1", "stage_name": "Sandbox"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await execute_sandbox_approval(state, config)
|
||||||
|
|
||||||
|
clients.azdo.approve_release.assert_called()
|
||||||
|
|
||||||
|
async def test_returns_empty_dict_on_success(self) -> None:
|
||||||
|
from release_agent.graph.release import execute_sandbox_approval
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.approve_release.return_value = {"status": "approved"}
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"pending_approvals": [{"approval_id": "sb1"}]}
|
||||||
|
|
||||||
|
result = await execute_sandbox_approval(state, config)
|
||||||
|
|
||||||
|
assert "errors" not in result or result["errors"] == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# New release graph node: execute_prod_approval
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestExecuteProdApproval:
|
||||||
|
async def test_approves_prod_approvals(self) -> None:
|
||||||
|
from release_agent.graph.release import execute_prod_approval
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.azdo.approve_release.return_value = {"status": "approved"}
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"pending_approvals": [{"approval_id": "pd1", "stage_name": "Production"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await execute_prod_approval(state, config)
|
||||||
|
|
||||||
|
clients.azdo.approve_release.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# New release graph node: notify_ci_failure
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestNotifyCiFailure:
|
||||||
|
async def test_sends_slack_notification(self) -> None:
|
||||||
|
from release_agent.graph.release import notify_ci_failure
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_notification.return_value = True
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"ci_build_result": "failed",
|
||||||
|
"ci_build_url": "https://build/1",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await notify_ci_failure(state, config)
|
||||||
|
|
||||||
|
clients.slack.send_notification.assert_called_once()
|
||||||
|
|
||||||
|
async def test_appends_message_on_success(self) -> None:
|
||||||
|
from release_agent.graph.release import notify_ci_failure
|
||||||
|
clients = build_mock_clients()
|
||||||
|
clients.slack.send_notification.return_value = True
|
||||||
|
config = build_config(clients)
|
||||||
|
state = {"repo_name": "my-repo", "ci_build_result": "failed"}
|
||||||
|
|
||||||
|
result = await notify_ci_failure(state, config)
|
||||||
|
|
||||||
|
assert "messages" in result or isinstance(result, dict)
|
||||||
302
tests/graph/test_routing.py
Normal file
302
tests/graph/test_routing.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
"""Tests for graph/routing.py. Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
All routing functions are pure — they take a state dict and return a string.
|
||||||
|
Every branch is tested, including missing state fields (defaults to falsy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.graph.routing import (
|
||||||
|
has_pending_approvals,
|
||||||
|
has_pipelines,
|
||||||
|
has_ticket,
|
||||||
|
is_pr_already_merged,
|
||||||
|
is_review_approved,
|
||||||
|
route_after_fetch,
|
||||||
|
route_approval_stage,
|
||||||
|
route_ci_result,
|
||||||
|
should_continue_to_release,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# is_pr_already_merged
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestIsPrAlreadyMerged:
|
||||||
|
def test_returns_merged_when_true(self) -> None:
|
||||||
|
state = {"pr_already_merged": True}
|
||||||
|
assert is_pr_already_merged(state) == "merged"
|
||||||
|
|
||||||
|
def test_returns_active_when_false(self) -> None:
|
||||||
|
state = {"pr_already_merged": False}
|
||||||
|
assert is_pr_already_merged(state) == "active"
|
||||||
|
|
||||||
|
def test_returns_active_when_field_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert is_pr_already_merged(state) == "active"
|
||||||
|
|
||||||
|
def test_returns_active_when_none(self) -> None:
|
||||||
|
state = {"pr_already_merged": None}
|
||||||
|
assert is_pr_already_merged(state) == "active"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# is_review_approved
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestIsReviewApproved:
|
||||||
|
def test_returns_approve_when_true(self) -> None:
|
||||||
|
state = {"review_approved": True}
|
||||||
|
assert is_review_approved(state) == "approve"
|
||||||
|
|
||||||
|
def test_returns_request_changes_when_false(self) -> None:
|
||||||
|
state = {"review_approved": False}
|
||||||
|
assert is_review_approved(state) == "request_changes"
|
||||||
|
|
||||||
|
def test_returns_request_changes_when_field_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert is_review_approved(state) == "request_changes"
|
||||||
|
|
||||||
|
def test_returns_request_changes_when_none(self) -> None:
|
||||||
|
state = {"review_approved": None}
|
||||||
|
assert is_review_approved(state) == "request_changes"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# has_ticket
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestHasTicket:
|
||||||
|
def test_returns_yes_when_true(self) -> None:
|
||||||
|
state = {"has_ticket": True}
|
||||||
|
assert has_ticket(state) == "yes"
|
||||||
|
|
||||||
|
def test_returns_no_when_false(self) -> None:
|
||||||
|
state = {"has_ticket": False}
|
||||||
|
assert has_ticket(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_no_when_field_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert has_ticket(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_no_when_none(self) -> None:
|
||||||
|
state = {"has_ticket": None}
|
||||||
|
assert has_ticket(state) == "no"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# should_continue_to_release
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestShouldContinueToRelease:
|
||||||
|
def test_returns_yes_when_true(self) -> None:
|
||||||
|
state = {"continue_to_release": True}
|
||||||
|
assert should_continue_to_release(state) == "yes"
|
||||||
|
|
||||||
|
def test_returns_no_when_false(self) -> None:
|
||||||
|
state = {"continue_to_release": False}
|
||||||
|
assert should_continue_to_release(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_no_when_field_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert should_continue_to_release(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_no_when_none(self) -> None:
|
||||||
|
state = {"continue_to_release": None}
|
||||||
|
assert should_continue_to_release(state) == "no"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# has_pipelines
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestHasPipelines:
|
||||||
|
def test_returns_yes_when_non_empty_list(self) -> None:
|
||||||
|
state = {"pipelines": [{"id": 1}]}
|
||||||
|
assert has_pipelines(state) == "yes"
|
||||||
|
|
||||||
|
def test_returns_no_when_empty_list(self) -> None:
|
||||||
|
state = {"pipelines": []}
|
||||||
|
assert has_pipelines(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_no_when_field_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert has_pipelines(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_no_when_none(self) -> None:
|
||||||
|
state = {"pipelines": None}
|
||||||
|
assert has_pipelines(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_yes_with_multiple_pipelines(self) -> None:
|
||||||
|
state = {"pipelines": [{"id": 1}, {"id": 2}]}
|
||||||
|
assert has_pipelines(state) == "yes"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# has_pending_approvals
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestHasPendingApprovals:
|
||||||
|
def test_returns_yes_when_non_empty_list(self) -> None:
|
||||||
|
state = {"pending_approvals": [{"approval_id": "abc"}]}
|
||||||
|
assert has_pending_approvals(state) == "yes"
|
||||||
|
|
||||||
|
def test_returns_no_when_empty_list(self) -> None:
|
||||||
|
state = {"pending_approvals": []}
|
||||||
|
assert has_pending_approvals(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_no_when_field_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert has_pending_approvals(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_no_when_none(self) -> None:
|
||||||
|
state = {"pending_approvals": None}
|
||||||
|
assert has_pending_approvals(state) == "no"
|
||||||
|
|
||||||
|
def test_returns_yes_with_multiple_approvals(self) -> None:
|
||||||
|
state = {"pending_approvals": [{"approval_id": "a"}, {"approval_id": "b"}]}
|
||||||
|
assert has_pending_approvals(state) == "yes"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# route_ci_result
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRouteCiResult:
|
||||||
|
"""Tests for route_ci_result routing function."""
|
||||||
|
|
||||||
|
def test_returns_ci_passed_when_succeeded(self) -> None:
|
||||||
|
state = {"ci_build_result": "succeeded"}
|
||||||
|
assert route_ci_result(state) == "ci_passed"
|
||||||
|
|
||||||
|
def test_returns_ci_failed_when_failed(self) -> None:
|
||||||
|
state = {"ci_build_result": "failed"}
|
||||||
|
assert route_ci_result(state) == "ci_failed"
|
||||||
|
|
||||||
|
def test_returns_ci_failed_when_canceled(self) -> None:
|
||||||
|
state = {"ci_build_result": "canceled"}
|
||||||
|
assert route_ci_result(state) == "ci_failed"
|
||||||
|
|
||||||
|
def test_returns_ci_failed_when_partially_succeeded(self) -> None:
|
||||||
|
state = {"ci_build_result": "partiallySucceeded"}
|
||||||
|
assert route_ci_result(state) == "ci_failed"
|
||||||
|
|
||||||
|
def test_returns_ci_failed_when_field_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert route_ci_result(state) == "ci_failed"
|
||||||
|
|
||||||
|
def test_returns_ci_failed_when_none(self) -> None:
|
||||||
|
state = {"ci_build_result": None}
|
||||||
|
assert route_ci_result(state) == "ci_failed"
|
||||||
|
|
||||||
|
def test_returns_ci_failed_when_empty_string(self) -> None:
|
||||||
|
state = {"ci_build_result": ""}
|
||||||
|
assert route_ci_result(state) == "ci_failed"
|
||||||
|
|
||||||
|
def test_case_sensitive_succeeded(self) -> None:
|
||||||
|
# AzDo returns "succeeded" (lowercase)
|
||||||
|
state = {"ci_build_result": "succeeded"}
|
||||||
|
assert route_ci_result(state) == "ci_passed"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# route_approval_stage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRouteApprovalStage:
|
||||||
|
"""Tests for route_approval_stage routing function."""
|
||||||
|
|
||||||
|
def test_returns_all_deployed_when_no_pending_approvals(self) -> None:
|
||||||
|
state = {"pending_approvals": []}
|
||||||
|
assert route_approval_stage(state) == "all_deployed"
|
||||||
|
|
||||||
|
def test_returns_all_deployed_when_field_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert route_approval_stage(state) == "all_deployed"
|
||||||
|
|
||||||
|
def test_returns_all_deployed_when_none(self) -> None:
|
||||||
|
state = {"pending_approvals": None}
|
||||||
|
assert route_approval_stage(state) == "all_deployed"
|
||||||
|
|
||||||
|
def test_returns_sandbox_pending_when_sandbox_approval_exists(self) -> None:
|
||||||
|
state = {
|
||||||
|
"current_stage": "sandbox_pending",
|
||||||
|
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
|
||||||
|
}
|
||||||
|
assert route_approval_stage(state) == "sandbox_pending"
|
||||||
|
|
||||||
|
def test_returns_prod_pending_when_prod_approval_exists(self) -> None:
|
||||||
|
state = {
|
||||||
|
"current_stage": "prod_pending",
|
||||||
|
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
|
||||||
|
}
|
||||||
|
assert route_approval_stage(state) == "prod_pending"
|
||||||
|
|
||||||
|
def test_uses_current_stage_field_when_present(self) -> None:
|
||||||
|
state = {
|
||||||
|
"current_stage": "sandbox_pending",
|
||||||
|
"pending_approvals": [{"approval_id": "z"}],
|
||||||
|
}
|
||||||
|
assert route_approval_stage(state) == "sandbox_pending"
|
||||||
|
|
||||||
|
def test_returns_all_deployed_when_no_current_stage_and_has_approvals(self) -> None:
|
||||||
|
# When current_stage is missing but approvals exist, stage is unknown
|
||||||
|
# so we treat as sandbox by default (first stage)
|
||||||
|
state = {
|
||||||
|
"pending_approvals": [{"approval_id": "a"}],
|
||||||
|
}
|
||||||
|
# Must return either sandbox_pending or prod_pending (not all_deployed)
|
||||||
|
result = route_approval_stage(state)
|
||||||
|
assert result in ("sandbox_pending", "prod_pending")
|
||||||
|
|
||||||
|
def test_sandbox_pending_from_current_stage(self) -> None:
|
||||||
|
state = {"current_stage": "sandbox_pending", "pending_approvals": [{"approval_id": "x"}]}
|
||||||
|
assert route_approval_stage(state) == "sandbox_pending"
|
||||||
|
|
||||||
|
def test_prod_pending_from_current_stage(self) -> None:
|
||||||
|
state = {"current_stage": "prod_pending", "pending_approvals": [{"approval_id": "x"}]}
|
||||||
|
assert route_approval_stage(state) == "prod_pending"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# route_after_fetch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRouteAfterFetch:
|
||||||
|
"""Tests for route_after_fetch — 3-way routing replacing is_pr_already_merged."""
|
||||||
|
|
||||||
|
def test_returns_merged_when_pr_already_merged(self) -> None:
|
||||||
|
state = {"pr_already_merged": True}
|
||||||
|
assert route_after_fetch(state) == "merged"
|
||||||
|
|
||||||
|
def test_returns_active_with_ticket_when_active_and_has_ticket(self) -> None:
|
||||||
|
state = {"pr_already_merged": False, "has_ticket": True}
|
||||||
|
assert route_after_fetch(state) == "active_with_ticket"
|
||||||
|
|
||||||
|
def test_returns_active_no_ticket_when_active_and_no_ticket(self) -> None:
|
||||||
|
state = {"pr_already_merged": False, "has_ticket": False}
|
||||||
|
assert route_after_fetch(state) == "active_no_ticket"
|
||||||
|
|
||||||
|
def test_returns_active_no_ticket_when_has_ticket_missing(self) -> None:
|
||||||
|
state = {"pr_already_merged": False}
|
||||||
|
assert route_after_fetch(state) == "active_no_ticket"
|
||||||
|
|
||||||
|
def test_returns_active_no_ticket_when_has_ticket_none(self) -> None:
|
||||||
|
state = {"pr_already_merged": False, "has_ticket": None}
|
||||||
|
assert route_after_fetch(state) == "active_no_ticket"
|
||||||
|
|
||||||
|
def test_returns_active_no_ticket_when_all_fields_missing(self) -> None:
|
||||||
|
state = {}
|
||||||
|
assert route_after_fetch(state) == "active_no_ticket"
|
||||||
|
|
||||||
|
def test_merged_takes_precedence_over_has_ticket(self) -> None:
|
||||||
|
# Even if has_ticket is True, merged PR should route to "merged"
|
||||||
|
state = {"pr_already_merged": True, "has_ticket": True}
|
||||||
|
assert route_after_fetch(state) == "merged"
|
||||||
|
|
||||||
|
def test_returns_active_with_ticket_ignores_merged_false(self) -> None:
|
||||||
|
state = {"pr_already_merged": False, "has_ticket": True}
|
||||||
|
result = route_after_fetch(state)
|
||||||
|
assert result != "merged"
|
||||||
|
assert result == "active_with_ticket"
|
||||||
0
tests/scripts/__init__.py
Normal file
0
tests/scripts/__init__.py
Normal file
332
tests/scripts/test_migrate.py
Normal file
332
tests/scripts/test_migrate.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
"""Tests for scripts/migrate_json_to_db.py.
|
||||||
|
|
||||||
|
Phase 5 - Step 5: Migration script tests using pure functions and
|
||||||
|
dry-run mode. No real database required.
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import date
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from scripts.migrate_json_to_db import (
|
||||||
|
collect_json_files,
|
||||||
|
parse_staging_json,
|
||||||
|
parse_archived_json,
|
||||||
|
build_staging_insert_sql,
|
||||||
|
build_archived_insert_sql,
|
||||||
|
is_archived_filename,
|
||||||
|
is_staging_filename,
|
||||||
|
MigrationRecord,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixture data (mirrors real JSON structure from release-workflow/releases/)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
STAGING_JSON = {
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"repo": "Billo.Platform.Document",
|
||||||
|
"started_at": "2026-03-17",
|
||||||
|
"tickets": [
|
||||||
|
{
|
||||||
|
"id": "ALLPOST-4219",
|
||||||
|
"summary": "Test release bot",
|
||||||
|
"pr_id": "10460",
|
||||||
|
"pr_url": "https://dev.azure.com/billodev/Billo%20App%20Platform/_git/Billo.Platform.Document/pullrequest/10460",
|
||||||
|
"pr_title": "chore: trigger release bot test",
|
||||||
|
"branch": "feature_ALLPOST-4219_test_release_bot",
|
||||||
|
"merged_at": "2026-03-17",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
PAYMENT_STAGING_JSON = {
|
||||||
|
"version": "v1.0.1",
|
||||||
|
"repo": "Billo.Platform.Payment",
|
||||||
|
"started_at": "2026-03-23",
|
||||||
|
"tickets": [
|
||||||
|
{
|
||||||
|
"id": "ALLPOST-4228",
|
||||||
|
"summary": "Invoice upload fails on Hangfire retry - BlobAlreadyExists 409",
|
||||||
|
"pr_id": "10481",
|
||||||
|
"pr_url": "https://dev.azure.com/billodev/Billo%20App%20Platform/_git/Billo.Platform.Payment/pullrequest/10481",
|
||||||
|
"pr_title": "Invoice upload fails on Hangfire retry - BlobAlreadyExists 409",
|
||||||
|
"branch": "bug/ALLPOST-4228_fix-invoice-upload-blob-already-exists",
|
||||||
|
"merged_at": "2026-03-23",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Archived JSON has an additional released_at field
|
||||||
|
ARCHIVED_JSON = {
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"repo": "Billo.Platform.Payment",
|
||||||
|
"started_at": "2026-01-01",
|
||||||
|
"tickets": [],
|
||||||
|
"released_at": "2026-01-15",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# is_staging_filename / is_archived_filename
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFileNameClassification:
|
||||||
|
def test_staging_filename_identified(self) -> None:
|
||||||
|
assert is_staging_filename("Billo.Platform.Document.json") is True
|
||||||
|
|
||||||
|
def test_archived_filename_identified(self) -> None:
|
||||||
|
assert is_archived_filename("Billo.Platform.Payment_v1.0.1_2026-03-23.json") is True
|
||||||
|
|
||||||
|
def test_staging_filename_not_archived(self) -> None:
|
||||||
|
assert is_archived_filename("Billo.Platform.Document.json") is False
|
||||||
|
|
||||||
|
def test_archived_filename_not_staging(self) -> None:
|
||||||
|
assert is_staging_filename("Billo.Platform.Payment_v1.0.1_2026-03-23.json") is False
|
||||||
|
|
||||||
|
def test_non_json_file_is_not_staging(self) -> None:
|
||||||
|
assert is_staging_filename("README.md") is False
|
||||||
|
|
||||||
|
def test_non_json_file_is_not_archived(self) -> None:
|
||||||
|
assert is_archived_filename("README.md") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# parse_staging_json
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestParseStagingJson:
|
||||||
|
def test_parse_returns_migration_record(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
assert isinstance(record, MigrationRecord)
|
||||||
|
|
||||||
|
def test_parse_extracts_repo(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
assert record.repo == "Billo.Platform.Document"
|
||||||
|
|
||||||
|
def test_parse_extracts_version(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
assert record.version == "v1.0.0"
|
||||||
|
|
||||||
|
def test_parse_extracts_started_at(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
assert record.started_at == date(2026, 3, 17)
|
||||||
|
|
||||||
|
def test_parse_extracts_tickets(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
assert len(record.tickets) == 1
|
||||||
|
assert record.tickets[0]["id"] == "ALLPOST-4219"
|
||||||
|
|
||||||
|
def test_parse_staging_has_no_released_at(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
assert record.released_at is None
|
||||||
|
|
||||||
|
def test_parse_staging_with_multiple_tickets(self) -> None:
|
||||||
|
data = {
|
||||||
|
**STAGING_JSON,
|
||||||
|
"tickets": [
|
||||||
|
{**STAGING_JSON["tickets"][0], "id": "ALLPOST-1"},
|
||||||
|
{**STAGING_JSON["tickets"][0], "id": "ALLPOST-2"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
record = parse_staging_json(data)
|
||||||
|
assert len(record.tickets) == 2
|
||||||
|
|
||||||
|
def test_parse_staging_with_empty_tickets(self) -> None:
|
||||||
|
data = {**STAGING_JSON, "tickets": []}
|
||||||
|
record = parse_staging_json(data)
|
||||||
|
assert record.tickets == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# parse_archived_json
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestParseArchivedJson:
|
||||||
|
def test_parse_returns_migration_record(self) -> None:
|
||||||
|
record = parse_archived_json(ARCHIVED_JSON)
|
||||||
|
assert isinstance(record, MigrationRecord)
|
||||||
|
|
||||||
|
def test_parse_extracts_released_at(self) -> None:
|
||||||
|
record = parse_archived_json(ARCHIVED_JSON)
|
||||||
|
assert record.released_at == date(2026, 1, 15)
|
||||||
|
|
||||||
|
def test_parse_extracts_repo(self) -> None:
|
||||||
|
record = parse_archived_json(ARCHIVED_JSON)
|
||||||
|
assert record.repo == "Billo.Platform.Payment"
|
||||||
|
|
||||||
|
def test_parse_extracts_version(self) -> None:
|
||||||
|
record = parse_archived_json(ARCHIVED_JSON)
|
||||||
|
assert record.version == "v1.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_staging_insert_sql
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBuildStagingInsertSql:
|
||||||
|
def test_returns_tuple_of_sql_and_params(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
sql, params = build_staging_insert_sql(record)
|
||||||
|
assert isinstance(sql, str)
|
||||||
|
assert isinstance(params, tuple)
|
||||||
|
|
||||||
|
def test_sql_inserts_into_staging_releases(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
sql, _ = build_staging_insert_sql(record)
|
||||||
|
assert "staging_releases" in sql
|
||||||
|
|
||||||
|
def test_sql_is_insert_statement(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
sql, _ = build_staging_insert_sql(record)
|
||||||
|
assert "INSERT" in sql.upper()
|
||||||
|
|
||||||
|
def test_params_include_repo(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
_, params = build_staging_insert_sql(record)
|
||||||
|
assert "Billo.Platform.Document" in params
|
||||||
|
|
||||||
|
def test_params_include_version(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
_, params = build_staging_insert_sql(record)
|
||||||
|
assert "v1.0.0" in params
|
||||||
|
|
||||||
|
def test_params_include_started_at(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
_, params = build_staging_insert_sql(record)
|
||||||
|
assert "2026-03-17" in params
|
||||||
|
|
||||||
|
def test_params_include_tickets_json(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
_, params = build_staging_insert_sql(record)
|
||||||
|
# tickets should be serialized as JSON string
|
||||||
|
tickets_json = next(p for p in params if isinstance(p, str) and "ALLPOST-4219" in p)
|
||||||
|
parsed = json.loads(tickets_json)
|
||||||
|
assert parsed[0]["id"] == "ALLPOST-4219"
|
||||||
|
|
||||||
|
def test_sql_uses_on_conflict_do_nothing_or_update(self) -> None:
|
||||||
|
record = parse_staging_json(STAGING_JSON)
|
||||||
|
sql, _ = build_staging_insert_sql(record)
|
||||||
|
assert "ON CONFLICT" in sql.upper() or "INSERT" in sql.upper()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_archived_insert_sql
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBuildArchivedInsertSql:
|
||||||
|
def test_returns_tuple_of_sql_and_params(self) -> None:
|
||||||
|
record = parse_archived_json(ARCHIVED_JSON)
|
||||||
|
sql, params = build_archived_insert_sql(record)
|
||||||
|
assert isinstance(sql, str)
|
||||||
|
assert isinstance(params, tuple)
|
||||||
|
|
||||||
|
def test_sql_inserts_into_archived_releases(self) -> None:
|
||||||
|
record = parse_archived_json(ARCHIVED_JSON)
|
||||||
|
sql, _ = build_archived_insert_sql(record)
|
||||||
|
assert "archived_releases" in sql
|
||||||
|
|
||||||
|
def test_sql_is_insert_statement(self) -> None:
|
||||||
|
record = parse_archived_json(ARCHIVED_JSON)
|
||||||
|
sql, _ = build_archived_insert_sql(record)
|
||||||
|
assert "INSERT" in sql.upper()
|
||||||
|
|
||||||
|
def test_params_include_released_at(self) -> None:
|
||||||
|
record = parse_archived_json(ARCHIVED_JSON)
|
||||||
|
_, params = build_archived_insert_sql(record)
|
||||||
|
assert "2026-01-15" in params
|
||||||
|
|
||||||
|
def test_params_include_repo(self) -> None:
|
||||||
|
record = parse_archived_json(ARCHIVED_JSON)
|
||||||
|
_, params = build_archived_insert_sql(record)
|
||||||
|
assert "Billo.Platform.Payment" in params
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# collect_json_files
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCollectJsonFiles:
|
||||||
|
def test_returns_empty_list_for_empty_directory(self, tmp_path: Path) -> None:
|
||||||
|
result = collect_json_files(tmp_path)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_finds_staging_json_files(self, tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON))
|
||||||
|
result = collect_json_files(tmp_path)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_finds_archived_json_files(self, tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "my-repo_v1.0.0_2025-06-01.json").write_text(
|
||||||
|
json.dumps(ARCHIVED_JSON)
|
||||||
|
)
|
||||||
|
result = collect_json_files(tmp_path)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_ignores_non_json_files(self, tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "README.md").write_text("readme")
|
||||||
|
(tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON))
|
||||||
|
result = collect_json_files(tmp_path)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_collects_from_nested_directories(self, tmp_path: Path) -> None:
|
||||||
|
repo_dir = tmp_path / "Billo.Platform.Document"
|
||||||
|
repo_dir.mkdir()
|
||||||
|
(repo_dir / "v1.0.0.json").write_text(json.dumps(STAGING_JSON))
|
||||||
|
result = collect_json_files(tmp_path)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_returns_path_objects(self, tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON))
|
||||||
|
result = collect_json_files(tmp_path)
|
||||||
|
assert all(isinstance(p, Path) for p in result)
|
||||||
|
|
||||||
|
def test_collects_multiple_files(self, tmp_path: Path) -> None:
|
||||||
|
for i in range(3):
|
||||||
|
(tmp_path / f"repo-{i}.json").write_text(json.dumps(STAGING_JSON))
|
||||||
|
result = collect_json_files(tmp_path)
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dry-run mode (integration of pure functions)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDryRunMode:
|
||||||
|
def test_dry_run_collects_records_without_db_access(self, tmp_path: Path) -> None:
|
||||||
|
"""Dry run processes files and returns SQL/params without executing."""
|
||||||
|
repo_dir = tmp_path / "Billo.Platform.Document"
|
||||||
|
repo_dir.mkdir()
|
||||||
|
(repo_dir / "v1.0.0.json").write_text(json.dumps(STAGING_JSON))
|
||||||
|
|
||||||
|
files = collect_json_files(tmp_path)
|
||||||
|
assert len(files) == 1
|
||||||
|
|
||||||
|
# Parse and build SQL — no DB connection needed
|
||||||
|
record = parse_staging_json(json.loads(files[0].read_text()))
|
||||||
|
sql, params = build_staging_insert_sql(record)
|
||||||
|
assert "INSERT" in sql.upper()
|
||||||
|
assert "Billo.Platform.Document" in params
|
||||||
|
|
||||||
|
def test_payment_staging_file_parses_correctly(self, tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "Billo.Platform.Payment.json").write_text(
|
||||||
|
json.dumps(PAYMENT_STAGING_JSON)
|
||||||
|
)
|
||||||
|
files = collect_json_files(tmp_path)
|
||||||
|
record = parse_staging_json(json.loads(files[0].read_text()))
|
||||||
|
assert record.repo == "Billo.Platform.Payment"
|
||||||
|
assert record.version == "v1.0.1"
|
||||||
|
assert len(record.tickets) == 1
|
||||||
|
assert record.tickets[0]["id"] == "ALLPOST-4228"
|
||||||
|
|
||||||
|
def test_archived_file_parses_correctly(self, tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "my-repo_v1.0.0_2026-01-15.json").write_text(
|
||||||
|
json.dumps(ARCHIVED_JSON)
|
||||||
|
)
|
||||||
|
files = collect_json_files(tmp_path)
|
||||||
|
record = parse_archived_json(json.loads(files[0].read_text()))
|
||||||
|
assert record.released_at == date(2026, 1, 15)
|
||||||
0
tests/services/__init__.py
Normal file
0
tests/services/__init__.py
Normal file
141
tests/services/test_pr_dedup.py
Normal file
141
tests/services/test_pr_dedup.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Tests for services/pr_dedup.py. Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
find_unprocessed_prs queries agent_threads to find which PRs have not yet
|
||||||
|
been processed (no existing thread for that repo+pr_id combination).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.models.pr import PRInfo
|
||||||
|
from release_agent.services.pr_dedup import find_unprocessed_prs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers — fake async pool
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_pr(pr_id: str, repo_name: str = "my-repo") -> PRInfo:
|
||||||
|
return PRInfo(
|
||||||
|
pr_id=pr_id,
|
||||||
|
pr_url=f"https://dev.azure.com/org/proj/_git/{repo_name}/pullrequest/{pr_id}",
|
||||||
|
repo_name=repo_name,
|
||||||
|
branch="refs/heads/feature/ALLPOST-100-fix",
|
||||||
|
pr_title=f"PR {pr_id}",
|
||||||
|
pr_status="active",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pool(existing_rows: list[tuple[str, str]]):
|
||||||
|
"""Return a fake async connection pool.
|
||||||
|
|
||||||
|
existing_rows: list of (pr_id, repo_name) tuples representing already-processed PRs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FakeCursor:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
async def execute(self, sql, params=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def fetchall(self):
|
||||||
|
return self._rows
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FakeConn:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
def cursor(self):
|
||||||
|
return FakeCursor(self._rows)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FakePool:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
def connection(self):
|
||||||
|
return FakeConn(self._rows)
|
||||||
|
|
||||||
|
return FakePool(existing_rows)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# find_unprocessed_prs tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFindUnprocessedPrs:
|
||||||
|
|
||||||
|
async def test_returns_all_when_none_processed(self) -> None:
|
||||||
|
prs = [_make_pr("10"), _make_pr("20")]
|
||||||
|
pool = _make_pool([])
|
||||||
|
|
||||||
|
result = await find_unprocessed_prs(pool, prs)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
async def test_returns_empty_when_all_processed(self) -> None:
|
||||||
|
prs = [_make_pr("10"), _make_pr("20")]
|
||||||
|
# existing rows: (pr_id, repo_name)
|
||||||
|
pool = _make_pool([("10", "my-repo"), ("20", "my-repo")])
|
||||||
|
|
||||||
|
result = await find_unprocessed_prs(pool, prs)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
async def test_returns_only_unprocessed(self) -> None:
|
||||||
|
prs = [_make_pr("10"), _make_pr("20"), _make_pr("30")]
|
||||||
|
pool = _make_pool([("10", "my-repo")])
|
||||||
|
|
||||||
|
result = await find_unprocessed_prs(pool, prs)
|
||||||
|
|
||||||
|
pr_ids = [p.pr_id for p in result]
|
||||||
|
assert "10" not in pr_ids
|
||||||
|
assert "20" in pr_ids
|
||||||
|
assert "30" in pr_ids
|
||||||
|
|
||||||
|
async def test_empty_input_returns_empty(self) -> None:
|
||||||
|
pool = _make_pool([])
|
||||||
|
|
||||||
|
result = await find_unprocessed_prs(pool, [])
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
async def test_different_repos_not_confused(self) -> None:
|
||||||
|
pr_repo_a = _make_pr("10", repo_name="repo-a")
|
||||||
|
pr_repo_b = _make_pr("10", repo_name="repo-b")
|
||||||
|
# Only repo-a/10 is processed
|
||||||
|
pool = _make_pool([("10", "repo-a")])
|
||||||
|
|
||||||
|
result = await find_unprocessed_prs(pool, [pr_repo_a, pr_repo_b])
|
||||||
|
|
||||||
|
# repo-b/10 should still be returned (different repo)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].repo_name == "repo-b"
|
||||||
|
|
||||||
|
async def test_returns_list_of_pr_info(self) -> None:
|
||||||
|
prs = [_make_pr("42")]
|
||||||
|
pool = _make_pool([])
|
||||||
|
|
||||||
|
result = await find_unprocessed_prs(pool, prs)
|
||||||
|
|
||||||
|
assert all(isinstance(p, PRInfo) for p in result)
|
||||||
|
|
||||||
|
async def test_preserves_pr_info_objects(self) -> None:
|
||||||
|
pr = _make_pr("77")
|
||||||
|
pool = _make_pool([])
|
||||||
|
|
||||||
|
result = await find_unprocessed_prs(pool, [pr])
|
||||||
|
|
||||||
|
assert result[0].pr_id == "77"
|
||||||
|
assert result[0].repo_name == "my-repo"
|
||||||
309
tests/services/test_pr_poller.py
Normal file
309
tests/services/test_pr_poller.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
"""Tests for services/pr_poller.py. Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
Tests verify:
|
||||||
|
- _synthesize_webhook_payload produces a valid payload dict
|
||||||
|
- run_pr_poll_loop calls list_active_prs, dedup, then schedules graph for each unprocessed PR
|
||||||
|
- Fake sleep is injected to avoid real waits
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.models.pr import PRInfo
|
||||||
|
from release_agent.services.pr_poller import _synthesize_webhook_payload, run_pr_poll_loop
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_pr(
|
||||||
|
pr_id: str = "10",
|
||||||
|
repo_name: str = "my-repo",
|
||||||
|
branch: str = "refs/heads/feature/ALLPOST-100-fix",
|
||||||
|
title: str = "Test PR",
|
||||||
|
status: str = "active",
|
||||||
|
) -> PRInfo:
|
||||||
|
return PRInfo(
|
||||||
|
pr_id=pr_id,
|
||||||
|
pr_url=f"https://dev.azure.com/org/proj/_git/{repo_name}/pullrequest/{pr_id}",
|
||||||
|
repo_name=repo_name,
|
||||||
|
branch=branch,
|
||||||
|
pr_title=title,
|
||||||
|
pr_status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _synthesize_webhook_payload tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSynthesizeWebhookPayload:
|
||||||
|
def test_returns_dict(self) -> None:
|
||||||
|
pr = _make_pr()
|
||||||
|
result = _synthesize_webhook_payload(pr)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
|
||||||
|
def test_has_resource_key(self) -> None:
|
||||||
|
pr = _make_pr()
|
||||||
|
result = _synthesize_webhook_payload(pr)
|
||||||
|
assert "resource" in result
|
||||||
|
|
||||||
|
def test_resource_contains_pull_request_id(self) -> None:
|
||||||
|
pr = _make_pr(pr_id="42")
|
||||||
|
result = _synthesize_webhook_payload(pr)
|
||||||
|
assert result["resource"]["pull_request_id"] == 42
|
||||||
|
|
||||||
|
def test_resource_contains_repository_name(self) -> None:
|
||||||
|
pr = _make_pr(repo_name="backend-api")
|
||||||
|
result = _synthesize_webhook_payload(pr)
|
||||||
|
assert result["resource"]["repository"]["name"] == "backend-api"
|
||||||
|
|
||||||
|
def test_resource_contains_title(self) -> None:
|
||||||
|
pr = _make_pr(title="My PR Title")
|
||||||
|
result = _synthesize_webhook_payload(pr)
|
||||||
|
assert result["resource"]["title"] == "My PR Title"
|
||||||
|
|
||||||
|
def test_resource_contains_source_ref_name(self) -> None:
|
||||||
|
pr = _make_pr(branch="refs/heads/feature/ALLPOST-200-test")
|
||||||
|
result = _synthesize_webhook_payload(pr)
|
||||||
|
assert result["resource"]["source_ref_name"] == "refs/heads/feature/ALLPOST-200-test"
|
||||||
|
|
||||||
|
def test_resource_status_is_active(self) -> None:
|
||||||
|
pr = _make_pr(status="active")
|
||||||
|
result = _synthesize_webhook_payload(pr)
|
||||||
|
assert result["resource"]["status"] == "active"
|
||||||
|
|
||||||
|
def test_event_type_is_pr_updated(self) -> None:
|
||||||
|
pr = _make_pr()
|
||||||
|
result = _synthesize_webhook_payload(pr)
|
||||||
|
assert "event_type" in result
|
||||||
|
|
||||||
|
def test_subscription_id_present(self) -> None:
|
||||||
|
pr = _make_pr()
|
||||||
|
result = _synthesize_webhook_payload(pr)
|
||||||
|
assert "subscription_id" in result
|
||||||
|
|
||||||
|
def test_different_prs_produce_different_payloads(self) -> None:
|
||||||
|
pr1 = _make_pr(pr_id="1")
|
||||||
|
pr2 = _make_pr(pr_id="2")
|
||||||
|
r1 = _synthesize_webhook_payload(pr1)
|
||||||
|
r2 = _synthesize_webhook_payload(pr2)
|
||||||
|
assert r1["resource"]["pull_request_id"] != r2["resource"]["pull_request_id"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_pr_poll_loop tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRunPrPollLoop:
|
||||||
|
|
||||||
|
async def test_calls_list_active_prs_for_each_repo(self) -> None:
|
||||||
|
azdo = AsyncMock()
|
||||||
|
azdo.list_active_prs = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
sleep_calls: list[float] = []
|
||||||
|
|
||||||
|
async def fake_sleep(seconds: float) -> None:
|
||||||
|
sleep_calls.append(seconds)
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await run_pr_poll_loop(
|
||||||
|
azdo_client=azdo,
|
||||||
|
db_pool=MagicMock(),
|
||||||
|
watched_repos=["repo-a", "repo-b"],
|
||||||
|
target_branch="refs/heads/develop",
|
||||||
|
interval_seconds=30,
|
||||||
|
schedule_fn=MagicMock(),
|
||||||
|
sleep_fn=fake_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert azdo.list_active_prs.call_count == 2
|
||||||
|
|
||||||
|
async def test_calls_find_unprocessed_prs(self) -> None:
|
||||||
|
pr = _make_pr(pr_id="10")
|
||||||
|
azdo = AsyncMock()
|
||||||
|
azdo.list_active_prs = AsyncMock(return_value=[pr])
|
||||||
|
|
||||||
|
find_mock = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
async def fake_sleep(seconds: float) -> None:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=find_mock):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await run_pr_poll_loop(
|
||||||
|
azdo_client=azdo,
|
||||||
|
db_pool=MagicMock(),
|
||||||
|
watched_repos=["my-repo"],
|
||||||
|
target_branch="refs/heads/develop",
|
||||||
|
interval_seconds=30,
|
||||||
|
schedule_fn=MagicMock(),
|
||||||
|
sleep_fn=fake_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
find_mock.assert_called_once()
|
||||||
|
|
||||||
|
async def test_schedules_graph_for_each_unprocessed_pr(self) -> None:
|
||||||
|
pr1 = _make_pr(pr_id="10")
|
||||||
|
pr2 = _make_pr(pr_id="20")
|
||||||
|
azdo = AsyncMock()
|
||||||
|
azdo.list_active_prs = AsyncMock(return_value=[pr1, pr2])
|
||||||
|
|
||||||
|
schedule_mock = MagicMock()
|
||||||
|
|
||||||
|
async def fake_sleep(seconds: float) -> None:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"release_agent.services.pr_poller.find_unprocessed_prs",
|
||||||
|
new=AsyncMock(return_value=[pr1, pr2]),
|
||||||
|
):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await run_pr_poll_loop(
|
||||||
|
azdo_client=azdo,
|
||||||
|
db_pool=MagicMock(),
|
||||||
|
watched_repos=["my-repo"],
|
||||||
|
target_branch="refs/heads/develop",
|
||||||
|
interval_seconds=30,
|
||||||
|
schedule_fn=schedule_mock,
|
||||||
|
sleep_fn=fake_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert schedule_mock.call_count == 2
|
||||||
|
|
||||||
|
async def test_does_not_schedule_already_processed_prs(self) -> None:
|
||||||
|
pr = _make_pr(pr_id="10")
|
||||||
|
azdo = AsyncMock()
|
||||||
|
azdo.list_active_prs = AsyncMock(return_value=[pr])
|
||||||
|
|
||||||
|
schedule_mock = MagicMock()
|
||||||
|
|
||||||
|
async def fake_sleep(seconds: float) -> None:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
# All PRs already processed
|
||||||
|
with patch(
|
||||||
|
"release_agent.services.pr_poller.find_unprocessed_prs",
|
||||||
|
new=AsyncMock(return_value=[]),
|
||||||
|
):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await run_pr_poll_loop(
|
||||||
|
azdo_client=azdo,
|
||||||
|
db_pool=MagicMock(),
|
||||||
|
watched_repos=["my-repo"],
|
||||||
|
target_branch="refs/heads/develop",
|
||||||
|
interval_seconds=30,
|
||||||
|
schedule_fn=schedule_mock,
|
||||||
|
sleep_fn=fake_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
schedule_mock.assert_not_called()
|
||||||
|
|
||||||
|
async def test_sleeps_for_configured_interval(self) -> None:
|
||||||
|
azdo = AsyncMock()
|
||||||
|
azdo.list_active_prs = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
sleep_calls: list[float] = []
|
||||||
|
|
||||||
|
async def fake_sleep(seconds: float) -> None:
|
||||||
|
sleep_calls.append(seconds)
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await run_pr_poll_loop(
|
||||||
|
azdo_client=azdo,
|
||||||
|
db_pool=MagicMock(),
|
||||||
|
watched_repos=["my-repo"],
|
||||||
|
target_branch="refs/heads/develop",
|
||||||
|
interval_seconds=123,
|
||||||
|
schedule_fn=MagicMock(),
|
||||||
|
sleep_fn=fake_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sleep_calls[0] == 123
|
||||||
|
|
||||||
|
async def test_handles_empty_watched_repos(self) -> None:
|
||||||
|
azdo = AsyncMock()
|
||||||
|
|
||||||
|
async def fake_sleep(seconds: float) -> None:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await run_pr_poll_loop(
|
||||||
|
azdo_client=azdo,
|
||||||
|
db_pool=MagicMock(),
|
||||||
|
watched_repos=[],
|
||||||
|
target_branch="refs/heads/develop",
|
||||||
|
interval_seconds=30,
|
||||||
|
schedule_fn=MagicMock(),
|
||||||
|
sleep_fn=fake_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
azdo.list_active_prs.assert_not_called()
|
||||||
|
|
||||||
|
async def test_schedule_fn_receives_synthesized_payload(self) -> None:
|
||||||
|
pr = _make_pr(pr_id="55", repo_name="test-repo")
|
||||||
|
azdo = AsyncMock()
|
||||||
|
azdo.list_active_prs = AsyncMock(return_value=[pr])
|
||||||
|
|
||||||
|
schedule_calls: list[dict] = []
|
||||||
|
|
||||||
|
def schedule_mock(**kwargs) -> None:
|
||||||
|
schedule_calls.append(kwargs)
|
||||||
|
|
||||||
|
async def fake_sleep(seconds: float) -> None:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"release_agent.services.pr_poller.find_unprocessed_prs",
|
||||||
|
new=AsyncMock(return_value=[pr]),
|
||||||
|
):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await run_pr_poll_loop(
|
||||||
|
azdo_client=azdo,
|
||||||
|
db_pool=MagicMock(),
|
||||||
|
watched_repos=["test-repo"],
|
||||||
|
target_branch="refs/heads/develop",
|
||||||
|
interval_seconds=30,
|
||||||
|
schedule_fn=schedule_mock,
|
||||||
|
sleep_fn=fake_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(schedule_calls) == 1
|
||||||
|
initial_state = schedule_calls[0]["initial_state"]
|
||||||
|
assert initial_state["webhook_payload"]["resource"]["pull_request_id"] == 55
|
||||||
|
assert initial_state["pr_id"] == "55"
|
||||||
|
assert initial_state["repo_name"] == "test-repo"
|
||||||
|
|
||||||
|
async def test_continues_after_list_active_prs_error(self) -> None:
|
||||||
|
azdo = AsyncMock()
|
||||||
|
# First repo raises, second succeeds
|
||||||
|
azdo.list_active_prs = AsyncMock(side_effect=[Exception("API error"), []])
|
||||||
|
|
||||||
|
sleep_calls: list[float] = []
|
||||||
|
|
||||||
|
async def fake_sleep(seconds: float) -> None:
|
||||||
|
sleep_calls.append(seconds)
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await run_pr_poll_loop(
|
||||||
|
azdo_client=azdo,
|
||||||
|
db_pool=MagicMock(),
|
||||||
|
watched_repos=["repo-a", "repo-b"],
|
||||||
|
target_branch="refs/heads/develop",
|
||||||
|
interval_seconds=30,
|
||||||
|
schedule_fn=MagicMock(),
|
||||||
|
sleep_fn=fake_sleep,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still sleep (loop iteration completed despite error)
|
||||||
|
assert len(sleep_calls) == 1
|
||||||
122
tests/test_branch_parser.py
Normal file
122
tests/test_branch_parser.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Tests for branch_parser module. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
|
||||||
|
from release_agent.branch_parser import parse_branch, strip_refs_prefix
|
||||||
|
|
||||||
|
|
||||||
|
class TestStripRefsPrefix:
|
||||||
|
"""Tests for strip_refs_prefix function."""
|
||||||
|
|
||||||
|
def test_strips_refs_heads_prefix(self) -> None:
|
||||||
|
assert strip_refs_prefix("refs/heads/fix/BILL-42_something") == "fix/BILL-42_something"
|
||||||
|
|
||||||
|
def test_strips_refs_heads_prefix_feature(self) -> None:
|
||||||
|
assert strip_refs_prefix("refs/heads/feature/ALLPOST-100_add-feature") == "feature/ALLPOST-100_add-feature"
|
||||||
|
|
||||||
|
def test_no_refs_prefix_unchanged(self) -> None:
|
||||||
|
assert strip_refs_prefix("bug/ALLPOST-4229_fix-review") == "bug/ALLPOST-4229_fix-review"
|
||||||
|
|
||||||
|
def test_main_unchanged(self) -> None:
|
||||||
|
assert strip_refs_prefix("main") == "main"
|
||||||
|
|
||||||
|
def test_develop_unchanged(self) -> None:
|
||||||
|
assert strip_refs_prefix("develop") == "develop"
|
||||||
|
|
||||||
|
def test_empty_string(self) -> None:
|
||||||
|
assert strip_refs_prefix("") == ""
|
||||||
|
|
||||||
|
def test_only_refs_heads(self) -> None:
|
||||||
|
assert strip_refs_prefix("refs/heads/") == ""
|
||||||
|
|
||||||
|
def test_refs_tags_not_stripped(self) -> None:
|
||||||
|
assert strip_refs_prefix("refs/tags/v1.0.0") == "refs/tags/v1.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseBranch:
|
||||||
|
"""Tests for parse_branch function."""
|
||||||
|
|
||||||
|
def test_bug_branch_with_ticket(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("bug/ALLPOST-4229_fix-review")
|
||||||
|
assert ticket_id == "ALLPOST-4229"
|
||||||
|
assert has_ticket is True
|
||||||
|
|
||||||
|
def test_feature_branch_with_ticket(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("feature/ALLPOST-100_add-feature")
|
||||||
|
assert ticket_id == "ALLPOST-100"
|
||||||
|
assert has_ticket is True
|
||||||
|
|
||||||
|
def test_refs_heads_fix_branch(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("refs/heads/fix/BILL-42_something")
|
||||||
|
assert ticket_id == "BILL-42"
|
||||||
|
assert has_ticket is True
|
||||||
|
|
||||||
|
def test_feat_branch_short(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("feat/MY-1_x")
|
||||||
|
assert ticket_id == "MY-1"
|
||||||
|
assert has_ticket is True
|
||||||
|
|
||||||
|
def test_chore_without_ticket(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("chore/update-dependencies")
|
||||||
|
assert ticket_id is None
|
||||||
|
assert has_ticket is False
|
||||||
|
|
||||||
|
def test_main_branch(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("main")
|
||||||
|
assert ticket_id is None
|
||||||
|
assert has_ticket is False
|
||||||
|
|
||||||
|
def test_develop_branch(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("develop")
|
||||||
|
assert ticket_id is None
|
||||||
|
assert has_ticket is False
|
||||||
|
|
||||||
|
def test_release_branch(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("release/v1.0.3")
|
||||||
|
assert ticket_id is None
|
||||||
|
assert has_ticket is False
|
||||||
|
|
||||||
|
def test_returns_tuple(self) -> None:
|
||||||
|
result = parse_branch("main")
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_ticket_id_type_when_present(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("bug/ALLPOST-4229_fix-review")
|
||||||
|
assert isinstance(ticket_id, str)
|
||||||
|
assert isinstance(has_ticket, bool)
|
||||||
|
|
||||||
|
def test_ticket_id_type_when_absent(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("main")
|
||||||
|
assert ticket_id is None
|
||||||
|
assert isinstance(has_ticket, bool)
|
||||||
|
|
||||||
|
def test_fix_prefix(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("fix/PROJ-999_some-fix")
|
||||||
|
assert ticket_id == "PROJ-999"
|
||||||
|
assert has_ticket is True
|
||||||
|
|
||||||
|
def test_refs_heads_feature_branch(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("refs/heads/feature/ALLPOST-100_add-feature")
|
||||||
|
assert ticket_id == "ALLPOST-100"
|
||||||
|
assert has_ticket is True
|
||||||
|
|
||||||
|
def test_ticket_with_multiple_digits(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("feature/ABC-12345_some-long-feature")
|
||||||
|
assert ticket_id == "ABC-12345"
|
||||||
|
assert has_ticket is True
|
||||||
|
|
||||||
|
def test_branch_without_underscore_separator(self) -> None:
|
||||||
|
# Branch has ticket pattern but no underscore - still detects ticket
|
||||||
|
ticket_id, has_ticket = parse_branch("feature/PROJ-100")
|
||||||
|
assert ticket_id == "PROJ-100"
|
||||||
|
assert has_ticket is True
|
||||||
|
|
||||||
|
def test_empty_string(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("")
|
||||||
|
assert ticket_id is None
|
||||||
|
assert has_ticket is False
|
||||||
|
|
||||||
|
def test_ticket_with_numeric_project_prefix(self) -> None:
|
||||||
|
ticket_id, has_ticket = parse_branch("feature/AB2-100_feature")
|
||||||
|
assert ticket_id == "AB2-100"
|
||||||
|
assert has_ticket is True
|
||||||
450
tests/test_config.py
Normal file
450
tests/test_config.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
"""Tests for config module. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr, ValidationError
|
||||||
|
|
||||||
|
from release_agent.config import Settings
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _base_env() -> dict[str, str]:
|
||||||
|
"""Return minimal valid environment variables."""
|
||||||
|
return {
|
||||||
|
"AZDO_ORGANIZATION": "my-org",
|
||||||
|
"AZDO_PROJECT": "my-project",
|
||||||
|
"AZDO_PAT": "super-secret-pat",
|
||||||
|
"ANTHROPIC_API_KEY": "sk-ant-key",
|
||||||
|
"POSTGRES_DSN": "postgresql://user:pass@localhost:5432/db",
|
||||||
|
"JIRA_EMAIL": "user@example.com",
|
||||||
|
"JIRA_API_TOKEN": "jira-token-abc",
|
||||||
|
"SLACK_WEBHOOK_URL": "https://hooks.slack.com/services/T000/B000/xxxx",
|
||||||
|
"WEBHOOK_SECRET": "test-webhook-secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Settings tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSettings:
|
||||||
|
"""Tests for Settings config class."""
|
||||||
|
|
||||||
|
def test_loads_from_env_vars(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.azdo_organization == "my-org"
|
||||||
|
assert settings.azdo_project == "my-project"
|
||||||
|
|
||||||
|
def test_pat_is_secret_str(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert isinstance(settings.azdo_pat, SecretStr)
|
||||||
|
|
||||||
|
def test_anthropic_key_is_secret_str(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert isinstance(settings.anthropic_api_key, SecretStr)
|
||||||
|
|
||||||
|
def test_secret_str_not_leaked_in_repr(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
repr_str = repr(settings)
|
||||||
|
assert "super-secret-pat" not in repr_str
|
||||||
|
assert "sk-ant-key" not in repr_str
|
||||||
|
|
||||||
|
def test_secret_str_not_leaked_in_str(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
str_repr = str(settings)
|
||||||
|
assert "super-secret-pat" not in str_repr
|
||||||
|
|
||||||
|
def test_missing_required_azdo_org_raises(self) -> None:
|
||||||
|
env = _base_env()
|
||||||
|
del env["AZDO_ORGANIZATION"]
|
||||||
|
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_missing_required_azdo_project_raises(self) -> None:
|
||||||
|
env = _base_env()
|
||||||
|
del env["AZDO_PROJECT"]
|
||||||
|
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_missing_required_pat_raises(self) -> None:
|
||||||
|
env = _base_env()
|
||||||
|
del env["AZDO_PAT"]
|
||||||
|
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_missing_anthropic_key_is_optional(self) -> None:
|
||||||
|
env = _base_env()
|
||||||
|
del env["ANTHROPIC_API_KEY"]
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.anthropic_api_key.get_secret_value() == ""
|
||||||
|
|
||||||
|
def test_missing_postgres_dsn_raises(self) -> None:
|
||||||
|
env = _base_env()
|
||||||
|
del env["POSTGRES_DSN"]
|
||||||
|
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_azdo_base_url_computed(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
expected = "https://dev.azure.com/my-org"
|
||||||
|
assert settings.azdo_base_url == expected
|
||||||
|
|
||||||
|
def test_azdo_api_url_computed(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
expected = "https://dev.azure.com/my-org/my-project/_apis"
|
||||||
|
assert settings.azdo_api_url == expected
|
||||||
|
|
||||||
|
def test_default_port(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.port == 8000
|
||||||
|
|
||||||
|
def test_custom_port_from_env(self) -> None:
|
||||||
|
env = {**_base_env(), "PORT": "9000"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.port == 9000
|
||||||
|
|
||||||
|
def test_port_below_minimum_raises(self) -> None:
|
||||||
|
env = {**_base_env(), "PORT": "0"}
|
||||||
|
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_port_above_maximum_raises(self) -> None:
|
||||||
|
env = {**_base_env(), "PORT": "65536"}
|
||||||
|
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_port_minimum_valid(self) -> None:
|
||||||
|
env = {**_base_env(), "PORT": "1"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.port == 1
|
||||||
|
|
||||||
|
def test_port_maximum_valid(self) -> None:
|
||||||
|
env = {**_base_env(), "PORT": "65535"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.port == 65535
|
||||||
|
|
||||||
|
def test_get_pat_value(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.azdo_pat.get_secret_value() == "super-secret-pat"
|
||||||
|
|
||||||
|
def test_get_anthropic_key_value(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.anthropic_api_key.get_secret_value() == "sk-ant-key"
|
||||||
|
|
||||||
|
def test_postgres_dsn_is_secret_str(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert isinstance(settings.postgres_dsn, SecretStr)
|
||||||
|
assert "localhost" in settings.postgres_dsn.get_secret_value()
|
||||||
|
|
||||||
|
def test_postgres_dsn_not_leaked_in_repr(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert "user:pass" not in repr(settings)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsPhase2:
|
||||||
|
"""Tests for Phase 2 settings fields."""
|
||||||
|
|
||||||
|
def test_jira_base_url_default(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.jira_base_url == "https://billolife.atlassian.net"
|
||||||
|
|
||||||
|
def test_jira_base_url_custom(self) -> None:
|
||||||
|
env = {**_base_env(), "JIRA_BASE_URL": "https://custom.atlassian.net"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.jira_base_url == "https://custom.atlassian.net"
|
||||||
|
|
||||||
|
def test_jira_email_required(self) -> None:
|
||||||
|
env = _base_env()
|
||||||
|
del env["JIRA_EMAIL"]
|
||||||
|
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_jira_email_stored(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.jira_email == "user@example.com"
|
||||||
|
|
||||||
|
def test_jira_api_token_required(self) -> None:
|
||||||
|
env = _base_env()
|
||||||
|
del env["JIRA_API_TOKEN"]
|
||||||
|
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
def test_jira_api_token_is_secret_str(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert isinstance(settings.jira_api_token, SecretStr)
|
||||||
|
|
||||||
|
def test_jira_api_token_not_leaked_in_repr(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert "jira-token-abc" not in repr(settings)
|
||||||
|
|
||||||
|
def test_slack_webhook_url_optional_defaults_empty(self) -> None:
|
||||||
|
env = _base_env()
|
||||||
|
del env["SLACK_WEBHOOK_URL"]
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.slack_webhook_url.get_secret_value() == ""
|
||||||
|
|
||||||
|
def test_slack_webhook_url_is_secret_str(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert isinstance(settings.slack_webhook_url, SecretStr)
|
||||||
|
|
||||||
|
def test_slack_webhook_url_not_leaked_in_repr(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert "xxxx" not in repr(settings)
|
||||||
|
|
||||||
|
def test_claude_review_model_default(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.claude_review_model == "claude-sonnet-4-20250514"
|
||||||
|
|
||||||
|
def test_claude_review_model_custom(self) -> None:
|
||||||
|
env = {**_base_env(), "CLAUDE_REVIEW_MODEL": "claude-opus-4-20250514"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.claude_review_model == "claude-opus-4-20250514"
|
||||||
|
|
||||||
|
def test_azdo_vsrm_api_url_computed(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
expected = "https://vsrm.dev.azure.com/my-org/my-project/_apis"
|
||||||
|
assert settings.azdo_vsrm_api_url == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsPhase4:
|
||||||
|
"""Tests for Phase 4 settings fields (webhook secret)."""
|
||||||
|
|
||||||
|
def test_webhook_secret_optional_defaults_empty(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
# When not provided, defaults to empty string or None
|
||||||
|
assert settings.webhook_secret is not None or settings.webhook_secret == ""
|
||||||
|
|
||||||
|
def test_webhook_secret_custom_value(self) -> None:
|
||||||
|
env = {**_base_env(), "WEBHOOK_SECRET": "my-super-secret"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.webhook_secret.get_secret_value() == "my-super-secret"
|
||||||
|
|
||||||
|
def test_webhook_secret_is_secret_str(self) -> None:
|
||||||
|
env = {**_base_env(), "WEBHOOK_SECRET": "secret-value"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert isinstance(settings.webhook_secret, SecretStr)
|
||||||
|
|
||||||
|
def test_webhook_secret_not_leaked_in_repr(self) -> None:
|
||||||
|
env = {**_base_env(), "WEBHOOK_SECRET": "super-private-secret"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert "super-private-secret" not in repr(settings)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsPhase5:
|
||||||
|
"""Tests for Phase 5 settings fields (Slack Web API + CI polling)."""
|
||||||
|
|
||||||
|
def test_slack_webhook_url_optional_when_bot_token_provided(self) -> None:
|
||||||
|
env = {k: v for k, v in _base_env().items() if k != "SLACK_WEBHOOK_URL"}
|
||||||
|
env["SLACK_BOT_TOKEN"] = "xoxb-test-token"
|
||||||
|
env["SLACK_CHANNEL_ID"] = "C12345678"
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.slack_bot_token is not None
|
||||||
|
assert settings.slack_bot_token.get_secret_value() == "xoxb-test-token"
|
||||||
|
|
||||||
|
def test_slack_bot_token_optional_defaults_empty(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.slack_bot_token.get_secret_value() == ""
|
||||||
|
|
||||||
|
def test_slack_bot_token_is_secret_str(self) -> None:
|
||||||
|
env = {**_base_env(), "SLACK_BOT_TOKEN": "xoxb-abc-123"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert isinstance(settings.slack_bot_token, SecretStr)
|
||||||
|
|
||||||
|
def test_slack_bot_token_not_leaked_in_repr(self) -> None:
|
||||||
|
env = {**_base_env(), "SLACK_BOT_TOKEN": "xoxb-super-secret"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert "xoxb-super-secret" not in repr(settings)
|
||||||
|
|
||||||
|
def test_slack_signing_secret_optional_defaults_empty(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.slack_signing_secret.get_secret_value() == ""
|
||||||
|
|
||||||
|
def test_slack_signing_secret_custom_value(self) -> None:
|
||||||
|
env = {**_base_env(), "SLACK_SIGNING_SECRET": "signing-secret-xyz"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.slack_signing_secret.get_secret_value() == "signing-secret-xyz"
|
||||||
|
|
||||||
|
def test_slack_signing_secret_is_secret_str(self) -> None:
|
||||||
|
env = {**_base_env(), "SLACK_SIGNING_SECRET": "some-secret"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert isinstance(settings.slack_signing_secret, SecretStr)
|
||||||
|
|
||||||
|
def test_slack_signing_secret_not_leaked_in_repr(self) -> None:
|
||||||
|
env = {**_base_env(), "SLACK_SIGNING_SECRET": "private-signing-secret"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert "private-signing-secret" not in repr(settings)
|
||||||
|
|
||||||
|
def test_slack_channel_id_optional_defaults_empty(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.slack_channel_id == ""
|
||||||
|
|
||||||
|
def test_slack_channel_id_custom_value(self) -> None:
|
||||||
|
env = {**_base_env(), "SLACK_CHANNEL_ID": "C0987654321"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.slack_channel_id == "C0987654321"
|
||||||
|
|
||||||
|
def test_ci_poll_interval_seconds_default(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.ci_poll_interval_seconds == 30
|
||||||
|
|
||||||
|
def test_ci_poll_interval_seconds_custom(self) -> None:
|
||||||
|
env = {**_base_env(), "CI_POLL_INTERVAL_SECONDS": "60"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.ci_poll_interval_seconds == 60
|
||||||
|
|
||||||
|
def test_ci_poll_max_wait_seconds_default(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.ci_poll_max_wait_seconds == 1800
|
||||||
|
|
||||||
|
def test_ci_poll_max_wait_seconds_custom(self) -> None:
|
||||||
|
env = {**_base_env(), "CI_POLL_MAX_WAIT_SECONDS": "3600"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.ci_poll_max_wait_seconds == 3600
|
||||||
|
|
||||||
|
def test_slack_webhook_url_still_optional(self) -> None:
|
||||||
|
env = {k: v for k, v in _base_env().items() if k != "SLACK_WEBHOOK_URL"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.slack_webhook_url.get_secret_value() == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsPrPolling:
|
||||||
|
"""Tests for PR polling config fields (Step 1)."""
|
||||||
|
|
||||||
|
def test_watched_repos_defaults_empty(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.watched_repos == ""
|
||||||
|
|
||||||
|
def test_watched_repos_custom_value(self) -> None:
|
||||||
|
env = {**_base_env(), "WATCHED_REPOS": "repo-a,repo-b"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.watched_repos == "repo-a,repo-b"
|
||||||
|
|
||||||
|
def test_watched_repos_list_empty_when_blank(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.watched_repos_list == []
|
||||||
|
|
||||||
|
def test_watched_repos_list_splits_comma_separated(self) -> None:
|
||||||
|
env = {**_base_env(), "WATCHED_REPOS": "repo-a,repo-b,repo-c"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.watched_repos_list == ["repo-a", "repo-b", "repo-c"]
|
||||||
|
|
||||||
|
def test_watched_repos_list_strips_whitespace(self) -> None:
|
||||||
|
env = {**_base_env(), "WATCHED_REPOS": " repo-a , repo-b "}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.watched_repos_list == ["repo-a", "repo-b"]
|
||||||
|
|
||||||
|
def test_watched_repos_list_ignores_empty_entries(self) -> None:
|
||||||
|
env = {**_base_env(), "WATCHED_REPOS": "repo-a,,repo-b"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.watched_repos_list == ["repo-a", "repo-b"]
|
||||||
|
|
||||||
|
def test_pr_poll_interval_seconds_default(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.pr_poll_interval_seconds == 300
|
||||||
|
|
||||||
|
def test_pr_poll_interval_seconds_custom(self) -> None:
|
||||||
|
env = {**_base_env(), "PR_POLL_INTERVAL_SECONDS": "60"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.pr_poll_interval_seconds == 60
|
||||||
|
|
||||||
|
def test_pr_poll_target_branch_default(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.pr_poll_target_branch == "refs/heads/develop"
|
||||||
|
|
||||||
|
def test_pr_poll_target_branch_custom(self) -> None:
|
||||||
|
env = {**_base_env(), "PR_POLL_TARGET_BRANCH": "refs/heads/main"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.pr_poll_target_branch == "refs/heads/main"
|
||||||
|
|
||||||
|
def test_pr_poll_enabled_defaults_false(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.pr_poll_enabled is False
|
||||||
|
|
||||||
|
def test_pr_poll_enabled_true_from_env(self) -> None:
|
||||||
|
env = {**_base_env(), "PR_POLL_ENABLED": "true"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.pr_poll_enabled is True
|
||||||
|
|
||||||
|
def test_default_jira_project_default(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.default_jira_project == "ALLPOST"
|
||||||
|
|
||||||
|
def test_default_jira_project_custom(self) -> None:
|
||||||
|
env = {**_base_env(), "DEFAULT_JIRA_PROJECT": "MYPROJ"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.default_jira_project == "MYPROJ"
|
||||||
|
|
||||||
|
def test_auto_create_ticket_enabled_defaults_true(self) -> None:
|
||||||
|
with patch.dict(os.environ, _base_env(), clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.auto_create_ticket_enabled is True
|
||||||
|
|
||||||
|
def test_auto_create_ticket_enabled_false_from_env(self) -> None:
|
||||||
|
env = {**_base_env(), "AUTO_CREATE_TICKET_ENABLED": "false"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.auto_create_ticket_enabled is False
|
||||||
198
tests/test_exceptions.py
Normal file
198
tests/test_exceptions.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""Tests for custom exception hierarchy. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.exceptions import (
|
||||||
|
AuthenticationError,
|
||||||
|
NotFoundError,
|
||||||
|
RateLimitError,
|
||||||
|
ReleaseAgentError,
|
||||||
|
ServiceError,
|
||||||
|
ServiceUnavailableError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestReleaseAgentError:
|
||||||
|
"""Tests for the base exception class."""
|
||||||
|
|
||||||
|
def test_is_exception(self) -> None:
|
||||||
|
err = ReleaseAgentError("something went wrong")
|
||||||
|
assert isinstance(err, Exception)
|
||||||
|
|
||||||
|
def test_message_stored(self) -> None:
|
||||||
|
err = ReleaseAgentError("something went wrong")
|
||||||
|
assert str(err) == "something went wrong"
|
||||||
|
|
||||||
|
def test_can_be_raised(self) -> None:
|
||||||
|
with pytest.raises(ReleaseAgentError):
|
||||||
|
raise ReleaseAgentError("boom")
|
||||||
|
|
||||||
|
|
||||||
|
class TestServiceError:
|
||||||
|
"""Tests for ServiceError with service name and status code."""
|
||||||
|
|
||||||
|
def test_is_release_agent_error(self) -> None:
|
||||||
|
err = ServiceError(service="jira", status_code=500, detail="Internal error")
|
||||||
|
assert isinstance(err, ReleaseAgentError)
|
||||||
|
|
||||||
|
def test_stores_service(self) -> None:
|
||||||
|
err = ServiceError(service="jira", status_code=500, detail="Internal error")
|
||||||
|
assert err.service == "jira"
|
||||||
|
|
||||||
|
def test_stores_status_code(self) -> None:
|
||||||
|
err = ServiceError(service="azdo", status_code=422, detail="Unprocessable")
|
||||||
|
assert err.status_code == 422
|
||||||
|
|
||||||
|
def test_stores_detail(self) -> None:
|
||||||
|
err = ServiceError(service="slack", status_code=400, detail="Bad payload")
|
||||||
|
assert err.detail == "Bad payload"
|
||||||
|
|
||||||
|
def test_str_includes_service_and_status(self) -> None:
|
||||||
|
err = ServiceError(service="jira", status_code=500, detail="Server error")
|
||||||
|
text = str(err)
|
||||||
|
assert "jira" in text
|
||||||
|
assert "500" in text
|
||||||
|
|
||||||
|
def test_can_be_raised(self) -> None:
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
raise ServiceError(service="azdo", status_code=400, detail="bad request")
|
||||||
|
|
||||||
|
def test_detail_none_allowed(self) -> None:
|
||||||
|
err = ServiceError(service="jira", status_code=404, detail=None)
|
||||||
|
assert err.detail is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthenticationError:
|
||||||
|
"""Tests for AuthenticationError (401/403)."""
|
||||||
|
|
||||||
|
def test_is_service_error(self) -> None:
|
||||||
|
err = AuthenticationError(service="azdo")
|
||||||
|
assert isinstance(err, ServiceError)
|
||||||
|
|
||||||
|
def test_default_status_code_401(self) -> None:
|
||||||
|
err = AuthenticationError(service="azdo")
|
||||||
|
assert err.status_code == 401
|
||||||
|
|
||||||
|
def test_service_stored(self) -> None:
|
||||||
|
err = AuthenticationError(service="jira")
|
||||||
|
assert err.service == "jira"
|
||||||
|
|
||||||
|
def test_custom_status_code(self) -> None:
|
||||||
|
err = AuthenticationError(service="azdo", status_code=403)
|
||||||
|
assert err.status_code == 403
|
||||||
|
|
||||||
|
def test_str_contains_service(self) -> None:
|
||||||
|
err = AuthenticationError(service="slack")
|
||||||
|
assert "slack" in str(err)
|
||||||
|
|
||||||
|
def test_can_be_raised(self) -> None:
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
raise AuthenticationError(service="azdo")
|
||||||
|
|
||||||
|
def test_caught_as_service_error(self) -> None:
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
raise AuthenticationError(service="azdo")
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotFoundError:
|
||||||
|
"""Tests for NotFoundError (404)."""
|
||||||
|
|
||||||
|
def test_is_service_error(self) -> None:
|
||||||
|
err = NotFoundError(service="azdo", detail="PR not found")
|
||||||
|
assert isinstance(err, ServiceError)
|
||||||
|
|
||||||
|
def test_status_code_is_404(self) -> None:
|
||||||
|
err = NotFoundError(service="jira", detail="Issue not found")
|
||||||
|
assert err.status_code == 404
|
||||||
|
|
||||||
|
def test_detail_stored(self) -> None:
|
||||||
|
err = NotFoundError(service="azdo", detail="PR 999 not found")
|
||||||
|
assert "PR 999" in err.detail
|
||||||
|
|
||||||
|
def test_can_be_raised(self) -> None:
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
raise NotFoundError(service="azdo", detail="not found")
|
||||||
|
|
||||||
|
def test_caught_as_release_agent_error(self) -> None:
|
||||||
|
with pytest.raises(ReleaseAgentError):
|
||||||
|
raise NotFoundError(service="jira", detail="issue missing")
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitError:
|
||||||
|
"""Tests for RateLimitError (429) with retry_after."""
|
||||||
|
|
||||||
|
def test_is_service_error(self) -> None:
|
||||||
|
err = RateLimitError(service="jira", retry_after=30)
|
||||||
|
assert isinstance(err, ServiceError)
|
||||||
|
|
||||||
|
def test_status_code_is_429(self) -> None:
|
||||||
|
err = RateLimitError(service="jira", retry_after=30)
|
||||||
|
assert err.status_code == 429
|
||||||
|
|
||||||
|
def test_stores_retry_after(self) -> None:
|
||||||
|
err = RateLimitError(service="slack", retry_after=60)
|
||||||
|
assert err.retry_after == 60
|
||||||
|
|
||||||
|
def test_retry_after_none_allowed(self) -> None:
|
||||||
|
err = RateLimitError(service="azdo", retry_after=None)
|
||||||
|
assert err.retry_after is None
|
||||||
|
|
||||||
|
def test_str_contains_service(self) -> None:
|
||||||
|
err = RateLimitError(service="jira", retry_after=5)
|
||||||
|
assert "jira" in str(err)
|
||||||
|
|
||||||
|
def test_can_be_raised(self) -> None:
|
||||||
|
with pytest.raises(RateLimitError):
|
||||||
|
raise RateLimitError(service="jira", retry_after=30)
|
||||||
|
|
||||||
|
|
||||||
|
class TestServiceUnavailableError:
|
||||||
|
"""Tests for ServiceUnavailableError (503)."""
|
||||||
|
|
||||||
|
def test_is_service_error(self) -> None:
|
||||||
|
err = ServiceUnavailableError(service="azdo")
|
||||||
|
assert isinstance(err, ServiceError)
|
||||||
|
|
||||||
|
def test_status_code_is_503(self) -> None:
|
||||||
|
err = ServiceUnavailableError(service="azdo")
|
||||||
|
assert err.status_code == 503
|
||||||
|
|
||||||
|
def test_service_stored(self) -> None:
|
||||||
|
err = ServiceUnavailableError(service="slack")
|
||||||
|
assert err.service == "slack"
|
||||||
|
|
||||||
|
def test_custom_detail(self) -> None:
|
||||||
|
err = ServiceUnavailableError(service="azdo", detail="Maintenance window")
|
||||||
|
assert "Maintenance" in err.detail
|
||||||
|
|
||||||
|
def test_can_be_raised(self) -> None:
|
||||||
|
with pytest.raises(ServiceUnavailableError):
|
||||||
|
raise ServiceUnavailableError(service="azdo")
|
||||||
|
|
||||||
|
def test_caught_as_service_error(self) -> None:
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
raise ServiceUnavailableError(service="jira")
|
||||||
|
|
||||||
|
|
||||||
|
class TestExceptionHierarchyInheritance:
|
||||||
|
"""Tests verifying the full exception hierarchy is correct."""
|
||||||
|
|
||||||
|
def test_all_are_release_agent_errors(self) -> None:
|
||||||
|
errors = [
|
||||||
|
AuthenticationError(service="svc"),
|
||||||
|
NotFoundError(service="svc", detail="x"),
|
||||||
|
RateLimitError(service="svc", retry_after=1),
|
||||||
|
ServiceUnavailableError(service="svc"),
|
||||||
|
]
|
||||||
|
for err in errors:
|
||||||
|
assert isinstance(err, ReleaseAgentError), f"{type(err)} not ReleaseAgentError"
|
||||||
|
|
||||||
|
def test_all_are_service_errors(self) -> None:
|
||||||
|
errors = [
|
||||||
|
AuthenticationError(service="svc"),
|
||||||
|
NotFoundError(service="svc", detail="x"),
|
||||||
|
RateLimitError(service="svc", retry_after=1),
|
||||||
|
ServiceUnavailableError(service="svc"),
|
||||||
|
]
|
||||||
|
for err in errors:
|
||||||
|
assert isinstance(err, ServiceError), f"{type(err)} not ServiceError"
|
||||||
666
tests/test_main.py
Normal file
666
tests/test_main.py
Normal file
@@ -0,0 +1,666 @@
|
|||||||
|
"""Tests for main FastAPI application. Written FIRST (TDD RED phase).
|
||||||
|
|
||||||
|
Heavy startup (PostgreSQL, httpx clients, graph compilation) is mocked.
|
||||||
|
Tests verify: routes registered, lifespan hooks, exception handlers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers / fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_mock_settings():
|
||||||
|
s = MagicMock()
|
||||||
|
s.webhook_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
s.postgres_dsn.get_secret_value.return_value = "postgresql://u:p@localhost/db"
|
||||||
|
s.azdo_pat.get_secret_value.return_value = "pat"
|
||||||
|
s.anthropic_api_key.get_secret_value.return_value = "key"
|
||||||
|
s.jira_api_token.get_secret_value.return_value = "jira"
|
||||||
|
s.slack_webhook_url.get_secret_value.return_value = "https://hooks.slack.com/x"
|
||||||
|
s.slack_bot_token.get_secret_value.return_value = ""
|
||||||
|
s.slack_channel_id = ""
|
||||||
|
s.slack_signing_secret.get_secret_value.return_value = ""
|
||||||
|
s.port = 8000
|
||||||
|
s.pr_poll_enabled = False
|
||||||
|
s.pr_poll_interval_seconds = 300
|
||||||
|
s.pr_poll_target_branch = "refs/heads/develop"
|
||||||
|
s.watched_repos_list = []
|
||||||
|
s.default_jira_project = "ALLPOST"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def _make_patched_app():
|
||||||
|
"""Return the FastAPI app with all heavy startup mocked."""
|
||||||
|
mock_settings = _make_mock_settings()
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
|
||||||
|
mock_pool.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
mock_graphs = {
|
||||||
|
"pr_completed": MagicMock(),
|
||||||
|
"release": MagicMock(),
|
||||||
|
}
|
||||||
|
mock_clients = MagicMock()
|
||||||
|
mock_staging_store = MagicMock()
|
||||||
|
|
||||||
|
patches = [
|
||||||
|
patch("release_agent.main.Settings", return_value=mock_settings),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=mock_clients),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=mock_staging_store),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
]
|
||||||
|
|
||||||
|
for p in patches:
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
from release_agent.main import create_app
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
for p in patches:
|
||||||
|
p.stop()
|
||||||
|
|
||||||
|
return app, mock_settings, mock_pool, mock_graphs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Route registration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRouteRegistration:
|
||||||
|
def test_webhook_route_registered(self) -> None:
|
||||||
|
from release_agent.main import create_app
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
|
||||||
|
assert "/webhooks/azdo" in routes
|
||||||
|
|
||||||
|
def test_approvals_routes_registered(self) -> None:
|
||||||
|
from release_agent.main import create_app
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
|
||||||
|
assert "/approvals/pending" in routes
|
||||||
|
assert "/approvals/{thread_id}" in routes
|
||||||
|
|
||||||
|
def test_status_routes_registered(self) -> None:
|
||||||
|
from release_agent.main import create_app
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
|
||||||
|
assert "/status" in routes
|
||||||
|
assert "/staging" in routes
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# schedule_graph / run_graph_in_background tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestScheduleGraph:
|
||||||
|
def test_schedule_graph_returns_thread_id(self) -> None:
|
||||||
|
from release_agent.main import schedule_graph
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.state.background_tasks = set()
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
|
||||||
|
with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()):
|
||||||
|
thread_id = schedule_graph(
|
||||||
|
app=mock_app,
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={"repo_name": "my-repo"},
|
||||||
|
thread_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(thread_id, str)
|
||||||
|
assert len(thread_id) > 0
|
||||||
|
|
||||||
|
def test_schedule_graph_uses_provided_thread_id(self) -> None:
|
||||||
|
from release_agent.main import schedule_graph
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.state.background_tasks = set()
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
|
||||||
|
with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()):
|
||||||
|
thread_id = schedule_graph(
|
||||||
|
app=mock_app,
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={},
|
||||||
|
thread_id="custom-thread-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert thread_id == "custom-thread-id"
|
||||||
|
|
||||||
|
def test_schedule_graph_adds_task_to_background_tasks(self) -> None:
|
||||||
|
from release_agent.main import schedule_graph
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.state.background_tasks = set()
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_task = MagicMock()
|
||||||
|
|
||||||
|
with patch("release_agent.main.asyncio.create_task", return_value=mock_task):
|
||||||
|
schedule_graph(
|
||||||
|
app=mock_app,
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={},
|
||||||
|
thread_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_task in mock_app.state.background_tasks
|
||||||
|
|
||||||
|
def test_run_graph_in_background_is_coroutine(self) -> None:
|
||||||
|
from release_agent.main import run_graph_in_background
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
assert inspect.iscoroutinefunction(run_graph_in_background)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _ensure_db_schema tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEnsureDbSchema:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_db_schema_creates_table(self) -> None:
|
||||||
|
from release_agent.main import _ensure_db_schema
|
||||||
|
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _ensure_db_schema(mock_pool)
|
||||||
|
|
||||||
|
# Phase 5: now executes multiple DDL statements (agent_threads +
|
||||||
|
# staging_releases + archived_releases), so called_once no longer holds.
|
||||||
|
assert mock_cursor.execute.call_count >= 1
|
||||||
|
all_sql = " ".join(
|
||||||
|
call.args[0] for call in mock_cursor.execute.call_args_list
|
||||||
|
)
|
||||||
|
assert "agent_threads" in all_sql
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _create_tool_clients tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCreateToolClients:
|
||||||
|
def test_create_tool_clients_returns_tool_clients_instance(self) -> None:
|
||||||
|
from release_agent.main import _create_tool_clients
|
||||||
|
from release_agent.graph.dependencies import ToolClients
|
||||||
|
|
||||||
|
mock_settings = _make_mock_settings()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("release_agent.main.AzDoClient") as mock_azdo,
|
||||||
|
patch("release_agent.main.JiraClient") as mock_jira,
|
||||||
|
patch("release_agent.main.SlackClient") as mock_slack,
|
||||||
|
patch("release_agent.main.ClaudeReviewer") as mock_reviewer,
|
||||||
|
patch("release_agent.main.httpx.AsyncClient") as mock_httpx,
|
||||||
|
):
|
||||||
|
clients, http_clients = _create_tool_clients(mock_settings)
|
||||||
|
|
||||||
|
assert isinstance(clients, ToolClients)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _create_staging_store tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCreateStagingStore:
|
||||||
|
def test_create_staging_store_returns_store(self) -> None:
|
||||||
|
from release_agent.main import _create_staging_store
|
||||||
|
from release_agent.graph.dependencies import JsonFileStagingStore
|
||||||
|
|
||||||
|
result = _create_staging_store()
|
||||||
|
assert isinstance(result, JsonFileStagingStore)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Global exception handler tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestExceptionHandlers:
|
||||||
|
def test_app_has_exception_handlers(self) -> None:
|
||||||
|
from release_agent.main import create_app
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
# FastAPI stores exception handlers in exception_handlers attribute
|
||||||
|
assert hasattr(app, "exception_handlers")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Lifespan tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGracefulShutdown:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifespan_cancels_timed_out_tasks(self) -> None:
|
||||||
|
"""Verify the lifespan waits for tasks and cancels timed-out ones."""
|
||||||
|
from release_agent.main import lifespan
|
||||||
|
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.open = AsyncMock()
|
||||||
|
mock_pool.close = AsyncMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.state.background_tasks = set()
|
||||||
|
|
||||||
|
mock_settings = _make_mock_settings()
|
||||||
|
mock_task = MagicMock()
|
||||||
|
mock_task.cancel = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("release_agent.main.Settings", return_value=mock_settings),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"release_agent.main.asyncio.wait",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(set(), {mock_task}),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
ctx = lifespan(app)
|
||||||
|
await ctx.__aenter__()
|
||||||
|
# Add a fake task to background_tasks after startup
|
||||||
|
app.state.background_tasks.add(mock_task)
|
||||||
|
await ctx.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
# The pending task should have been cancelled
|
||||||
|
mock_task.cancel.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLifespan:
|
||||||
|
def test_app_state_set_after_lifespan(self) -> None:
|
||||||
|
"""Verify app.state.graphs and app.state.settings are set during lifespan."""
|
||||||
|
from release_agent.main import create_app
|
||||||
|
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
|
||||||
|
mock_pool.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.open = AsyncMock()
|
||||||
|
mock_pool.close = AsyncMock()
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.execute = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
mock_settings = _make_mock_settings()
|
||||||
|
mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()}
|
||||||
|
mock_clients = MagicMock()
|
||||||
|
mock_staging_store = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("release_agent.main.Settings", return_value=mock_settings),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=(mock_clients, [])),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=mock_staging_store),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
app = create_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# App is started; state should be accessible
|
||||||
|
response = client.get("/status")
|
||||||
|
# We just verify no crash
|
||||||
|
assert response.status_code in (200, 500)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Phase 5: Slack interactions route + new config tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPhase5Routes:
|
||||||
|
"""Tests for Phase 5 additions to main.py."""
|
||||||
|
|
||||||
|
def _make_patches(self):
|
||||||
|
mock_settings = _make_mock_settings()
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.open = AsyncMock()
|
||||||
|
mock_pool.close = AsyncMock()
|
||||||
|
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
|
||||||
|
mock_pool.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
return [
|
||||||
|
patch("release_agent.main.Settings", return_value=mock_settings),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_slack_interactions_route_registered(self) -> None:
|
||||||
|
from release_agent.main import create_app
|
||||||
|
|
||||||
|
patches = self._make_patches()
|
||||||
|
for p in patches:
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app()
|
||||||
|
finally:
|
||||||
|
for p in patches:
|
||||||
|
p.stop()
|
||||||
|
|
||||||
|
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
|
||||||
|
assert "/slack/interactions" in routes
|
||||||
|
|
||||||
|
def test_create_tool_clients_uses_bot_token(self) -> None:
|
||||||
|
from release_agent.main import _create_tool_clients
|
||||||
|
|
||||||
|
mock_settings = _make_mock_settings()
|
||||||
|
mock_settings.slack_bot_token.get_secret_value.return_value = "xoxb-test"
|
||||||
|
mock_settings.slack_channel_id = "C12345"
|
||||||
|
mock_settings.slack_webhook_url.get_secret_value.return_value = ""
|
||||||
|
mock_settings.azdo_api_url = "https://dev.azure.com/org/proj/_apis"
|
||||||
|
mock_settings.azdo_vsrm_api_url = "https://vsrm.dev.azure.com/org/proj/_apis"
|
||||||
|
mock_settings.jira_base_url = "https://example.atlassian.net"
|
||||||
|
mock_settings.jira_email = "test@example.com"
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
clients, http_clients = _create_tool_clients(mock_settings)
|
||||||
|
assert clients is not None
|
||||||
|
# Clean up
|
||||||
|
for hc in http_clients:
|
||||||
|
asyncio.get_event_loop().run_until_complete(hc.aclose())
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhase5DbSchema:
|
||||||
|
"""Tests that _ensure_db_schema adds the slack_message_ts column."""
|
||||||
|
|
||||||
|
async def test_ensure_db_schema_executes_sql_statements(self) -> None:
|
||||||
|
from release_agent.main import _ensure_db_schema
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
executed_sql: list[str] = []
|
||||||
|
|
||||||
|
async def capture_execute(sql, *args):
|
||||||
|
executed_sql.append(sql.strip())
|
||||||
|
|
||||||
|
mock_cursor.execute = capture_execute
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _ensure_db_schema(mock_pool)
|
||||||
|
|
||||||
|
# Should have executed CREATE TABLE statements
|
||||||
|
assert len(executed_sql) >= 3
|
||||||
|
combined = " ".join(executed_sql)
|
||||||
|
assert "agent_threads" in combined
|
||||||
|
assert "staging_releases" in combined
|
||||||
|
|
||||||
|
async def test_ensure_db_schema_includes_slack_message_ts_column(self) -> None:
|
||||||
|
from release_agent.main import _ensure_db_schema
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
executed_sql: list[str] = []
|
||||||
|
|
||||||
|
async def capture_execute(sql, *args):
|
||||||
|
executed_sql.append(sql.strip())
|
||||||
|
|
||||||
|
mock_cursor.execute = capture_execute
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _ensure_db_schema(mock_pool)
|
||||||
|
|
||||||
|
combined = " ".join(executed_sql)
|
||||||
|
assert "slack_message_ts" in combined
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PR polling lifespan integration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPrPollingLifespan:
|
||||||
|
"""Tests for PR polling startup in the lifespan handler."""
|
||||||
|
|
||||||
|
def _make_polling_settings(self, *, pr_poll_enabled: bool = True) -> MagicMock:
|
||||||
|
s = _make_mock_settings()
|
||||||
|
s.pr_poll_enabled = pr_poll_enabled
|
||||||
|
s.pr_poll_interval_seconds = 30
|
||||||
|
s.pr_poll_target_branch = "refs/heads/develop"
|
||||||
|
s.watched_repos_list = ["repo-a"]
|
||||||
|
s.default_jira_project = "ALLPOST"
|
||||||
|
return s
|
||||||
|
|
||||||
|
async def test_poll_loop_started_when_pr_poll_enabled(self) -> None:
|
||||||
|
"""When pr_poll_enabled=True, a background task for polling is created."""
|
||||||
|
from release_agent.main import create_app
|
||||||
|
|
||||||
|
mock_settings = self._make_polling_settings(pr_poll_enabled=True)
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.open = AsyncMock()
|
||||||
|
mock_pool.close = AsyncMock()
|
||||||
|
mock_pool.connection = MagicMock()
|
||||||
|
|
||||||
|
poll_loop_started = []
|
||||||
|
|
||||||
|
async def fake_run_poll_loop(**kwargs):
|
||||||
|
poll_loop_started.append(True)
|
||||||
|
# Simulate an immediate cancellation to avoid infinite loop
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("release_agent.main.Settings", return_value=mock_settings),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop),
|
||||||
|
):
|
||||||
|
app = create_app()
|
||||||
|
async with app.router.lifespan_context(app):
|
||||||
|
# Give the event loop a chance to start background tasks
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(poll_loop_started) > 0
|
||||||
|
|
||||||
|
async def test_poll_loop_not_started_when_pr_poll_disabled(self) -> None:
|
||||||
|
"""When pr_poll_enabled=False, no polling background task is created."""
|
||||||
|
from release_agent.main import create_app
|
||||||
|
|
||||||
|
mock_settings = self._make_polling_settings(pr_poll_enabled=False)
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.open = AsyncMock()
|
||||||
|
mock_pool.close = AsyncMock()
|
||||||
|
mock_pool.connection = MagicMock()
|
||||||
|
|
||||||
|
poll_loop_started = []
|
||||||
|
|
||||||
|
async def fake_run_poll_loop(**kwargs):
|
||||||
|
poll_loop_started.append(True)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("release_agent.main.Settings", return_value=mock_settings),
|
||||||
|
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
|
||||||
|
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
|
||||||
|
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
|
||||||
|
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
|
||||||
|
patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop),
|
||||||
|
):
|
||||||
|
app = create_app()
|
||||||
|
async with app.router.lifespan_context(app):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(poll_loop_started) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _run_graph default_jira_project injection tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRunGraphJiraProjectInjection:
|
||||||
|
"""Tests that _run_graph passes default_jira_project into the graph config."""
|
||||||
|
|
||||||
|
async def test_default_jira_project_passed_to_graph_config(self) -> None:
|
||||||
|
from release_agent.api.webhooks import _run_graph
|
||||||
|
|
||||||
|
captured_configs: list[dict] = []
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
|
||||||
|
async def fake_ainvoke(state, config=None):
|
||||||
|
captured_configs.append(config or {})
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mock_graph.ainvoke = fake_ainvoke
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _run_graph(
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={"pr_id": "1", "repo_name": "r"},
|
||||||
|
thread_id="tid-1",
|
||||||
|
tool_clients=MagicMock(),
|
||||||
|
db_pool=mock_pool,
|
||||||
|
repos_base_dir="",
|
||||||
|
graph_name="pr_completed",
|
||||||
|
default_jira_project="MYPROJ",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_configs) == 1
|
||||||
|
configurable = captured_configs[0].get("configurable", {})
|
||||||
|
assert configurable.get("default_jira_project") == "MYPROJ"
|
||||||
|
|
||||||
|
async def test_default_jira_project_defaults_to_allpost(self) -> None:
|
||||||
|
from release_agent.api.webhooks import _run_graph
|
||||||
|
|
||||||
|
captured_configs: list[dict] = []
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
|
||||||
|
async def fake_ainvoke(state, config=None):
|
||||||
|
captured_configs.append(config or {})
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mock_graph.ainvoke = fake_ainvoke
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _run_graph(
|
||||||
|
graph=mock_graph,
|
||||||
|
initial_state={"pr_id": "1", "repo_name": "r"},
|
||||||
|
thread_id="tid-2",
|
||||||
|
tool_clients=MagicMock(),
|
||||||
|
db_pool=mock_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
configurable = captured_configs[0].get("configurable", {})
|
||||||
|
assert configurable.get("default_jira_project") == "ALLPOST"
|
||||||
147
tests/test_main_phase5.py
Normal file
147
tests/test_main_phase5.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""Tests for main.py Phase 5 changes.
|
||||||
|
|
||||||
|
Phase 5 - Step 4: _ensure_db_schema creates staging/archived tables,
|
||||||
|
and lifespan uses PostgresStagingStore.
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _ensure_db_schema includes staging DDL
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEnsureDbSchemaPhase5:
|
||||||
|
async def test_schema_creates_staging_releases_table(self) -> None:
|
||||||
|
from release_agent.main import _ensure_db_schema
|
||||||
|
|
||||||
|
executed_sqls: list[str] = []
|
||||||
|
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
async def capture_execute(sql: str, *args) -> None:
|
||||||
|
executed_sqls.append(sql)
|
||||||
|
|
||||||
|
mock_cursor.execute = capture_execute
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _ensure_db_schema(mock_pool)
|
||||||
|
|
||||||
|
all_sql = " ".join(executed_sqls)
|
||||||
|
assert "staging_releases" in all_sql
|
||||||
|
|
||||||
|
async def test_schema_creates_archived_releases_table(self) -> None:
|
||||||
|
from release_agent.main import _ensure_db_schema
|
||||||
|
|
||||||
|
executed_sqls: list[str] = []
|
||||||
|
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
async def capture_execute(sql: str, *args) -> None:
|
||||||
|
executed_sqls.append(sql)
|
||||||
|
|
||||||
|
mock_cursor.execute = capture_execute
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _ensure_db_schema(mock_pool)
|
||||||
|
|
||||||
|
all_sql = " ".join(executed_sqls)
|
||||||
|
assert "archived_releases" in all_sql
|
||||||
|
|
||||||
|
async def test_schema_still_creates_agent_threads_table(self) -> None:
|
||||||
|
from release_agent.main import _ensure_db_schema
|
||||||
|
|
||||||
|
executed_sqls: list[str] = []
|
||||||
|
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
async def capture_execute(sql: str, *args) -> None:
|
||||||
|
executed_sqls.append(sql)
|
||||||
|
|
||||||
|
mock_cursor.execute = capture_execute
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _ensure_db_schema(mock_pool)
|
||||||
|
|
||||||
|
all_sql = " ".join(executed_sqls)
|
||||||
|
assert "agent_threads" in all_sql
|
||||||
|
|
||||||
|
async def test_schema_uses_if_not_exists(self) -> None:
|
||||||
|
from release_agent.main import _ensure_db_schema
|
||||||
|
|
||||||
|
executed_sqls: list[str] = []
|
||||||
|
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_cursor.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
async def capture_execute(sql: str, *args) -> None:
|
||||||
|
executed_sqls.append(sql)
|
||||||
|
|
||||||
|
mock_cursor.execute = capture_execute
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.cursor = MagicMock(return_value=mock_cursor)
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
await _ensure_db_schema(mock_pool)
|
||||||
|
|
||||||
|
all_sql = " ".join(executed_sqls)
|
||||||
|
assert "IF NOT EXISTS" in all_sql.upper()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Lifespan: PostgresStagingStore wired in
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestLifespanUsesPostgresStagingStore:
|
||||||
|
def test_lifespan_creates_postgres_staging_store(self) -> None:
|
||||||
|
"""When PostgresStagingStore is imported in main, it is used in lifespan."""
|
||||||
|
from release_agent.main import _create_staging_store
|
||||||
|
from release_agent.graph.postgres_staging_store import PostgresStagingStore
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
result = _create_staging_store(pool=mock_pool)
|
||||||
|
assert isinstance(result, PostgresStagingStore)
|
||||||
|
|
||||||
|
def test_create_staging_store_without_pool_falls_back_to_json(self) -> None:
|
||||||
|
"""Without a pool, falls back to JsonFileStagingStore for local dev."""
|
||||||
|
from release_agent.main import _create_staging_store
|
||||||
|
from release_agent.graph.dependencies import JsonFileStagingStore
|
||||||
|
|
||||||
|
result = _create_staging_store(pool=None)
|
||||||
|
assert isinstance(result, JsonFileStagingStore)
|
||||||
635
tests/test_models.py
Normal file
635
tests/test_models.py
Normal file
@@ -0,0 +1,635 @@
|
|||||||
|
"""Tests for Pydantic models. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
from datetime import date, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from release_agent.models.jira import JiraIssue, JiraTransition
|
||||||
|
from release_agent.models.pipeline import PipelineInfo, ReleasePipelineStage
|
||||||
|
from release_agent.models.pr import PRInfo
|
||||||
|
from release_agent.models.release import ArchivedRelease, StagingRelease
|
||||||
|
from release_agent.models.review import ReviewIssue, ReviewResult
|
||||||
|
from release_agent.models.ticket import TicketEntry
|
||||||
|
from release_agent.models.webhook import WebhookPayload, WebhookRepository, WebhookResource
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PRInfo tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPRInfo:
|
||||||
|
"""Tests for PRInfo model."""
|
||||||
|
|
||||||
|
def _make_pr(self, **kwargs) -> PRInfo:
|
||||||
|
defaults = {
|
||||||
|
"pr_id": "PR-1",
|
||||||
|
"pr_url": "https://dev.azure.com/org/project/_git/repo/pullrequest/1",
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"branch": "feature/ALLPOST-100_add-feature",
|
||||||
|
"pr_title": "Add new feature",
|
||||||
|
"pr_status": "active",
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return PRInfo(**defaults)
|
||||||
|
|
||||||
|
def test_ticket_id_extracted_from_branch(self) -> None:
|
||||||
|
pr = self._make_pr(branch="feature/ALLPOST-100_add-feature")
|
||||||
|
assert pr.ticket_id == "ALLPOST-100"
|
||||||
|
assert pr.has_ticket is True
|
||||||
|
|
||||||
|
def test_branch_without_ticket(self) -> None:
|
||||||
|
pr = self._make_pr(branch="chore/update-dependencies")
|
||||||
|
assert pr.ticket_id is None
|
||||||
|
assert pr.has_ticket is False
|
||||||
|
|
||||||
|
def test_main_branch_no_ticket(self) -> None:
|
||||||
|
pr = self._make_pr(branch="main")
|
||||||
|
assert pr.ticket_id is None
|
||||||
|
assert pr.has_ticket is False
|
||||||
|
|
||||||
|
def test_refs_heads_branch_parsed(self) -> None:
|
||||||
|
pr = self._make_pr(branch="refs/heads/fix/BILL-42_fix-bug")
|
||||||
|
assert pr.ticket_id == "BILL-42"
|
||||||
|
assert pr.has_ticket is True
|
||||||
|
|
||||||
|
def test_pr_status_active(self) -> None:
|
||||||
|
pr = self._make_pr(pr_status="active")
|
||||||
|
assert pr.pr_status == "active"
|
||||||
|
|
||||||
|
def test_pr_status_completed(self) -> None:
|
||||||
|
pr = self._make_pr(pr_status="completed")
|
||||||
|
assert pr.pr_status == "completed"
|
||||||
|
|
||||||
|
def test_pr_status_abandoned(self) -> None:
|
||||||
|
pr = self._make_pr(pr_status="abandoned")
|
||||||
|
assert pr.pr_status == "abandoned"
|
||||||
|
|
||||||
|
def test_invalid_pr_status_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_pr(pr_status="unknown")
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
pr = self._make_pr()
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
pr.pr_id = "modified" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_pr_url_is_valid_url(self) -> None:
|
||||||
|
pr = self._make_pr()
|
||||||
|
# HttpUrl should have been validated
|
||||||
|
assert "dev.azure.com" in str(pr.pr_url)
|
||||||
|
|
||||||
|
def test_invalid_url_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_pr(pr_url="not-a-url")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TicketEntry tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTicketEntry:
|
||||||
|
"""Tests for TicketEntry model."""
|
||||||
|
|
||||||
|
def _make_ticket(self, **kwargs) -> TicketEntry:
|
||||||
|
defaults = {
|
||||||
|
"id": "ALLPOST-4229",
|
||||||
|
"summary": "Fix review bug",
|
||||||
|
"pr_id": "PR-42",
|
||||||
|
"pr_url": "https://dev.azure.com/org/project/_git/repo/pullrequest/42",
|
||||||
|
"pr_title": "Fix review",
|
||||||
|
"branch": "bug/ALLPOST-4229_fix-review",
|
||||||
|
"merged_at": date(2024, 1, 15),
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return TicketEntry(**defaults)
|
||||||
|
|
||||||
|
def test_valid_ticket_entry(self) -> None:
|
||||||
|
ticket = self._make_ticket()
|
||||||
|
assert ticket.id == "ALLPOST-4229"
|
||||||
|
assert ticket.summary == "Fix review bug"
|
||||||
|
|
||||||
|
def test_valid_jira_id_format(self) -> None:
|
||||||
|
ticket = self._make_ticket(id="BILL-42")
|
||||||
|
assert ticket.id == "BILL-42"
|
||||||
|
|
||||||
|
def test_invalid_id_lowercase_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_ticket(id="allpost-4229")
|
||||||
|
|
||||||
|
def test_invalid_id_no_number_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_ticket(id="ALLPOST-")
|
||||||
|
|
||||||
|
def test_invalid_id_no_dash_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_ticket(id="ALLPOST4229")
|
||||||
|
|
||||||
|
def test_invalid_id_starts_with_number_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_ticket(id="4ALLPOST-4229")
|
||||||
|
|
||||||
|
def test_merged_at_is_date(self) -> None:
|
||||||
|
ticket = self._make_ticket()
|
||||||
|
assert isinstance(ticket.merged_at, date)
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
ticket = self._make_ticket()
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ticket.id = "OTHER-1" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_minimum_valid_id(self) -> None:
|
||||||
|
# Single uppercase letter prefix followed by dash and digits
|
||||||
|
ticket = self._make_ticket(id="A-1")
|
||||||
|
assert ticket.id == "A-1"
|
||||||
|
|
||||||
|
def test_numeric_in_project_key(self) -> None:
|
||||||
|
ticket = self._make_ticket(id="AB2-100")
|
||||||
|
assert ticket.id == "AB2-100"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# StagingRelease tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestStagingRelease:
|
||||||
|
"""Tests for StagingRelease model."""
|
||||||
|
|
||||||
|
def _make_ticket(self, ticket_id: str = "ALLPOST-1") -> TicketEntry:
|
||||||
|
return TicketEntry(
|
||||||
|
id=ticket_id,
|
||||||
|
summary="Some ticket",
|
||||||
|
pr_id="PR-1",
|
||||||
|
pr_url="https://dev.azure.com/org/project/_git/repo/pullrequest/1",
|
||||||
|
pr_title="Some PR",
|
||||||
|
branch=f"feature/{ticket_id}_some-feature",
|
||||||
|
merged_at=date(2024, 1, 15),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_release(self, **kwargs) -> StagingRelease:
|
||||||
|
defaults = {
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"repo": "my-repo",
|
||||||
|
"started_at": date(2024, 1, 1),
|
||||||
|
"tickets": [],
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return StagingRelease(**defaults)
|
||||||
|
|
||||||
|
def test_valid_release(self) -> None:
|
||||||
|
release = self._make_release()
|
||||||
|
assert release.version == "v1.0.0"
|
||||||
|
|
||||||
|
def test_version_must_match_pattern(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_release(version="1.0.0")
|
||||||
|
|
||||||
|
def test_version_missing_patch_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_release(version="v1.0")
|
||||||
|
|
||||||
|
def test_version_extra_segments_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_release(version="v1.0.0.1")
|
||||||
|
|
||||||
|
def test_version_letters_in_numbers_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_release(version="v1.a.0")
|
||||||
|
|
||||||
|
def test_add_ticket_returns_new_instance(self) -> None:
|
||||||
|
release = self._make_release()
|
||||||
|
ticket = self._make_ticket("ALLPOST-1")
|
||||||
|
new_release = release.add_ticket(ticket)
|
||||||
|
assert new_release is not release
|
||||||
|
|
||||||
|
def test_add_ticket_immutability(self) -> None:
|
||||||
|
release = self._make_release()
|
||||||
|
ticket = self._make_ticket("ALLPOST-1")
|
||||||
|
new_release = release.add_ticket(ticket)
|
||||||
|
assert len(release.tickets) == 0
|
||||||
|
assert len(new_release.tickets) == 1
|
||||||
|
|
||||||
|
def test_add_ticket_contains_ticket(self) -> None:
|
||||||
|
release = self._make_release()
|
||||||
|
ticket = self._make_ticket("ALLPOST-1")
|
||||||
|
new_release = release.add_ticket(ticket)
|
||||||
|
assert ticket in new_release.tickets
|
||||||
|
|
||||||
|
def test_has_ticket_true(self) -> None:
|
||||||
|
ticket = self._make_ticket("ALLPOST-1")
|
||||||
|
release = self._make_release(tickets=[ticket])
|
||||||
|
assert release.has_ticket("ALLPOST-1") is True
|
||||||
|
|
||||||
|
def test_has_ticket_false(self) -> None:
|
||||||
|
release = self._make_release()
|
||||||
|
assert release.has_ticket("ALLPOST-99") is False
|
||||||
|
|
||||||
|
def test_has_ticket_after_add(self) -> None:
|
||||||
|
release = self._make_release()
|
||||||
|
ticket = self._make_ticket("ALLPOST-5")
|
||||||
|
new_release = release.add_ticket(ticket)
|
||||||
|
assert new_release.has_ticket("ALLPOST-5") is True
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
release = self._make_release()
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
release.version = "v2.0.0" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_multiple_tickets(self) -> None:
|
||||||
|
t1 = self._make_ticket("ALLPOST-1")
|
||||||
|
t2 = self._make_ticket("ALLPOST-2")
|
||||||
|
release = self._make_release(tickets=[t1, t2])
|
||||||
|
assert len(release.tickets) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ArchivedRelease tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestArchivedRelease:
|
||||||
|
"""Tests for ArchivedRelease model."""
|
||||||
|
|
||||||
|
def _make_archived(self, **kwargs) -> ArchivedRelease:
|
||||||
|
defaults = {
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"repo": "my-repo",
|
||||||
|
"started_at": date(2024, 1, 1),
|
||||||
|
"tickets": [],
|
||||||
|
"released_at": date(2024, 1, 10),
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return ArchivedRelease(**defaults)
|
||||||
|
|
||||||
|
def test_valid_archived_release(self) -> None:
|
||||||
|
release = self._make_archived()
|
||||||
|
assert release.released_at == date(2024, 1, 10)
|
||||||
|
|
||||||
|
def test_released_at_same_as_started_at_is_valid(self) -> None:
|
||||||
|
release = self._make_archived(started_at=date(2024, 1, 1), released_at=date(2024, 1, 1))
|
||||||
|
assert release.released_at == release.started_at
|
||||||
|
|
||||||
|
def test_released_at_before_started_at_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_archived(
|
||||||
|
started_at=date(2024, 1, 10),
|
||||||
|
released_at=date(2024, 1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
release = self._make_archived()
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
release.released_at = date(2024, 12, 31) # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_inherits_version_validation(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_archived(version="1.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PipelineInfo tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPipelineInfo:
|
||||||
|
"""Tests for PipelineInfo model."""
|
||||||
|
|
||||||
|
def test_valid_pipeline_info(self) -> None:
|
||||||
|
pipeline = PipelineInfo(id=42, name="Release Pipeline", repo="my-repo")
|
||||||
|
assert pipeline.id == 42
|
||||||
|
assert pipeline.name == "Release Pipeline"
|
||||||
|
assert pipeline.repo == "my-repo"
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
pipeline = PipelineInfo(id=1, name="Test", repo="repo")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
pipeline.id = 2 # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ReleasePipelineStage tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReleasePipelineStage:
|
||||||
|
"""Tests for ReleasePipelineStage model."""
|
||||||
|
|
||||||
|
def test_valid_stage_without_approval(self) -> None:
|
||||||
|
stage = ReleasePipelineStage(
|
||||||
|
name="Build", rank=0, requires_approval=False, approval_id=None
|
||||||
|
)
|
||||||
|
assert stage.name == "Build"
|
||||||
|
assert stage.rank == 0
|
||||||
|
|
||||||
|
def test_valid_stage_with_approval(self) -> None:
|
||||||
|
stage = ReleasePipelineStage(
|
||||||
|
name="Production", rank=2, requires_approval=True, approval_id="approval-uuid-123"
|
||||||
|
)
|
||||||
|
assert stage.requires_approval is True
|
||||||
|
assert stage.approval_id == "approval-uuid-123"
|
||||||
|
|
||||||
|
def test_negative_rank_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ReleasePipelineStage(
|
||||||
|
name="Bad", rank=-1, requires_approval=False, approval_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_requires_approval_false_with_approval_id_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ReleasePipelineStage(
|
||||||
|
name="Bad", rank=0, requires_approval=False, approval_id="some-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_requires_approval_true_without_approval_id_is_valid(self) -> None:
|
||||||
|
stage = ReleasePipelineStage(
|
||||||
|
name="Production", rank=2, requires_approval=True, approval_id=None
|
||||||
|
)
|
||||||
|
assert stage.requires_approval is True
|
||||||
|
assert stage.approval_id is None
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
stage = ReleasePipelineStage(name="Build", rank=0, requires_approval=False, approval_id=None)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
stage.name = "Changed" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# WebhookPayload tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWebhookPayload:
|
||||||
|
"""Tests for WebhookPayload and nested models."""
|
||||||
|
|
||||||
|
def _make_payload(self, **kwargs) -> WebhookPayload:
|
||||||
|
defaults = {
|
||||||
|
"subscription_id": "sub-123",
|
||||||
|
"event_type": "git.pullrequest.merged",
|
||||||
|
"resource": {
|
||||||
|
"repository": {
|
||||||
|
"id": "repo-uuid-456",
|
||||||
|
"name": "my-repo",
|
||||||
|
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
|
||||||
|
},
|
||||||
|
"pull_request_id": 42,
|
||||||
|
"title": "Fix the bug",
|
||||||
|
"source_ref_name": "refs/heads/bug/ALLPOST-4229_fix-review",
|
||||||
|
"target_ref_name": "refs/heads/main",
|
||||||
|
"status": "completed",
|
||||||
|
"closed_date": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return WebhookPayload(**defaults)
|
||||||
|
|
||||||
|
def test_valid_payload(self) -> None:
|
||||||
|
payload = self._make_payload()
|
||||||
|
assert payload.subscription_id == "sub-123"
|
||||||
|
assert payload.event_type == "git.pullrequest.merged"
|
||||||
|
|
||||||
|
def test_resource_parsed(self) -> None:
|
||||||
|
payload = self._make_payload()
|
||||||
|
assert isinstance(payload.resource, WebhookResource)
|
||||||
|
assert payload.resource.pull_request_id == 42
|
||||||
|
assert payload.resource.title == "Fix the bug"
|
||||||
|
|
||||||
|
def test_repository_parsed(self) -> None:
|
||||||
|
payload = self._make_payload()
|
||||||
|
repo = payload.resource.repository
|
||||||
|
assert isinstance(repo, WebhookRepository)
|
||||||
|
assert repo.name == "my-repo"
|
||||||
|
|
||||||
|
def test_repository_web_url(self) -> None:
|
||||||
|
payload = self._make_payload()
|
||||||
|
assert "dev.azure.com" in str(payload.resource.repository.web_url)
|
||||||
|
|
||||||
|
def test_closed_date_none(self) -> None:
|
||||||
|
payload = self._make_payload()
|
||||||
|
assert payload.resource.closed_date is None
|
||||||
|
|
||||||
|
def test_closed_date_populated(self) -> None:
|
||||||
|
payload_data = {
|
||||||
|
"subscription_id": "sub-123",
|
||||||
|
"event_type": "git.pullrequest.merged",
|
||||||
|
"resource": {
|
||||||
|
"repository": {
|
||||||
|
"id": "repo-uuid-456",
|
||||||
|
"name": "my-repo",
|
||||||
|
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
|
||||||
|
},
|
||||||
|
"pull_request_id": 42,
|
||||||
|
"title": "Fix the bug",
|
||||||
|
"source_ref_name": "refs/heads/bug/ALLPOST-4229_fix-review",
|
||||||
|
"target_ref_name": "refs/heads/main",
|
||||||
|
"status": "completed",
|
||||||
|
"closed_date": "2024-01-15T10:30:00Z",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload = WebhookPayload(**payload_data)
|
||||||
|
assert payload.resource.closed_date is not None
|
||||||
|
assert isinstance(payload.resource.closed_date, datetime)
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
payload = self._make_payload()
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
payload.subscription_id = "changed" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_source_ref_name_preserved(self) -> None:
|
||||||
|
payload = self._make_payload()
|
||||||
|
assert payload.resource.source_ref_name == "refs/heads/bug/ALLPOST-4229_fix-review"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ReviewIssue tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReviewIssue:
|
||||||
|
"""Tests for ReviewIssue model."""
|
||||||
|
|
||||||
|
def _make_issue(self, **kwargs) -> ReviewIssue:
|
||||||
|
defaults = {
|
||||||
|
"severity": "warning",
|
||||||
|
"description": "Variable name is not descriptive",
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return ReviewIssue(**defaults)
|
||||||
|
|
||||||
|
def test_valid_warning_issue(self) -> None:
|
||||||
|
issue = self._make_issue(severity="warning", description="Unclear variable")
|
||||||
|
assert issue.severity == "warning"
|
||||||
|
assert issue.description == "Unclear variable"
|
||||||
|
|
||||||
|
def test_valid_error_issue(self) -> None:
|
||||||
|
issue = self._make_issue(severity="error", description="Null pointer risk")
|
||||||
|
assert issue.severity == "error"
|
||||||
|
|
||||||
|
def test_valid_info_issue(self) -> None:
|
||||||
|
issue = self._make_issue(severity="info", description="Minor style note")
|
||||||
|
assert issue.severity == "info"
|
||||||
|
|
||||||
|
def test_valid_blocker_issue(self) -> None:
|
||||||
|
issue = self._make_issue(severity="blocker", description="Security vulnerability")
|
||||||
|
assert issue.severity == "blocker"
|
||||||
|
|
||||||
|
def test_invalid_severity_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_issue(severity="critical")
|
||||||
|
|
||||||
|
def test_file_path_optional_none_by_default(self) -> None:
|
||||||
|
issue = self._make_issue()
|
||||||
|
assert issue.file_path is None
|
||||||
|
|
||||||
|
def test_file_path_can_be_set(self) -> None:
|
||||||
|
issue = self._make_issue(file_path="src/foo.py")
|
||||||
|
assert issue.file_path == "src/foo.py"
|
||||||
|
|
||||||
|
def test_suggestion_optional_none_by_default(self) -> None:
|
||||||
|
issue = self._make_issue()
|
||||||
|
assert issue.suggestion is None
|
||||||
|
|
||||||
|
def test_suggestion_can_be_set(self) -> None:
|
||||||
|
issue = self._make_issue(suggestion="Rename to `user_count`")
|
||||||
|
assert issue.suggestion == "Rename to `user_count`"
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
issue = self._make_issue()
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
issue.severity = "error" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_description_required(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ReviewIssue(severity="warning") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ReviewResult tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReviewResult:
|
||||||
|
"""Tests for ReviewResult model."""
|
||||||
|
|
||||||
|
def _make_blocker_issue(self) -> ReviewIssue:
|
||||||
|
return ReviewIssue(severity="blocker", description="Must fix this")
|
||||||
|
|
||||||
|
def _make_warning_issue(self) -> ReviewIssue:
|
||||||
|
return ReviewIssue(severity="warning", description="Minor issue")
|
||||||
|
|
||||||
|
def _make_result(self, **kwargs) -> ReviewResult:
|
||||||
|
defaults = {
|
||||||
|
"verdict": "approve",
|
||||||
|
"summary": "Looks good overall",
|
||||||
|
"issues": [],
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return ReviewResult(**defaults)
|
||||||
|
|
||||||
|
def test_valid_approve_verdict(self) -> None:
|
||||||
|
result = self._make_result(verdict="approve")
|
||||||
|
assert result.verdict == "approve"
|
||||||
|
|
||||||
|
def test_valid_request_changes_verdict(self) -> None:
|
||||||
|
result = self._make_result(verdict="request_changes")
|
||||||
|
assert result.verdict == "request_changes"
|
||||||
|
|
||||||
|
def test_invalid_verdict_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
self._make_result(verdict="reject")
|
||||||
|
|
||||||
|
def test_summary_stored(self) -> None:
|
||||||
|
result = self._make_result(summary="Great PR")
|
||||||
|
assert result.summary == "Great PR"
|
||||||
|
|
||||||
|
def test_issues_empty_by_default(self) -> None:
|
||||||
|
result = self._make_result()
|
||||||
|
assert len(result.issues) == 0
|
||||||
|
|
||||||
|
def test_has_blockers_false_with_no_issues(self) -> None:
|
||||||
|
result = self._make_result(issues=[])
|
||||||
|
assert result.has_blockers is False
|
||||||
|
|
||||||
|
def test_has_blockers_false_with_only_warnings(self) -> None:
|
||||||
|
result = self._make_result(issues=[self._make_warning_issue()])
|
||||||
|
assert result.has_blockers is False
|
||||||
|
|
||||||
|
def test_has_blockers_true_with_blocker_issue(self) -> None:
|
||||||
|
result = self._make_result(issues=[self._make_blocker_issue()])
|
||||||
|
assert result.has_blockers is True
|
||||||
|
|
||||||
|
def test_has_blockers_true_mixed_issues(self) -> None:
|
||||||
|
result = self._make_result(
|
||||||
|
issues=[self._make_warning_issue(), self._make_blocker_issue()]
|
||||||
|
)
|
||||||
|
assert result.has_blockers is True
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
result = self._make_result()
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
result.verdict = "request_changes" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_multiple_issues_stored(self) -> None:
|
||||||
|
issues = [self._make_warning_issue(), self._make_blocker_issue()]
|
||||||
|
result = self._make_result(issues=issues)
|
||||||
|
assert len(result.issues) == 2
|
||||||
|
|
||||||
|
def test_has_blockers_is_computed(self) -> None:
|
||||||
|
# Verify has_blockers cannot be set directly (it's computed)
|
||||||
|
result = self._make_result(issues=[self._make_blocker_issue()])
|
||||||
|
assert result.has_blockers is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JiraTransition tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestJiraTransition:
|
||||||
|
"""Tests for JiraTransition model."""
|
||||||
|
|
||||||
|
def test_valid_transition(self) -> None:
|
||||||
|
transition = JiraTransition(id="11", name="To Do")
|
||||||
|
assert transition.id == "11"
|
||||||
|
assert transition.name == "To Do"
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
transition = JiraTransition(id="11", name="To Do")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
transition.id = "22" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_id_required(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
JiraTransition(name="To Do") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
def test_name_required(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
JiraTransition(id="11") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JiraIssue tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestJiraIssue:
|
||||||
|
"""Tests for JiraIssue model."""
|
||||||
|
|
||||||
|
def test_valid_issue(self) -> None:
|
||||||
|
issue = JiraIssue(key="ALLPOST-100", summary="Fix the bug", status="In Progress")
|
||||||
|
assert issue.key == "ALLPOST-100"
|
||||||
|
assert issue.summary == "Fix the bug"
|
||||||
|
assert issue.status == "In Progress"
|
||||||
|
|
||||||
|
def test_model_is_frozen(self) -> None:
|
||||||
|
issue = JiraIssue(key="ALLPOST-100", summary="Fix the bug", status="In Progress")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
issue.key = "ALLPOST-200" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_key_required(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
JiraIssue(summary="Fix the bug", status="In Progress") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
def test_summary_required(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
JiraIssue(key="ALLPOST-100", status="In Progress") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
def test_status_required(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
JiraIssue(key="ALLPOST-100", summary="Fix the bug") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
def test_various_statuses(self) -> None:
|
||||||
|
statuses = ["To Do", "In Progress", "Done", "Released"]
|
||||||
|
for status in statuses:
|
||||||
|
issue = JiraIssue(key="ALLPOST-1", summary="Test", status=status)
|
||||||
|
assert issue.status == status
|
||||||
148
tests/test_models_build.py
Normal file
148
tests/test_models_build.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""Tests for models/build.py — BuildStatus and ApprovalRecord.
|
||||||
|
|
||||||
|
Written FIRST (TDD RED phase).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dataclasses import FrozenInstanceError
|
||||||
|
|
||||||
|
from release_agent.models.build import ApprovalRecord, BuildStatus
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BuildStatus tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBuildStatus:
|
||||||
|
"""Tests for BuildStatus frozen dataclass."""
|
||||||
|
|
||||||
|
def test_can_be_created_with_all_fields(self) -> None:
|
||||||
|
bs = BuildStatus(
|
||||||
|
status="completed",
|
||||||
|
result="succeeded",
|
||||||
|
build_url="https://dev.azure.com/org/proj/_build/results?buildId=42",
|
||||||
|
)
|
||||||
|
assert bs.status == "completed"
|
||||||
|
assert bs.result == "succeeded"
|
||||||
|
assert bs.build_url == "https://dev.azure.com/org/proj/_build/results?buildId=42"
|
||||||
|
|
||||||
|
def test_result_can_be_none(self) -> None:
|
||||||
|
bs = BuildStatus(
|
||||||
|
status="inProgress",
|
||||||
|
result=None,
|
||||||
|
build_url="https://dev.azure.com/org/proj/_build/results?buildId=99",
|
||||||
|
)
|
||||||
|
assert bs.result is None
|
||||||
|
|
||||||
|
def test_build_url_can_be_none(self) -> None:
|
||||||
|
bs = BuildStatus(status="notStarted", result=None, build_url=None)
|
||||||
|
assert bs.build_url is None
|
||||||
|
|
||||||
|
def test_is_frozen_status(self) -> None:
|
||||||
|
bs = BuildStatus(status="completed", result="succeeded", build_url=None)
|
||||||
|
with pytest.raises((FrozenInstanceError, AttributeError)):
|
||||||
|
bs.status = "inProgress" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_is_frozen_result(self) -> None:
|
||||||
|
bs = BuildStatus(status="completed", result="succeeded", build_url=None)
|
||||||
|
with pytest.raises((FrozenInstanceError, AttributeError)):
|
||||||
|
bs.result = "failed" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_equality(self) -> None:
|
||||||
|
a = BuildStatus(status="completed", result="succeeded", build_url="http://x")
|
||||||
|
b = BuildStatus(status="completed", result="succeeded", build_url="http://x")
|
||||||
|
assert a == b
|
||||||
|
|
||||||
|
def test_inequality_on_status(self) -> None:
|
||||||
|
a = BuildStatus(status="completed", result="succeeded", build_url=None)
|
||||||
|
b = BuildStatus(status="inProgress", result="succeeded", build_url=None)
|
||||||
|
assert a != b
|
||||||
|
|
||||||
|
def test_inequality_on_result(self) -> None:
|
||||||
|
a = BuildStatus(status="completed", result="succeeded", build_url=None)
|
||||||
|
b = BuildStatus(status="completed", result="failed", build_url=None)
|
||||||
|
assert a != b
|
||||||
|
|
||||||
|
def test_repr_contains_status(self) -> None:
|
||||||
|
bs = BuildStatus(status="completed", result="succeeded", build_url=None)
|
||||||
|
assert "completed" in repr(bs)
|
||||||
|
|
||||||
|
def test_status_values_typical(self) -> None:
|
||||||
|
for s in ("notStarted", "inProgress", "completed", "cancelling"):
|
||||||
|
bs = BuildStatus(status=s, result=None, build_url=None)
|
||||||
|
assert bs.status == s
|
||||||
|
|
||||||
|
def test_result_values_typical(self) -> None:
|
||||||
|
for r in ("succeeded", "failed", "canceled", "partiallySucceeded"):
|
||||||
|
bs = BuildStatus(status="completed", result=r, build_url=None)
|
||||||
|
assert bs.result == r
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ApprovalRecord tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestApprovalRecord:
|
||||||
|
"""Tests for ApprovalRecord frozen dataclass."""
|
||||||
|
|
||||||
|
def test_can_be_created_with_all_fields(self) -> None:
|
||||||
|
ar = ApprovalRecord(
|
||||||
|
approval_id="approval-abc-123",
|
||||||
|
stage_name="Sandbox",
|
||||||
|
status="pending",
|
||||||
|
release_id=42,
|
||||||
|
)
|
||||||
|
assert ar.approval_id == "approval-abc-123"
|
||||||
|
assert ar.stage_name == "Sandbox"
|
||||||
|
assert ar.status == "pending"
|
||||||
|
assert ar.release_id == 42
|
||||||
|
|
||||||
|
def test_is_frozen_approval_id(self) -> None:
|
||||||
|
ar = ApprovalRecord(
|
||||||
|
approval_id="abc",
|
||||||
|
stage_name="Sandbox",
|
||||||
|
status="pending",
|
||||||
|
release_id=1,
|
||||||
|
)
|
||||||
|
with pytest.raises((FrozenInstanceError, AttributeError)):
|
||||||
|
ar.approval_id = "xyz" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_is_frozen_stage_name(self) -> None:
|
||||||
|
ar = ApprovalRecord(
|
||||||
|
approval_id="abc",
|
||||||
|
stage_name="Sandbox",
|
||||||
|
status="pending",
|
||||||
|
release_id=1,
|
||||||
|
)
|
||||||
|
with pytest.raises((FrozenInstanceError, AttributeError)):
|
||||||
|
ar.stage_name = "Production" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_equality(self) -> None:
|
||||||
|
a = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1)
|
||||||
|
b = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1)
|
||||||
|
assert a == b
|
||||||
|
|
||||||
|
def test_inequality_on_approval_id(self) -> None:
|
||||||
|
a = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1)
|
||||||
|
b = ApprovalRecord(approval_id="y", stage_name="S", status="pending", release_id=1)
|
||||||
|
assert a != b
|
||||||
|
|
||||||
|
def test_status_pending(self) -> None:
|
||||||
|
ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="pending", release_id=10)
|
||||||
|
assert ar.status == "pending"
|
||||||
|
|
||||||
|
def test_status_approved(self) -> None:
|
||||||
|
ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="approved", release_id=10)
|
||||||
|
assert ar.status == "approved"
|
||||||
|
|
||||||
|
def test_status_rejected(self) -> None:
|
||||||
|
ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="rejected", release_id=10)
|
||||||
|
assert ar.status == "rejected"
|
||||||
|
|
||||||
|
def test_repr_contains_stage_name(self) -> None:
|
||||||
|
ar = ApprovalRecord(approval_id="a", stage_name="Production", status="pending", release_id=5)
|
||||||
|
assert "Production" in repr(ar)
|
||||||
|
|
||||||
|
def test_release_id_is_int(self) -> None:
|
||||||
|
ar = ApprovalRecord(approval_id="a", stage_name="S", status="pending", release_id=999)
|
||||||
|
assert isinstance(ar.release_id, int)
|
||||||
241
tests/test_state.py
Normal file
241
tests/test_state.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
"""Tests for LangGraph state module. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from release_agent.state import ReleaseState, add_errors, add_messages
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ReleaseState tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReleaseState:
|
||||||
|
"""Tests for ReleaseState TypedDict."""
|
||||||
|
|
||||||
|
def test_empty_state_is_valid(self) -> None:
|
||||||
|
# total=False means all fields are optional
|
||||||
|
state: ReleaseState = {}
|
||||||
|
assert state == {}
|
||||||
|
|
||||||
|
def test_partial_state_with_repo(self) -> None:
|
||||||
|
state: ReleaseState = {"repo_name": "my-repo"}
|
||||||
|
assert state["repo_name"] == "my-repo"
|
||||||
|
|
||||||
|
def test_partial_state_with_messages(self) -> None:
|
||||||
|
state: ReleaseState = {"messages": ["Hello"]}
|
||||||
|
assert state["messages"] == ["Hello"]
|
||||||
|
|
||||||
|
def test_partial_state_with_errors(self) -> None:
|
||||||
|
state: ReleaseState = {"errors": ["Something went wrong"]}
|
||||||
|
assert state["errors"] == ["Something went wrong"]
|
||||||
|
|
||||||
|
def test_state_with_pr_id(self) -> None:
|
||||||
|
state: ReleaseState = {"pr_id": "PR-42"}
|
||||||
|
assert state["pr_id"] == "PR-42"
|
||||||
|
|
||||||
|
def test_state_with_ticket_id(self) -> None:
|
||||||
|
state: ReleaseState = {"ticket_id": "ALLPOST-100"}
|
||||||
|
assert state["ticket_id"] == "ALLPOST-100"
|
||||||
|
|
||||||
|
def test_state_with_version(self) -> None:
|
||||||
|
state: ReleaseState = {"version": "v1.0.1"}
|
||||||
|
assert state["version"] == "v1.0.1"
|
||||||
|
|
||||||
|
# Phase 3 new fields
|
||||||
|
def test_state_with_webhook_payload(self) -> None:
|
||||||
|
state: ReleaseState = {"webhook_payload": {"event_type": "git.pullrequest.merged"}}
|
||||||
|
assert state["webhook_payload"]["event_type"] == "git.pullrequest.merged"
|
||||||
|
|
||||||
|
def test_state_with_pr_info(self) -> None:
|
||||||
|
state: ReleaseState = {"pr_info": {"pr_id": "42", "repo_name": "my-repo"}}
|
||||||
|
assert state["pr_info"]["repo_name"] == "my-repo"
|
||||||
|
|
||||||
|
def test_state_with_pr_diff(self) -> None:
|
||||||
|
state: ReleaseState = {"pr_diff": "edit: src/main.py"}
|
||||||
|
assert state["pr_diff"] == "edit: src/main.py"
|
||||||
|
|
||||||
|
def test_state_with_last_merge_source_commit(self) -> None:
|
||||||
|
state: ReleaseState = {"last_merge_source_commit": "abc123"}
|
||||||
|
assert state["last_merge_source_commit"] == "abc123"
|
||||||
|
|
||||||
|
def test_state_with_ticket_summary(self) -> None:
|
||||||
|
state: ReleaseState = {"ticket_summary": "Fix login bug"}
|
||||||
|
assert state["ticket_summary"] == "Fix login bug"
|
||||||
|
|
||||||
|
def test_state_with_has_ticket(self) -> None:
|
||||||
|
state: ReleaseState = {"has_ticket": True}
|
||||||
|
assert state["has_ticket"] is True
|
||||||
|
|
||||||
|
def test_state_with_review_result(self) -> None:
|
||||||
|
state: ReleaseState = {"review_result": {"verdict": "approve", "summary": "LGTM"}}
|
||||||
|
assert state["review_result"]["verdict"] == "approve"
|
||||||
|
|
||||||
|
def test_state_with_review_approved(self) -> None:
|
||||||
|
state: ReleaseState = {"review_approved": True}
|
||||||
|
assert state["review_approved"] is True
|
||||||
|
|
||||||
|
def test_state_with_staging(self) -> None:
|
||||||
|
state: ReleaseState = {"staging": {"version": "v1.0.0", "tickets": []}}
|
||||||
|
assert state["staging"]["version"] == "v1.0.0"
|
||||||
|
|
||||||
|
def test_state_with_pr_already_merged(self) -> None:
|
||||||
|
state: ReleaseState = {"pr_already_merged": False}
|
||||||
|
assert state["pr_already_merged"] is False
|
||||||
|
|
||||||
|
def test_state_with_release_pr_id(self) -> None:
|
||||||
|
state: ReleaseState = {"release_pr_id": "123"}
|
||||||
|
assert state["release_pr_id"] == "123"
|
||||||
|
|
||||||
|
def test_state_with_release_pr_commit(self) -> None:
|
||||||
|
state: ReleaseState = {"release_pr_commit": "deadbeef"}
|
||||||
|
assert state["release_pr_commit"] == "deadbeef"
|
||||||
|
|
||||||
|
def test_state_with_pipelines(self) -> None:
|
||||||
|
state: ReleaseState = {"pipelines": [{"id": 1, "name": "build"}]}
|
||||||
|
assert len(state["pipelines"]) == 1
|
||||||
|
|
||||||
|
def test_state_with_triggered_builds(self) -> None:
|
||||||
|
state: ReleaseState = {"triggered_builds": [{"id": 99}]}
|
||||||
|
assert state["triggered_builds"][0]["id"] == 99
|
||||||
|
|
||||||
|
def test_state_with_pending_approvals(self) -> None:
|
||||||
|
state: ReleaseState = {"pending_approvals": [{"approval_id": "aaa"}]}
|
||||||
|
assert state["pending_approvals"][0]["approval_id"] == "aaa"
|
||||||
|
|
||||||
|
def test_state_with_continue_to_release(self) -> None:
|
||||||
|
state: ReleaseState = {"continue_to_release": True}
|
||||||
|
assert state["continue_to_release"] is True
|
||||||
|
|
||||||
|
# Phase 5: CI/CD and approval fields
|
||||||
|
def test_state_with_ci_build_id(self) -> None:
|
||||||
|
state: ReleaseState = {"ci_build_id": 12345}
|
||||||
|
assert state["ci_build_id"] == 12345
|
||||||
|
|
||||||
|
def test_state_with_ci_build_status(self) -> None:
|
||||||
|
state: ReleaseState = {"ci_build_status": "inProgress"}
|
||||||
|
assert state["ci_build_status"] == "inProgress"
|
||||||
|
|
||||||
|
def test_state_with_ci_build_result(self) -> None:
|
||||||
|
state: ReleaseState = {"ci_build_result": "succeeded"}
|
||||||
|
assert state["ci_build_result"] == "succeeded"
|
||||||
|
|
||||||
|
def test_state_with_ci_build_url(self) -> None:
|
||||||
|
state: ReleaseState = {"ci_build_url": "https://dev.azure.com/org/proj/_build/results?buildId=99"}
|
||||||
|
assert "buildId=99" in state["ci_build_url"]
|
||||||
|
|
||||||
|
def test_state_with_release_definition_id(self) -> None:
|
||||||
|
state: ReleaseState = {"release_definition_id": 7}
|
||||||
|
assert state["release_definition_id"] == 7
|
||||||
|
|
||||||
|
def test_state_with_release_id(self) -> None:
|
||||||
|
state: ReleaseState = {"release_id": 456}
|
||||||
|
assert state["release_id"] == 456
|
||||||
|
|
||||||
|
def test_state_with_current_stage(self) -> None:
|
||||||
|
state: ReleaseState = {"current_stage": "sandbox_pending"}
|
||||||
|
assert state["current_stage"] == "sandbox_pending"
|
||||||
|
|
||||||
|
def test_state_with_approval_message_ts(self) -> None:
|
||||||
|
state: ReleaseState = {"approval_message_ts": "1234567890.123456"}
|
||||||
|
assert state["approval_message_ts"] == "1234567890.123456"
|
||||||
|
|
||||||
|
def test_state_with_slack_message_ts(self) -> None:
|
||||||
|
state: ReleaseState = {"slack_message_ts": "9876543210.000001"}
|
||||||
|
assert state["slack_message_ts"] == "9876543210.000001"
|
||||||
|
|
||||||
|
def test_state_json_serializable_empty(self) -> None:
|
||||||
|
state: ReleaseState = {}
|
||||||
|
serialized = json.dumps(state)
|
||||||
|
assert json.loads(serialized) == {}
|
||||||
|
|
||||||
|
def test_state_json_serializable_with_strings(self) -> None:
|
||||||
|
state: ReleaseState = {
|
||||||
|
"repo_name": "my-repo",
|
||||||
|
"pr_id": "PR-1",
|
||||||
|
"ticket_id": "ALLPOST-1",
|
||||||
|
"version": "v1.0.0",
|
||||||
|
}
|
||||||
|
serialized = json.dumps(state)
|
||||||
|
loaded = json.loads(serialized)
|
||||||
|
assert loaded["repo_name"] == "my-repo"
|
||||||
|
assert loaded["pr_id"] == "PR-1"
|
||||||
|
|
||||||
|
def test_state_json_serializable_with_lists(self) -> None:
|
||||||
|
state: ReleaseState = {
|
||||||
|
"messages": ["msg1", "msg2"],
|
||||||
|
"errors": ["err1"],
|
||||||
|
}
|
||||||
|
serialized = json.dumps(state)
|
||||||
|
loaded = json.loads(serialized)
|
||||||
|
assert loaded["messages"] == ["msg1", "msg2"]
|
||||||
|
assert loaded["errors"] == ["err1"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reducer tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAddMessages:
|
||||||
|
"""Tests for add_messages reducer."""
|
||||||
|
|
||||||
|
def test_accumulates_to_empty(self) -> None:
|
||||||
|
result = add_messages([], ["Hello"])
|
||||||
|
assert result == ["Hello"]
|
||||||
|
|
||||||
|
def test_accumulates_to_existing(self) -> None:
|
||||||
|
result = add_messages(["Hello"], ["World"])
|
||||||
|
assert result == ["Hello", "World"]
|
||||||
|
|
||||||
|
def test_accumulates_multiple(self) -> None:
|
||||||
|
result = add_messages(["A", "B"], ["C", "D"])
|
||||||
|
assert result == ["A", "B", "C", "D"]
|
||||||
|
|
||||||
|
def test_existing_unchanged(self) -> None:
|
||||||
|
existing = ["Hello"]
|
||||||
|
add_messages(existing, ["World"])
|
||||||
|
# Original should not be mutated
|
||||||
|
assert existing == ["Hello"]
|
||||||
|
|
||||||
|
def test_empty_new_messages(self) -> None:
|
||||||
|
result = add_messages(["Hello"], [])
|
||||||
|
assert result == ["Hello"]
|
||||||
|
|
||||||
|
def test_both_empty(self) -> None:
|
||||||
|
result = add_messages([], [])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_returns_new_list(self) -> None:
|
||||||
|
existing = ["Hello"]
|
||||||
|
new_msgs = ["World"]
|
||||||
|
result = add_messages(existing, new_msgs)
|
||||||
|
assert result is not existing
|
||||||
|
assert result is not new_msgs
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddErrors:
|
||||||
|
"""Tests for add_errors reducer."""
|
||||||
|
|
||||||
|
def test_accumulates_to_empty(self) -> None:
|
||||||
|
result = add_errors([], ["Error occurred"])
|
||||||
|
assert result == ["Error occurred"]
|
||||||
|
|
||||||
|
def test_accumulates_to_existing(self) -> None:
|
||||||
|
result = add_errors(["First error"], ["Second error"])
|
||||||
|
assert result == ["First error", "Second error"]
|
||||||
|
|
||||||
|
def test_existing_unchanged(self) -> None:
|
||||||
|
existing = ["First error"]
|
||||||
|
add_errors(existing, ["Second error"])
|
||||||
|
assert existing == ["First error"]
|
||||||
|
|
||||||
|
def test_empty_new_errors(self) -> None:
|
||||||
|
result = add_errors(["Existing"], [])
|
||||||
|
assert result == ["Existing"]
|
||||||
|
|
||||||
|
def test_both_empty(self) -> None:
|
||||||
|
result = add_errors([], [])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_returns_new_list(self) -> None:
|
||||||
|
existing = ["Error"]
|
||||||
|
result = add_errors(existing, ["New error"])
|
||||||
|
assert result is not existing
|
||||||
124
tests/test_versioning.py
Normal file
124
tests/test_versioning.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""Tests for versioning module. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.versioning import (
|
||||||
|
calculate_next_version,
|
||||||
|
format_version,
|
||||||
|
parse_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseVersion:
|
||||||
|
"""Tests for parse_version function."""
|
||||||
|
|
||||||
|
def test_parse_with_v_prefix(self) -> None:
|
||||||
|
assert parse_version("v1.2.3") == (1, 2, 3)
|
||||||
|
|
||||||
|
def test_parse_without_v_prefix(self) -> None:
|
||||||
|
assert parse_version("1.2.3") == (1, 2, 3)
|
||||||
|
|
||||||
|
def test_parse_zeros(self) -> None:
|
||||||
|
assert parse_version("v0.0.0") == (0, 0, 0)
|
||||||
|
|
||||||
|
def test_parse_large_numbers(self) -> None:
|
||||||
|
assert parse_version("v10.20.300") == (10, 20, 300)
|
||||||
|
|
||||||
|
def test_parse_returns_tuple_of_ints(self) -> None:
|
||||||
|
result = parse_version("v1.2.3")
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert all(isinstance(x, int) for x in result)
|
||||||
|
|
||||||
|
def test_parse_invalid_raises_value_error(self) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_version("invalid")
|
||||||
|
|
||||||
|
def test_parse_partial_version_raises_value_error(self) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_version("v1.2")
|
||||||
|
|
||||||
|
def test_parse_non_numeric_raises_value_error(self) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_version("va.b.c")
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatVersion:
|
||||||
|
"""Tests for format_version function."""
|
||||||
|
|
||||||
|
def test_format_basic(self) -> None:
|
||||||
|
assert format_version(1, 0, 3) == "v1.0.3"
|
||||||
|
|
||||||
|
def test_format_zeros(self) -> None:
|
||||||
|
assert format_version(0, 0, 0) == "v0.0.0"
|
||||||
|
|
||||||
|
def test_format_large_numbers(self) -> None:
|
||||||
|
assert format_version(10, 20, 300) == "v10.20.300"
|
||||||
|
|
||||||
|
def test_format_returns_string(self) -> None:
|
||||||
|
result = format_version(1, 2, 3)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
def test_format_starts_with_v(self) -> None:
|
||||||
|
result = format_version(1, 2, 3)
|
||||||
|
assert result.startswith("v")
|
||||||
|
|
||||||
|
|
||||||
|
class TestCalculateNextVersion:
|
||||||
|
"""Tests for calculate_next_version function."""
|
||||||
|
|
||||||
|
def test_empty_list_returns_v1_0_0(self) -> None:
|
||||||
|
assert calculate_next_version("my-repo", []) == "v1.0.0"
|
||||||
|
|
||||||
|
def test_single_version_increments_patch(self) -> None:
|
||||||
|
assert calculate_next_version("my-repo", ["v1.0.0"]) == "v1.0.1"
|
||||||
|
|
||||||
|
def test_multiple_versions_uses_highest(self) -> None:
|
||||||
|
assert calculate_next_version("my-repo", ["v1.0.3", "v1.0.1"]) == "v1.0.4"
|
||||||
|
|
||||||
|
def test_different_major_versions(self) -> None:
|
||||||
|
assert calculate_next_version("my-repo", ["v2.1.0", "v1.9.9"]) == "v2.1.1"
|
||||||
|
|
||||||
|
def test_skips_malformed_versions(self) -> None:
|
||||||
|
assert calculate_next_version("my-repo", ["invalid", "v1.0.0"]) == "v1.0.1"
|
||||||
|
|
||||||
|
def test_all_malformed_versions_returns_v1_0_0(self) -> None:
|
||||||
|
assert calculate_next_version("my-repo", ["invalid", "bad", "nope"]) == "v1.0.0"
|
||||||
|
|
||||||
|
def test_repo_name_does_not_affect_result(self) -> None:
|
||||||
|
result_a = calculate_next_version("repo-a", ["v1.0.0"])
|
||||||
|
result_b = calculate_next_version("repo-b", ["v1.0.0"])
|
||||||
|
assert result_a == result_b
|
||||||
|
|
||||||
|
def test_versions_out_of_order(self) -> None:
|
||||||
|
assert calculate_next_version("my-repo", ["v1.0.1", "v1.0.3", "v1.0.2"]) == "v1.0.4"
|
||||||
|
|
||||||
|
def test_patch_overflow_does_not_occur(self) -> None:
|
||||||
|
# Just increments patch - no overflow logic required
|
||||||
|
result = calculate_next_version("my-repo", ["v1.0.99"])
|
||||||
|
assert result == "v1.0.100"
|
||||||
|
|
||||||
|
def test_versions_without_v_prefix_skipped(self) -> None:
|
||||||
|
# Versions without 'v' prefix are treated as malformed per spec
|
||||||
|
result = calculate_next_version("my-repo", ["1.0.0", "v2.0.0"])
|
||||||
|
assert result == "v2.0.1"
|
||||||
|
|
||||||
|
def test_result_format_starts_with_v(self) -> None:
|
||||||
|
result = calculate_next_version("my-repo", ["v1.0.0"])
|
||||||
|
assert result.startswith("v")
|
||||||
|
|
||||||
|
def test_result_has_three_parts(self) -> None:
|
||||||
|
result = calculate_next_version("my-repo", ["v1.0.0"])
|
||||||
|
parts = result[1:].split(".")
|
||||||
|
assert len(parts) == 3
|
||||||
|
assert all(p.isdigit() for p in parts)
|
||||||
|
|
||||||
|
def test_v_prefix_with_nonnumeric_parts_skipped(self) -> None:
|
||||||
|
# Starts with 'v' but is malformed - should be skipped gracefully
|
||||||
|
result = calculate_next_version("my-repo", ["va.b.c", "v1.0.0"])
|
||||||
|
assert result == "v1.0.1"
|
||||||
|
|
||||||
|
def test_v_prefix_partial_version_skipped(self) -> None:
|
||||||
|
# Starts with 'v' but only has two parts - should be skipped
|
||||||
|
result = calculate_next_version("my-repo", ["v1.0", "v2.0.0"])
|
||||||
|
assert result == "v2.0.1"
|
||||||
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
9
tests/tools/fixtures/azdo_approve_release.json
Normal file
9
tests/tools/fixtures/azdo_approve_release.json
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"id": "approval-uuid-123",
|
||||||
|
"status": "approved",
|
||||||
|
"approver": {
|
||||||
|
"id": "user-uuid-456",
|
||||||
|
"displayName": "Release Bot"
|
||||||
|
},
|
||||||
|
"comments": "Approved via release agent"
|
||||||
|
}
|
||||||
9
tests/tools/fixtures/azdo_build_status.json
Normal file
9
tests/tools/fixtures/azdo_build_status.json
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"id": 1001,
|
||||||
|
"buildNumber": "20240115.1",
|
||||||
|
"status": "completed",
|
||||||
|
"result": "succeeded",
|
||||||
|
"queueTime": "2024-01-15T10:00:00Z",
|
||||||
|
"startTime": "2024-01-15T10:01:00Z",
|
||||||
|
"finishTime": "2024-01-15T10:10:00Z"
|
||||||
|
}
|
||||||
8
tests/tools/fixtures/azdo_create_pr.json
Normal file
8
tests/tools/fixtures/azdo_create_pr.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"pullRequestId": 99,
|
||||||
|
"title": "Release v1.2.0",
|
||||||
|
"status": "active",
|
||||||
|
"sourceRefName": "refs/heads/release/v1.2.0",
|
||||||
|
"targetRefName": "refs/heads/main",
|
||||||
|
"url": "https://dev.azure.com/my-org/my-project/_apis/git/repositories/my-repo/pullRequests/99"
|
||||||
|
}
|
||||||
8
tests/tools/fixtures/azdo_merge_pr.json
Normal file
8
tests/tools/fixtures/azdo_merge_pr.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"pullRequestId": 42,
|
||||||
|
"status": "completed",
|
||||||
|
"title": "Fix the auth bug",
|
||||||
|
"completionOptions": {
|
||||||
|
"mergeStrategy": "squash"
|
||||||
|
}
|
||||||
|
}
|
||||||
15
tests/tools/fixtures/azdo_pipelines.json
Normal file
15
tests/tools/fixtures/azdo_pipelines.json
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": 10,
|
||||||
|
"name": "Release Pipeline",
|
||||||
|
"folder": "\\"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20,
|
||||||
|
"name": "Build Pipeline",
|
||||||
|
"folder": "\\"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"count": 2
|
||||||
|
}
|
||||||
16
tests/tools/fixtures/azdo_pr.json
Normal file
16
tests/tools/fixtures/azdo_pr.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"pullRequestId": 42,
|
||||||
|
"title": "Fix the auth bug",
|
||||||
|
"status": "active",
|
||||||
|
"sourceRefName": "refs/heads/bug/ALLPOST-999_fix-auth",
|
||||||
|
"targetRefName": "refs/heads/main",
|
||||||
|
"url": "https://dev.azure.com/my-org/my-project/_apis/git/repositories/my-repo/pullRequests/42",
|
||||||
|
"repository": {
|
||||||
|
"id": "repo-uuid-123",
|
||||||
|
"name": "my-repo",
|
||||||
|
"remoteUrl": "https://dev.azure.com/my-org/my-project/_git/my-repo"
|
||||||
|
},
|
||||||
|
"lastMergeSourceCommit": {
|
||||||
|
"commitId": "abc123def456"
|
||||||
|
}
|
||||||
|
}
|
||||||
11
tests/tools/fixtures/azdo_pr_diff.txt
Normal file
11
tests/tools/fixtures/azdo_pr_diff.txt
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
diff --git a/src/auth.py b/src/auth.py
|
||||||
|
index 1234567..abcdefg 100644
|
||||||
|
--- a/src/auth.py
|
||||||
|
+++ b/src/auth.py
|
||||||
|
@@ -10,6 +10,10 @@ class AuthService:
|
||||||
|
def authenticate(self, token: str) -> bool:
|
||||||
|
- return token == "hardcoded"
|
||||||
|
+ return self._validate_token(token)
|
||||||
|
+
|
||||||
|
+ def _validate_token(self, token: str) -> bool:
|
||||||
|
+ return len(token) > 0 and token.startswith("Bearer ")
|
||||||
11
tests/tools/fixtures/azdo_trigger_pipeline.json
Normal file
11
tests/tools/fixtures/azdo_trigger_pipeline.json
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"id": 1001,
|
||||||
|
"buildNumber": "20240115.1",
|
||||||
|
"status": "notStarted",
|
||||||
|
"queueTime": "2024-01-15T10:00:00Z",
|
||||||
|
"definition": {
|
||||||
|
"id": 10,
|
||||||
|
"name": "Release Pipeline"
|
||||||
|
},
|
||||||
|
"sourceBranch": "refs/heads/main"
|
||||||
|
}
|
||||||
10
tests/tools/fixtures/jira_issue.json
Normal file
10
tests/tools/fixtures/jira_issue.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"id": "12345",
|
||||||
|
"key": "ALLPOST-100",
|
||||||
|
"fields": {
|
||||||
|
"summary": "Fix the authentication bug",
|
||||||
|
"status": {
|
||||||
|
"name": "In Progress"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
20
tests/tools/fixtures/jira_transitions.json
Normal file
20
tests/tools/fixtures/jira_transitions.json
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"transitions": [
|
||||||
|
{
|
||||||
|
"id": "11",
|
||||||
|
"name": "To Do"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "21",
|
||||||
|
"name": "In Progress"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "31",
|
||||||
|
"name": "Done"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "41",
|
||||||
|
"name": "Released"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
819
tests/tools/test_azdo.py
Normal file
819
tests/tools/test_azdo.py
Normal file
@@ -0,0 +1,819 @@
|
|||||||
|
"""Tests for AzDoClient. Written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.exceptions import AuthenticationError, NotFoundError, ServiceError
|
||||||
|
from release_agent.models.build import ApprovalRecord, BuildStatus
|
||||||
|
from release_agent.models.pipeline import PipelineInfo
|
||||||
|
from release_agent.models.pr import PRInfo
|
||||||
|
from release_agent.tools.azdo import AzDoClient
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixture helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
FIXTURES = Path(__file__).parent / "fixtures"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_json(name: str) -> dict:
|
||||||
|
return json.loads((FIXTURES / name).read_text())
|
||||||
|
|
||||||
|
|
||||||
|
def _load_text(name: str) -> str:
|
||||||
|
return (FIXTURES / name).read_text()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_transport(routes: dict[tuple[str, str], tuple[int, bytes | str]]) -> httpx.MockTransport:
|
||||||
|
"""Build a MockTransport that dispatches based on (method, url_substring)."""
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
url = str(request.url)
|
||||||
|
method = request.method
|
||||||
|
for (m, url_fragment), (status, body) in routes.items():
|
||||||
|
if m == method and url_fragment in url:
|
||||||
|
content = body if isinstance(body, bytes) else body.encode()
|
||||||
|
return httpx.Response(status_code=status, content=content)
|
||||||
|
return httpx.Response(status_code=404, content=b'{"message": "Not found"}')
|
||||||
|
|
||||||
|
return httpx.MockTransport(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_client(routes: dict) -> AzDoClient:
|
||||||
|
"""Create an AzDoClient with mocked HTTP transport."""
|
||||||
|
transport = _make_transport(routes)
|
||||||
|
http_client = httpx.AsyncClient(transport=transport)
|
||||||
|
vsrm_client = httpx.AsyncClient(transport=transport)
|
||||||
|
return AzDoClient(
|
||||||
|
base_url="https://dev.azure.com/my-org/my-project/_apis",
|
||||||
|
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
|
||||||
|
pat="test-pat",
|
||||||
|
http_client=http_client,
|
||||||
|
vsrm_http_client=vsrm_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AzDoClient construction tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAzDoClientConstruction:
|
||||||
|
"""Tests for AzDoClient initialization."""
|
||||||
|
|
||||||
|
def test_can_be_instantiated_with_injected_clients(self) -> None:
|
||||||
|
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
|
||||||
|
http_client = httpx.AsyncClient(transport=transport)
|
||||||
|
vsrm_client = httpx.AsyncClient(transport=transport)
|
||||||
|
client = AzDoClient(
|
||||||
|
base_url="https://dev.azure.com/org/proj/_apis",
|
||||||
|
vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis",
|
||||||
|
pat="my-pat",
|
||||||
|
http_client=http_client,
|
||||||
|
vsrm_http_client=vsrm_client,
|
||||||
|
)
|
||||||
|
assert client is not None
|
||||||
|
|
||||||
|
async def test_context_manager_closes_clients(self) -> None:
|
||||||
|
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
|
||||||
|
http_client = httpx.AsyncClient(transport=transport)
|
||||||
|
vsrm_client = httpx.AsyncClient(transport=transport)
|
||||||
|
async with AzDoClient(
|
||||||
|
base_url="https://dev.azure.com/org/proj/_apis",
|
||||||
|
vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis",
|
||||||
|
pat="my-pat",
|
||||||
|
http_client=http_client,
|
||||||
|
vsrm_http_client=vsrm_client,
|
||||||
|
) as client:
|
||||||
|
assert client is not None
|
||||||
|
# After context manager exits, clients should be closed
|
||||||
|
assert http_client.is_closed
|
||||||
|
assert vsrm_client.is_closed
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_pr tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetPr:
|
||||||
|
"""Tests for AzDoClient.get_pr."""
|
||||||
|
|
||||||
|
async def test_returns_pr_info(self) -> None:
|
||||||
|
pr_data = _load_json("azdo_pr.json")
|
||||||
|
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_pr(42)
|
||||||
|
|
||||||
|
assert isinstance(result, PRInfo)
|
||||||
|
assert result.pr_id == "42"
|
||||||
|
|
||||||
|
async def test_pr_title_extracted(self) -> None:
|
||||||
|
pr_data = _load_json("azdo_pr.json")
|
||||||
|
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_pr(42)
|
||||||
|
assert result.pr_title == "Fix the auth bug"
|
||||||
|
|
||||||
|
async def test_pr_branch_extracted(self) -> None:
|
||||||
|
pr_data = _load_json("azdo_pr.json")
|
||||||
|
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_pr(42)
|
||||||
|
assert "ALLPOST-999" in result.branch or "bug" in result.branch
|
||||||
|
|
||||||
|
async def test_pr_status_extracted(self) -> None:
|
||||||
|
pr_data = _load_json("azdo_pr.json")
|
||||||
|
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_pr(42)
|
||||||
|
assert result.pr_status == "active"
|
||||||
|
|
||||||
|
async def test_404_raises_not_found(self) -> None:
|
||||||
|
routes = {("GET", "pullRequests/999"): (404, b'{"message": "PR not found"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.get_pr(999)
|
||||||
|
|
||||||
|
async def test_401_raises_authentication_error(self) -> None:
|
||||||
|
routes = {("GET", "pullRequests/42"): (401, b'{"message": "Unauthorized"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
await client.get_pr(42)
|
||||||
|
|
||||||
|
async def test_500_raises_service_error(self) -> None:
|
||||||
|
routes = {("GET", "pullRequests/42"): (500, b'{"message": "Internal error"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
await client.get_pr(42)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_pr_diff tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetPrDiff:
|
||||||
|
"""Tests for AzDoClient.get_pr_diff."""
|
||||||
|
|
||||||
|
async def test_returns_diff_string(self) -> None:
|
||||||
|
pr_data = _load_json("azdo_pr.json")
|
||||||
|
routes = {
|
||||||
|
("GET", "pullRequests/42"): (200, json.dumps(pr_data)),
|
||||||
|
("GET", "diffs"): (200, json.dumps({
|
||||||
|
"changes": [
|
||||||
|
{
|
||||||
|
"item": {"path": "/src/auth.py"},
|
||||||
|
"changeType": "edit",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_pr_diff(42)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
async def test_diff_includes_file_paths(self) -> None:
|
||||||
|
pr_data = _load_json("azdo_pr.json")
|
||||||
|
diff_data = {
|
||||||
|
"changes": [
|
||||||
|
{"item": {"path": "/src/auth.py"}, "changeType": "edit"},
|
||||||
|
{"item": {"path": "/src/util.py"}, "changeType": "add"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
url = str(request.url)
|
||||||
|
if "diffs" in url:
|
||||||
|
return httpx.Response(200, content=json.dumps(diff_data).encode())
|
||||||
|
if "pullRequests/42" in url:
|
||||||
|
return httpx.Response(200, content=json.dumps(pr_data).encode())
|
||||||
|
return httpx.Response(404, content=b"{}")
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
http_client = httpx.AsyncClient(transport=transport)
|
||||||
|
vsrm_client = httpx.AsyncClient(transport=transport)
|
||||||
|
client = AzDoClient(
|
||||||
|
base_url="https://dev.azure.com/my-org/my-project/_apis",
|
||||||
|
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
|
||||||
|
pat="test-pat",
|
||||||
|
http_client=http_client,
|
||||||
|
vsrm_http_client=vsrm_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await client.get_pr_diff(42)
|
||||||
|
assert "/src/auth.py" in result
|
||||||
|
assert "/src/util.py" in result
|
||||||
|
|
||||||
|
async def test_empty_changes_returns_empty_string(self) -> None:
|
||||||
|
pr_data = _load_json("azdo_pr.json")
|
||||||
|
diff_data: dict = {"changes": []}
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
url = str(request.url)
|
||||||
|
if "diffs" in url:
|
||||||
|
return httpx.Response(200, content=json.dumps(diff_data).encode())
|
||||||
|
if "pullRequests/42" in url:
|
||||||
|
return httpx.Response(200, content=json.dumps(pr_data).encode())
|
||||||
|
return httpx.Response(404, content=b"{}")
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
http_client = httpx.AsyncClient(transport=transport)
|
||||||
|
vsrm_client = httpx.AsyncClient(transport=transport)
|
||||||
|
client = AzDoClient(
|
||||||
|
base_url="https://dev.azure.com/my-org/my-project/_apis",
|
||||||
|
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
|
||||||
|
pat="test-pat",
|
||||||
|
http_client=http_client,
|
||||||
|
vsrm_http_client=vsrm_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await client.get_pr_diff(42)
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
async def test_404_raises_not_found(self) -> None:
|
||||||
|
routes = {("GET", "pullRequests/999"): (404, b'{"message": "PR not found"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.get_pr_diff(999)
|
||||||
|
|
||||||
|
async def test_pr_without_url_field_uses_remote_url(self) -> None:
|
||||||
|
"""When the API response omits the 'url' field, fallback URL is built."""
|
||||||
|
pr_data = {
|
||||||
|
"pullRequestId": 42,
|
||||||
|
"title": "Fix bug",
|
||||||
|
"status": "active",
|
||||||
|
"sourceRefName": "refs/heads/fix/ALLPOST-1_fix",
|
||||||
|
"repository": {
|
||||||
|
"id": "repo-uuid",
|
||||||
|
"name": "my-repo",
|
||||||
|
"remoteUrl": "https://dev.azure.com/org/proj/_git/my-repo",
|
||||||
|
},
|
||||||
|
# NOTE: 'url' field is intentionally omitted
|
||||||
|
}
|
||||||
|
routes = {
|
||||||
|
("GET", "pullRequests/42"): (200, json.dumps(pr_data)),
|
||||||
|
("GET", "diffs"): (200, json.dumps({"changes": []})),
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_pr(42)
|
||||||
|
assert "42" in str(result.pr_url)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# merge_pr tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMergePr:
|
||||||
|
"""Tests for AzDoClient.merge_pr."""
|
||||||
|
|
||||||
|
async def test_returns_true_on_success(self) -> None:
|
||||||
|
merge_data = _load_json("azdo_merge_pr.json")
|
||||||
|
routes = {("PATCH", "pullRequests/42"): (200, json.dumps(merge_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.merge_pr(pr_id=42, last_merge_source_commit="abc123def456")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
async def test_404_raises_not_found(self) -> None:
|
||||||
|
routes = {("PATCH", "pullRequests/999"): (404, b'{"message": "Not found"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.merge_pr(pr_id=999, last_merge_source_commit="abc123")
|
||||||
|
|
||||||
|
async def test_409_raises_service_error(self) -> None:
|
||||||
|
routes = {("PATCH", "pullRequests/42"): (409, b'{"message": "Conflict"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
await client.merge_pr(pr_id=42, last_merge_source_commit="abc123")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# create_pr tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCreatePr:
|
||||||
|
"""Tests for AzDoClient.create_pr."""
|
||||||
|
|
||||||
|
async def test_returns_dict_with_pr_id(self) -> None:
|
||||||
|
create_data = _load_json("azdo_create_pr.json")
|
||||||
|
routes = {("POST", "pullRequests"): (201, json.dumps(create_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.create_pr(
|
||||||
|
repo="my-repo",
|
||||||
|
source="refs/heads/release/v1.2.0",
|
||||||
|
target="refs/heads/main",
|
||||||
|
title="Release v1.2.0",
|
||||||
|
description="Release notes",
|
||||||
|
)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["pullRequestId"] == 99
|
||||||
|
|
||||||
|
async def test_400_raises_service_error(self) -> None:
|
||||||
|
routes = {("POST", "pullRequests"): (400, b'{"message": "Bad request"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
await client.create_pr(
|
||||||
|
repo="my-repo",
|
||||||
|
source="refs/heads/release/v1.2.0",
|
||||||
|
target="refs/heads/main",
|
||||||
|
title="Release",
|
||||||
|
description="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# list_build_pipelines tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestListBuildPipelines:
|
||||||
|
"""Tests for AzDoClient.list_build_pipelines."""
|
||||||
|
|
||||||
|
async def test_returns_list_of_pipeline_info(self) -> None:
|
||||||
|
pipeline_data = _load_json("azdo_pipelines.json")
|
||||||
|
routes = {("GET", "pipelines"): (200, json.dumps(pipeline_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.list_build_pipelines(repo="my-repo")
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(isinstance(p, PipelineInfo) for p in result)
|
||||||
|
|
||||||
|
async def test_pipeline_ids_extracted(self) -> None:
|
||||||
|
pipeline_data = _load_json("azdo_pipelines.json")
|
||||||
|
routes = {("GET", "pipelines"): (200, json.dumps(pipeline_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.list_build_pipelines(repo="my-repo")
|
||||||
|
ids = [p.id for p in result]
|
||||||
|
assert 10 in ids
|
||||||
|
assert 20 in ids
|
||||||
|
|
||||||
|
async def test_empty_list_on_no_pipelines(self) -> None:
|
||||||
|
routes = {("GET", "pipelines"): (200, json.dumps({"value": [], "count": 0}))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.list_build_pipelines(repo="my-repo")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# trigger_pipeline tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTriggerPipeline:
|
||||||
|
"""Tests for AzDoClient.trigger_pipeline."""
|
||||||
|
|
||||||
|
async def test_returns_dict_with_build_id(self) -> None:
|
||||||
|
trigger_data = _load_json("azdo_trigger_pipeline.json")
|
||||||
|
routes = {("POST", "pipelines/10/runs"): (200, json.dumps(trigger_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.trigger_pipeline(pipeline_id=10, branch="refs/heads/main")
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["id"] == 1001
|
||||||
|
|
||||||
|
async def test_404_raises_not_found(self) -> None:
|
||||||
|
routes = {("POST", "pipelines/999/runs"): (404, b'{"message": "Pipeline not found"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.trigger_pipeline(pipeline_id=999, branch="refs/heads/main")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_build_status tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetBuildStatus:
|
||||||
|
"""Tests for AzDoClient.get_build_status."""
|
||||||
|
|
||||||
|
async def test_returns_build_status_object(self) -> None:
|
||||||
|
build_data = _load_json("azdo_build_status.json")
|
||||||
|
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_build_status(build_id=1001)
|
||||||
|
assert isinstance(result, BuildStatus)
|
||||||
|
|
||||||
|
async def test_status_field_populated(self) -> None:
|
||||||
|
build_data = _load_json("azdo_build_status.json")
|
||||||
|
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_build_status(build_id=1001)
|
||||||
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
async def test_result_field_populated(self) -> None:
|
||||||
|
build_data = _load_json("azdo_build_status.json")
|
||||||
|
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_build_status(build_id=1001)
|
||||||
|
assert result.result == "succeeded"
|
||||||
|
|
||||||
|
async def test_build_url_present(self) -> None:
|
||||||
|
build_data = _load_json("azdo_build_status.json")
|
||||||
|
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_build_status(build_id=1001)
|
||||||
|
# build_url may be None if not in fixture, but field must exist
|
||||||
|
assert hasattr(result, "build_url")
|
||||||
|
|
||||||
|
async def test_result_none_when_not_completed(self) -> None:
|
||||||
|
build_data = {"id": 99, "status": "inProgress", "buildNumber": "20240101.1"}
|
||||||
|
routes = {("GET", "build/builds/99"): (200, json.dumps(build_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_build_status(build_id=99)
|
||||||
|
assert result.status == "inProgress"
|
||||||
|
assert result.result is None
|
||||||
|
|
||||||
|
async def test_404_raises_not_found(self) -> None:
|
||||||
|
routes = {("GET", "build/builds/9999"): (404, b'{"message": "Build not found"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.get_build_status(build_id=9999)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_release_approvals tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetReleaseApprovals:
|
||||||
|
"""Tests for AzDoClient.get_release_approvals."""
|
||||||
|
|
||||||
|
async def test_returns_list_of_approval_records(self) -> None:
|
||||||
|
approvals_data = {
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": 101,
|
||||||
|
"status": "pending",
|
||||||
|
"releaseEnvironment": {"name": "Sandbox", "release": {"id": 55}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 102,
|
||||||
|
"status": "approved",
|
||||||
|
"releaseEnvironment": {"name": "Production", "release": {"id": 55}},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"count": 2,
|
||||||
|
}
|
||||||
|
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_release_approvals(release_id=55)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(isinstance(a, ApprovalRecord) for a in result)
|
||||||
|
|
||||||
|
async def test_approval_id_populated(self) -> None:
|
||||||
|
approvals_data = {
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": 201,
|
||||||
|
"status": "pending",
|
||||||
|
"releaseEnvironment": {"name": "Sandbox", "release": {"id": 10}},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"count": 1,
|
||||||
|
}
|
||||||
|
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_release_approvals(release_id=10)
|
||||||
|
assert result[0].approval_id == "201"
|
||||||
|
|
||||||
|
async def test_stage_name_populated(self) -> None:
|
||||||
|
approvals_data = {
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": 300,
|
||||||
|
"status": "pending",
|
||||||
|
"releaseEnvironment": {"name": "Production", "release": {"id": 20}},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"count": 1,
|
||||||
|
}
|
||||||
|
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_release_approvals(release_id=20)
|
||||||
|
assert result[0].stage_name == "Production"
|
||||||
|
|
||||||
|
async def test_release_id_populated(self) -> None:
|
||||||
|
approvals_data = {
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": 400,
|
||||||
|
"status": "pending",
|
||||||
|
"releaseEnvironment": {"name": "Stage", "release": {"id": 99}},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"count": 1,
|
||||||
|
}
|
||||||
|
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_release_approvals(release_id=99)
|
||||||
|
assert result[0].release_id == 99
|
||||||
|
|
||||||
|
async def test_empty_list_when_no_approvals(self) -> None:
|
||||||
|
approvals_data = {"value": [], "count": 0}
|
||||||
|
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_release_approvals(release_id=77)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
async def test_filters_by_release_id_in_query(self) -> None:
|
||||||
|
captured_urls: list[str] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
captured_urls.append(str(request.url))
|
||||||
|
return httpx.Response(200, content=b'{"value": [], "count": 0}')
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
http_client = httpx.AsyncClient(transport=transport)
|
||||||
|
vsrm_client = httpx.AsyncClient(transport=transport)
|
||||||
|
client = AzDoClient(
|
||||||
|
base_url="https://dev.azure.com/my-org/my-project/_apis",
|
||||||
|
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
|
||||||
|
pat="test-pat",
|
||||||
|
http_client=http_client,
|
||||||
|
vsrm_http_client=vsrm_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.get_release_approvals(release_id=42)
|
||||||
|
assert any("approvals" in url for url in captured_urls)
|
||||||
|
|
||||||
|
async def test_404_raises_not_found(self) -> None:
|
||||||
|
routes = {("GET", "release/approvals"): (404, b'{"message": "Not found"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.get_release_approvals(release_id=999)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_latest_release tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetLatestRelease:
|
||||||
|
"""Tests for AzDoClient.get_latest_release."""
|
||||||
|
|
||||||
|
async def test_returns_dict(self) -> None:
|
||||||
|
release_data = {
|
||||||
|
"value": [{"id": 55, "name": "Release-55", "status": "active"}],
|
||||||
|
"count": 1,
|
||||||
|
}
|
||||||
|
routes = {("GET", "release/releases"): (200, json.dumps(release_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_latest_release(definition_id=7)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["id"] == 55
|
||||||
|
|
||||||
|
async def test_returns_empty_dict_when_no_releases(self) -> None:
|
||||||
|
release_data = {"value": [], "count": 0}
|
||||||
|
routes = {("GET", "release/releases"): (200, json.dumps(release_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.get_latest_release(definition_id=99)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
async def test_404_raises_not_found(self) -> None:
|
||||||
|
routes = {("GET", "release/releases"): (404, b'{"message": "Not found"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.get_latest_release(definition_id=999)
|
||||||
|
|
||||||
|
async def test_passes_definition_id_as_filter(self) -> None:
|
||||||
|
captured_urls: list[str] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
captured_urls.append(str(request.url))
|
||||||
|
return httpx.Response(200, content=b'{"value": [], "count": 0}')
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
http_client = httpx.AsyncClient(transport=transport)
|
||||||
|
vsrm_client = httpx.AsyncClient(transport=transport)
|
||||||
|
client = AzDoClient(
|
||||||
|
base_url="https://dev.azure.com/my-org/my-project/_apis",
|
||||||
|
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
|
||||||
|
pat="test-pat",
|
||||||
|
http_client=http_client,
|
||||||
|
vsrm_http_client=vsrm_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.get_latest_release(definition_id=13)
|
||||||
|
assert any("releases" in url for url in captured_urls)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# approve_release tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestApproveRelease:
|
||||||
|
"""Tests for AzDoClient.approve_release."""
|
||||||
|
|
||||||
|
async def test_returns_dict_with_status(self) -> None:
|
||||||
|
approve_data = _load_json("azdo_approve_release.json")
|
||||||
|
routes = {("PATCH", "release/approvals"): (200, json.dumps(approve_data))}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.approve_release(
|
||||||
|
approval_id="approval-uuid-123", comment="Approved"
|
||||||
|
)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["status"] == "approved"
|
||||||
|
|
||||||
|
async def test_404_raises_not_found(self) -> None:
|
||||||
|
routes = {("PATCH", "release/approvals"): (404, b'{"message": "Approval not found"}')}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.approve_release(approval_id="bad-id", comment="Approve")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# close() lifecycle tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAzDoClientLifecycle:
|
||||||
|
"""Tests for AzDoClient close() method."""
|
||||||
|
|
||||||
|
async def test_close_closes_both_clients(self) -> None:
|
||||||
|
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
|
||||||
|
http_client = httpx.AsyncClient(transport=transport)
|
||||||
|
vsrm_client = httpx.AsyncClient(transport=transport)
|
||||||
|
client = AzDoClient(
|
||||||
|
base_url="https://dev.azure.com/org/proj/_apis",
|
||||||
|
vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis",
|
||||||
|
pat="my-pat",
|
||||||
|
http_client=http_client,
|
||||||
|
vsrm_http_client=vsrm_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
assert http_client.is_closed
|
||||||
|
assert vsrm_client.is_closed
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# list_active_prs tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_pr_list_response(prs: list[dict]) -> str:
|
||||||
|
return json.dumps({"value": prs, "count": len(prs)})
|
||||||
|
|
||||||
|
|
||||||
|
def _make_active_pr_item(
|
||||||
|
pr_id: int = 10,
|
||||||
|
title: str = "Test PR",
|
||||||
|
branch: str = "refs/heads/feature/ALLPOST-100_fix",
|
||||||
|
status: str = "active",
|
||||||
|
repo_name: str = "my-repo",
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"pullRequestId": pr_id,
|
||||||
|
"title": title,
|
||||||
|
"status": status,
|
||||||
|
"sourceRefName": branch,
|
||||||
|
"targetRefName": "refs/heads/develop",
|
||||||
|
"url": f"https://dev.azure.com/org/proj/_apis/git/repositories/{repo_name}/pullRequests/{pr_id}",
|
||||||
|
"repository": {
|
||||||
|
"id": "repo-uuid",
|
||||||
|
"name": repo_name,
|
||||||
|
"remoteUrl": f"https://dev.azure.com/org/proj/_git/{repo_name}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestListActivePrs:
|
||||||
|
"""Tests for AzDoClient.list_active_prs."""
|
||||||
|
|
||||||
|
async def test_returns_list_of_pr_info(self) -> None:
|
||||||
|
pr_item = _make_active_pr_item()
|
||||||
|
routes = {
|
||||||
|
("GET", "git/repositories/my-repo/pullRequests"): (
|
||||||
|
200,
|
||||||
|
_make_pr_list_response([pr_item]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.list_active_prs("my-repo", "refs/heads/develop")
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(result[0], PRInfo)
|
||||||
|
|
||||||
|
async def test_pr_id_extracted(self) -> None:
|
||||||
|
pr_item = _make_active_pr_item(pr_id=55)
|
||||||
|
routes = {
|
||||||
|
("GET", "git/repositories/my-repo/pullRequests"): (
|
||||||
|
200,
|
||||||
|
_make_pr_list_response([pr_item]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.list_active_prs("my-repo", "refs/heads/develop")
|
||||||
|
assert result[0].pr_id == "55"
|
||||||
|
|
||||||
|
async def test_pr_title_extracted(self) -> None:
|
||||||
|
pr_item = _make_active_pr_item(title="My Feature")
|
||||||
|
routes = {
|
||||||
|
("GET", "git/repositories/my-repo/pullRequests"): (
|
||||||
|
200,
|
||||||
|
_make_pr_list_response([pr_item]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.list_active_prs("my-repo", "refs/heads/develop")
|
||||||
|
assert result[0].pr_title == "My Feature"
|
||||||
|
|
||||||
|
async def test_empty_list_when_no_prs(self) -> None:
|
||||||
|
routes = {
|
||||||
|
("GET", "git/repositories/my-repo/pullRequests"): (
|
||||||
|
200,
|
||||||
|
_make_pr_list_response([]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.list_active_prs("my-repo", "refs/heads/develop")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
async def test_multiple_prs_returned(self) -> None:
|
||||||
|
prs = [
|
||||||
|
_make_active_pr_item(pr_id=10, title="PR 10"),
|
||||||
|
_make_active_pr_item(pr_id=20, title="PR 20"),
|
||||||
|
]
|
||||||
|
routes = {
|
||||||
|
("GET", "git/repositories/my-repo/pullRequests"): (
|
||||||
|
200,
|
||||||
|
_make_pr_list_response(prs),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
result = await client.list_active_prs("my-repo", "refs/heads/develop")
|
||||||
|
assert len(result) == 2
|
||||||
|
assert {r.pr_id for r in result} == {"10", "20"}
|
||||||
|
|
||||||
|
async def test_404_raises_not_found(self) -> None:
|
||||||
|
routes = {
|
||||||
|
("GET", "git/repositories/missing-repo/pullRequests"): (
|
||||||
|
404,
|
||||||
|
b'{"message": "Repo not found"}',
|
||||||
|
)
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.list_active_prs("missing-repo", "refs/heads/develop")
|
||||||
|
|
||||||
|
async def test_401_raises_authentication_error(self) -> None:
|
||||||
|
routes = {
|
||||||
|
("GET", "git/repositories/my-repo/pullRequests"): (
|
||||||
|
401,
|
||||||
|
b'{"message": "Unauthorized"}',
|
||||||
|
)
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
await client.list_active_prs("my-repo", "refs/heads/develop")
|
||||||
|
|
||||||
|
async def test_500_raises_service_error(self) -> None:
|
||||||
|
routes = {
|
||||||
|
("GET", "git/repositories/my-repo/pullRequests"): (
|
||||||
|
500,
|
||||||
|
b'{"message": "Internal error"}',
|
||||||
|
)
|
||||||
|
}
|
||||||
|
client = _make_client(routes)
|
||||||
|
|
||||||
|
with pytest.raises(ServiceError):
|
||||||
|
await client.list_active_prs("my-repo", "refs/heads/develop")
|
||||||
454
tests/tools/test_claude_review.py
Normal file
454
tests/tools/test_claude_review.py
Normal file
@@ -0,0 +1,454 @@
|
|||||||
|
"""Tests for ClaudeReviewer using Claude Code CLI subprocess."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from release_agent.models.review import ReviewResult
|
||||||
|
from release_agent.tools.claude_review import (
|
||||||
|
ClaudeReviewer,
|
||||||
|
_build_prompt,
|
||||||
|
_parse_cli_output,
|
||||||
|
_truncate_diff,
|
||||||
|
)
|
||||||
|
|
||||||
|
MAX_DIFF_CHARS = 100_000
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers — fake subprocess runner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_cli_output(
|
||||||
|
verdict: str = "approve",
|
||||||
|
summary: str = "LGTM",
|
||||||
|
issues: list | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a JSON string mimicking Claude Code CLI --output-format json."""
|
||||||
|
structured = {
|
||||||
|
"verdict": verdict,
|
||||||
|
"summary": summary,
|
||||||
|
"issues": issues or [],
|
||||||
|
}
|
||||||
|
return json.dumps({"result": "", "structured_output": structured})
|
||||||
|
|
||||||
|
|
||||||
|
def _make_subprocess_runner(
|
||||||
|
stdout: str = "",
|
||||||
|
stderr: str = "",
|
||||||
|
returncode: int = 0,
|
||||||
|
):
|
||||||
|
"""Return a fake run_subprocess callable that records calls."""
|
||||||
|
calls: list[dict] = []
|
||||||
|
|
||||||
|
async def fake_run(*, cmd, cwd, timeout):
|
||||||
|
calls.append({"cmd": cmd, "cwd": cwd, "timeout": timeout})
|
||||||
|
return (stdout, stderr, returncode)
|
||||||
|
|
||||||
|
return fake_run, calls
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _truncate_diff tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTruncateDiff:
|
||||||
|
def test_short_diff_not_truncated(self) -> None:
|
||||||
|
diff = "short diff"
|
||||||
|
assert _truncate_diff(diff) == diff
|
||||||
|
|
||||||
|
def test_exact_limit_not_truncated(self) -> None:
|
||||||
|
diff = "x" * MAX_DIFF_CHARS
|
||||||
|
assert _truncate_diff(diff) == diff
|
||||||
|
|
||||||
|
def test_over_limit_truncated(self) -> None:
|
||||||
|
diff = "x" * (MAX_DIFF_CHARS + 1000)
|
||||||
|
result = _truncate_diff(diff)
|
||||||
|
assert len(result) < len(diff)
|
||||||
|
assert "TRUNCATED" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _build_prompt tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBuildPrompt:
|
||||||
|
def test_contains_pr_title(self) -> None:
|
||||||
|
prompt = _build_prompt(diff="d", pr_title="My Title", repo_name="repo")
|
||||||
|
assert "My Title" in prompt
|
||||||
|
|
||||||
|
def test_contains_repo_name(self) -> None:
|
||||||
|
prompt = _build_prompt(diff="d", pr_title="t", repo_name="my-repo")
|
||||||
|
assert "my-repo" in prompt
|
||||||
|
|
||||||
|
def test_contains_diff(self) -> None:
|
||||||
|
prompt = _build_prompt(diff="UNIQUE_DIFF", pr_title="t", repo_name="r")
|
||||||
|
assert "UNIQUE_DIFF" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _parse_cli_output tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestParseCliOutput:
|
||||||
|
def test_parses_structured_output(self) -> None:
|
||||||
|
stdout = _make_cli_output(verdict="approve", summary="Good")
|
||||||
|
result = _parse_cli_output(stdout)
|
||||||
|
assert isinstance(result, ReviewResult)
|
||||||
|
assert result.verdict == "approve"
|
||||||
|
assert result.summary == "Good"
|
||||||
|
|
||||||
|
def test_parses_request_changes(self) -> None:
|
||||||
|
stdout = _make_cli_output(
|
||||||
|
verdict="request_changes",
|
||||||
|
summary="Has issues",
|
||||||
|
issues=[{"severity": "blocker", "description": "SQL injection"}],
|
||||||
|
)
|
||||||
|
result = _parse_cli_output(stdout)
|
||||||
|
assert result.verdict == "request_changes"
|
||||||
|
assert len(result.issues) == 1
|
||||||
|
assert result.has_blockers is True
|
||||||
|
|
||||||
|
def test_parses_issues_with_optional_fields(self) -> None:
|
||||||
|
stdout = _make_cli_output(
|
||||||
|
verdict="request_changes",
|
||||||
|
summary="Issues found",
|
||||||
|
issues=[{
|
||||||
|
"severity": "warning",
|
||||||
|
"description": "Style issue",
|
||||||
|
"file_path": "src/foo.py",
|
||||||
|
"suggestion": "Fix it",
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
result = _parse_cli_output(stdout)
|
||||||
|
assert result.issues[0].file_path == "src/foo.py"
|
||||||
|
assert result.issues[0].suggestion == "Fix it"
|
||||||
|
|
||||||
|
def test_empty_issues_no_blockers(self) -> None:
|
||||||
|
stdout = _make_cli_output(verdict="approve", summary="Clean", issues=[])
|
||||||
|
result = _parse_cli_output(stdout)
|
||||||
|
assert result.has_blockers is False
|
||||||
|
assert len(result.issues) == 0
|
||||||
|
|
||||||
|
def test_result_field_as_json_string(self) -> None:
|
||||||
|
"""When structured_output is absent, falls back to parsing result as JSON."""
|
||||||
|
inner = {"verdict": "approve", "summary": "OK", "issues": []}
|
||||||
|
stdout = json.dumps({"result": json.dumps(inner)})
|
||||||
|
result = _parse_cli_output(stdout)
|
||||||
|
assert result.verdict == "approve"
|
||||||
|
|
||||||
|
def test_invalid_json_raises(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="Failed to parse"):
|
||||||
|
_parse_cli_output("not json at all")
|
||||||
|
|
||||||
|
def test_missing_structured_output_and_result_raises(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="No structured_output"):
|
||||||
|
_parse_cli_output(json.dumps({"other": "data"}))
|
||||||
|
|
||||||
|
def test_non_dict_structured_output_raises(self) -> None:
|
||||||
|
stdout = json.dumps({"structured_output": ["not", "a", "dict"]})
|
||||||
|
with pytest.raises(ValueError, match="Expected dict"):
|
||||||
|
_parse_cli_output(stdout)
|
||||||
|
|
||||||
|
def test_result_is_non_json_string_raises(self) -> None:
|
||||||
|
stdout = json.dumps({"result": "just plain text, not json"})
|
||||||
|
with pytest.raises(ValueError, match="not valid JSON"):
|
||||||
|
_parse_cli_output(stdout)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ClaudeReviewer construction tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestClaudeReviewerConstruction:
|
||||||
|
def test_can_be_instantiated(self) -> None:
|
||||||
|
reviewer = ClaudeReviewer()
|
||||||
|
assert reviewer is not None
|
||||||
|
|
||||||
|
def test_custom_claude_cmd(self) -> None:
|
||||||
|
reviewer = ClaudeReviewer(claude_cmd="/usr/local/bin/claude")
|
||||||
|
assert reviewer._claude_cmd == "/usr/local/bin/claude"
|
||||||
|
|
||||||
|
def test_custom_timeout(self) -> None:
|
||||||
|
reviewer = ClaudeReviewer(timeout=60)
|
||||||
|
assert reviewer._timeout == 60
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# review_pr tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReviewPr:
|
||||||
|
async def test_returns_review_result(self) -> None:
|
||||||
|
stdout = _make_cli_output(verdict="approve", summary="Looks good")
|
||||||
|
runner, _ = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
result = await reviewer.review_pr(
|
||||||
|
diff="diff --git a/foo.py ...",
|
||||||
|
pr_title="Fix bug",
|
||||||
|
repo_name="my-repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, ReviewResult)
|
||||||
|
assert result.verdict == "approve"
|
||||||
|
|
||||||
|
async def test_passes_cwd_to_subprocess(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.review_pr(
|
||||||
|
diff="diff",
|
||||||
|
pr_title="PR",
|
||||||
|
repo_name="repo",
|
||||||
|
cwd="/path/to/worktree",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert calls[0]["cwd"] == "/path/to/worktree"
|
||||||
|
|
||||||
|
async def test_cmd_includes_claude_p(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
cmd = calls[0]["cmd"]
|
||||||
|
assert cmd[0] == "claude"
|
||||||
|
assert "-p" in cmd
|
||||||
|
|
||||||
|
async def test_cmd_includes_output_format_json(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
cmd = calls[0]["cmd"]
|
||||||
|
idx = cmd.index("--output-format")
|
||||||
|
assert cmd[idx + 1] == "json"
|
||||||
|
|
||||||
|
async def test_cmd_includes_json_schema(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
cmd = calls[0]["cmd"]
|
||||||
|
assert "--json-schema" in cmd
|
||||||
|
|
||||||
|
async def test_cmd_includes_allowed_tools(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
cmd = calls[0]["cmd"]
|
||||||
|
idx = cmd.index("--allowedTools")
|
||||||
|
assert "Read" in cmd[idx + 1]
|
||||||
|
|
||||||
|
async def test_cmd_includes_system_prompt(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
cmd = calls[0]["cmd"]
|
||||||
|
assert "--system-prompt" in cmd
|
||||||
|
|
||||||
|
async def test_nonzero_exit_raises(self) -> None:
|
||||||
|
runner, _ = _make_subprocess_runner(
|
||||||
|
stdout="", stderr="error occurred", returncode=1
|
||||||
|
)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="exited with code 1"):
|
||||||
|
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
async def test_timeout_passed_to_subprocess(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner, timeout=120)
|
||||||
|
|
||||||
|
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
assert calls[0]["timeout"] == 120
|
||||||
|
|
||||||
|
async def test_pr_title_in_prompt(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.review_pr(
|
||||||
|
diff="d", pr_title="Specific Title", repo_name="r"
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = calls[0]["cmd"]
|
||||||
|
prompt = cmd[cmd.index("-p") + 1]
|
||||||
|
assert "Specific Title" in prompt
|
||||||
|
|
||||||
|
async def test_repo_name_in_prompt(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.review_pr(
|
||||||
|
diff="d", pr_title="t", repo_name="special-repo"
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = calls[0]["cmd"]
|
||||||
|
prompt = cmd[cmd.index("-p") + 1]
|
||||||
|
assert "special-repo" in prompt
|
||||||
|
|
||||||
|
async def test_cwd_none_when_not_provided(self) -> None:
|
||||||
|
stdout = _make_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
assert calls[0]["cwd"] is None
|
||||||
|
|
||||||
|
async def test_request_changes_with_issues(self) -> None:
|
||||||
|
stdout = _make_cli_output(
|
||||||
|
verdict="request_changes",
|
||||||
|
summary="Problems found",
|
||||||
|
issues=[
|
||||||
|
{"severity": "blocker", "description": "Security flaw"},
|
||||||
|
{"severity": "warning", "description": "Missing docs"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
runner, _ = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
result = await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
assert result.verdict == "request_changes"
|
||||||
|
assert len(result.issues) == 2
|
||||||
|
assert result.has_blockers is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ClaudeReviewer.generate_ticket_content tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_ticket_cli_output(summary: str = "My summary", description: str = "My desc") -> str:
|
||||||
|
"""Build a JSON string mimicking Claude Code CLI output for ticket generation."""
|
||||||
|
structured = {"summary": summary, "description": description}
|
||||||
|
return json.dumps({"result": "", "structured_output": structured})
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateTicketContent:
|
||||||
|
"""Tests for ClaudeReviewer.generate_ticket_content."""
|
||||||
|
|
||||||
|
async def test_returns_tuple_of_summary_and_description(self) -> None:
|
||||||
|
stdout = _make_ticket_cli_output(summary="Fix login bug", description="Detailed desc")
|
||||||
|
runner, _ = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
result = await reviewer.generate_ticket_content(
|
||||||
|
diff="edit: main.py", pr_title="Fix login", repo_name="backend"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
async def test_returns_correct_summary(self) -> None:
|
||||||
|
stdout = _make_ticket_cli_output(summary="Implement OAuth2 login")
|
||||||
|
runner, _ = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
summary, _ = await reviewer.generate_ticket_content(
|
||||||
|
diff="d", pr_title="Add OAuth", repo_name="auth-service"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert summary == "Implement OAuth2 login"
|
||||||
|
|
||||||
|
async def test_returns_correct_description(self) -> None:
|
||||||
|
stdout = _make_ticket_cli_output(description="This adds OAuth2 support for the login flow")
|
||||||
|
runner, _ = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
_, description = await reviewer.generate_ticket_content(
|
||||||
|
diff="d", pr_title="Add OAuth", repo_name="auth-service"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert description == "This adds OAuth2 support for the login flow"
|
||||||
|
|
||||||
|
async def test_uses_json_schema_with_summary_and_description_fields(self) -> None:
|
||||||
|
stdout = _make_ticket_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
cmd = calls[0]["cmd"]
|
||||||
|
# Verify --json-schema flag was used
|
||||||
|
assert "--json-schema" in cmd
|
||||||
|
schema_idx = cmd.index("--json-schema")
|
||||||
|
schema_json = cmd[schema_idx + 1]
|
||||||
|
schema = json.loads(schema_json)
|
||||||
|
assert "summary" in schema["properties"]
|
||||||
|
assert "description" in schema["properties"]
|
||||||
|
|
||||||
|
async def test_passes_pr_title_in_prompt(self) -> None:
|
||||||
|
stdout = _make_ticket_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.generate_ticket_content(
|
||||||
|
diff="d", pr_title="My Unique PR Title", repo_name="r"
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd_str = " ".join(calls[0]["cmd"])
|
||||||
|
assert "My Unique PR Title" in cmd_str
|
||||||
|
|
||||||
|
async def test_passes_repo_name_in_prompt(self) -> None:
|
||||||
|
stdout = _make_ticket_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.generate_ticket_content(
|
||||||
|
diff="d", pr_title="t", repo_name="my-special-repo"
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd_str = " ".join(calls[0]["cmd"])
|
||||||
|
assert "my-special-repo" in cmd_str
|
||||||
|
|
||||||
|
async def test_passes_cwd_to_subprocess(self) -> None:
|
||||||
|
stdout = _make_ticket_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.generate_ticket_content(
|
||||||
|
diff="d", pr_title="t", repo_name="r", cwd="/some/path"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert calls[0]["cwd"] == "/some/path"
|
||||||
|
|
||||||
|
async def test_cwd_none_by_default(self) -> None:
|
||||||
|
stdout = _make_ticket_cli_output()
|
||||||
|
runner, calls = _make_subprocess_runner(stdout=stdout)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
assert calls[0]["cwd"] is None
|
||||||
|
|
||||||
|
async def test_raises_on_nonzero_exit_code(self) -> None:
|
||||||
|
runner, _ = _make_subprocess_runner(stdout="", stderr="Error", returncode=1)
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Claude CLI"):
|
||||||
|
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
|
||||||
|
|
||||||
|
async def test_raises_on_invalid_json_output(self) -> None:
|
||||||
|
runner, _ = _make_subprocess_runner(stdout="not json at all")
|
||||||
|
reviewer = ClaudeReviewer(run_subprocess=runner)
|
||||||
|
|
||||||
|
with pytest.raises((ValueError, Exception)):
|
||||||
|
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user