feat: initial commit — Billo Release Agent (LangGraph)

LangGraph-based release automation agent with:
- PR discovery (webhook + polling)
- AI code review via Claude Code CLI (subscription-based)
- Auto-create Jira tickets for PRs without ticket ID
- Jira ticket lifecycle management (code review -> staging -> done)
- CI/CD pipeline trigger, polling, and approval gates
- Slack interactive messages with approval buttons
- Per-repo semantic versioning
- PostgreSQL persistence (threads, staging, releases)
- FastAPI API (webhooks, approvals, status, manual triggers)
- Docker Compose deployment

1069 tests, 95%+ coverage.
This commit is contained in:
Yaojia Wang
2026-03-24 17:38:23 +01:00
commit f5c2733cfb
104 changed files with 19721 additions and 0 deletions

55
.env.example Normal file
View File

@@ -0,0 +1,55 @@
# Billo Release Agent - Environment Variables
# Copy this file to .env and fill in your values.
# Never commit .env to source control.
# ===========================================================================
# REQUIRED
# ===========================================================================
# --- Azure DevOps ----------------------------------------------------------
AZDO_ORGANIZATION=billodev
AZDO_PROJECT=Billo App Platform
AZDO_PAT= # Personal access token (Code read/write, Build read/execute)
# --- PostgreSQL ------------------------------------------------------------
POSTGRES_PASSWORD= # Used by docker-compose db service
POSTGRES_DSN=postgresql://agent:${POSTGRES_PASSWORD}@localhost:5432/agent
# --- Jira ------------------------------------------------------------------
JIRA_EMAIL= # Jira account email
JIRA_API_TOKEN= # Jira API token
# --- Slack -----------------------------------------------------------------
SLACK_WEBHOOK_URL= # Incoming webhook URL (for release notifications)
# --- Webhook Security ------------------------------------------------------
WEBHOOK_SECRET= # HMAC secret for Azure DevOps webhook validation
# ===========================================================================
# OPTIONAL (have defaults)
# ===========================================================================
# --- Slack App (for interactive buttons) -----------------------------------
SLACK_BOT_TOKEN= # xoxb-... Bot token (needed for Slack buttons)
SLACK_SIGNING_SECRET= # Slack app signing secret (needed for /slack/interactions)
SLACK_CHANNEL_ID= # Channel ID to post messages (e.g., C0123456789)
# --- CI/CD Polling ---------------------------------------------------------
CI_POLL_INTERVAL_SECONDS=30 # Seconds between CI status polls
CI_POLL_MAX_WAIT_SECONDS=1800 # Max wait for CI completion (30 min)
# --- Claude Code Review ----------------------------------------------------
# ANTHROPIC_API_KEY= # Not needed — uses Claude Code CLI subscription
# CLAUDE_REVIEW_MODEL=claude-sonnet-4-20250514
# --- Local Repo Path -------------------------------------------------------
REPOS_BASE_DIR= # Base dir containing Billo repos (e.g., /c/Users/yaoji/git/Billo)
# --- Jira ------------------------------------------------------------------
JIRA_BASE_URL=https://billolife.atlassian.net
# --- Operator API ----------------------------------------------------------
OPERATOR_TOKEN= # When set, protects POST /approvals/* and POST /manual/*
# --- Server ----------------------------------------------------------------
PORT=8000

36
.gitignore vendored Normal file
View File

@@ -0,0 +1,36 @@
# Environment
.env
.env.local
# Python
__pycache__/
*.py[cod]
*.egg-info/
dist/
build/
.eggs/
# Virtual environment
.venv/
venv/
# uv
uv.lock
# IDE
.idea/
.vscode/
*.swp
*.swo
# Testing
.coverage
htmlcov/
.pytest_cache/
# Data
data/staging/
# OS
.DS_Store
Thumbs.db

32
Dockerfile Normal file
View File

@@ -0,0 +1,32 @@
FROM python:3.12-slim
WORKDIR /app
# Install system dependencies required for psycopg binary
RUN apt-get update && apt-get install -y --no-install-recommends \
libpq5 \
&& rm -rf /var/lib/apt/lists/*
# Copy dependency files first for layer caching
COPY pyproject.toml uv.lock ./
# Install uv for fast dependency management
RUN pip install --no-cache-dir uv
# Install runtime dependencies (no dev extras)
RUN uv pip install --system --no-cache-dir -e .
# Copy application source
COPY src/ ./src/
# Create data directory for staging store
RUN mkdir -p data/staging
# Non-root user for security
RUN adduser --disabled-password --gecos "" appuser && \
chown -R appuser:appuser /app
USER appuser
EXPOSE 8000
CMD ["uvicorn", "release_agent.main:app", "--host", "0.0.0.0", "--port", "8000"]

341
README.md Normal file
View File

@@ -0,0 +1,341 @@
# Billo Release Agent
A LangGraph-based release automation agent for Billo. Automates the full release
pipeline: PR discovery, code review (via Claude Code CLI), Jira ticket management,
staging release tracking, CI/CD pipeline triggering with approval gates, and Slack
interactive notifications.
## Architecture
```
+--- Azure DevOps Webhook ---+ +--- PR Poller (every 5 min) ---+
| POST /webhooks/azdo | | Scans WATCHED_REPOS for |
| (push-based) | | active PRs (pull-based) |
+------------+---------------+ +------------+------------------+
| |
v v
+---------------------------------------------+
| FastAPI Application |
| /webhooks/azdo /slack/interactions |
| /approvals/* /manual/* /status |
+---------------------+------------------------+
|
v
+---------------------------------------------+
| LangGraph Graphs |
| |
| pr_completed: |
| parse -> fetch -> [has ticket?] |
| no -> Claude generates ticket |
| yes -> Claude code review |
| -> merge -> Jira -> staging -> CI build |
| |
| release: |
| create release PR -> merge -> CI build |
| -> CD release -> [Sandbox approve] |
| -> [Production approve] -> Slack notify |
+---------------------+------------------------+
|
+----------------+----------------+
| | |
v v v
PostgreSQL Azure DevOps Slack (buttons)
- threads - PRs/Pipelines - Notifications
- staging - Builds/Releases - Approvals
- releases
```
## Key Features
| Feature | Description |
|---------|-------------|
| **PR Discovery** | Webhook-based (push) or polling-based (pull) — or both |
| **Auto-Create Jira Ticket** | When PR branch has no ticket ID, Claude generates summary + description and creates a Jira Story |
| **AI Code Review** | Claude Code CLI reviews PRs with full repo context (Read/Glob/Grep), using your subscription |
| **CI/CD Integration** | Triggers CI builds after merge, polls for completion, handles CD release approval gates |
| **Slack Interactive** | Approval requests with [Approve]/[Cancel] buttons, CI/CD status notifications |
| **Human-in-the-loop** | 5 interrupt points where operator confirmation is required before destructive actions |
| **Per-repo Versioning** | Independent semantic versioning per repository (patch auto-increment) |
## Prerequisites
- Python 3.12+
- PostgreSQL 16+
- [uv](https://github.com/astral-sh/uv) (recommended) or pip
- Claude Code CLI installed and authenticated (`claude` in PATH)
- Slack App (for interactive buttons) or Slack Incoming Webhook (for notifications only)
## Quick Start
### Local Development
```bash
# Install dependencies
uv sync --all-extras
# Copy and configure environment
cp .env.example .env
# Edit .env -- fill in all REQUIRED variables
# Start PostgreSQL
docker compose up -d db
# Run the server
uv run uvicorn release_agent.main:app --reload --port 8000
# Verify
curl http://localhost:8000/status
```
### Docker Compose (Production)
```bash
cp .env.example .env
# Edit .env -- POSTGRES_PASSWORD, WEBHOOK_SECRET, etc. are required
docker compose up -d
```
Tables are created automatically on first startup.
## Configuration
All configuration is via environment variables. See `.env.example` for the full list.
### Required Variables
| Variable | Description |
|----------|-------------|
| `AZDO_ORGANIZATION` | Azure DevOps organization name |
| `AZDO_PROJECT` | Azure DevOps project name |
| `AZDO_PAT` | Azure DevOps personal access token |
| `POSTGRES_DSN` | PostgreSQL connection string |
| `POSTGRES_PASSWORD` | PostgreSQL password (used by docker-compose) |
| `JIRA_EMAIL` | Jira account email |
| `JIRA_API_TOKEN` | Jira API token |
| `SLACK_WEBHOOK_URL` | Slack incoming webhook URL |
| `WEBHOOK_SECRET` | Shared secret for validating AzDo webhooks (must be non-empty) |
### Optional Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `REPOS_BASE_DIR` | `""` | Base dir with Billo repos (e.g., `/c/Users/yaoji/git/Billo`) |
| `WATCHED_REPOS` | `""` | Comma-separated repos to poll (e.g., `Billo.Platform.Payment,Billo.Platform.Document.DocumentAnalyser`) |
| `PR_POLL_ENABLED` | `False` | Enable periodic PR polling |
| `PR_POLL_INTERVAL_SECONDS` | `300` | Polling interval (5 min) |
| `PR_POLL_TARGET_BRANCH` | `refs/heads/develop` | Target branch filter |
| `DEFAULT_JIRA_PROJECT` | `ALLPOST` | Jira project key for auto-created tickets |
| `AUTO_CREATE_TICKET_ENABLED` | `True` | Auto-create Jira ticket when branch has no ticket ID |
| `SLACK_BOT_TOKEN` | `""` | Slack App bot token (for interactive buttons) |
| `SLACK_SIGNING_SECRET` | `""` | Slack App signing secret (required for /slack/interactions) |
| `SLACK_CHANNEL_ID` | `""` | Channel for interactive messages |
| `CI_POLL_INTERVAL_SECONDS` | `30` | CI build status poll interval |
| `CI_POLL_MAX_WAIT_SECONDS` | `1800` | Max wait for CI completion (30 min) |
| `OPERATOR_TOKEN` | `""` | Token for operator endpoints (empty = no auth) |
| `JIRA_BASE_URL` | `https://billolife.atlassian.net` | Jira instance URL |
| `PORT` | `8000` | HTTP server port |
### Security Notes
- `WEBHOOK_SECRET` must be non-empty; empty secret rejects all webhooks
- `POSTGRES_PASSWORD` has no default in docker-compose; fails if unset
- `SLACK_SIGNING_SECRET` must be set for `/slack/interactions` to accept requests (returns 503 if empty)
- Slack signature verification includes 5-minute replay attack prevention
- All secrets use `SecretStr` and are never logged or included in error responses
- Set `OPERATOR_TOKEN` in production to protect approval and manual trigger endpoints
## API Endpoints
### Webhooks
| Method | Path | Auth | Description |
|--------|------|------|-------------|
| POST | `/webhooks/azdo` | `X-Webhook-Secret` | Receive Azure DevOps PR webhook events |
| POST | `/slack/interactions` | Slack Signing Secret | Receive Slack button click callbacks |
### Approvals (requires `X-Operator-Token` when configured)
| Method | Path | Description |
|--------|------|-------------|
| GET | `/approvals/pending` | List threads awaiting operator approval |
| POST | `/approvals/{thread_id}` | Submit approval decision (merge/cancel/approve/skip) |
### Status and Manual Triggers
| Method | Path | Auth | Description |
|--------|------|------|-------------|
| GET | `/status` | None | Health check |
| GET | `/releases/{repo}` | None | List versions for a repo |
| GET | `/staging?repo={repo}` | None | Current staging release |
| POST | `/manual/pr/{pr_id}` | `X-Operator-Token` | Manually trigger PR processing |
| POST | `/manual/release` | `X-Operator-Token` | Manually trigger a release |
## Graph Workflows
### PR Completed
```
parse_webhook -> fetch_pr_details -> route_after_fetch
|-- merged -----------------> calculate_version -> update_staging -> CI build -> END
|-- active_with_ticket -----> move_jira_code_review -+
|-- active_no_ticket -------> auto_create_ticket ----+
|
run_code_review -> evaluate_review
|-- approve -> [Slack: Merge?] -> merge_pr
|-- request_changes -> notify -> END
-> Jira transitions -> calculate_version
-> update_staging -> CI build -> notify -> END
```
### Release
```
load_staging -> [Slack: Create release?] -> create_release_pr
-> [Slack: Merge release?] -> merge_release_pr
-> CI build on main -> poll until complete
|-- ci_failed -> notify failure -> END
|-- ci_passed -> wait for CD release -> approval loop:
|-- [Slack: Approve Sandbox?] -> approve -> poll again
|-- [Slack: Approve Production?] -> approve -> poll again
|-- all_deployed -> move tickets to Done
-> Slack release notification -> archive -> END
```
### Interrupt Points (Slack buttons)
| # | When | Slack Message | Buttons |
|---|------|--------------|---------|
| 1 | After code review approves | PR title + review summary | [Merge] [Cancel] |
| 2 | Before creating release PR | Version + ticket list | [Create] [Cancel] |
| 3 | Before merging release PR | Release PR link | [Merge] [Cancel] |
| 4 | Before triggering pipelines | Pipeline list | [Trigger] [Skip] |
| 5 | Before approving release stage | Stage name + status | [Approve] [Skip] |
## PR Polling (Alternative to Webhooks)
When `PR_POLL_ENABLED=True`, the agent periodically scans all `WATCHED_REPOS` for
active PRs targeting the configured branch. New PRs not yet tracked in `agent_threads`
are automatically processed through the `pr_completed` graph.
This eliminates the need for Azure DevOps webhook configuration and works behind
firewalls without public endpoint exposure.
## Auto-Create Jira Ticket
When a PR branch has no ticket ID (e.g., `chore/update-dependencies` instead of
`feature/ALLPOST-4028_login-page`), the agent automatically:
1. Sends the PR diff to Claude Code CLI
2. Claude generates a concise ticket summary and description
3. Creates a Jira Story in the `DEFAULT_JIRA_PROJECT`
4. Continues the normal workflow with the created ticket
## Database Schema
Tables are created automatically on startup:
```sql
-- Thread tracking for LangGraph interrupts and PR dedup
agent_threads (thread_id, graph_name, repo_name, pr_id, status, state JSONB,
slack_message_ts, created_at, updated_at)
-- Current in-progress releases (one per repo)
staging_releases (repo, version, started_at, tickets JSONB, updated_at)
-- Completed releases (immutable history)
archived_releases (repo, version, started_at, tickets JSONB, released_at)
```
## Migrating Existing JSON Files
If you have existing release data from the Claude Code skill:
```bash
# Dry run
uv run python scripts/migrate_json_to_db.py \
--source ../release-workflow/releases --dry-run
# Execute
uv run python scripts/migrate_json_to_db.py \
--source ../release-workflow/releases \
--dsn "postgresql://agent:password@localhost/agent"
```
## Development
### Running Tests
```bash
# Run all tests with coverage (1061 tests, 96%+ coverage)
uv run pytest
# Run without coverage (faster)
uv run pytest --no-cov
# Run specific module
uv run pytest tests/graph/test_pr_completed.py -v
```
### Project Structure
```
src/release_agent/
main.py # FastAPI app, lifespan, task management
config.py # pydantic-settings (all env vars)
state.py # LangGraph ReleaseState TypedDict
exceptions.py # Exception hierarchy
branch_parser.py # Extract ticket ID from branch name
versioning.py # Per-repo version calculation
api/
models.py # HTTP request/response Pydantic models
dependencies.py # FastAPI Depends() + operator auth
webhooks.py # POST /webhooks/azdo
approvals.py # Approval endpoints
status.py # Status, releases, manual triggers
slack_interactions.py # POST /slack/interactions (button callbacks)
graph/
dependencies.py # ToolClients, StagingStore Protocol
postgres_staging_store.py # PostgreSQL-backed store
routing.py # Pure routing functions (route_after_fetch, etc.)
pr_completed.py # PR graph nodes + auto_create_ticket
release.py # Release graph nodes + CI/CD approval loop
full_cycle.py # Subgraph composition
ci_nodes.py # CI trigger, poll, notify nodes
polling.py # Reusable async poll_until utility
models/
pr.py, ticket.py, release.py, pipeline.py
webhook.py, review.py, jira.py, build.py
tools/
azdo.py # Azure DevOps REST client
jira.py # Jira REST client (transitions + create_issue)
slack.py # Slack dual-mode (webhook + Web API)
claude_review.py # Claude Code CLI (review + ticket generation)
_http.py, _retry.py # Shared helpers
services/
pr_poller.py # Background PR polling loop
pr_dedup.py # PR deduplication via agent_threads
scripts/
migrate_json_to_db.py # One-time JSON -> PostgreSQL migration
tests/ # 1061 tests, 96%+ coverage
```
## Docker
```bash
# Required: set POSTGRES_PASSWORD and WEBHOOK_SECRET in .env
docker compose up -d
```
The agent service includes a health check at `/status`. PostgreSQL uses
`pg_isready` with `service_healthy` dependency.
## Slack App Setup
To use interactive buttons (optional — REST API approvals still work without it):
1. Create a Slack App at https://api.slack.com/apps
2. Enable **Interactivity** with Request URL: `https://<your-domain>/slack/interactions`
3. Add Bot Token Scopes: `chat:write`, `chat:update`
4. Install to workspace, get Bot Token (`xoxb-...`)
5. Set `SLACK_BOT_TOKEN`, `SLACK_SIGNING_SECRET`, `SLACK_CHANNEL_ID` in `.env`

48
docker-compose.yml Normal file
View File

@@ -0,0 +1,48 @@
services:
agent:
build: .
ports:
- "8000:8000"
environment:
AZDO_ORGANIZATION: ${AZDO_ORGANIZATION}
AZDO_PROJECT: ${AZDO_PROJECT}
AZDO_PAT: ${AZDO_PAT}
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-unused}
POSTGRES_DSN: postgresql://agent:${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}@db:5432/agent
JIRA_EMAIL: ${JIRA_EMAIL}
JIRA_API_TOKEN: ${JIRA_API_TOKEN}
SLACK_WEBHOOK_URL: ${SLACK_WEBHOOK_URL}
WEBHOOK_SECRET: ${WEBHOOK_SECRET:?WEBHOOK_SECRET must be set}
JIRA_BASE_URL: ${JIRA_BASE_URL:-https://billolife.atlassian.net}
CLAUDE_REVIEW_MODEL: ${CLAUDE_REVIEW_MODEL:-claude-sonnet-4-20250514}
REPOS_BASE_DIR: ${REPOS_BASE_DIR:-}
OPERATOR_TOKEN: ${OPERATOR_TOKEN:-}
PORT: 8000
depends_on:
db:
condition: service_healthy
restart: unless-stopped
healthcheck:
test: ["CMD-SHELL", "curl -sf http://localhost:8000/status || exit 1"]
interval: 10s
timeout: 5s
retries: 3
start_period: 15s
db:
image: postgres:16-alpine
environment:
POSTGRES_USER: agent
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
POSTGRES_DB: agent
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U agent -d agent"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
volumes:
postgres_data:

58
pyproject.toml Normal file
View File

@@ -0,0 +1,58 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "billo-release-agent"
version = "0.1.0"
description = "LangGraph-based release automation agent for Billo"
requires-python = ">=3.12"
dependencies = [
"langgraph>=0.3.30",
"langgraph-checkpoint-postgres>=2.0.15",
"fastapi>=0.115.0",
"uvicorn[standard]>=0.34.0",
"httpx>=0.28.0",
"anthropic>=0.52.0",
"pydantic>=2.11.0",
"pydantic-settings>=2.8.0",
"psycopg[binary]>=3.2.0",
"psycopg-pool>=3.2.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.3.0",
"pytest-asyncio>=0.25.0",
"pytest-cov>=6.1.0",
"ruff>=0.11.0",
"psycopg[binary]>=3.2.0",
"psycopg-pool>=3.2.0",
]
[tool.hatch.build.targets.wheel]
packages = ["src/release_agent"]
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"
pythonpath = ["."]
addopts = "--cov=src/release_agent --cov-report=term-missing --cov-fail-under=80"
[tool.coverage.run]
source = ["src/release_agent"]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"if TYPE_CHECKING:",
"raise NotImplementedError",
]
[tool.ruff]
line-length = 100
target-version = "py312"
[tool.ruff.lint]
select = ["E", "F", "I", "N", "UP", "B", "SIM"]
ignore = ["E501"]

0
scripts/__init__.py Normal file
View File

View File

@@ -0,0 +1,284 @@
"""Migration script: JSON files -> PostgreSQL.
Reads staging and archived release JSON files from a directory tree and
inserts them into the staging_releases and archived_releases tables.
All business logic is implemented as pure functions so it can be tested
without a real database. The main() entry point wires together the pure
functions with actual I/O.
Usage:
python scripts/migrate_json_to_db.py --source /path/to/releases \\
--dsn "postgresql://user:pass@localhost/db" [--dry-run]
Pure functions (testable without DB):
collect_json_files(directory) -> list[Path]
is_staging_filename(name) -> bool
is_archived_filename(name) -> bool
parse_staging_json(data) -> MigrationRecord
parse_archived_json(data) -> MigrationRecord
build_staging_insert_sql(record) -> tuple[str, tuple]
build_archived_insert_sql(record) -> tuple[str, tuple]
"""
from __future__ import annotations
import argparse
import json
import re
import sys
from dataclasses import dataclass, field
from datetime import date
from pathlib import Path
# ---------------------------------------------------------------------------
# Data model
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class MigrationRecord:
"""Parsed record ready for database insertion.
released_at is None for staging records, set for archived records.
"""
repo: str
version: str
started_at: date
tickets: list[dict]
released_at: date | None = None
# ---------------------------------------------------------------------------
# File classification
# ---------------------------------------------------------------------------
# Archived filenames match: <repo>_<version>_<date>.json
# e.g. Billo.Platform.Payment_v1.0.1_2026-03-23.json
_ARCHIVED_PATTERN = re.compile(
r"^.+_v\d+\.\d+\.\d+_\d{4}-\d{2}-\d{2}\.json$"
)
def is_staging_filename(name: str) -> bool:
"""Return True if the filename looks like a staging JSON file.
Staging files end in .json and do not match the archived pattern.
"""
if not name.endswith(".json"):
return False
return not _ARCHIVED_PATTERN.match(name)
def is_archived_filename(name: str) -> bool:
"""Return True if the filename looks like an archived release JSON file."""
return bool(_ARCHIVED_PATTERN.match(name))
# ---------------------------------------------------------------------------
# File collection
# ---------------------------------------------------------------------------
def collect_json_files(directory: Path) -> list[Path]:
"""Recursively collect all .json files under directory.
Returns a sorted list of Path objects.
"""
return sorted(directory.rglob("*.json"))
# ---------------------------------------------------------------------------
# Parsing pure functions
# ---------------------------------------------------------------------------
def parse_staging_json(data: dict) -> MigrationRecord:
"""Parse a staging release JSON dict into a MigrationRecord.
Args:
data: Parsed JSON dict with keys: version, repo, started_at, tickets.
Returns:
MigrationRecord with released_at=None.
"""
return MigrationRecord(
repo=data["repo"],
version=data["version"],
started_at=date.fromisoformat(data["started_at"]),
tickets=list(data.get("tickets") or []),
released_at=None,
)
def parse_archived_json(data: dict) -> MigrationRecord:
"""Parse an archived release JSON dict into a MigrationRecord.
Args:
data: Parsed JSON dict with keys: version, repo, started_at, tickets,
released_at.
Returns:
MigrationRecord with released_at set.
"""
return MigrationRecord(
repo=data["repo"],
version=data["version"],
started_at=date.fromisoformat(data["started_at"]),
tickets=list(data.get("tickets") or []),
released_at=date.fromisoformat(data["released_at"]),
)
# ---------------------------------------------------------------------------
# SQL builder pure functions
# ---------------------------------------------------------------------------
_STAGING_INSERT_SQL = """
INSERT INTO staging_releases (repo, version, started_at, tickets)
VALUES (%s, %s, %s, %s)
ON CONFLICT (repo) DO NOTHING
""".strip()
_ARCHIVED_INSERT_SQL = """
INSERT INTO archived_releases (repo, version, started_at, tickets, released_at)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (repo, version) DO NOTHING
""".strip()
def build_staging_insert_sql(record: MigrationRecord) -> tuple[str, tuple]:
"""Build the INSERT SQL and parameters for a staging release record.
Returns:
(sql_string, params_tuple) ready for cursor.execute().
"""
tickets_json = json.dumps(record.tickets)
params = (
record.repo,
record.version,
record.started_at.isoformat(),
tickets_json,
)
return _STAGING_INSERT_SQL, params
def build_archived_insert_sql(record: MigrationRecord) -> tuple[str, tuple]:
"""Build the INSERT SQL and parameters for an archived release record.
Returns:
(sql_string, params_tuple) ready for cursor.execute().
"""
tickets_json = json.dumps(record.tickets)
params = (
record.repo,
record.version,
record.started_at.isoformat(),
tickets_json,
record.released_at.isoformat() if record.released_at else None,
)
return _ARCHIVED_INSERT_SQL, params
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Migrate JSON release files to PostgreSQL"
)
parser.add_argument(
"--source",
type=Path,
required=True,
help="Root directory containing release JSON files",
)
parser.add_argument(
"--dsn",
type=str,
default="",
help="PostgreSQL DSN (e.g. postgresql://user:pass@localhost/db)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print SQL statements without executing them",
)
return parser.parse_args(argv)
def main(argv: list[str] | None = None) -> int:
"""Entry point for the migration script.
Returns:
0 on success, 1 on error.
"""
args = _parse_args(argv)
if not args.source.exists():
print(f"ERROR: source directory does not exist: {args.source}", file=sys.stderr)
return 1
files = collect_json_files(args.source)
print(f"Found {len(files)} JSON file(s) under {args.source}")
statements: list[tuple[str, tuple]] = []
errors: list[str] = []
for path in files:
try:
data = json.loads(path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError) as exc:
errors.append(f"Failed to read {path}: {exc}")
continue
# Determine file type from filename
name = path.name
try:
if is_archived_filename(name) or "released_at" in data:
record = parse_archived_json(data)
sql, params = build_archived_insert_sql(record)
else:
record = parse_staging_json(data)
sql, params = build_staging_insert_sql(record)
except (KeyError, ValueError) as exc:
errors.append(f"Failed to parse {path}: {exc}")
continue
statements.append((sql, params))
if errors:
for err in errors:
print(f"WARNING: {err}", file=sys.stderr)
if args.dry_run:
print(f"\nDry run: {len(statements)} statement(s) would be executed:")
for sql, params in statements:
print(f" SQL: {sql!r}")
print(f" Params: {params}")
return 0
if not args.dsn:
print("ERROR: --dsn is required when not using --dry-run", file=sys.stderr)
return 1
try:
import psycopg # noqa: PLC0415
except ImportError:
print("ERROR: psycopg not installed. Run: pip install psycopg[binary]", file=sys.stderr)
return 1
inserted = 0
with psycopg.connect(args.dsn) as conn:
with conn.cursor() as cur:
for sql, params in statements:
cur.execute(sql, params)
inserted += 1
conn.commit()
print(f"Migration complete: {inserted} record(s) inserted.")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1 @@
"""Billo Release Agent - LangGraph-based release automation."""

View File

View File

@@ -0,0 +1,166 @@
"""Approvals endpoint for resuming interrupted graph threads.
POST /approvals/{thread_id} — resume an interrupted graph with a decision.
GET /approvals/pending — list threads with status="interrupted".
"""
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, Request
from release_agent.api.dependencies import get_db_pool, get_graphs, get_tool_clients, require_operator_token
from release_agent.api.models import (
ApprovalDecision,
ApprovalResponse,
PendingApproval,
PendingApprovalsResponse,
)
router = APIRouter()
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
async def _resume_graph(
*,
graph,
thread_id: str,
decision: str,
tool_clients,
db_pool,
) -> dict:
"""Resume an interrupted graph thread with the given decision.
The decision string is passed as the resume value to the graph.
Thread status is updated in the database.
"""
from release_agent.api.webhooks import _upsert_thread
config = {
"configurable": {
"thread_id": thread_id,
"clients": tool_clients,
}
}
try:
result = await graph.ainvoke({"decision": decision}, config=config)
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="completed", state=result or {})
return result or {}
except Exception as exc:
await _upsert_thread(
db_pool,
thread_id=thread_id,
thread_status="failed",
state={"errors": [str(exc)]},
)
raise
async def _get_thread_graph_name(db_pool, thread_id: str) -> str | None:
"""Look up the graph_name for a thread_id from the database."""
sql = "SELECT graph_name FROM agent_threads WHERE thread_id = %s"
async with db_pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(sql, (thread_id,))
row = await cur.fetchone()
return row[0] if row else None
async def _fetch_pending_threads(db_pool) -> list[tuple]:
"""Query agent_threads for rows with status='interrupted'."""
sql = """
SELECT
thread_id,
graph_name,
interrupt_value,
created_at,
repo_name,
pr_id,
version
FROM agent_threads
WHERE status = 'interrupted'
ORDER BY created_at ASC
"""
async with db_pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(sql)
return await cur.fetchall()
# ---------------------------------------------------------------------------
# GET /approvals/pending
# ---------------------------------------------------------------------------
@router.get("/approvals/pending")
async def list_pending_approvals(
request: Request, # noqa: ARG001
db_pool=Depends(get_db_pool),
_auth: None = Depends(require_operator_token),
) -> PendingApprovalsResponse:
"""Return all graph threads currently awaiting operator approval."""
rows = await _fetch_pending_threads(db_pool)
items = [
PendingApproval(
thread_id=row[0],
graph_name=row[1],
interrupt_value=row[2],
created_at=row[3] if isinstance(row[3], datetime) else datetime.fromisoformat(str(row[3])),
repo_name=row[4],
pr_id=row[5],
version=row[6],
)
for row in rows
]
return PendingApprovalsResponse(items=items, count=len(items))
# ---------------------------------------------------------------------------
# POST /approvals/{thread_id}
# ---------------------------------------------------------------------------
@router.post("/approvals/{thread_id}")
async def submit_approval(
thread_id: str,
body: ApprovalDecision,
request: Request, # noqa: ARG001
graphs=Depends(get_graphs),
tool_clients=Depends(get_tool_clients),
db_pool=Depends(get_db_pool),
_auth: None = Depends(require_operator_token),
) -> ApprovalResponse:
"""Submit an approval decision to resume an interrupted graph thread.
The decision is forwarded as the interrupt resume value to the graph.
Thread status is updated to 'completed' or 'failed' based on result.
"""
# Look up graph_name from DB for the correct graph
graph_name = await _get_thread_graph_name(db_pool, thread_id)
if graph_name is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
graph = graphs.get(graph_name)
if graph is None:
raise HTTPException(status_code=400, detail=f"Unknown graph: {graph_name}")
try:
await _resume_graph(
graph=graph,
thread_id=thread_id,
decision=body.decision,
tool_clients=tool_clients,
db_pool=db_pool,
)
except Exception:
return ApprovalResponse(
thread_id=thread_id,
status="failed",
message=f"Thread {thread_id} failed during resume",
)
return ApprovalResponse(
thread_id=thread_id,
status="resumed",
message=f"Thread {thread_id} resumed with decision '{body.decision}'",
)

View File

@@ -0,0 +1,66 @@
"""FastAPI dependency callables that extract objects from app.state.
Each function accepts a Request and returns the corresponding object stored
on app.state during lifespan startup. Use with Depends() in route handlers.
"""
import hmac
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import APIKeyHeader
from release_agent.config import Settings
from release_agent.graph.dependencies import StagingStore, ToolClients
_operator_token_header = APIKeyHeader(name="X-Operator-Token", auto_error=False)
def get_settings(request: Request) -> Settings:
"""Return the Settings instance stored on app.state."""
return request.app.state.settings
def get_graphs(request: Request) -> dict:
"""Return the compiled graphs dict stored on app.state."""
return request.app.state.graphs
def get_tool_clients(request: Request) -> ToolClients:
"""Return the ToolClients instance stored on app.state."""
return request.app.state.tool_clients
def get_staging_store(request: Request) -> StagingStore:
"""Return the StagingStore instance stored on app.state."""
return request.app.state.staging_store
def get_db_pool(request: Request):
"""Return the database connection pool stored on app.state."""
return request.app.state.db_pool
async def require_operator_token(
request: Request,
token: str | None = Depends(_operator_token_header),
) -> None:
"""Validate the X-Operator-Token header against the configured operator token.
If operator_token is not configured (empty string), authentication is skipped
and all requests are allowed. When configured, uses hmac.compare_digest
for constant-time comparison to prevent timing attacks.
Raises:
HTTPException: 401 if a token is configured but the provided token is
missing or does not match.
"""
configured = request.app.state.settings.operator_token.get_secret_value()
if not configured:
# No token configured — auth disabled
return
if not token or not hmac.compare_digest(token, configured):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing operator token",
)

View File

@@ -0,0 +1,137 @@
"""API request/response Pydantic models for the HTTP layer.
All models are frozen (immutable) following the project immutability policy.
"""
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, ConfigDict
# ---------------------------------------------------------------------------
# Webhook models
# ---------------------------------------------------------------------------
class WebhookResponse(BaseModel):
"""Response returned after a webhook event is accepted."""
model_config = ConfigDict(frozen=True)
thread_id: str
message: str
# ---------------------------------------------------------------------------
# Approval models
# ---------------------------------------------------------------------------
class ApprovalDecision(BaseModel):
"""Body for POST /approvals/{thread_id}."""
model_config = ConfigDict(frozen=True)
decision: Literal["merge", "cancel", "approve", "skip", "trigger"]
class ApprovalResponse(BaseModel):
"""Response returned after an approval decision is processed."""
model_config = ConfigDict(frozen=True)
thread_id: str
status: str
message: str
class PendingApproval(BaseModel):
"""A single interrupted graph thread awaiting operator input."""
model_config = ConfigDict(frozen=True)
thread_id: str
graph_name: str
interrupt_value: str
created_at: datetime
repo_name: str | None = None
pr_id: str | None = None
version: str | None = None
class PendingApprovalsResponse(BaseModel):
"""Response for GET /approvals/pending."""
model_config = ConfigDict(frozen=True)
items: list[PendingApproval]
count: int
# ---------------------------------------------------------------------------
# Health / status models
# ---------------------------------------------------------------------------
class HealthResponse(BaseModel):
"""Response for GET /status."""
model_config = ConfigDict(frozen=True)
status: Literal["ok", "degraded"]
version: str
uptime_seconds: float
# ---------------------------------------------------------------------------
# Release / staging models
# ---------------------------------------------------------------------------
class ReleaseVersionListResponse(BaseModel):
"""Response for GET /releases/{repo}."""
model_config = ConfigDict(frozen=True)
repo: str
versions: list[str]
class StagingResponse(BaseModel):
"""Response for GET /staging."""
model_config = ConfigDict(frozen=True)
repo: str
staging: dict | None
# ---------------------------------------------------------------------------
# Manual trigger models
# ---------------------------------------------------------------------------
class ManualTriggerResponse(BaseModel):
"""Response for manual trigger endpoints."""
model_config = ConfigDict(frozen=True)
thread_id: str
message: str
class ManualReleaseRequest(BaseModel):
"""Request body for POST /manual/release."""
model_config = ConfigDict(frozen=True)
repo: str
# ---------------------------------------------------------------------------
# Error model
# ---------------------------------------------------------------------------
class ErrorResponse(BaseModel):
"""Standard error response envelope."""
model_config = ConfigDict(frozen=True)
error: str
detail: str | None = None

View File

@@ -0,0 +1,264 @@
"""Slack interactive messages endpoint.
POST /slack/interactions — receives Slack button click payloads, verifies
the signing secret, extracts thread_id + decision, and resumes the
appropriate graph thread in the background.
Slack requires a 200 response within 3 seconds. The graph resume is
therefore scheduled as a background task.
"""
import asyncio
import hashlib
import hmac
import json
import logging
import time
import urllib.parse
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import JSONResponse
from release_agent.api.dependencies import get_db_pool, get_graphs, get_settings, get_tool_clients
logger = logging.getLogger(__name__)
router = APIRouter()
# ---------------------------------------------------------------------------
# Pure helper: signature verification
# ---------------------------------------------------------------------------
_MAX_TIMESTAMP_AGE_SECONDS = 300 # 5 minutes — Slack's recommendation
_ALLOWED_DECISIONS = frozenset({"merge", "cancel", "approve", "skip", "trigger", "yes", "no"})
def _verify_slack_signature(
*,
signing_secret: str,
timestamp: str,
body: str,
signature: str,
current_time: float | None = None,
) -> bool:
"""Verify a Slack request signature using HMAC-SHA256.
Also rejects requests with timestamps older than 5 minutes
to prevent replay attacks.
Args:
signing_secret: The app's Slack signing secret.
timestamp: Value of the X-Slack-Request-Timestamp header.
body: Raw request body as a string.
signature: Value of the X-Slack-Signature header.
current_time: Current unix timestamp (injectable for testing).
Returns:
True if the signature is valid and fresh, False otherwise.
"""
if not signature or not signature.startswith("v0="):
return False
# Replay attack prevention
try:
request_ts = int(timestamp)
except ValueError:
return False
now = current_time if current_time is not None else time.time()
if abs(now - request_ts) > _MAX_TIMESTAMP_AGE_SECONDS:
return False
base_string = f"v0:{timestamp}:{body}"
computed_hash = hmac.new(
signing_secret.encode(),
base_string.encode(),
hashlib.sha256,
).hexdigest()
expected_sig = f"v0={computed_hash}"
return hmac.compare_digest(expected_sig, signature)
# ---------------------------------------------------------------------------
# Internal: parse button payload
# ---------------------------------------------------------------------------
def _parse_button_payload(form_body: str) -> dict[str, Any] | None:
"""Parse a URL-encoded Slack button action payload.
Args:
form_body: Raw URL-encoded form body from Slack.
Returns:
Dict with "thread_id" and "decision" keys, or None if parsing fails.
"""
try:
parsed = urllib.parse.parse_qs(form_body)
payload_json = parsed.get("payload", [None])[0]
if not payload_json:
return None
payload = json.loads(payload_json)
actions = payload.get("actions", [])
if not actions:
return None
action = actions[0]
value = action.get("value", "")
# value format: "{thread_id}:{decision}"
if ":" not in value:
return None
parts = value.split(":", 1)
thread_id = parts[0]
decision = parts[1]
user = payload.get("user") or {}
user_name = user.get("name", user.get("id", "unknown"))
return {
"thread_id": thread_id,
"decision": decision,
"user_name": user_name,
}
except (json.JSONDecodeError, KeyError, IndexError, ValueError) as exc:
logger.warning("Failed to parse Slack button payload: %s", exc)
return None
# ---------------------------------------------------------------------------
# Internal: background graph resume
# ---------------------------------------------------------------------------
async def _background_resume(
*,
thread_id: str,
decision: str,
graphs: dict,
tool_clients: Any,
db_pool: Any,
) -> None:
"""Fetch the graph for thread_id from DB and resume it."""
from release_agent.api.approvals import _get_thread_graph_name, _resume_graph
graph_name = await _get_thread_graph_name(db_pool, thread_id)
if graph_name is None:
logger.warning("slack_interactions: thread %s not found in DB", thread_id)
return
graph = graphs.get(graph_name)
if graph is None:
logger.warning(
"slack_interactions: unknown graph '%s' for thread %s",
graph_name,
thread_id,
)
return
try:
await _resume_graph(
graph=graph,
thread_id=thread_id,
decision=decision,
tool_clients=tool_clients,
db_pool=db_pool,
)
except Exception as exc:
logger.error(
"slack_interactions: failed to resume thread %s: %s",
thread_id,
exc,
)
# ---------------------------------------------------------------------------
# Endpoint: POST /slack/interactions
# ---------------------------------------------------------------------------
@router.post("/slack/interactions")
async def slack_interactions(
request: Request,
graphs: dict = Depends(get_graphs),
tool_clients: Any = Depends(get_tool_clients),
db_pool: Any = Depends(get_db_pool),
settings: Any = Depends(get_settings),
) -> JSONResponse:
"""Receive and process Slack interactive button clicks.
1. Verify the Slack signing secret.
2. Parse the payload to extract thread_id and decision.
3. Schedule the graph resume as a background task.
4. Return 200 immediately (Slack requires a fast response).
Returns:
200 {"ok": True} on success.
400 if required headers are missing.
403 if signature verification fails.
"""
raw_body = await request.body()
body_str = raw_body.decode("utf-8")
timestamp = request.headers.get("X-Slack-Request-Timestamp")
signature = request.headers.get("X-Slack-Signature")
if not timestamp:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing X-Slack-Request-Timestamp header",
)
if not signature:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing X-Slack-Signature header",
)
signing_secret = settings.slack_signing_secret.get_secret_value()
if not signing_secret:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Slack signing secret not configured",
)
valid = _verify_slack_signature(
signing_secret=signing_secret,
timestamp=timestamp,
body=body_str,
signature=signature,
)
if not valid:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid Slack signature",
)
parsed = _parse_button_payload(body_str)
if not parsed:
logger.warning("slack_interactions: could not parse payload")
return JSONResponse({"ok": True})
thread_id = parsed["thread_id"]
decision = parsed["decision"]
if decision not in _ALLOWED_DECISIONS:
logger.warning("slack_interactions: unexpected decision '%s', ignoring", decision)
return JSONResponse({"ok": True})
# Schedule the graph resume as a background task (Slack needs 200 fast)
task = asyncio.create_task(
_background_resume(
thread_id=thread_id,
decision=decision,
graphs=graphs,
tool_clients=tool_clients,
db_pool=db_pool,
)
)
if hasattr(request.app.state, "background_tasks"):
request.app.state.background_tasks.add(task)
task.add_done_callback(request.app.state.background_tasks.discard)
return JSONResponse({"ok": True})

View File

@@ -0,0 +1,153 @@
"""Status, releases, staging, and manual trigger endpoints.
GET /status — health check
GET /releases/{repo} — list release versions for a repo
GET /staging — current staging release for a repo
POST /manual/pr/{pr_id} — manually trigger PR processing
POST /manual/release — manually trigger a release run
"""
import asyncio
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, Request, status
from fastapi.responses import JSONResponse
from release_agent.api.dependencies import (
get_db_pool,
get_graphs,
get_staging_store,
get_tool_clients,
require_operator_token,
)
from release_agent.api.models import (
HealthResponse,
ManualReleaseRequest,
ManualTriggerResponse,
ReleaseVersionListResponse,
StagingResponse,
)
router = APIRouter()
_VERSION = "0.1.0"
# ---------------------------------------------------------------------------
# GET /status
# ---------------------------------------------------------------------------
@router.get("/status")
async def get_status(request: Request) -> HealthResponse:
"""Return the health status of the agent service."""
started_at: datetime = request.app.state.started_at
uptime = (datetime.now(tz=timezone.utc) - started_at).total_seconds()
return HealthResponse(
status="ok",
version=_VERSION,
uptime_seconds=uptime,
)
# ---------------------------------------------------------------------------
# GET /releases/{repo}
# ---------------------------------------------------------------------------
@router.get("/releases/{repo}")
async def list_release_versions(
repo: str,
staging_store=Depends(get_staging_store),
) -> ReleaseVersionListResponse:
"""List all known release versions for the given repository."""
versions = await staging_store.list_versions(repo)
return ReleaseVersionListResponse(repo=repo, versions=versions)
# ---------------------------------------------------------------------------
# GET /staging
# ---------------------------------------------------------------------------
@router.get("/staging")
async def get_staging(
repo: str,
staging_store=Depends(get_staging_store),
) -> StagingResponse:
"""Return the current staging release for the given repository."""
staging_obj = await staging_store.load(repo)
staging_dict = staging_obj.model_dump(mode="json") if staging_obj is not None else None
return StagingResponse(repo=repo, staging=staging_dict)
# ---------------------------------------------------------------------------
# POST /manual/pr/{pr_id}
# ---------------------------------------------------------------------------
@router.post("/manual/pr/{pr_id}", status_code=status.HTTP_202_ACCEPTED)
async def manual_pr_trigger(
pr_id: str,
request: Request,
graphs=Depends(get_graphs),
tool_clients=Depends(get_tool_clients),
db_pool=Depends(get_db_pool),
_auth: None = Depends(require_operator_token),
) -> ManualTriggerResponse:
"""Manually trigger PR processing for the given PR ID."""
from release_agent.api.webhooks import _run_graph
thread_id = str(uuid.uuid4())
initial_state = {"pr_id": pr_id}
task = asyncio.create_task(
_run_graph(
graph=graphs["pr_completed"],
initial_state=initial_state,
thread_id=thread_id,
tool_clients=tool_clients,
db_pool=db_pool,
)
)
request.app.state.background_tasks.add(task)
task.add_done_callback(request.app.state.background_tasks.discard)
return ManualTriggerResponse(
thread_id=thread_id,
message=f"PR {pr_id} processing scheduled as thread {thread_id}",
)
# ---------------------------------------------------------------------------
# POST /manual/release
# ---------------------------------------------------------------------------
@router.post("/manual/release", status_code=status.HTTP_202_ACCEPTED)
async def manual_release_trigger(
body: ManualReleaseRequest,
request: Request,
graphs=Depends(get_graphs),
tool_clients=Depends(get_tool_clients),
db_pool=Depends(get_db_pool),
_auth: None = Depends(require_operator_token),
) -> ManualTriggerResponse:
"""Manually trigger a release run for the given repository."""
from release_agent.api.webhooks import _run_graph
thread_id = str(uuid.uuid4())
initial_state = {"repo_name": body.repo}
task = asyncio.create_task(
_run_graph(
graph=graphs["release"],
initial_state=initial_state,
thread_id=thread_id,
tool_clients=tool_clients,
db_pool=db_pool,
)
)
request.app.state.background_tasks.add(task)
task.add_done_callback(request.app.state.background_tasks.discard)
return ManualTriggerResponse(
thread_id=thread_id,
message=f"Release for repo '{body.repo}' scheduled as thread {thread_id}",
)

View File

@@ -0,0 +1,195 @@
"""Webhook endpoint for Azure DevOps events.
POST /webhooks/azdo — validates secret, parses payload, filters events,
schedules graph execution as a background task, returns 202.
"""
import asyncio
import hmac
import logging
import uuid
from typing import Annotated
logger = logging.getLogger(__name__)
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from release_agent.api.dependencies import get_db_pool, get_graphs, get_settings, get_tool_clients
from release_agent.api.models import WebhookResponse
from release_agent.models.webhook import WebhookPayload
router = APIRouter()
# ---------------------------------------------------------------------------
# Secret validation helper (pure function — easily unit-tested)
# ---------------------------------------------------------------------------
def _validate_webhook_secret(header: str | None, expected: str) -> bool:
"""Return True if header matches expected using constant-time comparison.
Returns False if:
- expected is empty (auth misconfigured, reject all)
- header is None (no header sent)
- header does not match expected
"""
if not expected:
return False
if header is None:
return False
return hmac.compare_digest(header, expected)
# ---------------------------------------------------------------------------
# POST /webhooks/azdo
# ---------------------------------------------------------------------------
@router.post("/webhooks/azdo")
async def receive_azdo_webhook(
request: Request,
x_webhook_secret: Annotated[str | None, Header()] = None,
settings=Depends(get_settings),
graphs=Depends(get_graphs),
tool_clients=Depends(get_tool_clients),
db_pool=Depends(get_db_pool),
):
"""Receive an Azure DevOps webhook event.
Validates the X-Webhook-Secret header, parses the payload,
filters to completed PR events only, then schedules graph execution.
Returns 202 with thread_id for accepted events.
Returns 200 with "event ignored" for non-completed PR events.
Returns 401 for missing or invalid secret.
Returns 422 for invalid payload structure.
"""
expected_secret = settings.webhook_secret.get_secret_value()
if not _validate_webhook_secret(x_webhook_secret, expected_secret):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing webhook secret",
)
body = await request.json()
try:
payload = WebhookPayload.model_validate(body)
except (ValidationError, Exception) as exc:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid payload: {exc}",
) from exc
# Only process completed PRs
if payload.resource.status != "completed":
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"message": "event ignored: PR not completed", "status": payload.resource.status},
)
# Generate a unique thread ID for this execution
thread_id = str(uuid.uuid4())
# Build initial state from webhook payload
initial_state = {
"webhook_payload": body,
"pr_id": str(payload.resource.pull_request_id),
"repo_name": payload.resource.repository.name,
}
graph = graphs["pr_completed"]
# Schedule as a background task (non-blocking)
settings = request.app.state.settings
task = asyncio.create_task(
_run_graph(
graph=graph,
initial_state=initial_state,
thread_id=thread_id,
tool_clients=tool_clients,
db_pool=db_pool,
repos_base_dir=settings.repos_base_dir,
graph_name="pr_completed",
default_jira_project=settings.default_jira_project,
)
)
request.app.state.background_tasks.add(task)
task.add_done_callback(request.app.state.background_tasks.discard)
return JSONResponse(
status_code=status.HTTP_202_ACCEPTED,
content=WebhookResponse(
thread_id=thread_id,
message=f"Webhook accepted, scheduled as thread {thread_id}",
).model_dump(),
)
# ---------------------------------------------------------------------------
# Internal graph runner
# ---------------------------------------------------------------------------
async def _run_graph(
*,
graph,
initial_state: dict,
thread_id: str,
tool_clients,
db_pool,
repos_base_dir: str = "",
graph_name: str = "",
default_jira_project: str = "ALLPOST",
) -> None:
"""Invoke the graph and persist thread status to the database."""
repo_name = initial_state.get("repo_name", "")
pr_id = initial_state.get("pr_id", "")
config = {
"configurable": {
"thread_id": thread_id,
"clients": tool_clients,
"repos_base_dir": repos_base_dir,
"default_jira_project": default_jira_project,
}
}
try:
await _upsert_thread(
db_pool, thread_id=thread_id, thread_status="running",
state=initial_state, graph_name=graph_name,
repo_name=repo_name, pr_id=pr_id,
)
result = await graph.ainvoke(initial_state, config=config)
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="completed", state=result or {})
except Exception as exc:
logger.exception("Graph execution failed for thread %s", thread_id)
error_state = {**initial_state, "errors": [str(exc)]}
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="failed", state=error_state)
async def _upsert_thread(
pool,
*,
thread_id: str,
thread_status: str,
state: dict,
graph_name: str = "",
repo_name: str = "",
pr_id: str = "",
) -> None:
"""Insert or update a thread record in agent_threads."""
import json
sql = """
INSERT INTO agent_threads (thread_id, graph_name, repo_name, pr_id, status, state, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, NOW())
ON CONFLICT (thread_id) DO UPDATE
SET status = EXCLUDED.status,
state = EXCLUDED.state,
updated_at = EXCLUDED.updated_at
"""
async with pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(sql, (
thread_id, graph_name, repo_name, pr_id,
thread_status, json.dumps(state),
))

View File

@@ -0,0 +1,46 @@
"""Branch name parsing utilities for extracting Jira ticket IDs.
Pure functions only - no side effects, no mutation.
"""
import re
# Matches Jira-style ticket IDs: one or more uppercase letters/digits (starting
# with a letter) followed by a dash and one or more digits.
# Examples: ALLPOST-4229, BILL-42, MY-1, AB2-100
_TICKET_PATTERN = re.compile(r"(?<![A-Z0-9])([A-Z][A-Z0-9]+-\d+)(?![A-Z0-9-])")
# The refs/heads/ prefix that Azure DevOps sometimes includes in branch names.
_REFS_HEADS_PREFIX = "refs/heads/"
def strip_refs_prefix(branch: str) -> str:
"""Remove the 'refs/heads/' prefix from a branch name if present.
Returns the branch name unchanged if the prefix is not present.
Does not strip 'refs/tags/' or any other prefix.
"""
if branch.startswith(_REFS_HEADS_PREFIX):
return branch[len(_REFS_HEADS_PREFIX):]
return branch
def parse_branch(branch: str) -> tuple[str | None, bool]:
"""Parse a branch name and extract a Jira ticket ID if present.
Handles the 'refs/heads/' prefix automatically.
Returns:
A tuple of (ticket_id, has_ticket) where ticket_id is the Jira ID
string (e.g. "ALLPOST-4229") or None, and has_ticket is a bool.
"""
if not branch:
return None, False
normalized = strip_refs_prefix(branch)
match = _TICKET_PATTERN.search(normalized)
if match:
return match.group(1), True
return None, False

110
src/release_agent/config.py Normal file
View File

@@ -0,0 +1,110 @@
"""Application configuration using pydantic-settings.
All secrets are stored as SecretStr to prevent accidental logging.
"""
from pydantic import Field, SecretStr, computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings loaded from environment variables.
Required variables:
AZDO_ORGANIZATION - Azure DevOps organization name
AZDO_PROJECT - Azure DevOps project name
AZDO_PAT - Azure DevOps personal access token
ANTHROPIC_API_KEY - Anthropic API key for Claude
POSTGRES_DSN - PostgreSQL connection string
JIRA_EMAIL - Jira account email address
JIRA_API_TOKEN - Jira API token
SLACK_WEBHOOK_URL - Slack incoming webhook URL
Optional variables:
PORT - HTTP server port (default: 8000, range: 1-65535)
JIRA_BASE_URL - Jira base URL (default: https://billolife.atlassian.net)
CLAUDE_REVIEW_MODEL - Anthropic model for PR review (default: claude-sonnet-4-20250514)
WATCHED_REPOS - Comma-separated repo names to poll for PRs
PR_POLL_INTERVAL_SECONDS - Seconds between PR polling runs (default: 300)
PR_POLL_TARGET_BRANCH - Target branch ref to filter PRs (default: refs/heads/develop)
PR_POLL_ENABLED - Enable PR polling background task (default: False)
DEFAULT_JIRA_PROJECT - Default Jira project key for auto-created tickets (default: ALLPOST)
AUTO_CREATE_TICKET_ENABLED - Auto-create Jira ticket when PR has no ticket (default: True)
"""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
azdo_organization: str
azdo_project: str
azdo_pat: SecretStr
anthropic_api_key: SecretStr = Field(default=SecretStr("")) # Optional: only needed if using API directly
postgres_dsn: SecretStr
port: int = Field(default=8000, ge=1, le=65535)
# Jira settings
jira_base_url: str = "https://billolife.atlassian.net"
jira_email: str
jira_api_token: SecretStr
# Slack settings
slack_webhook_url: SecretStr = Field(default=SecretStr(""))
slack_bot_token: SecretStr = Field(default=SecretStr(""))
slack_signing_secret: SecretStr = Field(default=SecretStr(""))
slack_channel_id: str = ""
# CI polling settings
ci_poll_interval_seconds: int = 30
ci_poll_max_wait_seconds: int = 1800
# Claude settings
claude_review_model: str = "claude-sonnet-4-20250514"
# Local repo settings
repos_base_dir: str = "" # Base directory containing Billo repos (e.g., /c/Users/yaoji/git/Billo)
# Webhook settings
webhook_secret: SecretStr
# Operator API settings
operator_token: SecretStr = Field(default=SecretStr(""))
# PR polling settings
watched_repos: str = ""
pr_poll_interval_seconds: int = 300
pr_poll_target_branch: str = "refs/heads/develop"
pr_poll_enabled: bool = False
# Auto-create Jira ticket settings
default_jira_project: str = "ALLPOST"
auto_create_ticket_enabled: bool = True
@computed_field # type: ignore[misc]
@property
def watched_repos_list(self) -> list[str]:
"""Return watched_repos as a list, splitting on commas and stripping whitespace."""
if not self.watched_repos:
return []
return [r.strip() for r in self.watched_repos.split(",") if r.strip()]
@computed_field # type: ignore[misc]
@property
def azdo_base_url(self) -> str:
"""Base URL for the Azure DevOps organization."""
return f"https://dev.azure.com/{self.azdo_organization}"
@computed_field # type: ignore[misc]
@property
def azdo_api_url(self) -> str:
"""Base API URL for the Azure DevOps project."""
return f"https://dev.azure.com/{self.azdo_organization}/{self.azdo_project}/_apis"
@computed_field # type: ignore[misc]
@property
def azdo_vsrm_api_url(self) -> str:
"""Base API URL for the Azure DevOps VSRM release management."""
return f"https://vsrm.dev.azure.com/{self.azdo_organization}/{self.azdo_project}/_apis"

View File

@@ -0,0 +1,59 @@
"""Custom exception hierarchy for the release agent.
All exceptions inherit from ReleaseAgentError so callers can catch
the entire hierarchy with a single except clause when needed.
"""
class ReleaseAgentError(Exception):
"""Base exception for all release agent errors."""
class ServiceError(ReleaseAgentError):
"""An HTTP error response from an external service.
Attributes:
service: Name of the service that returned the error (e.g. "jira").
status_code: HTTP status code from the response.
detail: Optional human-readable detail message.
"""
def __init__(self, *, service: str, status_code: int, detail: str | None) -> None:
self.service = service
self.status_code = status_code
self.detail = detail
super().__init__(f"[{service}] HTTP {status_code}: {detail}")
class AuthenticationError(ServiceError):
"""Authentication or authorisation failure (401 / 403)."""
def __init__(self, *, service: str, status_code: int = 401, detail: str | None = None) -> None:
super().__init__(service=service, status_code=status_code, detail=detail)
class NotFoundError(ServiceError):
"""Requested resource was not found (404)."""
def __init__(self, *, service: str, detail: str | None = None) -> None:
super().__init__(service=service, status_code=404, detail=detail)
class RateLimitError(ServiceError):
"""Rate limit exceeded (429).
Attributes:
retry_after: Seconds to wait before retrying, or None if not provided.
"""
def __init__(self, *, service: str, retry_after: int | None) -> None:
self.retry_after = retry_after
detail = f"retry after {retry_after}s" if retry_after is not None else None
super().__init__(service=service, status_code=429, detail=detail)
class ServiceUnavailableError(ServiceError):
"""Service temporarily unavailable (503)."""
def __init__(self, *, service: str, detail: str | None = None) -> None:
super().__init__(service=service, status_code=503, detail=detail)

View File

View File

@@ -0,0 +1,180 @@
"""CI/CD graph nodes for triggering, polling, and notifying CI builds.
Each node is an async function (state, config) -> dict.
Nodes never mutate state — they return new dicts.
External clients are accessed via config["configurable"]["clients"].
"""
import logging
from typing import Any
from release_agent.exceptions import ReleaseAgentError
from release_agent.graph.dependencies import ToolClients
from release_agent.graph.polling import poll_until
from release_agent.models.build import BuildStatus
from release_agent.tools.slack import _build_ci_status_blocks
logger = logging.getLogger(__name__)
_DEFAULT_BRANCH = "refs/heads/main"
_CI_POLL_INTERVAL = 30
_CI_POLL_MAX_WAIT = 1800
def _get_clients(config: dict) -> ToolClients:
return config["configurable"]["clients"]
def _get_settings(config: dict):
return config["configurable"].get("settings")
# ---------------------------------------------------------------------------
# Node: trigger_ci_build
# ---------------------------------------------------------------------------
async def trigger_ci_build(state: dict[str, Any], config: dict) -> dict:
"""Trigger the CI build pipeline for the repository.
Finds the first available pipeline for the repo and triggers it on
the appropriate branch (main for post-merge, release branch otherwise).
Args:
state: Graph state dict.
config: LangGraph config dict with "configurable" key.
Returns:
Dict with ci_build_id on success, or errors on failure.
"""
clients = _get_clients(config)
repo_name = state.get("repo_name", "")
version = state.get("version", "")
# Determine branch: main if version present (release), develop otherwise (PR merge)
if version:
branch = "refs/heads/main"
else:
branch = "refs/heads/develop"
try:
pipelines = await clients.azdo.list_build_pipelines(repo=repo_name)
except ReleaseAgentError as exc:
return {"errors": [f"trigger_ci_build: failed to list pipelines: {exc}"]}
if not pipelines:
return {"errors": [f"trigger_ci_build: no pipelines found for repo '{repo_name}'"]}
pipeline = pipelines[0]
try:
result = await clients.azdo.trigger_pipeline(
pipeline_id=pipeline.id,
branch=branch,
)
raw_id = result.get("id")
if not isinstance(raw_id, int):
return {"errors": [f"trigger_ci_build: unexpected build id type: {raw_id!r}"]}
return {
"ci_build_id": raw_id,
"messages": [f"CI build triggered: pipeline {pipeline.id}, build {raw_id}"],
}
except ReleaseAgentError as exc:
return {"errors": [f"trigger_ci_build: failed to trigger pipeline {pipeline.id}: {exc}"]}
# ---------------------------------------------------------------------------
# Node: poll_ci_build
# ---------------------------------------------------------------------------
async def poll_ci_build(state: dict[str, Any], config: dict) -> dict:
"""Poll the CI build until completion or timeout.
Uses the polling utility with configurable interval and max wait.
On timeout, appends an error and returns the last known status.
Args:
state: Graph state dict containing ci_build_id.
config: LangGraph config dict.
Returns:
Dict with ci_build_status, ci_build_result, and ci_build_url on
completion, or errors on timeout/failure.
"""
clients = _get_clients(config)
build_id = state.get("ci_build_id")
if not build_id:
return {"errors": ["poll_ci_build: ci_build_id not set in state"]}
settings = _get_settings(config)
interval = getattr(settings, "ci_poll_interval_seconds", _CI_POLL_INTERVAL) if settings else _CI_POLL_INTERVAL
max_wait = getattr(settings, "ci_poll_max_wait_seconds", _CI_POLL_MAX_WAIT) if settings else _CI_POLL_MAX_WAIT
async def poll_fn() -> BuildStatus:
return await clients.azdo.get_build_status(build_id=build_id)
def is_done(bs: BuildStatus) -> bool:
return bs is not None and bs.status == "completed"
last_status, completed = await poll_until(
poll_fn=poll_fn,
is_done=is_done,
interval_seconds=interval,
max_wait_seconds=max_wait,
)
if last_status is None:
return {"errors": [f"poll_ci_build: build {build_id} polling failed (no status returned)"]}
result: dict = {
"ci_build_status": last_status.status,
"ci_build_result": last_status.result,
"ci_build_url": last_status.build_url,
}
if not completed:
result["errors"] = [
f"poll_ci_build: build {build_id} did not complete within timeout "
f"(last status: {last_status.status})"
]
return result
# ---------------------------------------------------------------------------
# Node: notify_ci_result
# ---------------------------------------------------------------------------
async def notify_ci_result(state: dict[str, Any], config: dict) -> dict:
"""Send a Slack notification with the CI build result.
Non-critical: errors are appended rather than re-raised.
Args:
state: Graph state dict.
config: LangGraph config dict.
Returns:
Dict with messages on success, or errors on failure.
"""
clients = _get_clients(config)
repo_name = state.get("repo_name", "unknown")
build_result = state.get("ci_build_result", "unknown")
build_url = state.get("ci_build_url")
branch = state.get("version", "main")
blocks = _build_ci_status_blocks(
repo=repo_name,
branch=branch,
status=build_result,
build_url=build_url,
)
status_label = "passed" if build_result == "succeeded" else "failed"
text = f"CI build {status_label} for {repo_name}"
try:
await clients.slack.send_notification(text=text, blocks=blocks)
return {"messages": [text]}
except ReleaseAgentError as exc:
return {"errors": [f"notify_ci_result: {exc}"]}
except Exception as exc:
return {"errors": [f"notify_ci_result: unexpected error: {exc}"]}

View File

@@ -0,0 +1,171 @@
"""Graph dependency types: ToolClients dataclass and StagingStore protocol.
ToolClients is a frozen dataclass injected via config["configurable"]["clients"].
StagingStore is a Protocol with file-based implementation JsonFileStagingStore.
All StagingStore methods are async to support both file-based and DB backends.
"""
import json
from dataclasses import dataclass
from datetime import date
from pathlib import Path
from typing import Any, Protocol
from release_agent.models.release import ArchivedRelease, StagingRelease
# ---------------------------------------------------------------------------
# ToolClients
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class ToolClients:
"""Frozen container for all external service clients.
Injected into graph nodes via config["configurable"]["clients"].
Attributes:
azdo: AzDoClient instance.
jira: JiraClient instance.
slack: SlackClient instance.
reviewer: ClaudeReviewer instance.
"""
azdo: Any
jira: Any
slack: Any
reviewer: Any
# ---------------------------------------------------------------------------
# StagingStore Protocol
# ---------------------------------------------------------------------------
class StagingStore(Protocol):
"""Protocol for persistent staging release storage.
Implementations handle reading and writing StagingRelease and
ArchivedRelease objects. All methods are async to support both
file-based and PostgreSQL backends.
"""
async def load(self, repo: str) -> StagingRelease | None:
"""Load the current staging release for the given repository.
Returns None if no staging release exists.
"""
... # pragma: no cover
async def save(self, release: StagingRelease) -> None:
"""Persist a staging release, overwriting any existing entry."""
... # pragma: no cover
async def archive(self, release: StagingRelease, released_at: date) -> None:
"""Archive the staging release with the given release date.
Creates an archive entry and removes the staging entry.
"""
... # pragma: no cover
async def list_versions(self, repo: str) -> list[str]:
"""Return all version strings known for the given repository.
Includes both current staging version (if any) and all archived versions.
"""
... # pragma: no cover
# ---------------------------------------------------------------------------
# JsonFileStagingStore
# ---------------------------------------------------------------------------
class JsonFileStagingStore:
"""File-based StagingStore implementation using JSON files.
Files are stored in a single directory:
- Staging: <directory>/<repo>.json
- Archive: <directory>/<repo>_<version>_<date>.json
Args:
directory: Path to the directory for storing JSON files.
Created automatically if it does not exist.
"""
def __init__(self, *, directory: Path) -> None:
self._dir = directory
self._dir.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------------------------
# Public interface (async)
# ------------------------------------------------------------------
async def load(self, repo: str) -> StagingRelease | None:
"""Load the current staging release for the given repository."""
path = self._staging_path(repo)
if not path.exists():
return None
data = json.loads(path.read_text(encoding="utf-8"))
return StagingRelease.model_validate(data)
async def save(self, release: StagingRelease) -> None:
"""Persist a staging release to disk, overwriting any existing file."""
path = self._staging_path(release.repo)
path.write_text(
release.model_dump_json(),
encoding="utf-8",
)
async def archive(self, release: StagingRelease, released_at: date) -> None:
"""Archive the staging release and remove the staging file."""
archived = ArchivedRelease(
version=release.version,
repo=release.repo,
started_at=release.started_at,
tickets=list(release.tickets),
released_at=released_at,
)
archive_path = self._archive_path(release.repo, release.version, released_at)
archive_path.write_text(
archived.model_dump_json(),
encoding="utf-8",
)
# Remove staging file if it exists
staging_path = self._staging_path(release.repo)
if staging_path.exists():
staging_path.unlink()
async def list_versions(self, repo: str) -> list[str]:
"""Return all version strings known for the given repository."""
versions: list[str] = []
# Check current staging file
staging = await self.load(repo)
if staging is not None:
versions.append(staging.version)
# Check archived files
prefix = f"{repo}_"
for path in self._dir.iterdir():
if path.name.startswith(prefix) and path.name.endswith(".json"):
try:
data = json.loads(path.read_text(encoding="utf-8"))
v = data.get("version")
if v and v not in versions:
versions.append(v)
except (json.JSONDecodeError, KeyError):
continue
return versions
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _staging_path(self, repo: str) -> Path:
"""Return the path to the staging file for the given repo."""
return self._dir / f"{repo}.json"
def _archive_path(self, repo: str, version: str, released_at: date) -> Path:
"""Return the path to an archive file."""
date_str = released_at.isoformat()
return self._dir / f"{repo}_{version}_{date_str}.json"

View File

@@ -0,0 +1,44 @@
"""Full Cycle graph: composes pr_completed and release as subgraph nodes.
The full cycle graph handles the complete lifecycle from a PR merge webhook
through to a released version. After the PR completed subgraph, a conditional
edge routes to the release subgraph if continue_to_release is True.
"""
from langgraph.graph import END, START, StateGraph
from release_agent.graph.pr_completed import build_pr_completed_graph
from release_agent.graph.release import build_release_graph
from release_agent.graph.routing import should_continue_to_release
from release_agent.state import ReleaseState
def build_full_cycle_graph():
"""Assemble and compile the Full Cycle StateGraph.
The graph consists of:
- pr_completed subgraph: handles webhook parsing, code review, and PR merge
- A conditional edge routing to the release subgraph or END
- release subgraph: handles staging confirmation, release PR, and pipeline triggers
Returns the compiled graph ready for invocation.
"""
pr_completed = build_pr_completed_graph()
release = build_release_graph()
graph = StateGraph(ReleaseState)
graph.add_node("pr_completed", pr_completed)
graph.add_node("release", release)
graph.add_edge(START, "pr_completed")
graph.add_conditional_edges(
"pr_completed",
should_continue_to_release,
{"yes": "release", "no": END},
)
graph.add_edge("release", END)
return graph.compile()

View File

@@ -0,0 +1,91 @@
"""Reusable async polling utility for CI/CD status checks.
poll_until runs a poll_fn repeatedly until is_done returns True,
a timeout is reached, or too many consecutive failures occur.
"""
import asyncio
import logging
from typing import Any, Callable, TypeVar
logger = logging.getLogger(__name__)
T = TypeVar("T")
_MAX_CONSECUTIVE_FAILURES = 3
async def poll_until(
*,
poll_fn: Callable[[], Any],
is_done: Callable[[Any], bool],
interval_seconds: float = 30,
max_wait_seconds: float = 1800,
sleep_fn: Callable[[float], Any] | None = None,
) -> tuple[Any, bool]:
"""Poll poll_fn until is_done returns True, timeout, or too many failures.
Args:
poll_fn: Async callable that fetches the latest status.
is_done: Pure function that returns True when polling should stop.
interval_seconds: Seconds to wait between polls (default: 30).
max_wait_seconds: Maximum total seconds before giving up (default: 1800).
sleep_fn: Async sleep function to inject (default: asyncio.sleep).
Inject a no-op in tests to avoid real waits.
Returns:
A tuple (last_result, completed_within_timeout) where:
- last_result is the last value returned by poll_fn (or None if every
call raised an exception).
- completed_within_timeout is True if is_done returned True before the
timeout and before consecutive failure abort.
"""
if sleep_fn is None:
sleep_fn = asyncio.sleep
elapsed: float = 0.0
consecutive_failures = 0
last_result: Any = None
while True:
try:
last_result = await poll_fn()
consecutive_failures = 0
except Exception as exc:
consecutive_failures += 1
logger.warning(
"poll_until: poll_fn raised exception (consecutive=%d): %s",
consecutive_failures,
exc,
)
if consecutive_failures >= _MAX_CONSECUTIVE_FAILURES:
logger.error(
"poll_until: aborting after %d consecutive failures",
_MAX_CONSECUTIVE_FAILURES,
)
return None, False
# Sleep before retry
await sleep_fn(interval_seconds)
elapsed += interval_seconds
if elapsed >= max_wait_seconds:
return last_result, False
continue
if is_done(last_result):
return last_result, True
# Check if we've already used the full budget
if elapsed >= max_wait_seconds:
return last_result, False
# Sleep and advance elapsed time
remaining = max_wait_seconds - elapsed
sleep_time = min(interval_seconds, remaining)
if sleep_time <= 0:
return last_result, False
await sleep_fn(sleep_time)
elapsed += sleep_time
if elapsed >= max_wait_seconds:
return last_result, False

View File

@@ -0,0 +1,148 @@
"""PostgreSQL-backed StagingStore implementation.
Uses psycopg async connection pool. Tables:
- staging_releases: current in-progress releases per repo
- archived_releases: completed releases (immutable history)
"""
import json
from datetime import date
from release_agent.models.release import ArchivedRelease, StagingRelease
from release_agent.models.ticket import TicketEntry
_INSERT_STAGING_SQL = """
INSERT INTO staging_releases (repo, version, started_at, tickets)
VALUES (%s, %s, %s, %s)
ON CONFLICT (repo) DO UPDATE
SET version = EXCLUDED.version,
started_at = EXCLUDED.started_at,
tickets = EXCLUDED.tickets,
updated_at = NOW()
"""
_SELECT_STAGING_SQL = """
SELECT repo, version, started_at, tickets
FROM staging_releases
WHERE repo = %s
"""
_INSERT_ARCHIVED_SQL = """
INSERT INTO archived_releases (repo, version, started_at, tickets, released_at)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (repo, version) DO NOTHING
"""
_DELETE_STAGING_SQL = """
DELETE FROM staging_releases WHERE repo = %s
"""
_SELECT_ARCHIVED_SQL = """
SELECT repo, version, started_at, tickets, released_at
FROM archived_releases
WHERE repo = %s
"""
class PostgresStagingStore:
"""Async PostgreSQL-backed StagingStore.
Args:
pool: A psycopg_pool.AsyncConnectionPool instance (or compatible fake).
"""
def __init__(self, *, pool) -> None:
self._pool = pool
# ------------------------------------------------------------------
# Public async interface
# ------------------------------------------------------------------
async def load(self, repo: str) -> StagingRelease | None:
"""Load the current staging release for the given repository."""
async with self._pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(_SELECT_STAGING_SQL, (repo,))
row = await cur.fetchone()
if row is None:
return None
return self._row_to_staging(row)
async def save(self, release: StagingRelease) -> None:
"""Upsert the staging release into the database."""
tickets_json = json.dumps(
[t.model_dump(mode="json") for t in release.tickets]
)
async with self._pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(
_INSERT_STAGING_SQL,
(
release.repo,
release.version,
release.started_at.isoformat(),
tickets_json,
),
)
async def archive(self, release: StagingRelease, released_at: date) -> None:
"""Archive the staging release and delete it from staging_releases."""
tickets_json = json.dumps(
[t.model_dump(mode="json") for t in release.tickets]
)
async with self._pool.connection() as conn:
async with conn.transaction():
async with conn.cursor() as cur:
await cur.execute(
_INSERT_ARCHIVED_SQL,
(
release.repo,
release.version,
release.started_at.isoformat(),
tickets_json,
released_at.isoformat(),
),
)
await cur.execute(_DELETE_STAGING_SQL, (release.repo,))
async def list_versions(self, repo: str) -> list[str]:
"""Return all version strings for the given repository."""
versions: list[str] = []
# Check current staging
staging = await self.load(repo)
if staging is not None and staging.version not in versions:
versions.append(staging.version)
# Check archived
async with self._pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(_SELECT_ARCHIVED_SQL, (repo,))
rows = await cur.fetchall()
for row in rows:
version = row[1] # version is column index 1
if version and version not in versions:
versions.append(version)
return versions
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _row_to_staging(self, row: tuple) -> StagingRelease:
"""Convert a DB row (repo, version, started_at, tickets) to StagingRelease."""
repo, version, started_at_str, tickets_raw = row[0], row[1], row[2], row[3]
tickets_data = json.loads(tickets_raw) if isinstance(tickets_raw, str) else tickets_raw
tickets = [TicketEntry.model_validate(t) for t in (tickets_data or [])]
return StagingRelease(
repo=repo,
version=version,
started_at=date.fromisoformat(str(started_at_str)),
tickets=tickets,
)

View File

@@ -0,0 +1,573 @@
"""Node functions for the PR Completed subgraph.
Each node is an async function (state, config) -> dict.
Nodes never mutate state — they return new dicts with updated fields.
External clients are accessed via config["configurable"]["clients"].
"""
from datetime import date
from typing import Any
from langgraph.graph import END, START, StateGraph
from langgraph.types import interrupt
from release_agent.branch_parser import parse_branch
from release_agent.exceptions import ReleaseAgentError
from release_agent.graph.ci_nodes import notify_ci_result, poll_ci_build, trigger_ci_build
from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients
from release_agent.graph.routing import has_ticket, is_pr_already_merged, is_review_approved, route_after_fetch
from release_agent.models.release import StagingRelease
from release_agent.models.review import ReviewResult
from release_agent.models.ticket import TicketEntry
from release_agent.models.webhook import WebhookPayload
from release_agent.state import ReleaseState
from release_agent.versioning import calculate_next_version
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_clients(config: dict) -> ToolClients:
return config["configurable"]["clients"]
def _get_staging_store(config: dict):
return config["configurable"].get("staging_store")
# ---------------------------------------------------------------------------
# Node: parse_webhook
# ---------------------------------------------------------------------------
async def parse_webhook(state: dict[str, Any], config: dict) -> dict:
"""Parse the webhook_payload field and extract PR info and ticket ID.
Returns a dict with pr_info, ticket_id, has_ticket, repo_name, pr_id.
On validation error, appends to errors and returns partial state.
"""
payload_raw = state.get("webhook_payload", {})
try:
payload = WebhookPayload.model_validate(payload_raw)
resource = payload.resource
pr_id = str(resource.pull_request_id)
repo_name = resource.repository.name
branch = resource.source_ref_name
ticket_id, ticket_found = parse_branch(branch)
pr_info = {
"pr_id": pr_id,
"repo_name": repo_name,
"branch": branch,
"pr_title": resource.title,
"pr_status": resource.status,
"pr_url": str(resource.repository.web_url) + f"/pullrequest/{pr_id}",
"ticket_id": ticket_id,
"has_ticket": ticket_found,
}
return {
"pr_info": pr_info,
"pr_id": pr_id,
"repo_name": repo_name,
"ticket_id": ticket_id,
"has_ticket": ticket_found,
}
except Exception as exc:
return {"errors": [f"parse_webhook failed: {exc}"]}
# ---------------------------------------------------------------------------
# Node: fetch_pr_details
# ---------------------------------------------------------------------------
async def fetch_pr_details(state: dict[str, Any], config: dict) -> dict:
"""Fetch full PR details from AzDo and check if already merged.
Sets pr_already_merged, pr_diff, and last_merge_source_commit.
On ReleaseAgentError, appends to errors (non-critical).
"""
clients = _get_clients(config)
pr_id_raw = state.get("pr_id") or (state.get("pr_info") or {}).get("pr_id", "")
try:
pr_id = int(pr_id_raw)
pr = await clients.azdo.get_pr(pr_id)
already_merged = pr.pr_status == "completed"
diff = ""
if not already_merged:
diff = await clients.azdo.get_pr_diff(pr_id)
# last_merge_source_commit is not directly on PRInfo; pass None if unavailable
return {
"pr_already_merged": already_merged,
"pr_diff": diff,
"last_merge_source_commit": None,
}
except ReleaseAgentError as exc:
return {"errors": [f"fetch_pr_details failed: {exc}"]}
# ---------------------------------------------------------------------------
# Node: auto_create_ticket
# ---------------------------------------------------------------------------
async def auto_create_ticket(state: dict[str, Any], config: dict) -> dict:
"""Auto-create a Jira ticket for a PR that has no existing ticket.
Uses ClaudeReviewer to generate ticket content from the PR diff,
then creates a new Jira Story in the configured default project.
Sets ticket_id and has_ticket=True on success.
On error, appends to errors (non-critical).
"""
clients = _get_clients(config)
pr_info = state.get("pr_info") or {}
pr_title = pr_info.get("pr_title", "")
repo_name = pr_info.get("repo_name", "")
pr_diff = state.get("pr_diff", "")
default_jira_project = config.get("configurable", {}).get("default_jira_project", "ALLPOST")
repos_base_dir = config.get("configurable", {}).get("repos_base_dir", "")
cwd = None
if repos_base_dir and repo_name:
from pathlib import Path as _Path
repo_path = _Path(repos_base_dir) / repo_name
if repo_path.is_dir():
cwd = str(repo_path)
try:
summary, description = await clients.reviewer.generate_ticket_content(
diff=pr_diff,
pr_title=pr_title,
repo_name=repo_name,
cwd=cwd,
)
except Exception as exc:
return {"errors": [f"auto_create_ticket generate_ticket_content failed: {exc}"]}
try:
ticket_id = await clients.jira.create_issue(
project=default_jira_project,
summary=summary,
description=description,
)
return {
"ticket_id": ticket_id,
"has_ticket": True,
"messages": [f"Auto-created Jira ticket {ticket_id} for PR in {repo_name}"],
}
except Exception as exc:
return {"errors": [f"auto_create_ticket create_issue failed: {exc}"]}
# ---------------------------------------------------------------------------
# Node: move_jira_code_review
# ---------------------------------------------------------------------------
async def move_jira_code_review(state: dict[str, Any], config: dict) -> dict:
"""Transition the Jira ticket to Code Review status.
Skipped if has_ticket is False. Non-critical: errors are appended.
"""
if not state.get("has_ticket"):
return {}
clients = _get_clients(config)
ticket_id = state.get("ticket_id", "")
try:
await clients.jira.transition_issue(ticket_id, "code review")
return {"messages": [f"Jira {ticket_id} moved to Code Review"]}
except ReleaseAgentError as exc:
return {"errors": [f"move_jira_code_review failed: {exc}"]}
# ---------------------------------------------------------------------------
# Helper: post review comments to Azure DevOps PR
# ---------------------------------------------------------------------------
async def _post_review_to_pr(
clients: ToolClients,
repo_name: str,
pr_id: int,
review: ReviewResult,
) -> None:
"""Post review results as comments on the Azure DevOps PR.
Posts inline comments for issues with file_path + line_start,
and a summary comment for the overall review.
Non-critical: errors are logged but not raised.
"""
import logging
logger = logging.getLogger(__name__)
# Post inline comments for issues with file path and line numbers
for issue in review.issues:
if issue.file_path and issue.line_start:
severity_icon = {"blocker": "BLOCKER", "error": "ERROR", "warning": "WARNING", "info": "INFO"}
label = severity_icon.get(issue.severity, issue.severity.upper())
content = f"**[{label}]** {issue.description}"
if issue.suggestion:
content += f"\n\n**Suggestion:** {issue.suggestion}"
try:
await clients.azdo.add_pr_inline_comment(
repo=repo_name,
pr_id=pr_id,
content=content,
file_path=issue.file_path,
line_start=issue.line_start,
line_end=issue.line_end,
)
except Exception as exc:
logger.warning("Failed to post inline comment on %s:%d: %s", issue.file_path, issue.line_start, exc)
# Post summary comment
verdict_text = "APPROVE" if review.verdict == "approve" else "REQUEST CHANGES"
issue_count = len(review.issues)
summary = f"## Claude Code Review: {verdict_text}\n\n{review.summary}\n\n"
if issue_count > 0:
summary += f"**{issue_count} issue(s) found:**\n"
for issue in review.issues:
loc = f" ({issue.file_path}:{issue.line_start})" if issue.file_path and issue.line_start else ""
summary += f"- **{issue.severity.upper()}**{loc}: {issue.description}\n"
else:
summary += "No issues found."
try:
await clients.azdo.add_pr_comment(repo=repo_name, pr_id=pr_id, content=summary)
except Exception as exc:
logger.warning("Failed to post review summary comment: %s", exc)
# ---------------------------------------------------------------------------
# Node: run_code_review
# ---------------------------------------------------------------------------
async def run_code_review(state: dict[str, Any], config: dict) -> dict:
"""Run Claude code review on the PR diff.
Returns review_result as a serialisable dict.
On error, appends to errors.
"""
clients = _get_clients(config)
diff = state.get("pr_diff", "")
pr_info = state.get("pr_info") or {}
pr_title = pr_info.get("pr_title", "")
repo_name = pr_info.get("repo_name", "")
# Build cwd from repos_base_dir + repo_name so Claude Code can read the codebase
repos_base_dir = config.get("configurable", {}).get("repos_base_dir", "")
cwd = None
if repos_base_dir and repo_name:
from pathlib import Path
repo_path = Path(repos_base_dir) / repo_name
if repo_path.is_dir():
cwd = str(repo_path)
try:
review: ReviewResult = await clients.reviewer.review_pr(
diff=diff, pr_title=pr_title, repo_name=repo_name, cwd=cwd
)
# Post review comments to Azure DevOps PR
pr_id = int(pr_info.get("pr_id", 0))
if pr_id and repo_name:
await _post_review_to_pr(clients, repo_name, pr_id, review)
return {"review_result": review.model_dump(mode="json")}
except Exception as exc:
return {"errors": [f"run_code_review failed: {exc}"]}
# ---------------------------------------------------------------------------
# Node: evaluate_review
# ---------------------------------------------------------------------------
async def evaluate_review(state: dict[str, Any], config: dict) -> dict:
"""Evaluate the review result and set review_approved.
Approved only when verdict=="approve" and no blockers present.
"""
review = state.get("review_result") or {}
verdict = review.get("verdict", "request_changes")
has_blockers = review.get("has_blockers", False)
# Also check issues list for blocker severity
if not has_blockers:
issues = review.get("issues") or []
has_blockers = any(i.get("severity") == "blocker" for i in issues)
approved = (verdict == "approve") and not has_blockers
return {"review_approved": approved}
# ---------------------------------------------------------------------------
# Node: interrupt_confirm_merge
# ---------------------------------------------------------------------------
async def interrupt_confirm_merge(state: dict[str, Any], config: dict) -> dict:
"""Interrupt to ask the operator to confirm the PR merge.
Passes a human-readable summary string to interrupt().
"""
pr_info = state.get("pr_info") or {}
review = state.get("review_result") or {}
summary = (
f"PR #{pr_info.get('pr_id')} - {pr_info.get('pr_title')} "
f"[repo: {pr_info.get('repo_name')}]\n"
f"Review: {review.get('summary', 'No review summary')}\n"
f"Confirm merge? (yes/no)"
)
interrupt(summary)
return {}
# ---------------------------------------------------------------------------
# Node: merge_pr_node
# ---------------------------------------------------------------------------
async def merge_pr_node(state: dict[str, Any], config: dict) -> dict:
"""Merge the PR via AzDo API.
Critical node — re-raises ReleaseAgentError.
"""
clients = _get_clients(config)
pr_info = state.get("pr_info") or {}
pr_id = int(pr_info.get("pr_id", state.get("pr_id", 0)))
commit = state.get("last_merge_source_commit", "")
await clients.azdo.merge_pr(pr_id=pr_id, last_merge_source_commit=commit)
return {"messages": [f"PR #{pr_id} merged"]}
# ---------------------------------------------------------------------------
# Node: move_jira_ready_for_stage
# ---------------------------------------------------------------------------
async def move_jira_ready_for_stage(state: dict[str, Any], config: dict) -> dict:
"""Transition the Jira ticket to Ready for stage (2).
Skipped if has_ticket is False. Non-critical: errors appended.
"""
if not state.get("has_ticket"):
return {}
clients = _get_clients(config)
ticket_id = state.get("ticket_id", "")
try:
await clients.jira.transition_issue(ticket_id, "Ready for stage (2)")
return {"messages": [f"Jira {ticket_id} moved to Ready for stage (2)"]}
except ReleaseAgentError as exc:
return {"errors": [f"move_jira_ready_for_stage failed: {exc}"]}
# ---------------------------------------------------------------------------
# Node: add_jira_pr_link
# ---------------------------------------------------------------------------
async def add_jira_pr_link(state: dict[str, Any], config: dict) -> dict:
"""Add the PR URL as a remote link on the Jira ticket.
Skipped if has_ticket is False. Non-critical: errors appended.
"""
if not state.get("has_ticket"):
return {}
clients = _get_clients(config)
ticket_id = state.get("ticket_id", "")
pr_info = state.get("pr_info") or {}
pr_url = pr_info.get("pr_url", "")
pr_title = pr_info.get("pr_title", "PR")
pr_id = pr_info.get("pr_id", "")
try:
await clients.jira.add_remote_link(
ticket_id=ticket_id,
url=str(pr_url),
title=f"PR #{pr_id}: {pr_title}",
)
return {"messages": [f"PR link added to Jira {ticket_id}"]}
except ReleaseAgentError as exc:
return {"errors": [f"add_jira_pr_link failed: {exc}"]}
# ---------------------------------------------------------------------------
# Node: calculate_version
# ---------------------------------------------------------------------------
async def calculate_version(state: dict[str, Any], config: dict) -> dict:
"""Calculate the next version for the repository using the staging store.
Uses calculate_next_version from versioning module.
"""
repo_name = state.get("repo_name", "")
staging_store = _get_staging_store(config)
existing_versions: list[str] = []
if staging_store is not None:
existing_versions = await staging_store.list_versions(repo_name)
version = calculate_next_version(repo_name, existing_versions)
return {"version": version}
# ---------------------------------------------------------------------------
# Node: update_staging
# ---------------------------------------------------------------------------
async def update_staging(state: dict[str, Any], config: dict) -> dict:
"""Add the PR's ticket to the staging release.
If no staging exists, creates a new one. If has_ticket is False, skips
the ticket addition but still ensures a staging record exists.
"""
if not state.get("has_ticket"):
return {}
clients = _get_clients(config)
staging_store = _get_staging_store(config)
repo_name = state.get("repo_name", "")
version = state.get("version", "v1.0.0")
ticket_id = state.get("ticket_id", "")
pr_info = state.get("pr_info") or {}
# Fetch ticket summary from Jira
try:
issue = await clients.jira.get_issue(ticket_id)
summary = issue.summary
except Exception:
summary = ticket_id
pr_url = pr_info.get("pr_url", "")
pr_id = pr_info.get("pr_id", "")
pr_title = pr_info.get("pr_title", "")
branch = pr_info.get("branch", "")
ticket_entry = TicketEntry(
id=ticket_id,
summary=summary,
pr_id=str(pr_id),
pr_url=pr_url if pr_url.startswith("http") else f"https://example.com/{pr_id}",
pr_title=pr_title,
branch=branch,
merged_at=date.today(),
)
if staging_store is not None:
existing = await staging_store.load(repo_name)
if existing is None:
staging = StagingRelease(
version=version,
repo=repo_name,
started_at=date.today(),
tickets=[ticket_entry],
)
else:
staging = existing.add_ticket(ticket_entry)
await staging_store.save(staging)
return {"staging": staging.model_dump(mode="json")}
return {}
# ---------------------------------------------------------------------------
# Node: notify_request_changes
# ---------------------------------------------------------------------------
async def notify_request_changes(state: dict[str, Any], config: dict) -> dict:
"""Send a Slack notification when the review requests changes.
Non-critical: errors appended.
"""
clients = _get_clients(config)
pr_info = state.get("pr_info") or {}
review = state.get("review_result") or {}
pr_id = pr_info.get("pr_id", "?")
pr_title = pr_info.get("pr_title", "")
repo = pr_info.get("repo_name", "")
summary = review.get("summary", "Review requested changes")
issues = review.get("issues") or []
issues_text = "\n".join(
f"- [{i.get('severity', 'info')}] {i.get('description', '')}" for i in issues
)
details = f"PR #{pr_id}: {pr_title} [{repo}]\n{summary}\n{issues_text}"
try:
await clients.slack.send_approval_request(
action="Request Changes",
details=details,
approval_url="",
)
return {"messages": [f"Slack notified: request changes for PR #{pr_id}"]}
except ReleaseAgentError as exc:
return {"errors": [f"notify_request_changes failed: {exc}"]}
# ---------------------------------------------------------------------------
# Graph builder
# ---------------------------------------------------------------------------
def build_pr_completed_graph():
"""Assemble and compile the PR Completed StateGraph.
Returns the compiled graph ready for invocation.
Routing after fetch_pr_details:
- "merged" -> calculate_version (PR already merged)
- "active_with_ticket" -> move_jira_code_review (PR active, has Jira ticket)
- "active_no_ticket" -> auto_create_ticket -> move_jira_code_review (no ticket)
"""
graph = StateGraph(ReleaseState)
# Add all nodes
graph.add_node("parse_webhook", parse_webhook)
graph.add_node("fetch_pr_details", fetch_pr_details)
graph.add_node("auto_create_ticket", auto_create_ticket)
graph.add_node("move_jira_code_review", move_jira_code_review)
graph.add_node("run_code_review", run_code_review)
graph.add_node("evaluate_review", evaluate_review)
graph.add_node("interrupt_confirm_merge", interrupt_confirm_merge)
graph.add_node("merge_pr_node", merge_pr_node)
graph.add_node("move_jira_ready_for_stage", move_jira_ready_for_stage)
graph.add_node("add_jira_pr_link", add_jira_pr_link)
graph.add_node("calculate_version", calculate_version)
graph.add_node("update_staging", update_staging)
graph.add_node("notify_request_changes", notify_request_changes)
graph.add_node("trigger_ci_build", trigger_ci_build)
graph.add_node("poll_ci_build", poll_ci_build)
graph.add_node("notify_ci_result", notify_ci_result)
# Entry
graph.add_edge(START, "parse_webhook")
graph.add_edge("parse_webhook", "fetch_pr_details")
# Three-way route: merged, active_with_ticket, active_no_ticket
graph.add_conditional_edges(
"fetch_pr_details",
route_after_fetch,
{
"merged": "calculate_version",
"active_with_ticket": "move_jira_code_review",
"active_no_ticket": "auto_create_ticket",
},
)
# auto_create_ticket -> move_jira_code_review (ticket now exists after creation)
graph.add_edge("auto_create_ticket", "move_jira_code_review")
# Active PR path: code review
graph.add_edge("move_jira_code_review", "run_code_review")
graph.add_edge("run_code_review", "evaluate_review")
# Route based on review outcome
graph.add_conditional_edges(
"evaluate_review",
is_review_approved,
{
"approve": "interrupt_confirm_merge",
"request_changes": "notify_request_changes",
},
)
graph.add_edge("interrupt_confirm_merge", "merge_pr_node")
graph.add_edge("merge_pr_node", "move_jira_ready_for_stage")
graph.add_edge("move_jira_ready_for_stage", "add_jira_pr_link")
graph.add_edge("add_jira_pr_link", "calculate_version")
# calculate_version -> update_staging -> trigger_ci_build -> poll_ci_build -> notify_ci_result -> END
graph.add_edge("calculate_version", "update_staging")
graph.add_edge("update_staging", "trigger_ci_build")
graph.add_edge("trigger_ci_build", "poll_ci_build")
graph.add_edge("poll_ci_build", "notify_ci_result")
graph.add_edge("notify_ci_result", END)
# notify_request_changes -> END
graph.add_edge("notify_request_changes", END)
return graph.compile()

View File

@@ -0,0 +1,627 @@
"""Node functions for the Release subgraph.
Each node is an async function (state, config) -> dict.
Nodes never mutate state — they return new dicts.
External clients are accessed via config["configurable"]["clients"].
Release flow (Phase 5):
merge_release_pr
-> trigger_ci_build_main
-> poll_ci_build_main
-> route_ci_result
-> ci_passed: wait_for_cd_release
-> poll_release_approvals
-> route_approval_stage
-> sandbox_pending: interrupt_sandbox_approval
-> execute_sandbox_approval
-> poll_release_approvals (loop)
-> prod_pending: interrupt_prod_approval
-> execute_prod_approval
-> poll_release_approvals (loop)
-> all_deployed: move_tickets_to_done
-> send_release_notification
-> archive_release
-> END
-> ci_failed: notify_ci_failure -> END
"""
from datetime import date
from typing import Any
from langgraph.graph import END, START, StateGraph
from langgraph.types import interrupt
from release_agent.exceptions import ReleaseAgentError
from release_agent.graph.ci_nodes import poll_ci_build, trigger_ci_build
from release_agent.graph.dependencies import ToolClients
from release_agent.graph.routing import route_approval_stage, route_ci_result
from release_agent.models.release import StagingRelease
from release_agent.models.ticket import TicketEntry
from release_agent.state import ReleaseState
from release_agent.tools.slack import _build_ci_status_blocks
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_clients(config: dict) -> ToolClients:
return config["configurable"]["clients"]
def _get_staging_store(config: dict):
return config["configurable"].get("staging_store")
# ---------------------------------------------------------------------------
# Node: load_staging
# ---------------------------------------------------------------------------
async def load_staging(state: dict[str, Any], config: dict) -> dict:
"""Load the current staging release from the store."""
repo_name = state.get("repo_name", "")
staging_store = _get_staging_store(config)
if staging_store is None:
return {"staging": None}
staging = await staging_store.load(repo_name)
if staging is None:
return {"staging": None}
return {"staging": staging.model_dump(mode="json")}
# ---------------------------------------------------------------------------
# Node: interrupt_confirm_release
# ---------------------------------------------------------------------------
async def interrupt_confirm_release(state: dict[str, Any], config: dict) -> dict:
"""Interrupt to ask the operator to confirm starting the release."""
repo_name = state.get("repo_name", "")
staging_dict = state.get("staging") or {}
version = staging_dict.get("version", "unknown")
tickets = staging_dict.get("tickets") or []
ticket_lines = "\n".join(f" - {t.get('id')}: {t.get('summary', '')}" for t in tickets)
summary = (
f"Release {version} for repo '{repo_name}'\n"
f"Tickets ({len(tickets)}):\n{ticket_lines}\n"
f"Confirm release? (yes/no)"
)
interrupt(summary)
return {}
# ---------------------------------------------------------------------------
# Node: create_release_pr
# ---------------------------------------------------------------------------
async def create_release_pr(state: dict[str, Any], config: dict) -> dict:
"""Create a release PR in AzDo from a release branch to main."""
clients = _get_clients(config)
repo_name = state.get("repo_name", "")
version = state.get("version", "")
staging_dict = state.get("staging") or {}
tickets = staging_dict.get("tickets") or []
ticket_ids = ", ".join(t.get("id", "") for t in tickets)
source_branch = f"refs/heads/release/{version}"
target_branch = "refs/heads/main"
title = f"Release {version}"
description = f"Release {version} for {repo_name}\n\nTickets: {ticket_ids}"
data = await clients.azdo.create_pr(
repo=repo_name,
source=source_branch,
target=target_branch,
title=title,
description=description,
)
pr_id = str(data.get("pullRequestId", ""))
commit = (data.get("lastMergeSourceCommit") or {}).get("commitId", "")
return {
"release_pr_id": pr_id,
"release_pr_commit": commit,
"messages": [f"Release PR #{pr_id} created for {version}"],
}
# ---------------------------------------------------------------------------
# Node: interrupt_confirm_merge_release
# ---------------------------------------------------------------------------
async def interrupt_confirm_merge_release(state: dict[str, Any], config: dict) -> dict:
"""Interrupt to ask the operator to confirm merging the release PR."""
release_pr_id = state.get("release_pr_id", "?")
version = state.get("version", "")
repo_name = state.get("repo_name", "")
summary = (
f"Merge release PR #{release_pr_id} ({version}) into main for '{repo_name}'?\n"
f"Confirm merge? (yes/no)"
)
interrupt(summary)
return {}
# ---------------------------------------------------------------------------
# Node: merge_release_pr
# ---------------------------------------------------------------------------
async def merge_release_pr(state: dict[str, Any], config: dict) -> dict:
"""Merge the release PR via AzDo API."""
clients = _get_clients(config)
pr_id = int(state.get("release_pr_id", 0))
commit = state.get("release_pr_commit", "")
await clients.azdo.merge_pr(pr_id=pr_id, last_merge_source_commit=commit)
return {"messages": [f"Release PR #{pr_id} merged"]}
# ---------------------------------------------------------------------------
# Node: trigger_ci_build_main (delegates to ci_nodes.trigger_ci_build)
# ---------------------------------------------------------------------------
async def trigger_ci_build_main(state: dict[str, Any], config: dict) -> dict:
"""Trigger CI build on main after the release PR is merged."""
return await trigger_ci_build(state, config)
# ---------------------------------------------------------------------------
# Node: poll_ci_build_main (delegates to ci_nodes.poll_ci_build)
# ---------------------------------------------------------------------------
async def poll_ci_build_main(state: dict[str, Any], config: dict) -> dict:
"""Poll the main branch CI build until completion."""
return await poll_ci_build(state, config)
# ---------------------------------------------------------------------------
# Node: notify_ci_failure
# ---------------------------------------------------------------------------
async def notify_ci_failure(state: dict[str, Any], config: dict) -> dict:
"""Send a Slack notification that the CI build failed on main.
Non-critical: errors appended.
"""
clients = _get_clients(config)
repo_name = state.get("repo_name", "unknown")
build_result = state.get("ci_build_result", "failed")
build_url = state.get("ci_build_url")
version = state.get("version", "")
blocks = _build_ci_status_blocks(
repo=repo_name,
branch=f"main (release {version})" if version else "main",
status=build_result,
build_url=build_url,
)
text = f"CI build failed on main for {repo_name} release {version}"
try:
await clients.slack.send_notification(text=text, blocks=blocks)
return {"messages": [text]}
except Exception as exc:
return {"errors": [f"notify_ci_failure: {exc}"]}
# ---------------------------------------------------------------------------
# Node: wait_for_cd_release
# ---------------------------------------------------------------------------
async def wait_for_cd_release(state: dict[str, Any], config: dict) -> dict:
"""Wait for the CD pipeline to create a release after CI passes.
Fetches the latest release for the configured release definition.
Non-critical: errors appended if not found.
"""
clients = _get_clients(config)
definition_id = state.get("release_definition_id")
if not definition_id:
return {"errors": ["wait_for_cd_release: release_definition_id not set in state"]}
try:
release = await clients.azdo.get_latest_release(definition_id=definition_id)
if not release:
return {"errors": ["wait_for_cd_release: no release found for definition"]}
return {
"release_id": release["id"],
"messages": [f"CD release found: {release.get('name', release['id'])}"],
}
except ReleaseAgentError as exc:
return {"errors": [f"wait_for_cd_release: {exc}"]}
# ---------------------------------------------------------------------------
# Node: poll_release_approvals
# ---------------------------------------------------------------------------
async def poll_release_approvals(state: dict[str, Any], config: dict) -> dict:
"""Poll AzDo for pending release environment approvals.
Non-critical: errors appended on failure.
"""
clients = _get_clients(config)
release_id = state.get("release_id")
if not release_id:
return {"pending_approvals": [], "errors": ["poll_release_approvals: release_id not set"]}
try:
approvals = await clients.azdo.get_release_approvals(release_id=release_id)
pending = [
{
"approval_id": a.approval_id,
"stage_name": a.stage_name,
"status": a.status,
"release_id": a.release_id,
}
for a in approvals
if a.status == "pending"
]
return {"pending_approvals": pending}
except ReleaseAgentError as exc:
return {"errors": [f"poll_release_approvals: {exc}"], "pending_approvals": []}
# ---------------------------------------------------------------------------
# Node: interrupt_sandbox_approval
# ---------------------------------------------------------------------------
async def interrupt_sandbox_approval(state: dict[str, Any], config: dict) -> dict:
"""Interrupt to ask the operator to approve the sandbox deployment."""
approvals = state.get("pending_approvals") or []
version = state.get("version", "")
stage_names = ", ".join(a.get("stage_name", "?") for a in approvals)
summary = (
f"Approve sandbox deployment for release {version}?\n"
f"Stages: {stage_names}\n"
f"Approve? (yes/no)"
)
interrupt(summary)
return {"current_stage": "sandbox_pending"}
# ---------------------------------------------------------------------------
# Node: execute_sandbox_approval
# ---------------------------------------------------------------------------
async def execute_sandbox_approval(state: dict[str, Any], config: dict) -> dict:
"""Approve all pending sandbox stage approvals via AzDo VSRM.
Non-critical per approval: errors appended on individual failures.
"""
clients = _get_clients(config)
approvals = state.get("pending_approvals") or []
errors: list[str] = []
for approval in approvals:
approval_id = approval.get("approval_id", "")
try:
await clients.azdo.approve_release(
approval_id=approval_id,
comment="Sandbox approved by release agent via Slack",
)
except ReleaseAgentError as exc:
errors.append(f"execute_sandbox_approval failed for {approval_id}: {exc}")
if errors:
return {"errors": errors}
return {}
# ---------------------------------------------------------------------------
# Node: interrupt_prod_approval
# ---------------------------------------------------------------------------
async def interrupt_prod_approval(state: dict[str, Any], config: dict) -> dict:
"""Interrupt to ask the operator to approve the production deployment."""
approvals = state.get("pending_approvals") or []
version = state.get("version", "")
stage_names = ", ".join(a.get("stage_name", "?") for a in approvals)
summary = (
f"Approve production deployment for release {version}?\n"
f"Stages: {stage_names}\n"
f"Approve? (yes/no)"
)
interrupt(summary)
return {"current_stage": "prod_pending"}
# ---------------------------------------------------------------------------
# Node: execute_prod_approval
# ---------------------------------------------------------------------------
async def execute_prod_approval(state: dict[str, Any], config: dict) -> dict:
"""Approve all pending production stage approvals via AzDo VSRM.
Non-critical per approval: errors appended on individual failures.
"""
clients = _get_clients(config)
approvals = state.get("pending_approvals") or []
errors: list[str] = []
for approval in approvals:
approval_id = approval.get("approval_id", "")
try:
await clients.azdo.approve_release(
approval_id=approval_id,
comment="Production approved by release agent via Slack",
)
except ReleaseAgentError as exc:
errors.append(f"execute_prod_approval failed for {approval_id}: {exc}")
if errors:
return {"errors": errors}
return {}
# ---------------------------------------------------------------------------
# Node: move_tickets_to_done
# ---------------------------------------------------------------------------
async def move_tickets_to_done(state: dict[str, Any], config: dict) -> dict:
"""Transition all staging tickets to Done/Released in Jira."""
clients = _get_clients(config)
staging_dict = state.get("staging") or {}
tickets = staging_dict.get("tickets") or []
errors: list[str] = []
for ticket_raw in tickets:
ticket_id = ticket_raw.get("id", "")
try:
await clients.jira.transition_issue(ticket_id, "Done")
except ReleaseAgentError as exc:
errors.append(f"move_tickets_to_done failed for {ticket_id}: {exc}")
if errors:
return {"errors": errors}
return {}
# ---------------------------------------------------------------------------
# Node: send_slack_notification (send_release_notification)
# ---------------------------------------------------------------------------
async def send_slack_notification(state: dict[str, Any], config: dict) -> dict:
"""Send a release notification to Slack."""
clients = _get_clients(config)
repo_name = state.get("repo_name", "")
version = state.get("version", "")
staging_dict = state.get("staging") or {}
tickets_raw = staging_dict.get("tickets") or []
try:
ticket_entries = [TicketEntry.model_validate(t) for t in tickets_raw]
await clients.slack.send_release_notification(
repo=repo_name,
version=version,
release_date=date.today(),
tickets=ticket_entries,
)
return {"messages": [f"Slack notified: release {version} for {repo_name}"]}
except ReleaseAgentError as exc:
return {"errors": [f"send_slack_notification failed: {exc}"]}
# ---------------------------------------------------------------------------
# Node: archive_release
# ---------------------------------------------------------------------------
async def archive_release(state: dict[str, Any], config: dict) -> dict:
"""Archive the staging release in the store."""
staging_store = _get_staging_store(config)
staging_dict = state.get("staging") or {}
if staging_store is None or not staging_dict:
return {}
staging = StagingRelease.model_validate(staging_dict)
await staging_store.archive(staging, date.today())
return {"messages": [f"Release {staging.version} archived"]}
# ---------------------------------------------------------------------------
# Legacy nodes kept for backward compatibility
# ---------------------------------------------------------------------------
async def list_pipelines(state: dict[str, Any], config: dict) -> dict:
"""List build pipelines for the repository via AzDo."""
clients = _get_clients(config)
repo_name = state.get("repo_name", "")
try:
pipelines = await clients.azdo.list_build_pipelines(repo=repo_name)
return {"pipelines": [p.model_dump() for p in pipelines]}
except ReleaseAgentError as exc:
return {"errors": [f"list_pipelines failed: {exc}"], "pipelines": []}
async def interrupt_confirm_trigger(state: dict[str, Any], config: dict) -> dict:
"""Interrupt to ask the operator to confirm triggering pipelines."""
repo_name = state.get("repo_name", "")
version = state.get("version", "")
pipelines = state.get("pipelines") or []
pipeline_names = ", ".join(p.get("name", str(p.get("id", ""))) for p in pipelines)
summary = (
f"Trigger pipelines for {repo_name} {version}?\n"
f"Pipelines: {pipeline_names}\n"
f"Confirm? (yes/no)"
)
interrupt(summary)
return {}
async def trigger_pipelines(state: dict[str, Any], config: dict) -> dict:
"""Trigger all listed pipelines for the release branch."""
clients = _get_clients(config)
repo_name = state.get("repo_name", "")
version = state.get("version", "")
pipelines = state.get("pipelines") or []
branch = f"refs/heads/release/{version}"
triggered: list[dict] = []
errors: list[str] = []
for pipeline in pipelines:
pipeline_id = pipeline.get("id")
try:
result = await clients.azdo.trigger_pipeline(
pipeline_id=pipeline_id, branch=branch
)
triggered.append(result)
except ReleaseAgentError as exc:
errors.append(f"trigger_pipelines failed for pipeline {pipeline_id}: {exc}")
result_dict: dict = {"triggered_builds": triggered}
if errors:
result_dict["errors"] = errors
return result_dict
async def check_release_approvals(state: dict[str, Any], config: dict) -> dict:
"""Check triggered builds for pending stage approvals."""
clients = _get_clients(config)
triggered_builds = state.get("triggered_builds") or []
pending: list[dict] = []
errors: list[str] = []
for build in triggered_builds:
build_id = build.get("id")
try:
await clients.azdo.get_build_status(build_id=build_id)
except ReleaseAgentError as exc:
errors.append(f"check_release_approvals failed for build {build_id}: {exc}")
result_dict: dict = {"pending_approvals": pending}
if errors:
result_dict["errors"] = errors
return result_dict
async def interrupt_confirm_approve(state: dict[str, Any], config: dict) -> dict:
"""Interrupt to ask the operator to confirm approving pipeline stages."""
approvals = state.get("pending_approvals") or []
version = state.get("version", "")
stage_names = ", ".join(a.get("stage_name", a.get("approval_id", "?")) for a in approvals)
summary = (
f"Approve pipeline stages for release {version}?\n"
f"Stages: {stage_names}\n"
f"Confirm? (yes/no)"
)
interrupt(summary)
return {}
async def approve_stage(state: dict[str, Any], config: dict) -> dict:
"""Approve all pending pipeline stage approvals via AzDo VSRM."""
clients = _get_clients(config)
approvals = state.get("pending_approvals") or []
errors: list[str] = []
for approval in approvals:
approval_id = approval.get("approval_id", "")
try:
await clients.azdo.approve_release(
approval_id=approval_id,
comment="Approved by release agent",
)
except ReleaseAgentError as exc:
errors.append(f"approve_stage failed for {approval_id}: {exc}")
if errors:
return {"errors": errors}
return {}
# ---------------------------------------------------------------------------
# Graph builder
# ---------------------------------------------------------------------------
def build_release_graph():
"""Assemble and compile the Release StateGraph (Phase 5 architecture).
Flow:
load_staging -> interrupt_confirm_release -> create_release_pr
-> interrupt_confirm_merge_release -> merge_release_pr
-> trigger_ci_build_main -> poll_ci_build_main
-> route_ci_result:
ci_passed: wait_for_cd_release -> poll_release_approvals
-> route_approval_stage:
sandbox_pending: interrupt_sandbox_approval
-> execute_sandbox_approval
-> poll_release_approvals
prod_pending: interrupt_prod_approval
-> execute_prod_approval
-> poll_release_approvals
all_deployed: move_tickets_to_done
-> send_slack_notification
-> archive_release -> END
ci_failed: notify_ci_failure -> END
"""
graph = StateGraph(ReleaseState)
# Core release PR nodes
graph.add_node("load_staging", load_staging)
graph.add_node("interrupt_confirm_release", interrupt_confirm_release)
graph.add_node("create_release_pr", create_release_pr)
graph.add_node("interrupt_confirm_merge_release", interrupt_confirm_merge_release)
graph.add_node("merge_release_pr", merge_release_pr)
# CI nodes
graph.add_node("trigger_ci_build_main", trigger_ci_build_main)
graph.add_node("poll_ci_build_main", poll_ci_build_main)
graph.add_node("notify_ci_failure", notify_ci_failure)
# CD approval nodes
graph.add_node("wait_for_cd_release", wait_for_cd_release)
graph.add_node("poll_release_approvals", poll_release_approvals)
graph.add_node("interrupt_sandbox_approval", interrupt_sandbox_approval)
graph.add_node("execute_sandbox_approval", execute_sandbox_approval)
graph.add_node("interrupt_prod_approval", interrupt_prod_approval)
graph.add_node("execute_prod_approval", execute_prod_approval)
# Completion nodes
graph.add_node("move_tickets_to_done", move_tickets_to_done)
graph.add_node("send_release_notification", send_slack_notification)
graph.add_node("archive_release", archive_release)
# Main release flow up to merge
graph.add_edge(START, "load_staging")
graph.add_edge("load_staging", "interrupt_confirm_release")
graph.add_edge("interrupt_confirm_release", "create_release_pr")
graph.add_edge("create_release_pr", "interrupt_confirm_merge_release")
graph.add_edge("interrupt_confirm_merge_release", "merge_release_pr")
# CI pipeline after merge
graph.add_edge("merge_release_pr", "trigger_ci_build_main")
graph.add_edge("trigger_ci_build_main", "poll_ci_build_main")
graph.add_conditional_edges(
"poll_ci_build_main",
route_ci_result,
{"ci_passed": "wait_for_cd_release", "ci_failed": "notify_ci_failure"},
)
graph.add_edge("notify_ci_failure", END)
# CD approval flow
graph.add_edge("wait_for_cd_release", "poll_release_approvals")
graph.add_conditional_edges(
"poll_release_approvals",
route_approval_stage,
{
"sandbox_pending": "interrupt_sandbox_approval",
"prod_pending": "interrupt_prod_approval",
"all_deployed": "move_tickets_to_done",
},
)
# Sandbox approval loop
graph.add_edge("interrupt_sandbox_approval", "execute_sandbox_approval")
graph.add_edge("execute_sandbox_approval", "poll_release_approvals")
# Prod approval loop
graph.add_edge("interrupt_prod_approval", "execute_prod_approval")
graph.add_edge("execute_prod_approval", "poll_release_approvals")
# Completion
graph.add_edge("move_tickets_to_done", "send_release_notification")
graph.add_edge("send_release_notification", "archive_release")
graph.add_edge("archive_release", END)
return graph.compile()

View File

@@ -0,0 +1,131 @@
"""Pure routing functions for LangGraph conditional edges.
Each function takes a state dict and returns a routing string.
No side effects, no mutation, no I/O.
"""
from typing import Any
def is_pr_already_merged(state: dict[str, Any]) -> str:
"""Route based on whether the PR was already merged before processing.
Returns:
"merged" if state["pr_already_merged"] is True.
"active" otherwise (including missing or None).
"""
if state.get("pr_already_merged"):
return "merged"
return "active"
def route_after_fetch(state: dict[str, Any]) -> str:
"""Three-way route after fetch_pr_details based on merge status and ticket presence.
Returns:
"merged" if state["pr_already_merged"] is True.
"active_with_ticket" if PR is active and state["has_ticket"] is True.
"active_no_ticket" if PR is active and state["has_ticket"] is falsy.
"""
if state.get("pr_already_merged"):
return "merged"
if state.get("has_ticket"):
return "active_with_ticket"
return "active_no_ticket"
def is_review_approved(state: dict[str, Any]) -> str:
"""Route based on the review approval result.
Returns:
"approve" if state["review_approved"] is True.
"request_changes" otherwise (including missing or None).
"""
if state.get("review_approved"):
return "approve"
return "request_changes"
def has_ticket(state: dict[str, Any]) -> str:
"""Route based on whether the PR branch contains a Jira ticket ID.
Returns:
"yes" if state["has_ticket"] is True.
"no" otherwise (including missing or None).
"""
if state.get("has_ticket"):
return "yes"
return "no"
def should_continue_to_release(state: dict[str, Any]) -> str:
"""Route based on whether the operator chose to proceed to the release flow.
Returns:
"yes" if state["continue_to_release"] is True and no errors accumulated.
"no" otherwise (including missing, None, or if errors are present).
"""
if state.get("continue_to_release") and not state.get("errors"):
return "yes"
return "no"
def has_pipelines(state: dict[str, Any]) -> str:
"""Route based on whether any pipelines are defined for the repository.
Returns:
"yes" if state["pipelines"] is a non-empty list.
"no" otherwise (including missing, None, or empty list).
"""
pipelines = state.get("pipelines")
if pipelines:
return "yes"
return "no"
def has_pending_approvals(state: dict[str, Any]) -> str:
"""Route based on whether there are pipeline stage approvals pending.
Returns:
"yes" if state["pending_approvals"] is a non-empty list.
"no" otherwise (including missing, None, or empty list).
"""
approvals = state.get("pending_approvals")
if approvals:
return "yes"
return "no"
def route_ci_result(state: dict[str, Any]) -> str:
"""Route based on the CI build result.
Returns:
"ci_passed" if state["ci_build_result"] == "succeeded".
"ci_failed" otherwise (including missing, None, or any other value).
"""
if state.get("ci_build_result") == "succeeded":
return "ci_passed"
return "ci_failed"
def route_approval_stage(state: dict[str, Any]) -> str:
"""Route based on the current CD approval stage.
Uses state["current_stage"] if present. Falls back to "sandbox_pending"
when approvals exist but no stage is set (first stage assumption).
Returns:
"sandbox_pending" if state["current_stage"] == "sandbox_pending" or
approvals exist with no explicit stage (default first stage).
"prod_pending" if state["current_stage"] == "prod_pending".
"all_deployed" if state["pending_approvals"] is empty or missing.
"""
approvals = state.get("pending_approvals")
if not approvals:
return "all_deployed"
current_stage = state.get("current_stage", "")
if current_stage == "prod_pending":
return "prod_pending"
# Default: sandbox_pending (covers both explicit sandbox_pending and unknown)
return "sandbox_pending"

361
src/release_agent/main.py Normal file
View File

@@ -0,0 +1,361 @@
"""FastAPI application entry point.
Responsibilities:
- Create the FastAPI app with lifespan handler
- Startup: compile graphs, open DB pool, build tool clients, ensure schema
- Shutdown: wait for background tasks (30s timeout)
- Register routers and global exception handlers
- Expose schedule_graph and run_graph_in_background helpers
"""
import asyncio
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from pathlib import Path
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from release_agent.api.approvals import router as approvals_router
from release_agent.api.models import ErrorResponse
from release_agent.api.slack_interactions import router as slack_interactions_router
from release_agent.api.status import router as status_router
from release_agent.api.webhooks import router as webhook_router
from release_agent.config import Settings
from release_agent.exceptions import ReleaseAgentError
from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients
from release_agent.graph.postgres_staging_store import PostgresStagingStore
from release_agent.graph.pr_completed import build_pr_completed_graph
from release_agent.graph.release import build_release_graph
from release_agent.services.pr_poller import run_pr_poll_loop
from release_agent.tools.azdo import AzDoClient
from release_agent.tools.claude_review import ClaudeReviewer
from release_agent.tools.jira import JiraClient
from release_agent.tools.slack import SlackClient
try:
from psycopg_pool import AsyncConnectionPool
except ImportError: # pragma: no cover
AsyncConnectionPool = None # type: ignore[assignment,misc]
logger = logging.getLogger(__name__)
_SHUTDOWN_TIMEOUT_SECONDS = 30
# ---------------------------------------------------------------------------
# Startup helpers
# ---------------------------------------------------------------------------
def _create_tool_clients(settings: Settings) -> tuple[ToolClients, list]:
"""Build and return a ToolClients instance and closeable HTTP clients."""
http_client = httpx.AsyncClient(timeout=30.0)
vsrm_http_client = httpx.AsyncClient(timeout=30.0)
azdo = AzDoClient(
http_client=http_client,
vsrm_http_client=vsrm_http_client,
base_url=settings.azdo_api_url,
vsrm_base_url=settings.azdo_vsrm_api_url,
pat=settings.azdo_pat.get_secret_value(),
)
jira = JiraClient(
http_client=http_client,
base_url=settings.jira_base_url,
email=settings.jira_email,
api_token=settings.jira_api_token.get_secret_value(),
)
slack = SlackClient(
http_client=http_client,
webhook_url=settings.slack_webhook_url.get_secret_value(),
bot_token=settings.slack_bot_token.get_secret_value(),
channel_id=settings.slack_channel_id,
)
reviewer = ClaudeReviewer()
clients = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
return clients, [http_client, vsrm_http_client]
def _create_staging_store(pool=None) -> PostgresStagingStore | JsonFileStagingStore:
"""Return a StagingStore instance.
When a PostgreSQL pool is provided, returns a PostgresStagingStore.
Falls back to JsonFileStagingStore for local development without a DB.
"""
if pool is not None:
return PostgresStagingStore(pool=pool)
return JsonFileStagingStore(directory=Path("data/staging"))
async def _ensure_db_schema(pool) -> None:
"""Create all required tables if they do not already exist."""
statements = [
"""
CREATE TABLE IF NOT EXISTS agent_threads (
thread_id TEXT PRIMARY KEY,
graph_name TEXT,
status TEXT NOT NULL DEFAULT 'running',
interrupt_value TEXT,
state JSONB,
repo_name TEXT,
pr_id TEXT,
version TEXT,
slack_message_ts TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
""",
"""
ALTER TABLE agent_threads
ADD COLUMN IF NOT EXISTS slack_message_ts TEXT
""",
"""
CREATE TABLE IF NOT EXISTS staging_releases (
repo TEXT PRIMARY KEY,
version TEXT NOT NULL,
started_at DATE NOT NULL,
tickets JSONB NOT NULL DEFAULT '[]',
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
""",
"""
CREATE TABLE IF NOT EXISTS archived_releases (
repo TEXT NOT NULL,
version TEXT NOT NULL,
started_at DATE NOT NULL,
tickets JSONB NOT NULL DEFAULT '[]',
released_at DATE NOT NULL,
PRIMARY KEY (repo, version)
)
""",
]
async with pool.connection() as conn:
async with conn.cursor() as cur:
for sql in statements:
await cur.execute(sql)
# ---------------------------------------------------------------------------
# Background task helpers
# ---------------------------------------------------------------------------
def schedule_graph(
*,
app: FastAPI,
graph,
initial_state: dict,
thread_id: str | None = None,
tool_clients: ToolClients | None = None,
db_pool=None,
) -> str:
"""Schedule a graph run as a background asyncio task.
Creates a new thread_id if not provided. Adds the task to
app.state.background_tasks for graceful shutdown tracking.
Returns the thread_id used for this execution.
"""
if thread_id is None:
thread_id = str(uuid.uuid4())
if tool_clients is None and hasattr(app.state, "tool_clients"):
tool_clients = app.state.tool_clients
if db_pool is None and hasattr(app.state, "db_pool"):
db_pool = app.state.db_pool
# Extract settings for config propagation
settings = getattr(app.state, "settings", None)
repos_base_dir = getattr(settings, "repos_base_dir", "") if settings else ""
default_jira_project = getattr(settings, "default_jira_project", "ALLPOST") if settings else "ALLPOST"
task = asyncio.create_task(
run_graph_in_background(
graph=graph,
initial_state=initial_state,
thread_id=thread_id,
tool_clients=tool_clients,
db_pool=db_pool,
repos_base_dir=repos_base_dir,
default_jira_project=default_jira_project,
)
)
app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard)
return thread_id
async def run_graph_in_background(
*,
graph,
initial_state: dict,
thread_id: str,
tool_clients: ToolClients | None = None,
db_pool=None,
repos_base_dir: str = "",
default_jira_project: str = "ALLPOST",
) -> None:
"""Execute the graph and update thread status in the database."""
from release_agent.api.webhooks import _upsert_thread
config = {
"configurable": {
"thread_id": thread_id,
"clients": tool_clients,
"repos_base_dir": repos_base_dir,
"default_jira_project": default_jira_project,
}
}
try:
if db_pool is not None:
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="running", state=initial_state)
result = await graph.ainvoke(initial_state, config=config)
if db_pool is not None:
await _upsert_thread(db_pool, thread_id=thread_id, thread_status="completed", state=result or {})
except Exception as exc:
logger.exception("Graph execution failed for thread %s: %s", thread_id, exc)
if db_pool is not None:
await _upsert_thread(
db_pool,
thread_id=thread_id,
thread_status="failed",
state={"errors": [str(exc)]},
)
# ---------------------------------------------------------------------------
# Lifespan
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage startup and graceful shutdown of shared resources."""
settings = Settings()
app.state.settings = settings
app.state.background_tasks = set()
app.state.started_at = datetime.now(tz=timezone.utc)
# Compile graphs once at startup
logger.info("Compiling graphs...")
app.state.graphs = {
"pr_completed": build_pr_completed_graph(),
"release": build_release_graph(),
}
# Build tool clients
logger.info("Building tool clients...")
app.state.tool_clients, app.state._http_clients = _create_tool_clients(settings)
# Open PostgreSQL connection pool
logger.info("Opening DB connection pool...")
pool = AsyncConnectionPool(
conninfo=settings.postgres_dsn.get_secret_value(),
open=False,
)
await pool.open()
app.state.db_pool = pool
# Ensure schema exists
await _ensure_db_schema(pool)
# Build staging store (backed by PostgreSQL)
app.state.staging_store = _create_staging_store(pool=pool)
# Start PR polling background task if enabled
if settings.pr_poll_enabled:
logger.info("Starting PR polling background task...")
poll_task = asyncio.create_task(
run_pr_poll_loop(
azdo_client=app.state.tool_clients.azdo,
db_pool=pool,
watched_repos=settings.watched_repos_list,
target_branch=settings.pr_poll_target_branch,
interval_seconds=settings.pr_poll_interval_seconds,
schedule_fn=lambda *, initial_state: schedule_graph(
app=app,
graph=app.state.graphs["pr_completed"],
initial_state=initial_state,
),
)
)
app.state.background_tasks.add(poll_task)
poll_task.add_done_callback(app.state.background_tasks.discard)
logger.info("Startup complete.")
yield
# Graceful shutdown: wait for background tasks
tasks = list(app.state.background_tasks)
if tasks:
logger.info("Waiting for %d background task(s) to complete...", len(tasks))
done, pending = await asyncio.wait(tasks, timeout=_SHUTDOWN_TIMEOUT_SECONDS)
for task in pending:
logger.warning("Cancelling timed-out background task %s", task)
task.cancel()
# Close HTTP clients
for client in getattr(app.state, "_http_clients", []):
await client.aclose()
await pool.close()
logger.info("Shutdown complete.")
# ---------------------------------------------------------------------------
# Exception handlers
# ---------------------------------------------------------------------------
async def _release_agent_error_handler(request: Request, exc: ReleaseAgentError) -> JSONResponse:
return JSONResponse(
status_code=500,
content=ErrorResponse(
error=type(exc).__name__,
detail=str(exc),
).model_dump(),
)
async def _generic_error_handler(request: Request, exc: Exception) -> JSONResponse:
logger.exception("Unhandled exception: %s", exc)
return JSONResponse(
status_code=500,
content=ErrorResponse(
error="InternalServerError",
detail="An unexpected error occurred",
).model_dump(),
)
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
app = FastAPI(
title="Billo Release Agent",
version="0.1.0",
description="LangGraph-based release automation agent",
lifespan=lifespan,
)
# Register routers
app.include_router(webhook_router)
app.include_router(approvals_router)
app.include_router(status_router)
app.include_router(slack_interactions_router)
# Register exception handlers
app.add_exception_handler(ReleaseAgentError, _release_agent_error_handler) # type: ignore[arg-type]
app.add_exception_handler(Exception, _generic_error_handler)
return app
# ---------------------------------------------------------------------------
# ASGI entry point
# ---------------------------------------------------------------------------
app = create_app()

View File

@@ -0,0 +1 @@
"""Pydantic models for the release agent."""

View File

@@ -0,0 +1,43 @@
"""Build and approval dataclasses for CI/CD tracking.
Immutable (frozen) dataclasses used across graph nodes, routing functions,
and API endpoints.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class BuildStatus:
"""Immutable snapshot of an Azure DevOps build's current state.
Attributes:
status: AzDo build status string (e.g. "notStarted", "inProgress",
"completed", "cancelling").
result: AzDo build result string (e.g. "succeeded", "failed",
"canceled", "partiallySucceeded"). None when not completed.
build_url: Direct URL to the build results page, or None if unknown.
"""
status: str
result: str | None
build_url: str | None
@dataclass(frozen=True)
class ApprovalRecord:
"""Immutable record representing a single release pipeline approval gate.
Attributes:
approval_id: Unique identifier of the approval record in AzDo.
stage_name: Human-readable name of the deployment stage (e.g.
"Sandbox", "Production").
status: Current approval status (e.g. "pending", "approved",
"rejected").
release_id: Numeric ID of the AzDo release containing this approval.
"""
approval_id: str
stage_name: str
status: str
release_id: int

View File

@@ -0,0 +1,33 @@
"""Jira domain models."""
from pydantic import BaseModel, ConfigDict
class JiraTransition(BaseModel):
"""A Jira workflow transition.
Attributes:
id: Transition identifier as returned by the Jira API.
name: Human-readable transition name (e.g. "Done", "Released").
"""
model_config = ConfigDict(frozen=True)
id: str
name: str
class JiraIssue(BaseModel):
"""A Jira issue.
Attributes:
key: Issue key (e.g. "ALLPOST-100").
summary: Issue summary / title.
status: Current workflow status name (e.g. "In Progress").
"""
model_config = ConfigDict(frozen=True)
key: str
summary: str
status: str

View File

@@ -0,0 +1,31 @@
"""Pipeline-related models for Azure DevOps pipeline tracking."""
from pydantic import BaseModel, ConfigDict, Field, model_validator
class PipelineInfo(BaseModel):
"""An immutable record identifying an Azure DevOps pipeline."""
model_config = ConfigDict(frozen=True)
id: int
name: str
repo: str
class ReleasePipelineStage(BaseModel):
"""An immutable record representing a single stage in a release pipeline."""
model_config = ConfigDict(frozen=True)
name: str
rank: int = Field(ge=0)
requires_approval: bool
approval_id: str | None = None
@model_validator(mode="after")
def validate_approval_consistency(self) -> "ReleasePipelineStage":
"""Ensure approval_id is consistent with requires_approval."""
if not self.requires_approval and self.approval_id is not None:
raise ValueError("approval_id must be None when requires_approval is False")
return self

View File

@@ -0,0 +1,38 @@
"""PRInfo model representing a pull request with ticket extraction."""
from typing import Literal
from pydantic import BaseModel, ConfigDict, HttpUrl, model_validator
from release_agent.branch_parser import parse_branch
class PRInfo(BaseModel):
"""An immutable representation of a pull request.
The ticket_id and has_ticket fields are automatically derived
from the branch name during model construction.
"""
model_config = ConfigDict(frozen=True)
pr_id: str
pr_url: HttpUrl
repo_name: str
branch: str
pr_title: str
pr_status: Literal["active", "completed", "abandoned"]
ticket_id: str | None = None
has_ticket: bool = False
@model_validator(mode="before")
@classmethod
def extract_ticket_from_branch(cls, values: dict) -> dict:
"""Derive ticket_id and has_ticket from the branch name."""
branch = values.get("branch", "")
ticket_id, has_ticket = parse_branch(branch)
# ticket_id and has_ticket are always derived from branch parsing.
# Any caller-provided values are overridden to keep the model consistent.
values["ticket_id"] = ticket_id
values["has_ticket"] = has_ticket
return values

View File

@@ -0,0 +1,65 @@
"""StagingRelease and ArchivedRelease models for tracking release state."""
import re
from datetime import date
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
from release_agent.models.ticket import TicketEntry
_VERSION_PATTERN = re.compile(r"^v\d+\.\d+\.\d+$")
class StagingRelease(BaseModel):
"""An immutable model representing a release currently in staging.
All mutation operations return new instances rather than modifying
the existing one.
"""
model_config = ConfigDict(frozen=True)
version: str
repo: str
started_at: date
tickets: list[TicketEntry]
@field_validator("version")
@classmethod
def validate_version_format(cls, value: str) -> str:
"""Validate version string matches vMAJOR.MINOR.PATCH."""
if not _VERSION_PATTERN.match(value):
raise ValueError(
f"Invalid version format: {value!r}. Must match ^v\\d+\\.\\d+\\.\\d+$"
)
return value
def add_ticket(self, ticket: TicketEntry) -> "StagingRelease":
"""Return a new StagingRelease with the given ticket appended.
Does not mutate this instance.
"""
return self.model_copy(update={"tickets": [*self.tickets, ticket]})
def has_ticket(self, ticket_id: str) -> bool:
"""Return True if a ticket with the given ID exists in this release."""
return any(t.id == ticket_id for t in self.tickets)
class ArchivedRelease(StagingRelease):
"""An immutable model representing a completed/released release.
Extends StagingRelease with a released_at date that must be
on or after started_at.
"""
released_at: date
@model_validator(mode="after")
def validate_released_after_started(self) -> "ArchivedRelease":
"""Validate that released_at is not before started_at."""
if self.released_at < self.started_at:
raise ValueError(
f"released_at ({self.released_at}) must be >= started_at ({self.started_at})"
)
return self

View File

@@ -0,0 +1,48 @@
"""Review models for Claude PR review structured output."""
from typing import Literal
from pydantic import BaseModel, ConfigDict, computed_field
class ReviewIssue(BaseModel):
"""A single issue identified during PR review.
Attributes:
severity: One of "blocker", "error", "warning", or "info".
description: Human-readable description of the issue.
file_path: Optional path to the affected file.
suggestion: Optional remediation suggestion.
"""
model_config = ConfigDict(frozen=True)
severity: Literal["blocker", "error", "warning", "info"]
description: str
file_path: str | None = None
line_start: int | None = None
line_end: int | None = None
suggestion: str | None = None
class ReviewResult(BaseModel):
"""Structured output from a Claude PR review.
Attributes:
verdict: Either "approve" or "request_changes".
summary: A human-readable summary of the review.
issues: List of issues identified during review.
has_blockers: Computed field - True if any issue has severity "blocker".
"""
model_config = ConfigDict(frozen=True)
verdict: Literal["approve", "request_changes"]
summary: str
issues: tuple[ReviewIssue, ...] = ()
@computed_field # type: ignore[misc]
@property
def has_blockers(self) -> bool:
"""Return True if any issue has blocker severity."""
return any(issue.severity == "blocker" for issue in self.issues)

View File

@@ -0,0 +1,33 @@
"""TicketEntry model representing a Jira ticket linked to a merged PR."""
import re
from datetime import date
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
_JIRA_ID_PATTERN = re.compile(r"^[A-Z][A-Z0-9]*-\d+$")
class TicketEntry(BaseModel):
"""An immutable record of a Jira ticket merged into a release."""
model_config = ConfigDict(frozen=True)
id: str
summary: str
pr_id: str
pr_url: HttpUrl
pr_title: str
branch: str
merged_at: date
@field_validator("id")
@classmethod
def validate_jira_id(cls, value: str) -> str:
"""Validate that the ID matches the Jira ticket format."""
if not _JIRA_ID_PATTERN.match(value):
raise ValueError(
f"Invalid Jira ticket ID: {value!r}. "
"Must match pattern ^[A-Z][A-Z0-9]*-\\d+$"
)
return value

View File

@@ -0,0 +1,40 @@
"""WebhookPayload models for Azure DevOps webhook events."""
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, ConfigDict, HttpUrl
class WebhookRepository(BaseModel):
"""An immutable record of a repository as reported in a webhook event."""
model_config = ConfigDict(frozen=True)
id: str
name: str
web_url: HttpUrl
class WebhookResource(BaseModel):
"""An immutable record of the resource object in a webhook payload."""
model_config = ConfigDict(frozen=True)
repository: WebhookRepository
pull_request_id: int
title: str
source_ref_name: str
target_ref_name: str
status: Literal["active", "completed", "abandoned"]
closed_date: datetime | None
class WebhookPayload(BaseModel):
"""An immutable model representing an Azure DevOps webhook event payload."""
model_config = ConfigDict(frozen=True)
subscription_id: str
event_type: str
resource: WebhookResource

View File

View File

@@ -0,0 +1,46 @@
"""PR deduplication service.
Queries the agent_threads table to find which PRs from a given list have
not yet been processed. This prevents the PR poller from re-triggering
graph runs for PRs that already have an existing thread.
"""
from release_agent.models.pr import PRInfo
_QUERY = """
SELECT pr_id, repo_name
FROM agent_threads
WHERE (pr_id, repo_name) IN (
SELECT unnest(%s::text[]), unnest(%s::text[])
)
"""
async def find_unprocessed_prs(pool, prs: list[PRInfo]) -> list[PRInfo]:
"""Return the subset of prs that have no existing agent_threads record.
A PR is considered already-processed if there exists a row in agent_threads
with matching (pr_id, repo_name) pair. The SQL uses unnest to enforce
pair-wise matching (not independent ANY on each column).
Args:
pool: Async psycopg connection pool.
prs: List of PRInfo objects to check.
Returns:
A new list containing only PRs with no existing thread.
Original list is not mutated.
"""
if not prs:
return []
pr_ids = [p.pr_id for p in prs]
repo_names = [p.repo_name for p in prs]
async with pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(_QUERY, (pr_ids, repo_names))
rows = await cur.fetchall()
processed = {(str(row[0]), str(row[1])) for row in rows}
return [p for p in prs if (p.pr_id, p.repo_name) not in processed]

View File

@@ -0,0 +1,108 @@
"""PR polling service.
Polls Azure DevOps for active pull requests targeting a configured branch,
deduplicates against already-processed PRs, and schedules graph runs for
newly discovered PRs by synthesizing a fake webhook payload.
"""
import asyncio
import logging
from collections.abc import Callable, Coroutine
from typing import Any
from release_agent.models.pr import PRInfo
from release_agent.services.pr_dedup import find_unprocessed_prs
logger = logging.getLogger(__name__)
def _synthesize_webhook_payload(pr: PRInfo) -> dict:
"""Build a fake webhook payload dict from a PRInfo.
Produces a dict that matches the structure expected by the parse_webhook
node so the PR can be processed as if it arrived via the real webhook.
Args:
pr: The PRInfo object representing the active pull request.
Returns:
A dict with event_type, subscription_id, and resource fields.
"""
return {
"event_type": "git.pullrequest.updated",
"subscription_id": f"polled-{pr.repo_name}-{pr.pr_id}",
"resource": {
"pull_request_id": int(pr.pr_id),
"title": pr.pr_title,
"status": pr.pr_status,
"source_ref_name": pr.branch,
"target_ref_name": "",
"repository": {
"id": f"{pr.repo_name}-id",
"name": pr.repo_name,
"web_url": str(pr.pr_url).rsplit("/pullrequest/", 1)[0],
},
},
}
async def run_pr_poll_loop(
*,
azdo_client,
db_pool,
watched_repos: list[str],
target_branch: str,
interval_seconds: int,
schedule_fn: Callable[..., Any],
sleep_fn: Callable[[float], Coroutine] | None = None,
) -> None:
"""Poll Azure DevOps for active PRs and schedule unprocessed ones.
Runs indefinitely until cancelled. For each polling interval:
1. For each watched repo, list active PRs targeting target_branch.
2. Find PRs not yet recorded in agent_threads.
3. For each unprocessed PR, call schedule_fn with the synthesized payload.
4. Sleep for interval_seconds.
Args:
azdo_client: AzDoClient instance.
db_pool: Async psycopg connection pool.
watched_repos: List of repository names to poll.
target_branch: Target branch ref to filter PRs (e.g. "refs/heads/develop").
interval_seconds: Seconds to sleep between polling iterations.
schedule_fn: Callable to schedule a graph run; called with keyword arg:
initial_state (dict).
sleep_fn: Async sleep callable (default: asyncio.sleep). Injectable for testing.
"""
if sleep_fn is None:
sleep_fn = asyncio.sleep # type: ignore[assignment]
while True:
all_prs: list[PRInfo] = []
for repo in watched_repos:
try:
prs = await azdo_client.list_active_prs(repo, target_branch)
all_prs.extend(prs)
except Exception as exc:
logger.warning("PR polling failed for repo %s: %s", repo, exc)
try:
unprocessed = await find_unprocessed_prs(db_pool, all_prs)
except Exception as exc:
logger.warning("PR dedup query failed: %s", exc)
unprocessed = []
for pr in unprocessed:
webhook_payload = _synthesize_webhook_payload(pr)
initial_state = {
"webhook_payload": webhook_payload,
"pr_id": pr.pr_id,
"repo_name": pr.repo_name,
}
try:
schedule_fn(initial_state=initial_state)
except Exception as exc:
logger.warning("Failed to schedule PR %s/%s: %s", pr.repo_name, pr.pr_id, exc)
await sleep_fn(interval_seconds) # type: ignore[misc]

View File

@@ -0,0 +1,91 @@
"""LangGraph state definition for the release agent.
Uses TypedDict with total=False so all fields are optional, enabling
partial state updates throughout the graph execution.
Reducers follow immutable patterns - they always return new lists.
"""
from typing import Annotated
from typing_extensions import TypedDict
def add_messages(existing: list[str], new: list[str]) -> list[str]:
"""Accumulate messages without mutating the existing list.
Returns a new list containing all existing messages followed by new ones.
"""
return [*existing, *new]
def add_errors(existing: list[str], new: list[str]) -> list[str]:
"""Accumulate error messages without mutating the existing list.
Returns a new list containing all existing errors followed by new ones.
"""
return [*existing, *new]
class ReleaseState(TypedDict, total=False):
"""The mutable state passed between nodes in the release agent graph.
All fields are optional (total=False). Each field can be updated
independently by graph nodes. List fields use reducers that accumulate
values rather than replacing them.
"""
# Core identifiers
repo_name: str
pr_id: str
ticket_id: str
version: str
# Accumulated lists
messages: Annotated[list[str], add_messages]
errors: Annotated[list[str], add_errors]
# Phase 3: Webhook / PR fields
webhook_payload: dict
pr_info: dict
pr_diff: str
last_merge_source_commit: str
# Phase 3: Jira / ticket fields
ticket_summary: str
has_ticket: bool
# Phase 3: Review fields
review_result: dict
review_approved: bool
# Phase 3: Staging fields
staging: dict
pr_already_merged: bool
# Phase 3: Release PR fields
release_pr_id: str
release_pr_commit: str
# Phase 3: Pipeline / approval fields
pipelines: list[dict]
triggered_builds: list[dict]
pending_approvals: list[dict]
# Phase 3: Flow control
continue_to_release: bool
# Phase 5: CI build tracking
ci_build_id: int
ci_build_status: str
ci_build_result: str
ci_build_url: str
# Phase 5: CD release tracking
release_definition_id: int
release_id: int
current_stage: str
# Phase 5: Slack interactive message tracking
approval_message_ts: str
slack_message_ts: str

View File

View File

@@ -0,0 +1,103 @@
"""Shared HTTP helpers for service clients.
Provides a unified response-to-exception mapper and a Basic auth header builder.
"""
import base64
import httpx
from release_agent.exceptions import (
AuthenticationError,
NotFoundError,
RateLimitError,
ServiceError,
ServiceUnavailableError,
)
# Status codes that should not raise exceptions
_OK_RANGE_START = 200
_OK_RANGE_END = 299
_REDIRECT_RANGE_END = 399
def raise_for_status(response: httpx.Response, *, service: str) -> None:
"""Raise an appropriate exception based on the HTTP status code.
2xx and 3xx responses do not raise. 4xx/5xx responses raise typed exceptions
that carry the service name for debugging.
Args:
response: The httpx response to check.
service: Name of the service that returned this response.
Raises:
AuthenticationError: For 401 or 403 status codes.
NotFoundError: For 404 status codes.
RateLimitError: For 429 status codes.
ServiceUnavailableError: For 503 status codes.
ServiceError: For all other 4xx/5xx status codes.
"""
code = response.status_code
if code <= _REDIRECT_RANGE_END:
return
if code in (401, 403):
raise AuthenticationError(
service=service, status_code=code, detail=_extract_detail(response)
)
if code == 404:
raise NotFoundError(service=service, detail=_extract_detail(response))
if code == 429:
retry_after = _parse_retry_after(response)
raise RateLimitError(service=service, retry_after=retry_after)
if code == 503:
raise ServiceUnavailableError(service=service, detail=_extract_detail(response))
raise ServiceError(service=service, status_code=code, detail=_extract_detail(response))
def build_auth_header(username: str, password: str) -> dict[str, str]:
"""Build a Basic authentication header.
Args:
username: The username (may be empty string for PAT-only auth).
password: The password or token.
Returns:
A dict with a single "Authorization" key containing the Basic auth value.
"""
credentials = f"{username}:{password}"
encoded = base64.b64encode(credentials.encode()).decode()
return {"Authorization": f"Basic {encoded}"}
def _extract_detail(response: httpx.Response) -> str | None:
"""Try to extract a human-readable detail string from the response body."""
try:
body = response.json()
if isinstance(body, dict):
return (
body.get("message")
or body.get("detail")
or ("; ".join(body["errorMessages"]) if body.get("errorMessages") else None)
or str(body)
)
return str(body)
except (ValueError, KeyError):
return response.text or None
def _parse_retry_after(response: httpx.Response) -> int | None:
"""Parse the Retry-After header value as an integer number of seconds."""
value = response.headers.get("Retry-After")
if value is None:
return None
try:
return int(value)
except ValueError:
return None

View File

@@ -0,0 +1,74 @@
"""Async retry decorator with exponential backoff.
Retries only on RateLimitError and ServiceUnavailableError.
Respects the Retry-After header when present in RateLimitError.
"""
import asyncio
import functools
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from release_agent.exceptions import RateLimitError, ServiceUnavailableError
_T = TypeVar("_T")
_DEFAULT_MAX_ATTEMPTS = 3
_DEFAULT_BASE_DELAY = 1.0
_BACKOFF_MULTIPLIER = 2.0
# Types that trigger a retry
_RETRYABLE = (RateLimitError, ServiceUnavailableError)
def with_retry(
max_attempts: int = _DEFAULT_MAX_ATTEMPTS,
base_delay: float = _DEFAULT_BASE_DELAY,
sleep_fn: Callable[[float], Awaitable[None]] | None = None,
) -> Callable[[Callable[..., Awaitable[_T]]], Callable[..., Awaitable[_T]]]:
"""Decorator factory that retries an async function on retryable errors.
Args:
max_attempts: Maximum total attempts (including the first call).
base_delay: Initial delay in seconds between retries. Doubles each retry.
sleep_fn: Async sleep function (defaults to asyncio.sleep). Injected for testing.
Returns:
A decorator that wraps an async function with retry logic.
"""
if max_attempts < 1:
raise ValueError(f"max_attempts must be >= 1, got {max_attempts}")
_sleep = sleep_fn if sleep_fn is not None else asyncio.sleep
def decorator(fn: Callable[..., Awaitable[_T]]) -> Callable[..., Awaitable[_T]]:
@functools.wraps(fn)
async def wrapper(*args: Any, **kwargs: Any) -> _T:
delay = base_delay
last_exc: Exception | None = None
for attempt in range(1, max_attempts + 1):
try:
return await fn(*args, **kwargs)
except _RETRYABLE as exc:
last_exc = exc
if attempt >= max_attempts:
raise
wait = _resolve_wait(exc, delay)
await _sleep(wait)
delay *= _BACKOFF_MULTIPLIER
raise RuntimeError("Retry loop exited unexpectedly") from last_exc
return wrapper
return decorator
def _resolve_wait(exc: Exception, fallback_delay: float) -> float:
"""Return the number of seconds to wait before the next attempt.
For RateLimitError with a retry_after value, that value takes precedence.
Otherwise the exponential backoff delay is used.
"""
if isinstance(exc, RateLimitError) and exc.retry_after is not None:
return float(exc.retry_after)
return fallback_delay

View File

@@ -0,0 +1,513 @@
"""Azure DevOps service client.
Uses two httpx.AsyncClient instances:
- http_client: main AzDo REST API (dev.azure.com)
- vsrm_http_client: VSRM release management API (vsrm.dev.azure.com)
Both clients are injected via constructor for testability.
"""
from types import TracebackType
from typing import Self
import httpx
from release_agent.models.build import ApprovalRecord, BuildStatus
from release_agent.models.pipeline import PipelineInfo
from release_agent.models.pr import PRInfo
from release_agent.tools._http import build_auth_header, raise_for_status
_API_VERSION = "7.1"
class AzDoClient:
"""Client for the Azure DevOps REST API.
Args:
base_url: Main AzDo project API base URL.
vsrm_base_url: VSRM release management API base URL.
pat: Personal Access Token for authentication.
http_client: Injected httpx.AsyncClient for the main API.
vsrm_http_client: Injected httpx.AsyncClient for the VSRM API.
"""
def __init__(
self,
*,
base_url: str,
vsrm_base_url: str,
pat: str,
http_client: httpx.AsyncClient,
vsrm_http_client: httpx.AsyncClient,
) -> None:
self._base_url = base_url.rstrip("/")
self._vsrm_base_url = vsrm_base_url.rstrip("/")
self._auth = build_auth_header("", pat)
self._http = http_client
self._vsrm = vsrm_http_client
async def close(self) -> None:
"""Close both underlying HTTP clients."""
await self._http.aclose()
await self._vsrm.aclose()
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
# ------------------------------------------------------------------
# Pull requests
# ------------------------------------------------------------------
async def get_pr(self, pr_id: int) -> PRInfo:
"""Fetch a pull request by ID.
Args:
pr_id: Numeric pull request ID.
Returns:
A PRInfo model populated from the API response.
Raises:
NotFoundError: If the PR does not exist.
AuthenticationError: If authentication fails.
ServiceError: For other HTTP errors.
"""
url = f"{self._base_url}/git/pullRequests/{pr_id}"
response = await self._http.get(url, headers=self._auth, params={"api-version": _API_VERSION})
raise_for_status(response, service="azdo")
data = response.json()
return _parse_pr(data)
async def list_active_prs(self, repo: str, target_branch: str) -> list[PRInfo]:
"""List active pull requests for a repository filtered by target branch.
Args:
repo: Repository name.
target_branch: Target branch ref name (e.g. "refs/heads/develop").
Returns:
List of PRInfo models for active PRs targeting the given branch.
Raises:
NotFoundError: If the repository does not exist.
AuthenticationError: If authentication fails.
ServiceError: For other HTTP errors.
"""
url = f"{self._base_url}/git/repositories/{repo}/pullRequests"
response = await self._http.get(
url,
headers=self._auth,
params={
"api-version": _API_VERSION,
"status": "active",
"targetRefName": target_branch,
},
)
raise_for_status(response, service="azdo")
data = response.json()
return [_parse_pr(item) for item in data.get("value", [])]
async def get_pr_diff(self, pr_id: int) -> str:
"""Return a diff-like string for the given pull request.
Fetches the PR to determine the repository, then retrieves the
diffs endpoint and formats a summary string suitable for code review.
Args:
pr_id: Numeric pull request ID.
Returns:
A text string describing changed files.
Raises:
NotFoundError: If the PR does not exist.
ServiceError: For other HTTP errors.
"""
pr = await self.get_pr(pr_id)
repo = pr.repo_name
url = f"{self._base_url}/git/repositories/{repo}/pullRequests/{pr_id}/diffs"
response = await self._http.get(
url,
headers=self._auth,
params={"api-version": _API_VERSION, "baseVersionDescriptor.versionType": "commit"},
)
raise_for_status(response, service="azdo")
data = response.json()
return _format_diff(data)
async def merge_pr(self, *, pr_id: int, last_merge_source_commit: str) -> bool:
"""Complete (merge) a pull request.
Args:
pr_id: Numeric pull request ID.
last_merge_source_commit: The commit ID to use as the merge source.
Returns:
True on success.
Raises:
NotFoundError: If the PR does not exist.
ServiceError: For other HTTP errors including merge conflicts.
"""
url = f"{self._base_url}/git/pullRequests/{pr_id}"
payload = {
"status": "completed",
"lastMergeSourceCommit": {"commitId": last_merge_source_commit},
"completionOptions": {"mergeStrategy": "squash"},
}
response = await self._http.patch(
url,
headers={**self._auth, "Content-Type": "application/json"},
json=payload,
params={"api-version": _API_VERSION},
)
raise_for_status(response, service="azdo")
return True
async def create_pr(
self,
*,
repo: str,
source: str,
target: str,
title: str,
description: str,
) -> dict:
"""Create a new pull request.
Args:
repo: Repository name.
source: Source branch ref name (e.g. "refs/heads/release/v1.2.0").
target: Target branch ref name (e.g. "refs/heads/main").
title: PR title.
description: PR description.
Returns:
Raw dict from the API response.
Raises:
ServiceError: For HTTP errors.
"""
url = f"{self._base_url}/git/repositories/{repo}/pullRequests"
payload = {
"sourceRefName": source,
"targetRefName": target,
"title": title,
"description": description,
}
response = await self._http.post(
url,
headers={**self._auth, "Content-Type": "application/json"},
json=payload,
params={"api-version": _API_VERSION},
)
raise_for_status(response, service="azdo")
return response.json()
# ------------------------------------------------------------------
# Pipelines
# ------------------------------------------------------------------
async def list_build_pipelines(self, *, repo: str) -> list[PipelineInfo]:
"""List build pipelines for a repository.
Args:
repo: Repository name (used to filter pipelines).
Returns:
List of PipelineInfo models.
Raises:
ServiceError: For HTTP errors.
"""
url = f"{self._base_url}/pipelines"
response = await self._http.get(
url,
headers=self._auth,
params={"api-version": _API_VERSION},
)
raise_for_status(response, service="azdo")
data = response.json()
return [
PipelineInfo(id=item["id"], name=item["name"], repo=repo)
for item in data.get("value", [])
]
async def trigger_pipeline(self, *, pipeline_id: int, branch: str) -> dict:
"""Trigger a pipeline run.
Args:
pipeline_id: Numeric pipeline definition ID.
branch: Branch ref to build (e.g. "refs/heads/main").
Returns:
Raw dict from the API response containing the build/run ID.
Raises:
NotFoundError: If the pipeline does not exist.
ServiceError: For other HTTP errors.
"""
url = f"{self._base_url}/pipelines/{pipeline_id}/runs"
payload = {"resources": {"repositories": {"self": {"refName": branch}}}}
response = await self._http.post(
url,
headers={**self._auth, "Content-Type": "application/json"},
json=payload,
params={"api-version": _API_VERSION},
)
raise_for_status(response, service="azdo")
return response.json()
async def get_build_status(self, *, build_id: int) -> BuildStatus:
"""Get the status of a build.
Args:
build_id: Numeric build ID.
Returns:
BuildStatus dataclass with status, result, and build_url fields.
Raises:
NotFoundError: If the build does not exist.
ServiceError: For other HTTP errors.
"""
url = f"{self._base_url}/build/builds/{build_id}"
response = await self._http.get(
url,
headers=self._auth,
params={"api-version": _API_VERSION},
)
raise_for_status(response, service="azdo")
data = response.json()
links = data.get("_links") or {}
web_link = links.get("web") or {}
build_url = web_link.get("href") or data.get("url")
return BuildStatus(
status=data["status"],
result=data.get("result"),
build_url=build_url,
)
async def get_release_approvals(self, *, release_id: int) -> list[ApprovalRecord]:
"""Get pending approvals for a release.
Args:
release_id: Numeric release ID.
Returns:
List of ApprovalRecord dataclasses for this release.
Raises:
NotFoundError: If the release does not exist.
ServiceError: For other HTTP errors.
"""
url = f"{self._vsrm_base_url}/release/approvals"
response = await self._vsrm.get(
url,
headers=self._auth,
params={"api-version": _API_VERSION, "releaseId": release_id},
)
raise_for_status(response, service="azdo")
data = response.json()
records: list[ApprovalRecord] = []
for item in data.get("value", []):
env = item.get("releaseEnvironment") or {}
rel = env.get("release") or {}
records.append(
ApprovalRecord(
approval_id=str(item["id"]),
stage_name=env.get("name", ""),
status=item.get("status", "pending"),
release_id=rel.get("id", release_id),
)
)
return records
async def get_latest_release(self, *, definition_id: int) -> dict:
"""Get the latest release for a release definition.
Args:
definition_id: Numeric release definition ID.
Returns:
Raw dict from the API response for the latest release,
or an empty dict if no releases exist.
Raises:
NotFoundError: If the definition does not exist.
ServiceError: For other HTTP errors.
"""
url = f"{self._vsrm_base_url}/release/releases"
response = await self._vsrm.get(
url,
headers=self._auth,
params={
"api-version": _API_VERSION,
"definitionId": definition_id,
"$top": 1,
"$orderby": "id desc",
},
)
raise_for_status(response, service="azdo")
data = response.json()
values = data.get("value", [])
return values[0] if values else {}
# ------------------------------------------------------------------
# Release approvals (VSRM)
# ------------------------------------------------------------------
async def approve_release(self, *, approval_id: str, comment: str) -> dict:
"""Approve a release pipeline stage approval.
Uses the VSRM API endpoint.
Args:
approval_id: The approval record ID.
comment: Comment to attach to the approval.
Returns:
Raw dict from the API response.
Raises:
NotFoundError: If the approval does not exist.
ServiceError: For other HTTP errors.
"""
url = f"{self._vsrm_base_url}/release/approvals/{approval_id}"
payload = {"status": "approved", "comments": comment}
response = await self._vsrm.patch(
url,
headers={**self._auth, "Content-Type": "application/json"},
json=payload,
params={"api-version": _API_VERSION},
)
raise_for_status(response, service="azdo")
return response.json()
# -------------------------------------------------------------------
# PR Comment methods
# -------------------------------------------------------------------
async def add_pr_comment(
self,
*,
repo: str,
pr_id: int,
content: str,
) -> dict:
"""Add a general comment thread to a PR (not file-specific).
Args:
repo: Repository name.
pr_id: Pull request ID.
content: Markdown content for the comment.
Returns:
Raw dict from the API response.
"""
url = f"{self._base_url}/git/repositories/{repo}/pullRequests/{pr_id}/threads"
payload = {
"comments": [{"parentCommentId": 0, "content": content, "commentType": 1}],
"status": "active",
}
response = await self._http.post(
url,
headers={**self._auth, "Content-Type": "application/json"},
json=payload,
params={"api-version": _API_VERSION},
)
raise_for_status(response, service="azdo")
return response.json()
async def add_pr_inline_comment(
self,
*,
repo: str,
pr_id: int,
content: str,
file_path: str,
line_start: int,
line_end: int | None = None,
) -> dict:
"""Add an inline comment thread to a specific file and line in a PR.
Args:
repo: Repository name.
pr_id: Pull request ID.
content: Markdown content for the comment.
file_path: Path to the file (e.g., "/src/Handlers/InvoiceHandler.cs").
line_start: Starting line number.
line_end: Ending line number (defaults to line_start).
Returns:
Raw dict from the API response.
"""
effective_end = line_end if line_end is not None else line_start
url = f"{self._base_url}/git/repositories/{repo}/pullRequests/{pr_id}/threads"
payload = {
"comments": [{"parentCommentId": 0, "content": content, "commentType": 1}],
"threadContext": {
"filePath": file_path if file_path.startswith("/") else f"/{file_path}",
"rightFileStart": {"line": line_start, "offset": 1},
"rightFileEnd": {"line": effective_end, "offset": 1},
},
"status": "active",
}
response = await self._http.post(
url,
headers={**self._auth, "Content-Type": "application/json"},
json=payload,
params={"api-version": _API_VERSION},
)
raise_for_status(response, service="azdo")
return response.json()
# ---------------------------------------------------------------------------
# Private parsing helpers
# ---------------------------------------------------------------------------
def _parse_pr(data: dict) -> PRInfo:
"""Map a raw AzDo PR API response to a PRInfo model."""
pr_id = str(data["pullRequestId"])
repo = data["repository"]["name"]
branch = data.get("sourceRefName", "")
url = data.get("url", "")
if not url:
repo_url = data["repository"].get("remoteUrl", "")
url = f"{repo_url}/pullrequest/{pr_id}"
return PRInfo(
pr_id=pr_id,
pr_url=url,
repo_name=repo,
branch=branch,
pr_title=data.get("title", ""),
pr_status=_map_pr_status(data.get("status", "active")),
)
def _map_pr_status(raw: str) -> str:
"""Normalise AzDo PR status to a value accepted by PRInfo."""
mapping = {"active": "active", "completed": "completed", "abandoned": "abandoned"}
return mapping.get(raw, "active")
def _format_diff(data: dict) -> str:
"""Format the diffs API response as a text string."""
changes = data.get("changes", [])
lines = []
for change in changes:
item = change.get("item", {})
path = item.get("path", "unknown")
change_type = change.get("changeType", "edit")
lines.append(f"{change_type}: {path}")
return "\n".join(lines)

View File

@@ -0,0 +1,335 @@
"""Claude PR reviewer using Claude Code CLI for code review.
Uses `claude -p` (print mode) with `--output-format json` to leverage the
Claude Code subscription instead of API key billing. Claude Code can
autonomously read files, grep code, and explore the codebase for deeper reviews.
"""
import asyncio
import json
import logging
from pathlib import Path
from release_agent.models.review import ReviewIssue, ReviewResult
logger = logging.getLogger(__name__)
_MAX_DIFF_CHARS = 100_000
_TRUNCATION_NOTE = "\n\n[DIFF TRUNCATED: content exceeded 100,000 characters]"
_CLI_TIMEOUT_SECONDS = 300
# JSON schema for structured output from Claude Code CLI
_REVIEW_JSON_SCHEMA = json.dumps({
"type": "object",
"properties": {
"verdict": {
"type": "string",
"enum": ["approve", "request_changes"],
},
"summary": {
"type": "string",
},
"issues": {
"type": "array",
"items": {
"type": "object",
"properties": {
"severity": {
"type": "string",
"enum": ["blocker", "error", "warning", "info"],
},
"description": {"type": "string"},
"file_path": {"type": "string"},
"line_start": {"type": "integer", "description": "Starting line number in the file"},
"line_end": {"type": "integer", "description": "Ending line number in the file"},
"suggestion": {"type": "string"},
},
"required": ["severity", "description"],
},
},
},
"required": ["verdict", "summary", "issues"],
})
_SYSTEM_PROMPT = (
"You are a senior code reviewer for .NET C# projects. "
"Review the PR for: backward compatibility (DB migration, serialization, integration events), "
"null handling, business logic correctness, test coverage, hardcoded values, and security. "
"You have access to the full codebase via Read, Glob, and Grep tools. "
"Use them to understand the context around the changed files. "
"For each issue, include the file_path and line_start/line_end where the issue occurs."
)
# JSON schema for structured ticket content output
_TICKET_JSON_SCHEMA = json.dumps({
"type": "object",
"properties": {
"summary": {
"type": "string",
"description": "Short one-line summary of the work done (max 100 chars)",
},
"description": {
"type": "string",
"description": "Detailed description of the changes and business value",
},
},
"required": ["summary", "description"],
})
_TICKET_SYSTEM_PROMPT = (
"You are a Jira ticket writer for a software development team. "
"Given a PR diff and title, produce a concise Jira Story ticket. "
"The summary should be a clear, action-oriented title (max 100 chars). "
"The description should explain what was changed and why, suitable for a product backlog."
)
class ClaudeReviewer:
"""Reviews pull request diffs using Claude Code CLI.
Uses the user's Claude Code subscription (not API key) by invoking
`claude -p` as a subprocess. This allows Claude to autonomously read
files and explore the codebase for more thorough reviews.
Args:
claude_cmd: Path to the claude CLI binary (default: "claude").
timeout: Maximum seconds to wait for the review (default: 300).
run_subprocess: Async callable for running subprocesses (injected for testing).
"""
def __init__(
self,
*,
claude_cmd: str = "claude",
timeout: int = _CLI_TIMEOUT_SECONDS,
run_subprocess: object | None = None,
) -> None:
self._claude_cmd = claude_cmd
self._timeout = timeout
self._run_subprocess = run_subprocess or _default_run_subprocess
async def review_pr(
self,
*,
diff: str,
pr_title: str,
repo_name: str,
cwd: str | Path | None = None,
) -> ReviewResult:
"""Review a pull request and return a structured result.
Args:
diff: The raw diff text of the pull request.
pr_title: Title of the pull request.
repo_name: Name of the repository.
cwd: Working directory for Claude Code (e.g., git worktree path).
If provided, Claude can read files in this directory.
Returns:
A ReviewResult with verdict, summary, and list of issues.
"""
truncated_diff = _truncate_diff(diff)
prompt = _build_prompt(
diff=truncated_diff,
pr_title=pr_title,
repo_name=repo_name,
)
cmd = [
self._claude_cmd, "-p", prompt,
"--output-format", "json",
"--json-schema", _REVIEW_JSON_SCHEMA,
"--allowedTools", "Read,Glob,Grep",
"--system-prompt", _SYSTEM_PROMPT,
]
stdout, stderr, returncode = await self._run_subprocess(
cmd=cmd,
cwd=str(cwd) if cwd else None,
timeout=self._timeout,
)
if returncode != 0:
logger.error(
"Claude CLI failed (exit %d): %s", returncode, stderr[:500]
)
raise RuntimeError(f"Claude CLI exited with code {returncode}: {stderr[:200]}")
return _parse_cli_output(stdout)
async def generate_ticket_content(
self,
*,
diff: str,
pr_title: str,
repo_name: str,
cwd: str | Path | None = None,
) -> tuple[str, str]:
"""Generate Jira ticket summary and description from a PR diff.
Args:
diff: The raw diff text of the pull request.
pr_title: Title of the pull request.
repo_name: Name of the repository.
cwd: Working directory for Claude Code.
Returns:
A tuple of (summary, description) strings suitable for Jira.
Raises:
RuntimeError: If the Claude CLI fails or times out.
ValueError: If the output cannot be parsed.
"""
truncated_diff = _truncate_diff(diff)
prompt = (
f"Generate a Jira Story ticket for this pull request from repository '{repo_name}'.\n\n"
f"PR Title: {pr_title}\n\n"
f"Diff:\n```\n{truncated_diff}\n```\n\n"
"Produce a concise summary and description for a Jira Story ticket."
)
cmd = [
self._claude_cmd, "-p", prompt,
"--output-format", "json",
"--json-schema", _TICKET_JSON_SCHEMA,
"--allowedTools", "Read,Glob,Grep",
"--system-prompt", _TICKET_SYSTEM_PROMPT,
]
stdout, stderr, returncode = await self._run_subprocess(
cmd=cmd,
cwd=str(cwd) if cwd else None,
timeout=self._timeout,
)
if returncode != 0:
logger.error(
"Claude CLI failed (exit %d): %s", returncode, stderr[:500]
)
raise RuntimeError(f"Claude CLI exited with code {returncode}: {stderr[:200]}")
return _parse_ticket_output(stdout)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
def _truncate_diff(diff: str) -> str:
"""Truncate diff to at most _MAX_DIFF_CHARS characters."""
if len(diff) <= _MAX_DIFF_CHARS:
return diff
return diff[:_MAX_DIFF_CHARS] + _TRUNCATION_NOTE
def _build_prompt(*, diff: str, pr_title: str, repo_name: str) -> str:
"""Build the user prompt for the review request."""
return (
f"Review this pull request from repository '{repo_name}'.\n\n"
f"PR Title: {pr_title}\n\n"
f"Diff:\n```\n{diff}\n```\n\n"
"Read the related source files to understand context. "
"Provide your review as structured JSON output."
)
def _parse_cli_output(stdout: str) -> ReviewResult:
"""Parse the JSON output from Claude Code CLI into a ReviewResult."""
try:
data = json.loads(stdout)
except json.JSONDecodeError as exc:
raise ValueError(f"Failed to parse Claude CLI output as JSON: {exc}") from exc
# Claude Code --output-format json wraps result in {"result": "...", ...}
# The structured_output field contains our schema-conforming data
structured = data.get("structured_output") or data.get("result")
if structured is None:
raise ValueError("No structured_output or result in Claude CLI response")
# If structured is a string (result field), try to parse it as JSON
if isinstance(structured, str):
try:
structured = json.loads(structured)
except json.JSONDecodeError:
raise ValueError(
"Claude CLI result is not valid JSON for review schema"
)
if not isinstance(structured, dict):
raise ValueError(f"Expected dict, got {type(structured).__name__}")
issues = [
ReviewIssue(
severity=issue["severity"],
description=issue["description"],
file_path=issue.get("file_path"),
line_start=issue.get("line_start"),
line_end=issue.get("line_end"),
suggestion=issue.get("suggestion"),
)
for issue in structured.get("issues", [])
]
return ReviewResult(
verdict=structured["verdict"],
summary=structured["summary"],
issues=tuple(issues),
)
def _parse_ticket_output(stdout: str) -> tuple[str, str]:
"""Parse the JSON output from Claude Code CLI into (summary, description)."""
try:
data = json.loads(stdout)
except json.JSONDecodeError as exc:
raise ValueError(f"Failed to parse Claude CLI ticket output as JSON: {exc}") from exc
structured = data.get("structured_output") or data.get("result")
if structured is None:
raise ValueError("No structured_output or result in Claude CLI response")
if isinstance(structured, str):
try:
structured = json.loads(structured)
except json.JSONDecodeError:
raise ValueError("Claude CLI result is not valid JSON for ticket schema")
if not isinstance(structured, dict):
raise ValueError(f"Expected dict, got {type(structured).__name__}")
summary = structured.get("summary", "")
description = structured.get("description", "")
if not summary:
raise ValueError("Claude ticket output missing required 'summary' field")
return (summary, description)
async def _default_run_subprocess(
*,
cmd: list[str],
cwd: str | None,
timeout: int,
) -> tuple[str, str, int]:
"""Run a subprocess and return (stdout, stderr, returncode)."""
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
process.communicate(), timeout=timeout
)
except asyncio.TimeoutError:
process.kill()
await process.wait()
raise RuntimeError(f"Claude CLI timed out after {timeout} seconds")
return (
stdout_bytes.decode("utf-8", errors="replace"),
stderr_bytes.decode("utf-8", errors="replace"),
process.returncode or 0,
)

View File

@@ -0,0 +1,269 @@
"""Jira service client.
Uses Basic auth (email:api_token) and the Jira REST API v3.
The httpx.AsyncClient is injected via constructor for testability.
"""
from types import TracebackType
from typing import Self
import httpx
from release_agent.models.jira import JiraIssue, JiraTransition
from release_agent.tools._http import build_auth_header, raise_for_status
_API_VERSION = "rest/api/3"
_DEV_IN_PROGRESS_TRANSITION = "Dev in Progress"
class JiraClient:
"""Client for the Jira REST API.
Args:
base_url: Jira instance base URL (e.g. "https://billolife.atlassian.net").
email: Account email address for Basic auth.
api_token: Jira API token for Basic auth.
http_client: Injected httpx.AsyncClient.
"""
def __init__(
self,
*,
base_url: str,
email: str,
api_token: str,
http_client: httpx.AsyncClient,
) -> None:
self._base_url = base_url.rstrip("/")
self._auth = build_auth_header(email, api_token)
self._http = http_client
async def close(self) -> None:
"""Close the underlying HTTP client."""
await self._http.aclose()
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
# ------------------------------------------------------------------
# Issue operations
# ------------------------------------------------------------------
async def get_issue(self, ticket_id: str) -> JiraIssue:
"""Fetch a Jira issue by key.
Args:
ticket_id: Jira issue key (e.g. "ALLPOST-100").
Returns:
A JiraIssue model populated from the API response.
Raises:
NotFoundError: If the issue does not exist.
AuthenticationError: If authentication fails.
ServiceError: For other HTTP errors.
"""
url = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}"
response = await self._http.get(url, headers=self._auth)
raise_for_status(response, service="jira")
data = response.json()
return _parse_issue(data)
async def get_transitions(self, ticket_id: str) -> list[JiraTransition]:
"""Fetch available workflow transitions for a Jira issue.
Args:
ticket_id: Jira issue key.
Returns:
List of available JiraTransition models.
Raises:
NotFoundError: If the issue does not exist.
ServiceError: For other HTTP errors.
"""
url = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}/transitions"
response = await self._http.get(url, headers=self._auth)
raise_for_status(response, service="jira")
data = response.json()
return [
JiraTransition(id=t["id"], name=t["name"])
for t in data.get("transitions", [])
]
async def transition_issue(self, ticket_id: str, transition_name: str) -> bool:
"""Move a Jira issue through a named workflow transition.
If the target transition is not currently available, a two-step
fallback is attempted: first transition to "Dev in Progress", then
retry the target transition. Returns False if the transition is still
unavailable after the fallback attempt.
Args:
ticket_id: Jira issue key.
transition_name: Name of the target transition (e.g. "Released").
Returns:
True if the transition succeeded, False if unavailable.
Raises:
ServiceError: For HTTP errors during transition execution.
"""
transitions = await self.get_transitions(ticket_id)
transition_id = _find_transition_id(transitions, transition_name)
if transition_id is None:
# Two-step fallback: move to "Dev in Progress" first, then retry.
dev_id = _find_transition_id(transitions, _DEV_IN_PROGRESS_TRANSITION)
if dev_id is None:
return False
await self._do_transition(ticket_id, dev_id)
# Retry target transition after the fallback step.
transitions = await self.get_transitions(ticket_id)
transition_id = _find_transition_id(transitions, transition_name)
if transition_id is None:
return False
await self._do_transition(ticket_id, transition_id)
return True
async def create_issue(
self,
*,
project: str,
summary: str,
description: str,
issue_type: str = "Story",
) -> str:
"""Create a new Jira issue and return its ticket key.
Args:
project: Jira project key (e.g. "ALLPOST").
summary: Issue summary / title.
description: Plain-text description, converted to ADF format.
issue_type: Issue type name (default: "Story").
Returns:
The ticket key of the created issue (e.g. "ALLPOST-42").
Raises:
AuthenticationError: If authentication fails.
ServiceError: For other HTTP errors.
"""
url = f"{self._base_url}/{_API_VERSION}/issue"
payload = {
"fields": {
"project": {"key": project},
"summary": summary,
"description": _text_to_adf(description),
"issuetype": {"name": issue_type},
}
}
response = await self._http.post(
url,
headers={**self._auth, "Content-Type": "application/json"},
json=payload,
)
raise_for_status(response, service="jira")
data = response.json()
return data["key"]
async def add_remote_link(self, *, ticket_id: str, url: str, title: str) -> bool:
"""Add a remote link (e.g. a PR URL) to a Jira issue.
Args:
ticket_id: Jira issue key.
url: URL of the remote resource.
title: Display title of the link.
Returns:
True on success.
Raises:
NotFoundError: If the issue does not exist.
ServiceError: For other HTTP errors.
"""
endpoint = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}/remotelink"
payload = {
"object": {
"url": url,
"title": title,
}
}
response = await self._http.post(
endpoint,
headers={**self._auth, "Content-Type": "application/json"},
json=payload,
)
raise_for_status(response, service="jira")
return True
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
async def _do_transition(self, ticket_id: str, transition_id: str) -> None:
"""Execute a transition by ID."""
url = f"{self._base_url}/{_API_VERSION}/issue/{ticket_id}/transitions"
payload = {"transition": {"id": transition_id}}
response = await self._http.post(
url,
headers={**self._auth, "Content-Type": "application/json"},
json=payload,
)
raise_for_status(response, service="jira")
# ---------------------------------------------------------------------------
# Private parsing helpers
# ---------------------------------------------------------------------------
def _text_to_adf(text: str) -> dict:
"""Convert plain text to Atlassian Document Format (ADF).
Each non-empty line becomes a paragraph node. Empty string produces
an empty document (no content nodes).
Args:
text: Plain text to convert.
Returns:
ADF dict conforming to the Jira REST API v3 description format.
"""
content = []
for line in text.splitlines():
stripped = line.strip()
if stripped:
content.append({
"type": "paragraph",
"content": [{"type": "text", "text": stripped}],
})
return {"version": 1, "type": "doc", "content": content}
def _parse_issue(data: dict) -> JiraIssue:
"""Map a raw Jira issue API response to a JiraIssue model."""
fields = data.get("fields", {})
status = fields.get("status", {}).get("name", "Unknown")
return JiraIssue(
key=data["key"],
summary=fields.get("summary", ""),
status=status,
)
def _find_transition_id(transitions: list[JiraTransition], name: str) -> str | None:
"""Return the ID of the first transition matching the given name, or None."""
for transition in transitions:
if transition.name == name:
return transition.id
return None

View File

@@ -0,0 +1,500 @@
"""Slack service client — dual-mode: webhook fallback + Web API.
Supports two modes:
1. Webhook mode: uses webhook_url for fire-and-forget notifications.
2. Web API mode: uses bot_token + channel_id for interactive messages,
message updates, and retrieving message timestamps.
Block Kit payloads are built by pure functions that can be tested independently.
The httpx.AsyncClient is injected via constructor for testability.
"""
from datetime import date
from types import TracebackType
from typing import Self
import httpx
from release_agent.models.ticket import TicketEntry
from release_agent.tools._http import raise_for_status
_SLACK_API_BASE = "https://slack.com/api"
class SlackClient:
"""Client for sending messages to Slack.
Supports two modes determined by which parameters are provided:
- Webhook mode: provide webhook_url.
- Web API mode: provide bot_token and channel_id.
Both modes can be configured simultaneously; Web API takes priority
for methods that require it (send_interactive_approval, update_message).
Webhook is used as fallback for simple notifications.
Args:
webhook_url: The Slack incoming webhook URL (optional).
bot_token: Slack bot OAuth token (xoxb-...) for Web API (optional).
channel_id: Slack channel ID for Web API messages (optional).
http_client: Injected httpx.AsyncClient.
"""
def __init__(
self,
*,
webhook_url: str = "",
bot_token: str = "",
channel_id: str = "",
http_client: httpx.AsyncClient,
) -> None:
self._webhook_url = webhook_url
self._bot_token = bot_token
self._channel_id = channel_id
self._http = http_client
async def close(self) -> None:
"""Close the underlying HTTP client."""
await self._http.aclose()
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
# ------------------------------------------------------------------
# Public methods: fire-and-forget notifications
# ------------------------------------------------------------------
async def send_release_notification(
self,
*,
repo: str,
version: str,
release_date: date,
tickets: list[TicketEntry],
) -> bool:
"""Send a release notification to Slack.
Args:
repo: Repository name.
version: Release version string (e.g. "v1.2.0").
release_date: Date of the release.
tickets: List of tickets included in the release.
Returns:
True on success.
Raises:
ServiceError: If the Slack webhook returns an error response.
"""
blocks = _build_release_blocks(
repo=repo,
version=version,
release_date=release_date,
tickets=tickets,
)
return await self._post_blocks(blocks)
async def send_approval_request(
self,
*,
action: str,
details: str,
approval_url: str,
) -> bool:
"""Send an approval request notification to Slack.
Args:
action: Description of the action requiring approval.
details: Additional context for the approval.
approval_url: URL where the approver can approve the action.
Returns:
True on success.
Raises:
ServiceError: If the Slack webhook returns an error response.
"""
blocks = _build_approval_blocks(
action=action,
details=details,
approval_url=approval_url,
)
return await self._post_blocks(blocks)
async def send_notification(self, *, text: str, blocks: list[dict]) -> bool:
"""Send a plain notification to Slack.
Uses Web API if bot_token is configured, otherwise falls back to
the webhook URL.
Args:
text: Fallback text for notifications.
blocks: Block Kit blocks to include in the message.
Returns:
True on success, False on error.
"""
if self._bot_token and self._channel_id:
ts = await self._post_via_web_api(text=text, blocks=blocks)
return ts != ""
if self._webhook_url:
try:
return await self._post_blocks(blocks or [{"type": "section", "text": {"type": "mrkdwn", "text": text}}])
except Exception:
return False
return False
# ------------------------------------------------------------------
# Public methods: interactive messages (Web API only)
# ------------------------------------------------------------------
async def send_interactive_approval(
self,
*,
thread_id: str,
action: str,
details: str,
buttons: list[dict],
) -> str:
"""Send an interactive approval message with clickable buttons.
Requires bot_token and channel_id. Returns the message timestamp
(ts) which can be used to update the message later.
Args:
thread_id: The graph thread_id, embedded in button payloads
so the interactions endpoint can resume the correct thread.
action: Human-readable description of the action (e.g. "Deploy to Sandbox").
details: Additional context shown in the message body.
buttons: List of button dicts with "text" and "value" keys.
Returns:
Message timestamp string (e.g. "1234567890.123456") on success,
or empty string on failure.
"""
blocks = _build_interactive_approval_blocks(
thread_id=thread_id,
action=action,
details=details,
buttons=buttons,
)
return await self._post_via_web_api(text=action, blocks=blocks)
async def update_message(
self,
*,
message_ts: str,
text: str,
blocks: list[dict],
) -> bool:
"""Update an existing Slack message (Web API chat.update).
Args:
message_ts: Timestamp of the message to update.
text: New fallback text.
blocks: New Block Kit blocks.
Returns:
True on success, False on failure.
"""
url = f"{_SLACK_API_BASE}/chat.update"
payload: dict = {
"channel": self._channel_id,
"ts": message_ts,
"text": text,
}
if blocks:
payload["blocks"] = blocks
response = await self._http.post(
url,
json=payload,
headers=self._web_api_headers(),
)
if response.status_code != 200:
return False
data = response.json()
return bool(data.get("ok"))
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
async def _post_blocks(self, blocks: list[dict]) -> bool:
"""POST a Block Kit payload to the webhook URL."""
response = await self._http.post(
self._webhook_url,
json={"blocks": blocks},
headers={"Content-Type": "application/json"},
)
raise_for_status(response, service="slack")
return True
async def _post_via_web_api(self, *, text: str, blocks: list[dict]) -> str:
"""POST a message via Slack Web API chat.postMessage.
Returns the message ts on success, or empty string on failure.
"""
url = f"{_SLACK_API_BASE}/chat.postMessage"
payload: dict = {
"channel": self._channel_id,
"text": text,
}
if blocks:
payload["blocks"] = blocks
response = await self._http.post(
url,
json=payload,
headers=self._web_api_headers(),
)
if response.status_code != 200:
return ""
data = response.json()
if not data.get("ok"):
return ""
return data.get("ts", "")
def _web_api_headers(self) -> dict:
"""Return headers for Slack Web API requests."""
return {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {self._bot_token}",
}
# ---------------------------------------------------------------------------
# Pure Block Kit builder functions
# ---------------------------------------------------------------------------
def _build_release_blocks(
*,
repo: str,
version: str,
release_date: date,
tickets: list[TicketEntry],
) -> list[dict]:
"""Build Slack Block Kit blocks for a release notification.
This is a pure function with no side effects, allowing it to be tested
independently of HTTP calls.
Args:
repo: Repository name.
version: Release version string.
release_date: Date of the release.
tickets: Tickets included in the release.
Returns:
List of Slack Block Kit block dicts.
"""
header_block = {
"type": "header",
"text": {
"type": "plain_text",
"text": f"Release {version} - {repo}",
},
}
date_block = {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"*Release date:* {release_date.isoformat()}",
},
}
blocks: list[dict] = [header_block, date_block]
if tickets:
ticket_lines = [f"- *{t.id}*: {t.summary}" for t in tickets]
ticket_text = "\n".join(ticket_lines)
ticket_block = {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"*Tickets:*\n{ticket_text}",
},
}
blocks.append(ticket_block)
return blocks
def _build_approval_blocks(
*,
action: str,
details: str,
approval_url: str,
) -> list[dict]:
"""Build Slack Block Kit blocks for an approval request.
This is a pure function with no side effects, allowing it to be tested
independently of HTTP calls.
Args:
action: Description of the action requiring approval.
details: Additional context for the approval.
approval_url: URL to the approval page.
Returns:
List of Slack Block Kit block dicts.
"""
header_block = {
"type": "header",
"text": {
"type": "plain_text",
"text": f"Approval Required: {action}",
},
}
details_block = {
"type": "section",
"text": {
"type": "mrkdwn",
"text": details,
},
}
actions_block = {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"<{approval_url}|Review and Approve>",
},
}
return [header_block, details_block, actions_block]
def _build_interactive_approval_blocks(
*,
thread_id: str,
action: str,
details: str,
buttons: list[dict],
) -> list[dict]:
"""Build Slack Block Kit blocks with interactive approval buttons.
The thread_id is embedded in each button's action_id and value so the
/slack/interactions endpoint can resume the correct graph thread when
a button is clicked.
Args:
thread_id: Graph thread ID to embed in button payloads.
action: Human-readable action description for the header.
details: Contextual details shown in the message body.
buttons: List of dicts with "text" (str) and "value" (str) keys.
Returns:
List of Slack Block Kit block dicts.
"""
header_block = {
"type": "header",
"text": {
"type": "plain_text",
"text": f"Approval Required: {action}",
},
}
details_block = {
"type": "section",
"text": {
"type": "mrkdwn",
"text": details,
},
}
blocks: list[dict] = [header_block, details_block]
if buttons:
button_elements = [
{
"type": "button",
"text": {"type": "plain_text", "text": btn["text"]},
"value": f"{thread_id}:{btn['value']}",
"action_id": f"approval_{btn['value']}_{thread_id}",
}
for btn in buttons
]
actions_block = {
"type": "actions",
"elements": button_elements,
}
blocks.append(actions_block)
return blocks
def _build_ci_status_blocks(
*,
repo: str,
branch: str,
status: str,
build_url: str | None,
) -> list[dict]:
"""Build Slack Block Kit blocks for a CI build status notification.
Args:
repo: Repository name.
branch: Branch that was built.
status: Build result status (e.g. "succeeded", "failed").
build_url: Direct URL to the build results page, or None.
Returns:
List of Slack Block Kit block dicts.
"""
header_block = {
"type": "header",
"text": {
"type": "plain_text",
"text": f"CI Build: {repo}",
},
}
detail_text = f"*Branch:* {branch}\n*Status:* {status}"
if build_url:
detail_text += f"\n<{build_url}|View Build>"
details_block = {
"type": "section",
"text": {
"type": "mrkdwn",
"text": detail_text,
},
}
return [header_block, details_block]
def _build_resolved_approval_blocks(
*,
action: str,
outcome: str,
user: str,
) -> list[dict]:
"""Build Slack Block Kit blocks for a resolved approval message.
Replaces the interactive approval message after the decision is made.
Args:
action: The original action that was approved or rejected.
outcome: The approval outcome (e.g. "approved", "rejected").
user: The Slack user who made the decision.
Returns:
List of Slack Block Kit block dicts.
"""
header_block = {
"type": "header",
"text": {
"type": "plain_text",
"text": f"Approval {outcome.capitalize()}: {action}",
},
}
details_block = {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"Decision: *{outcome}* by {user}",
},
}
return [header_block, details_block]

View File

@@ -0,0 +1,70 @@
"""Version calculation utilities for release management.
Pure functions only - no side effects, no mutation.
All version strings must include a 'v' prefix (e.g. "v1.2.3").
"""
import re
_VERSION_PATTERN = re.compile(r"^v(\d+)\.(\d+)\.(\d+)$")
def parse_version(version_str: str) -> tuple[int, int, int]:
"""Parse a version string into a (major, minor, patch) tuple.
Accepts strings with or without a leading 'v' prefix.
Raises:
ValueError: If the string does not match the expected format.
"""
normalized = version_str if version_str.startswith("v") else f"v{version_str}"
match = _VERSION_PATTERN.match(normalized)
if not match:
raise ValueError(f"Invalid version string: {version_str!r}")
return int(match.group(1)), int(match.group(2)), int(match.group(3))
def format_version(major: int, minor: int, patch: int) -> str:
"""Format a (major, minor, patch) tuple into a version string with 'v' prefix.
Example: format_version(1, 0, 3) -> "v1.0.3"
"""
return f"v{major}.{minor}.{patch}"
def calculate_next_version(repo_name: str, existing_versions: list[str]) -> str:
"""Calculate the next patch version given a list of existing version strings.
Filters out any malformed version strings (those that do not match
vMAJOR.MINOR.PATCH). Finds the highest valid version and increments
the patch component by one. If no valid versions exist, returns "v1.0.0".
The repo_name parameter is accepted for extensibility but does not
affect the calculation.
Args:
repo_name: The repository name (unused in calculation).
existing_versions: A list of version strings, may include malformed ones.
Returns:
The next version string in "vX.Y.Z" format.
"""
valid_versions: list[tuple[int, int, int]] = []
for version_str in existing_versions:
# Only consider strings that already have the 'v' prefix to match spec:
# "Versions without 'v' prefix are treated as malformed."
if not version_str.startswith("v"):
continue
try:
parsed = parse_version(version_str)
valid_versions.append(parsed)
except ValueError:
continue
if not valid_versions:
return "v1.0.0"
highest = max(valid_versions)
major, minor, patch = highest
return format_version(major, minor, patch + 1)

0
tests/__init__.py Normal file
View File

0
tests/api/__init__.py Normal file
View File

259
tests/api/test_approvals.py Normal file
View File

@@ -0,0 +1,259 @@
"""Tests for approvals endpoint. Written FIRST (TDD RED phase)."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.approvals import router as approvals_router
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_app(
*,
interrupted_threads: list[dict] | None = None,
graph_resume_result: dict | None = None,
) -> FastAPI:
"""Return a FastAPI app with mocked state for approvals tests."""
app = FastAPI()
app.include_router(approvals_router)
if interrupted_threads is None:
interrupted_threads = []
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = ""
mock_graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
mock_clients = MagicMock()
# Mock pool that returns interrupted threads from DB
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
rows = [
(
t["thread_id"],
t.get("graph_name", "pr_completed"),
t.get("interrupt_value", "Confirm?"),
t.get("created_at", datetime.now(tz=timezone.utc)),
t.get("repo_name"),
t.get("pr_id"),
t.get("version"),
)
for t in interrupted_threads
]
mock_cursor.fetchall = AsyncMock(return_value=rows)
mock_cursor.fetchone = AsyncMock(return_value=("pr_completed",))
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
app.state.settings = mock_settings
app.state.graphs = mock_graphs
app.state.tool_clients = mock_clients
app.state.db_pool = mock_pool
app.state.background_tasks = set()
return app
# ---------------------------------------------------------------------------
# POST /approvals/{thread_id}
# ---------------------------------------------------------------------------
class TestPostApproval:
def test_valid_merge_decision_returns_200(self) -> None:
app = _make_test_app()
mock_graph = MagicMock()
mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]})
app.state.graphs["pr_completed"] = mock_graph
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {"messages": ["resumed"]}
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
)
assert response.status_code == 200
data = response.json()
assert data["thread_id"] == "thread-123"
assert "status" in data
assert "message" in data
def test_valid_cancel_decision_returns_200(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {"messages": ["cancelled"]}
with TestClient(app) as client:
response = client.post(
"/approvals/thread-456",
json={"decision": "cancel"},
)
assert response.status_code == 200
def test_invalid_decision_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "invalid_decision"},
)
assert response.status_code == 422
def test_missing_decision_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={},
)
assert response.status_code == 422
def test_response_contains_thread_id(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/my-thread-id",
json={"decision": "approve"},
)
assert response.json()["thread_id"] == "my-thread-id"
def test_approve_decision_returns_200(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/t1",
json={"decision": "approve"},
)
assert response.status_code == 200
def test_skip_decision_returns_200(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/t1",
json={"decision": "skip"},
)
assert response.status_code == 200
def test_trigger_decision_returns_200(self) -> None:
app = _make_test_app()
with patch("release_agent.api.approvals._resume_graph", new_callable=AsyncMock) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/t1",
json={"decision": "trigger"},
)
assert response.status_code == 200
# ---------------------------------------------------------------------------
# GET /approvals/pending
# ---------------------------------------------------------------------------
class TestGetPendingApprovals:
def test_empty_pending_returns_200(self) -> None:
app = _make_test_app(interrupted_threads=[])
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 200
data = response.json()
assert data["count"] == 0
assert data["items"] == []
def test_pending_approvals_list_structure(self) -> None:
now = datetime.now(tz=timezone.utc)
threads = [
{
"thread_id": "t1",
"graph_name": "pr_completed",
"interrupt_value": "Confirm merge?",
"created_at": now,
"repo_name": "my-repo",
"pr_id": "42",
"version": "v1.0.0",
}
]
app = _make_test_app(interrupted_threads=threads)
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 200
data = response.json()
assert data["count"] == 1
assert data["items"][0]["thread_id"] == "t1"
assert data["items"][0]["graph_name"] == "pr_completed"
def test_multiple_pending_approvals(self) -> None:
now = datetime.now(tz=timezone.utc)
threads = [
{
"thread_id": f"t{i}",
"graph_name": "pr_completed",
"interrupt_value": "Confirm?",
"created_at": now,
"repo_name": None,
"pr_id": None,
"version": None,
}
for i in range(3)
]
app = _make_test_app(interrupted_threads=threads)
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 200
data = response.json()
assert data["count"] == 3
assert len(data["items"]) == 3
def test_pending_approval_optional_fields_nullable(self) -> None:
now = datetime.now(tz=timezone.utc)
threads = [
{
"thread_id": "t1",
"graph_name": "release",
"interrupt_value": "Run release?",
"created_at": now,
"repo_name": None,
"pr_id": None,
"version": None,
}
]
app = _make_test_app(interrupted_threads=threads)
with TestClient(app) as client:
response = client.get("/approvals/pending")
item = response.json()["items"][0]
assert item["repo_name"] is None
assert item["pr_id"] is None
assert item["version"] is None
# ---------------------------------------------------------------------------
# _resume_graph helper function tests
# ---------------------------------------------------------------------------
class TestResumeGraph:
def test_resume_graph_callable(self) -> None:
from release_agent.api.approvals import _resume_graph
import inspect
assert inspect.iscoroutinefunction(_resume_graph)

View File

@@ -0,0 +1,139 @@
"""Tests for approvals endpoints with operator token authentication.
Phase 5 - Step 3: Verifies that POST /approvals/{thread_id} and
GET /approvals/pending require operator token when configured.
Written FIRST (TDD RED phase).
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.approvals import router as approvals_router
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_app(operator_token: str = "") -> FastAPI:
"""Return a FastAPI app with approvals router and configurable operator token."""
app = FastAPI()
app.include_router(approvals_router)
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = operator_token
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchall = AsyncMock(return_value=[])
mock_cursor.fetchone = AsyncMock(return_value=("pr_completed",))
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
app.state.settings = mock_settings
app.state.graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
app.state.tool_clients = MagicMock()
app.state.db_pool = mock_pool
app.state.background_tasks = set()
return app
# ---------------------------------------------------------------------------
# POST /approvals/{thread_id} with auth
# ---------------------------------------------------------------------------
class TestPostApprovalWithAuth:
def test_valid_token_allows_post_approval(self) -> None:
app = _make_test_app(operator_token="secret-token")
with patch(
"release_agent.api.approvals._resume_graph", new_callable=AsyncMock
) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
headers={"X-Operator-Token": "secret-token"},
)
assert response.status_code == 200
def test_missing_token_rejects_post_approval(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
)
assert response.status_code == 401
def test_wrong_token_rejects_post_approval(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
headers={"X-Operator-Token": "wrong-token"},
)
assert response.status_code == 401
def test_no_auth_required_when_token_not_configured(self) -> None:
app = _make_test_app(operator_token="")
with patch(
"release_agent.api.approvals._resume_graph", new_callable=AsyncMock
) as mock_resume:
mock_resume.return_value = {}
with TestClient(app) as client:
response = client.post(
"/approvals/thread-123",
json={"decision": "merge"},
)
assert response.status_code == 200
# ---------------------------------------------------------------------------
# GET /approvals/pending with auth
# ---------------------------------------------------------------------------
class TestGetPendingApprovalsWithAuth:
def test_valid_token_allows_get_pending(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.get(
"/approvals/pending",
headers={"X-Operator-Token": "secret-token"},
)
assert response.status_code == 200
def test_missing_token_rejects_get_pending(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 401
def test_wrong_token_rejects_get_pending(self) -> None:
app = _make_test_app(operator_token="secret-token")
with TestClient(app) as client:
response = client.get(
"/approvals/pending",
headers={"X-Operator-Token": "wrong"},
)
assert response.status_code == 401
def test_no_auth_required_when_token_not_configured(self) -> None:
app = _make_test_app(operator_token="")
with TestClient(app) as client:
response = client.get("/approvals/pending")
assert response.status_code == 200

View File

@@ -0,0 +1,149 @@
"""Tests for API FastAPI dependencies. Written FIRST (TDD RED phase)."""
from unittest.mock import MagicMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.dependencies import (
get_db_pool,
get_graphs,
get_settings,
get_staging_store,
get_tool_clients,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_app_with_state(**state_kwargs) -> FastAPI:
"""Return a minimal FastAPI app with app.state attributes set."""
app = FastAPI()
for key, value in state_kwargs.items():
setattr(app.state, key, value)
return app
# ---------------------------------------------------------------------------
# get_settings
# ---------------------------------------------------------------------------
class TestGetSettings:
def test_returns_settings_from_state(self) -> None:
mock_settings = MagicMock()
app = _make_app_with_state(settings=mock_settings)
with TestClient(app) as client:
# We test the dependency directly by simulating a request
request = MagicMock()
request.app = app
result = get_settings(request)
assert result is mock_settings
def test_raises_when_settings_missing(self) -> None:
app = FastAPI() # no state.settings
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_settings(request)
# ---------------------------------------------------------------------------
# get_graphs
# ---------------------------------------------------------------------------
class TestGetGraphs:
def test_returns_graphs_from_state(self) -> None:
mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()}
app = _make_app_with_state(graphs=mock_graphs)
request = MagicMock()
request.app = app
result = get_graphs(request)
assert result is mock_graphs
def test_raises_when_graphs_missing(self) -> None:
app = FastAPI()
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_graphs(request)
# ---------------------------------------------------------------------------
# get_tool_clients
# ---------------------------------------------------------------------------
class TestGetToolClients:
def test_returns_tool_clients_from_state(self) -> None:
mock_clients = MagicMock()
app = _make_app_with_state(tool_clients=mock_clients)
request = MagicMock()
request.app = app
result = get_tool_clients(request)
assert result is mock_clients
def test_raises_when_tool_clients_missing(self) -> None:
app = FastAPI()
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_tool_clients(request)
# ---------------------------------------------------------------------------
# get_staging_store
# ---------------------------------------------------------------------------
class TestGetStagingStore:
def test_returns_staging_store_from_state(self) -> None:
mock_store = MagicMock()
app = _make_app_with_state(staging_store=mock_store)
request = MagicMock()
request.app = app
result = get_staging_store(request)
assert result is mock_store
def test_raises_when_staging_store_missing(self) -> None:
app = FastAPI()
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_staging_store(request)
# ---------------------------------------------------------------------------
# get_db_pool
# ---------------------------------------------------------------------------
class TestGetDbPool:
def test_returns_db_pool_from_state(self) -> None:
mock_pool = MagicMock()
app = _make_app_with_state(db_pool=mock_pool)
request = MagicMock()
request.app = app
result = get_db_pool(request)
assert result is mock_pool
def test_raises_when_db_pool_missing(self) -> None:
app = FastAPI()
request = MagicMock()
request.app = app
with pytest.raises(AttributeError):
get_db_pool(request)

446
tests/api/test_internals.py Normal file
View File

@@ -0,0 +1,446 @@
"""Tests for internal async helper functions.
Tests _run_graph, _upsert_thread, _resume_graph, and exception handlers.
Written FIRST then verified (TDD GREEN phase for internal helpers).
"""
import json
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
# ---------------------------------------------------------------------------
# _upsert_thread tests
# ---------------------------------------------------------------------------
class TestUpsertThread:
@pytest.mark.asyncio
async def test_upsert_thread_executes_sql(self) -> None:
from release_agent.api.webhooks import _upsert_thread
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _upsert_thread(
mock_pool,
thread_id="t1",
thread_status="running",
state={"repo_name": "my-repo"},
)
mock_cursor.execute.assert_called_once()
args = mock_cursor.execute.call_args[0]
assert "agent_threads" in args[0]
assert args[1][0] == "t1"
assert args[1][4] == "running"
# state is JSON-encoded
state_json = json.loads(args[1][5])
assert state_json["repo_name"] == "my-repo"
@pytest.mark.asyncio
async def test_upsert_thread_completed_status(self) -> None:
from release_agent.api.webhooks import _upsert_thread
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _upsert_thread(
mock_pool,
thread_id="t2",
thread_status="completed",
state={},
)
args = mock_cursor.execute.call_args[0]
assert args[1][4] == "completed"
@pytest.mark.asyncio
async def test_upsert_thread_failed_status(self) -> None:
from release_agent.api.webhooks import _upsert_thread
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _upsert_thread(
mock_pool,
thread_id="t3",
thread_status="failed",
state={"errors": ["something went wrong"]},
)
args = mock_cursor.execute.call_args[0]
assert args[1][4] == "failed"
state_json = json.loads(args[1][5])
assert state_json["errors"] == ["something went wrong"]
# ---------------------------------------------------------------------------
# _run_graph tests
# ---------------------------------------------------------------------------
class TestRunGraph:
@pytest.mark.asyncio
async def test_run_graph_success_upserts_completed(self) -> None:
from release_agent.api.webhooks import _run_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]})
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"repo_name": "test"},
thread_id="t1",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
# Should have been called with "running" then "completed"
calls = mock_cursor.execute.call_args_list
assert len(calls) == 2
# First call: "running", second call: "completed"
assert calls[0][0][1][4] == "running"
assert calls[1][0][1][4] == "completed"
@pytest.mark.asyncio
async def test_run_graph_failure_upserts_failed(self) -> None:
from release_agent.api.webhooks import _run_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("graph crashed"))
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"repo_name": "test"},
thread_id="t-fail",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
calls = mock_cursor.execute.call_args_list
# First call: "running", second call: "failed"
assert calls[0][0][1][4] == "running"
assert calls[1][0][1][4] == "failed"
# State should contain errors
failed_state = json.loads(calls[1][0][1][5])
assert "errors" in failed_state
@pytest.mark.asyncio
async def test_run_graph_invokes_with_correct_config(self) -> None:
from release_agent.api.webhooks import _run_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={})
mock_clients = MagicMock()
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"repo_name": "test"},
thread_id="t-config",
tool_clients=mock_clients,
db_pool=mock_pool,
)
call_args = mock_graph.ainvoke.call_args
config = call_args[1]["config"]
assert config["configurable"]["thread_id"] == "t-config"
assert config["configurable"]["clients"] is mock_clients
# ---------------------------------------------------------------------------
# _resume_graph tests
# ---------------------------------------------------------------------------
class TestResumeGraphInternal:
@pytest.mark.asyncio
async def test_resume_graph_success(self) -> None:
from release_agent.api.approvals import _resume_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={"result": "ok"})
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
result = await _resume_graph(
graph=mock_graph,
thread_id="t1",
decision="merge",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
assert result == {"result": "ok"}
mock_graph.ainvoke.assert_called_once()
# Verify the decision was passed
call_args = mock_graph.ainvoke.call_args
assert call_args[0][0]["decision"] == "merge"
@pytest.mark.asyncio
async def test_resume_graph_failure_re_raises(self) -> None:
from release_agent.api.approvals import _resume_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("resume failed"))
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
with pytest.raises(RuntimeError, match="resume failed"):
await _resume_graph(
graph=mock_graph,
thread_id="t1",
decision="cancel",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
@pytest.mark.asyncio
async def test_resume_graph_upserts_completed_on_success(self) -> None:
from release_agent.api.approvals import _resume_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={"messages": ["done"]})
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _resume_graph(
graph=mock_graph,
thread_id="t-success",
decision="approve",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
# The last execute call should be "completed"
last_call = mock_cursor.execute.call_args_list[-1]
assert last_call[0][1][4] == "completed"
@pytest.mark.asyncio
async def test_resume_graph_upserts_failed_on_exception(self) -> None:
from release_agent.api.approvals import _resume_graph
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=ValueError("bad"))
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
with pytest.raises(ValueError):
await _resume_graph(
graph=mock_graph,
thread_id="t-fail",
decision="skip",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
last_call = mock_cursor.execute.call_args_list[-1]
assert last_call[0][1][4] == "failed"
# ---------------------------------------------------------------------------
# run_graph_in_background tests (main.py)
# ---------------------------------------------------------------------------
class TestRunGraphInBackground:
@pytest.mark.asyncio
async def test_success_with_db_pool(self) -> None:
from release_agent.main import run_graph_in_background
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={"done": True})
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await run_graph_in_background(
graph=mock_graph,
initial_state={"repo_name": "test"},
thread_id="t-bg",
db_pool=mock_pool,
)
calls = mock_cursor.execute.call_args_list
assert calls[0][0][1][4] == "running"
assert calls[1][0][1][4] == "completed"
@pytest.mark.asyncio
async def test_failure_with_db_pool(self) -> None:
from release_agent.main import run_graph_in_background
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("bg failed"))
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await run_graph_in_background(
graph=mock_graph,
initial_state={},
thread_id="t-bg-fail",
db_pool=mock_pool,
)
last_call = mock_cursor.execute.call_args_list[-1]
assert last_call[0][1][4] == "failed"
@pytest.mark.asyncio
async def test_success_without_db_pool(self) -> None:
"""run_graph_in_background works even without a db_pool."""
from release_agent.main import run_graph_in_background
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value={})
# Should not raise even with no db_pool
await run_graph_in_background(
graph=mock_graph,
initial_state={},
thread_id="t-no-pool",
db_pool=None,
)
mock_graph.ainvoke.assert_called_once()
# ---------------------------------------------------------------------------
# Exception handler tests (main.py)
# ---------------------------------------------------------------------------
class TestExceptionHandlerFunctions:
@pytest.mark.asyncio
async def test_release_agent_error_handler_returns_500(self) -> None:
from release_agent.main import _release_agent_error_handler
from release_agent.exceptions import ServiceError
request = MagicMock()
exc = ServiceError(service="azdo", status_code=503, detail="unavailable")
response = await _release_agent_error_handler(request, exc)
assert response.status_code == 500
body = json.loads(response.body)
assert body["error"] == "ServiceError"
assert "unavailable" in body["detail"]
@pytest.mark.asyncio
async def test_generic_error_handler_returns_500(self) -> None:
from release_agent.main import _generic_error_handler
request = MagicMock()
exc = ValueError("something generic")
response = await _generic_error_handler(request, exc)
assert response.status_code == 500
body = json.loads(response.body)
assert body["error"] == "InternalServerError"
assert "An unexpected error occurred" in body["detail"]

294
tests/api/test_models.py Normal file
View File

@@ -0,0 +1,294 @@
"""Tests for API request/response models. Written FIRST (TDD RED phase)."""
from datetime import datetime, timezone
import pytest
from pydantic import ValidationError
from release_agent.api.models import (
ApprovalDecision,
ApprovalResponse,
ErrorResponse,
HealthResponse,
ManualReleaseRequest,
ManualTriggerResponse,
PendingApproval,
PendingApprovalsResponse,
ReleaseVersionListResponse,
StagingResponse,
WebhookResponse,
)
# ---------------------------------------------------------------------------
# WebhookResponse
# ---------------------------------------------------------------------------
class TestWebhookResponse:
def test_valid_construction(self) -> None:
resp = WebhookResponse(thread_id="thread-123", message="scheduled")
assert resp.thread_id == "thread-123"
assert resp.message == "scheduled"
def test_frozen_immutable(self) -> None:
resp = WebhookResponse(thread_id="t1", message="ok")
with pytest.raises((TypeError, ValidationError)):
resp.thread_id = "other" # type: ignore[misc]
def test_missing_thread_id_raises(self) -> None:
with pytest.raises(ValidationError):
WebhookResponse(message="ok") # type: ignore[call-arg]
def test_missing_message_raises(self) -> None:
with pytest.raises(ValidationError):
WebhookResponse(thread_id="t1") # type: ignore[call-arg]
# ---------------------------------------------------------------------------
# ApprovalDecision
# ---------------------------------------------------------------------------
class TestApprovalDecision:
def test_merge_decision(self) -> None:
d = ApprovalDecision(decision="merge")
assert d.decision == "merge"
def test_cancel_decision(self) -> None:
d = ApprovalDecision(decision="cancel")
assert d.decision == "cancel"
def test_approve_decision(self) -> None:
d = ApprovalDecision(decision="approve")
assert d.decision == "approve"
def test_skip_decision(self) -> None:
d = ApprovalDecision(decision="skip")
assert d.decision == "skip"
def test_trigger_decision(self) -> None:
d = ApprovalDecision(decision="trigger")
assert d.decision == "trigger"
def test_invalid_decision_raises(self) -> None:
with pytest.raises(ValidationError):
ApprovalDecision(decision="invalid") # type: ignore[arg-type]
def test_frozen_immutable(self) -> None:
d = ApprovalDecision(decision="merge")
with pytest.raises((TypeError, ValidationError)):
d.decision = "cancel" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ApprovalResponse
# ---------------------------------------------------------------------------
class TestApprovalResponse:
def test_valid_construction(self) -> None:
resp = ApprovalResponse(
thread_id="t1", status="resumed", message="Graph resumed"
)
assert resp.thread_id == "t1"
assert resp.status == "resumed"
assert resp.message == "Graph resumed"
def test_frozen_immutable(self) -> None:
resp = ApprovalResponse(thread_id="t1", status="ok", message="m")
with pytest.raises((TypeError, ValidationError)):
resp.status = "bad" # type: ignore[misc]
# ---------------------------------------------------------------------------
# PendingApproval
# ---------------------------------------------------------------------------
class TestPendingApproval:
def test_full_construction(self) -> None:
now = datetime.now(tz=timezone.utc)
pa = PendingApproval(
thread_id="t1",
graph_name="pr_completed",
interrupt_value="Confirm merge?",
created_at=now,
repo_name="my-repo",
pr_id="42",
version="v1.2.3",
)
assert pa.thread_id == "t1"
assert pa.graph_name == "pr_completed"
assert pa.repo_name == "my-repo"
assert pa.pr_id == "42"
assert pa.version == "v1.2.3"
def test_optional_fields_none(self) -> None:
now = datetime.now(tz=timezone.utc)
pa = PendingApproval(
thread_id="t1",
graph_name="release",
interrupt_value="Confirm?",
created_at=now,
)
assert pa.repo_name is None
assert pa.pr_id is None
assert pa.version is None
def test_frozen_immutable(self) -> None:
now = datetime.now(tz=timezone.utc)
pa = PendingApproval(
thread_id="t1",
graph_name="g",
interrupt_value="v",
created_at=now,
)
with pytest.raises((TypeError, ValidationError)):
pa.thread_id = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# PendingApprovalsResponse
# ---------------------------------------------------------------------------
class TestPendingApprovalsResponse:
def test_empty_list(self) -> None:
resp = PendingApprovalsResponse(items=[], count=0)
assert resp.items == []
assert resp.count == 0
def test_with_items(self) -> None:
now = datetime.now(tz=timezone.utc)
item = PendingApproval(
thread_id="t1",
graph_name="g",
interrupt_value="v",
created_at=now,
)
resp = PendingApprovalsResponse(items=[item], count=1)
assert resp.count == 1
assert len(resp.items) == 1
def test_frozen_immutable(self) -> None:
resp = PendingApprovalsResponse(items=[], count=0)
with pytest.raises((TypeError, ValidationError)):
resp.count = 5 # type: ignore[misc]
# ---------------------------------------------------------------------------
# HealthResponse
# ---------------------------------------------------------------------------
class TestHealthResponse:
def test_ok_status(self) -> None:
resp = HealthResponse(status="ok", version="0.1.0", uptime_seconds=123.4)
assert resp.status == "ok"
assert resp.version == "0.1.0"
assert resp.uptime_seconds == pytest.approx(123.4)
def test_degraded_status(self) -> None:
resp = HealthResponse(status="degraded", version="0.1.0", uptime_seconds=0.0)
assert resp.status == "degraded"
def test_invalid_status_raises(self) -> None:
with pytest.raises(ValidationError):
HealthResponse(status="unknown", version="0.1.0", uptime_seconds=0.0) # type: ignore[arg-type]
def test_frozen_immutable(self) -> None:
resp = HealthResponse(status="ok", version="0.1.0", uptime_seconds=1.0)
with pytest.raises((TypeError, ValidationError)):
resp.status = "degraded" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ReleaseVersionListResponse
# ---------------------------------------------------------------------------
class TestReleaseVersionListResponse:
def test_valid_construction(self) -> None:
resp = ReleaseVersionListResponse(repo="my-repo", versions=["v1.0.0", "v1.1.0"])
assert resp.repo == "my-repo"
assert resp.versions == ["v1.0.0", "v1.1.0"]
def test_empty_versions(self) -> None:
resp = ReleaseVersionListResponse(repo="r", versions=[])
assert resp.versions == []
def test_frozen_immutable(self) -> None:
resp = ReleaseVersionListResponse(repo="r", versions=[])
with pytest.raises((TypeError, ValidationError)):
resp.repo = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# StagingResponse
# ---------------------------------------------------------------------------
class TestStagingResponse:
def test_with_staging(self) -> None:
staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []}
resp = StagingResponse(repo="my-repo", staging=staging_data)
assert resp.repo == "my-repo"
assert resp.staging is not None
assert resp.staging["version"] == "v1.0.0"
def test_without_staging(self) -> None:
resp = StagingResponse(repo="my-repo", staging=None)
assert resp.staging is None
def test_frozen_immutable(self) -> None:
resp = StagingResponse(repo="r", staging=None)
with pytest.raises((TypeError, ValidationError)):
resp.repo = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ManualTriggerResponse
# ---------------------------------------------------------------------------
class TestManualTriggerResponse:
def test_valid_construction(self) -> None:
resp = ManualTriggerResponse(thread_id="t1", message="triggered")
assert resp.thread_id == "t1"
assert resp.message == "triggered"
def test_frozen_immutable(self) -> None:
resp = ManualTriggerResponse(thread_id="t1", message="m")
with pytest.raises((TypeError, ValidationError)):
resp.thread_id = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ManualReleaseRequest
# ---------------------------------------------------------------------------
class TestManualReleaseRequest:
def test_valid_construction(self) -> None:
req = ManualReleaseRequest(repo="my-repo")
assert req.repo == "my-repo"
def test_missing_repo_raises(self) -> None:
with pytest.raises(ValidationError):
ManualReleaseRequest() # type: ignore[call-arg]
def test_frozen_immutable(self) -> None:
req = ManualReleaseRequest(repo="r")
with pytest.raises((TypeError, ValidationError)):
req.repo = "other" # type: ignore[misc]
# ---------------------------------------------------------------------------
# ErrorResponse
# ---------------------------------------------------------------------------
class TestErrorResponse:
def test_error_only(self) -> None:
resp = ErrorResponse(error="Something went wrong")
assert resp.error == "Something went wrong"
assert resp.detail is None
def test_error_with_detail(self) -> None:
resp = ErrorResponse(error="Not found", detail="Thread t1 not found")
assert resp.detail == "Thread t1 not found"
def test_frozen_immutable(self) -> None:
resp = ErrorResponse(error="e")
with pytest.raises((TypeError, ValidationError)):
resp.error = "other" # type: ignore[misc]

View File

@@ -0,0 +1,111 @@
"""Tests for operator token authentication dependency.
Phase 5 - Step 3: require_operator_token FastAPI dependency.
Written FIRST (TDD RED phase).
"""
from unittest.mock import MagicMock
import pytest
from fastapi import FastAPI, Depends, HTTPException
from fastapi.testclient import TestClient
from release_agent.api.dependencies import require_operator_token
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_app_with_token(operator_token: str = "") -> FastAPI:
"""Return a minimal app with a protected route and the given token config."""
app = FastAPI()
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = operator_token
app.state.settings = mock_settings
@app.get("/protected")
async def protected_route(_: None = Depends(require_operator_token)):
return {"ok": True}
return app
# ---------------------------------------------------------------------------
# require_operator_token tests
# ---------------------------------------------------------------------------
class TestRequireOperatorToken:
def test_valid_token_allows_access(self) -> None:
app = _make_app_with_token("super-secret-token")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": "super-secret-token"},
)
assert response.status_code == 200
def test_missing_token_header_returns_401_when_token_configured(self) -> None:
app = _make_app_with_token("super-secret-token")
with TestClient(app) as client:
response = client.get("/protected")
assert response.status_code == 401
def test_wrong_token_returns_401(self) -> None:
app = _make_app_with_token("super-secret-token")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": "wrong-token"},
)
assert response.status_code == 401
def test_empty_operator_token_config_skips_auth(self) -> None:
"""When operator_token is not configured (empty), all requests pass."""
app = _make_app_with_token("")
with TestClient(app) as client:
response = client.get("/protected")
assert response.status_code == 200
def test_empty_operator_token_config_passes_even_without_header(self) -> None:
app = _make_app_with_token("")
with TestClient(app) as client:
response = client.get("/protected", headers={})
assert response.status_code == 200
def test_token_comparison_is_constant_time(self) -> None:
"""Verify hmac.compare_digest is used (not == operator) — tested by checking
that the function still works correctly, not timing (which we can't test here)."""
app = _make_app_with_token("my-token")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": "my-token"},
)
assert response.status_code == 200
def test_empty_string_token_header_rejected_when_token_configured(self) -> None:
app = _make_app_with_token("configured-token")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": ""},
)
assert response.status_code == 401
def test_401_response_has_detail_field(self) -> None:
app = _make_app_with_token("secret")
with TestClient(app) as client:
response = client.get("/protected")
data = response.json()
assert "detail" in data
def test_valid_token_returns_correct_response_body(self) -> None:
app = _make_app_with_token("token123")
with TestClient(app) as client:
response = client.get(
"/protected",
headers={"X-Operator-Token": "token123"},
)
assert response.json() == {"ok": True}

View File

@@ -0,0 +1,473 @@
"""Tests for api/slack_interactions.py endpoint.
Written FIRST (TDD RED phase).
Tests cover:
- Signature verification (HMAC-SHA256)
- Payload parsing
- Button routing
- _resume_graph invocation
- Error handling
"""
import hashlib
import hmac
import json
import time
import urllib.parse
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.slack_interactions import router as slack_interactions_router
from release_agent.api.slack_interactions import _verify_slack_signature
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_TEST_SIGNING_SECRET = "test-signing-secret-abc"
def _make_slack_signature(*, signing_secret: str, timestamp: str, body: str) -> str:
"""Compute a valid Slack signing signature."""
base_string = f"v0:{timestamp}:{body}"
sig = hmac.new(
signing_secret.encode(),
base_string.encode(),
hashlib.sha256,
).hexdigest()
return f"v0={sig}"
def _make_test_app(
*,
signing_secret: str = _TEST_SIGNING_SECRET,
thread_graph_name: str | None = "release",
graph_result: dict | None = None,
) -> FastAPI:
"""Return a FastAPI test app with mocked state for slack interactions."""
app = FastAPI()
app.include_router(slack_interactions_router)
mock_settings = MagicMock()
mock_settings.slack_signing_secret.get_secret_value.return_value = signing_secret
mock_settings.operator_token.get_secret_value.return_value = ""
mock_graph = MagicMock()
mock_graph.ainvoke = AsyncMock(return_value=graph_result or {"messages": ["done"]})
mock_graphs = {
"release": mock_graph,
"pr_completed": MagicMock(),
}
mock_clients = MagicMock()
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(
return_value=(thread_graph_name,) if thread_graph_name else None
)
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
app.state.settings = mock_settings
app.state.graphs = mock_graphs
app.state.tool_clients = mock_clients
app.state.db_pool = mock_pool
app.state.background_tasks = set()
return app
def _make_button_payload(
*,
thread_id: str = "test-thread-123",
value: str = "approve",
user_id: str = "U12345",
user_name: str = "alice",
) -> str:
"""Build a URL-encoded Slack button action payload."""
payload = {
"type": "block_actions",
"user": {"id": user_id, "name": user_name},
"actions": [
{
"type": "button",
"value": f"{thread_id}:{value}",
"action_id": f"approval_{value}_{thread_id}",
}
],
}
return urllib.parse.urlencode({"payload": json.dumps(payload)})
# ---------------------------------------------------------------------------
# _verify_slack_signature pure function tests
# ---------------------------------------------------------------------------
class TestVerifySlackSignature:
"""Tests for the _verify_slack_signature pure function."""
def test_returns_true_for_valid_signature(self) -> None:
timestamp = str(int(time.time()))
body = "test=body&data=here"
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
signature=sig,
) is True
def test_returns_false_for_wrong_secret(self) -> None:
timestamp = str(int(time.time()))
body = "test=body"
sig = _make_slack_signature(
signing_secret="wrong-secret",
timestamp=timestamp,
body=body,
)
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
signature=sig,
) is False
def test_returns_false_for_tampered_body(self) -> None:
timestamp = str(int(time.time()))
original_body = "original=body"
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=original_body,
)
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body="tampered=body",
signature=sig,
) is False
def test_returns_false_for_wrong_timestamp(self) -> None:
body = "test=body"
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp="1000000",
body=body,
)
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp="9999999",
body=body,
signature=sig,
) is False
def test_returns_false_for_malformed_signature(self) -> None:
timestamp = str(int(time.time()))
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body="body",
signature="not-a-valid-sig",
) is False
def test_returns_false_for_empty_signature(self) -> None:
timestamp = str(int(time.time()))
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body="body",
signature="",
) is False
def test_uses_hmac_sha256(self) -> None:
timestamp = "1234567890"
body = "payload=data"
base = f"v0:{timestamp}:{body}"
expected_hash = hmac.new(
_TEST_SIGNING_SECRET.encode(),
base.encode(),
hashlib.sha256,
).hexdigest()
sig = f"v0={expected_hash}"
# Inject current_time matching timestamp to bypass replay prevention
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
signature=sig,
current_time=1234567890.0,
) is True
def test_rejects_stale_timestamp(self) -> None:
old_timestamp = "1000000000" # year 2001
body = "payload=data"
base = f"v0:{old_timestamp}:{body}"
expected_hash = hmac.new(
_TEST_SIGNING_SECRET.encode(),
base.encode(),
hashlib.sha256,
).hexdigest()
sig = f"v0={expected_hash}"
# Valid signature but timestamp too old
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=old_timestamp,
body=body,
signature=sig,
) is False
def test_rejects_non_integer_timestamp(self) -> None:
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp="not-a-number",
body="body",
signature="v0=abc",
) is False
def test_signature_prefix_must_be_v0(self) -> None:
timestamp = "1234567890"
body = "payload=data"
base = f"v0:{timestamp}:{body}"
hash_val = hmac.new(
_TEST_SIGNING_SECRET.encode(),
base.encode(),
hashlib.sha256,
).hexdigest()
wrong_prefix_sig = f"v1={hash_val}"
assert _verify_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
signature=wrong_prefix_sig,
) is False
# ---------------------------------------------------------------------------
# POST /slack/interactions endpoint tests
# ---------------------------------------------------------------------------
class TestSlackInteractionsEndpoint:
"""Tests for POST /slack/interactions."""
def test_returns_200_for_valid_request(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="t-abc", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
def test_returns_403_for_invalid_signature(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload()
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": "v0=invalid_signature",
},
)
assert response.status_code == 403
def test_returns_400_when_missing_timestamp_header(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
body = _make_button_payload()
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Signature": "v0=something",
},
)
assert response.status_code in (400, 403, 422)
def test_rejects_when_signing_secret_not_configured(self) -> None:
app = _make_test_app(signing_secret="")
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="t-abc", value="approve")
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": "v0=any_sig",
},
)
assert response.status_code == 503
def test_returns_200_with_approve_action(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="thread-1", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
def test_returns_200_with_reject_action(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="thread-2", value="reject")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
def test_schedules_graph_resume_in_background(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="t-bg", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
def test_returns_404_for_unknown_thread(self) -> None:
app = _make_test_app(thread_graph_name=None)
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="unknown-thread", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
# Should return 200 immediately (Slack requires immediate 200)
# but the background task may log an error
assert response.status_code == 200
def test_response_body_is_empty_or_ok(self) -> None:
app = _make_test_app()
client = TestClient(app, raise_server_exceptions=False)
timestamp = str(int(time.time()))
body = _make_button_payload(thread_id="t-ok", value="approve")
sig = _make_slack_signature(
signing_secret=_TEST_SIGNING_SECRET,
timestamp=timestamp,
body=body,
)
response = client.post(
"/slack/interactions",
content=body,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"X-Slack-Request-Timestamp": timestamp,
"X-Slack-Signature": sig,
},
)
assert response.status_code == 200
# Body may be empty or a simple JSON with ok=True
if response.content:
data = response.json()
assert data.get("ok") is True or "ok" not in data

270
tests/api/test_status.py Normal file
View File

@@ -0,0 +1,270 @@
"""Tests for status, releases, staging, and manual trigger endpoints.
Written FIRST (TDD RED phase).
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.status import router as status_router
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_app(
*,
versions: list[str] | None = None,
staging_data: dict | None = None,
) -> FastAPI:
"""Return a FastAPI app with mocked state for status tests."""
app = FastAPI()
app.include_router(status_router)
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = ""
mock_graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
mock_clients = MagicMock()
mock_staging_store = MagicMock()
mock_staging_store.list_versions = AsyncMock(return_value=versions or [])
# staging store returns StagingRelease-like or None
if staging_data is not None:
mock_staging_obj = MagicMock()
mock_staging_obj.model_dump = MagicMock(return_value=staging_data)
mock_staging_store.load = AsyncMock(return_value=mock_staging_obj)
else:
mock_staging_store.load = AsyncMock(return_value=None)
mock_pool = MagicMock()
app.state.settings = mock_settings
app.state.graphs = mock_graphs
app.state.tool_clients = mock_clients
app.state.staging_store = mock_staging_store
app.state.db_pool = mock_pool
app.state.background_tasks = set()
app.state.started_at = datetime.now(tz=timezone.utc)
return app
# ---------------------------------------------------------------------------
# GET /status
# ---------------------------------------------------------------------------
class TestGetStatus:
def test_returns_200(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
assert response.status_code == 200
def test_response_has_status_field(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
data = response.json()
assert "status" in data
assert data["status"] in ("ok", "degraded")
def test_response_has_version_field(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
data = response.json()
assert "version" in data
assert isinstance(data["version"], str)
def test_response_has_uptime_seconds(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
data = response.json()
assert "uptime_seconds" in data
assert data["uptime_seconds"] >= 0.0
def test_status_is_ok_when_healthy(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/status")
assert response.json()["status"] == "ok"
# ---------------------------------------------------------------------------
# GET /releases/{repo}
# ---------------------------------------------------------------------------
class TestGetReleaseVersions:
def test_returns_200(self) -> None:
app = _make_test_app(versions=["v1.0.0", "v1.1.0"])
with TestClient(app) as client:
response = client.get("/releases/my-repo")
assert response.status_code == 200
def test_response_has_repo_and_versions(self) -> None:
app = _make_test_app(versions=["v1.0.0", "v1.1.0"])
with TestClient(app) as client:
response = client.get("/releases/my-repo")
data = response.json()
assert data["repo"] == "my-repo"
assert data["versions"] == ["v1.0.0", "v1.1.0"]
def test_empty_versions_list(self) -> None:
app = _make_test_app(versions=[])
with TestClient(app) as client:
response = client.get("/releases/unknown-repo")
data = response.json()
assert data["versions"] == []
def test_repo_name_in_path_used(self) -> None:
mock_staging_store = MagicMock()
mock_staging_store.list_versions = AsyncMock(return_value=[])
app = _make_test_app()
app.state.staging_store = mock_staging_store
with TestClient(app) as client:
client.get("/releases/specific-repo")
mock_staging_store.list_versions.assert_called_once_with("specific-repo")
# ---------------------------------------------------------------------------
# GET /staging
# ---------------------------------------------------------------------------
class TestGetStaging:
def test_returns_200_with_staging(self) -> None:
staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []}
app = _make_test_app(staging_data=staging_data)
with TestClient(app) as client:
response = client.get("/staging?repo=my-repo")
assert response.status_code == 200
def test_response_has_repo_and_staging(self) -> None:
staging_data = {"version": "v1.0.0", "repo": "my-repo", "tickets": []}
app = _make_test_app(staging_data=staging_data)
with TestClient(app) as client:
response = client.get("/staging?repo=my-repo")
data = response.json()
assert data["repo"] == "my-repo"
assert data["staging"] is not None
assert data["staging"]["version"] == "v1.0.0"
def test_returns_null_staging_when_not_found(self) -> None:
app = _make_test_app(staging_data=None)
with TestClient(app) as client:
response = client.get("/staging?repo=no-staging-repo")
assert response.status_code == 200
data = response.json()
assert data["staging"] is None
def test_missing_repo_query_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.get("/staging")
assert response.status_code == 422
# ---------------------------------------------------------------------------
# POST /manual/pr/{pr_id}
# ---------------------------------------------------------------------------
class TestManualPrTrigger:
def test_returns_202(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post("/manual/pr/42")
assert response.status_code == 202
def test_response_has_thread_id(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post("/manual/pr/42")
data = response.json()
assert "thread_id" in data
assert isinstance(data["thread_id"], str)
def test_response_has_message(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post("/manual/pr/42")
assert "message" in response.json()
def test_schedules_background_task(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
) as mock_create:
with TestClient(app) as client:
client.post("/manual/pr/99")
mock_create.assert_called_once()
# ---------------------------------------------------------------------------
# POST /manual/release
# ---------------------------------------------------------------------------
class TestManualReleaseTrigger:
def test_returns_202(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
)
assert response.status_code == 202
def test_response_has_thread_id(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
)
data = response.json()
assert "thread_id" in data
def test_missing_repo_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={},
)
assert response.status_code == 422
def test_schedules_background_task(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
) as mock_create:
with TestClient(app) as client:
client.post(
"/manual/release",
json={"repo": "my-repo"},
)
mock_create.assert_called_once()

View File

@@ -0,0 +1,166 @@
"""Tests for status/manual endpoints with operator token authentication.
Phase 5 - Step 3: Verifies that POST /manual/* require operator token
when configured. GET endpoints are not protected.
Written FIRST (TDD RED phase).
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.status import router as status_router
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_app(operator_token: str = "") -> FastAPI:
"""Return a FastAPI app with status router and configurable operator token."""
app = FastAPI()
app.include_router(status_router)
mock_settings = MagicMock()
mock_settings.operator_token.get_secret_value.return_value = operator_token
mock_staging_store = MagicMock()
mock_staging_store.list_versions = AsyncMock(return_value=[])
mock_staging_store.load = AsyncMock(return_value=None)
mock_pool = MagicMock()
app.state.settings = mock_settings
app.state.graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
app.state.tool_clients = MagicMock()
app.state.staging_store = mock_staging_store
app.state.db_pool = mock_pool
app.state.background_tasks = set()
app.state.started_at = datetime.now(tz=timezone.utc)
return app
# ---------------------------------------------------------------------------
# POST /manual/pr/{pr_id} with auth
# ---------------------------------------------------------------------------
class TestManualPrTriggerWithAuth:
def test_valid_token_allows_manual_pr(self) -> None:
app = _make_test_app(operator_token="secure-token")
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/pr/42",
headers={"X-Operator-Token": "secure-token"},
)
assert response.status_code == 202
def test_missing_token_rejects_manual_pr(self) -> None:
app = _make_test_app(operator_token="secure-token")
with TestClient(app) as client:
response = client.post("/manual/pr/42")
assert response.status_code == 401
def test_wrong_token_rejects_manual_pr(self) -> None:
app = _make_test_app(operator_token="secure-token")
with TestClient(app) as client:
response = client.post(
"/manual/pr/42",
headers={"X-Operator-Token": "bad-token"},
)
assert response.status_code == 401
def test_no_auth_required_when_token_not_configured(self) -> None:
app = _make_test_app(operator_token="")
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post("/manual/pr/42")
assert response.status_code == 202
# ---------------------------------------------------------------------------
# POST /manual/release with auth
# ---------------------------------------------------------------------------
class TestManualReleaseTriggerWithAuth:
def test_valid_token_allows_manual_release(self) -> None:
app = _make_test_app(operator_token="secure-token")
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
headers={"X-Operator-Token": "secure-token"},
)
assert response.status_code == 202
def test_missing_token_rejects_manual_release(self) -> None:
app = _make_test_app(operator_token="secure-token")
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
)
assert response.status_code == 401
def test_wrong_token_rejects_manual_release(self) -> None:
app = _make_test_app(operator_token="secure-token")
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
headers={"X-Operator-Token": "wrong"},
)
assert response.status_code == 401
def test_no_auth_required_when_token_not_configured(self) -> None:
app = _make_test_app(operator_token="")
with patch(
"release_agent.api.status.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/manual/release",
json={"repo": "my-repo"},
)
assert response.status_code == 202
# ---------------------------------------------------------------------------
# GET /status, /releases, /staging do NOT require auth
# ---------------------------------------------------------------------------
class TestReadEndpointsNoAuth:
def test_get_status_no_token_needed(self) -> None:
"""GET /status should never require auth."""
app = _make_test_app(operator_token="super-secret")
with TestClient(app) as client:
response = client.get("/status")
assert response.status_code == 200
def test_get_releases_no_token_needed(self) -> None:
app = _make_test_app(operator_token="super-secret")
with TestClient(app) as client:
response = client.get("/releases/my-repo")
assert response.status_code == 200
def test_get_staging_no_token_needed(self) -> None:
app = _make_test_app(operator_token="super-secret")
with TestClient(app) as client:
response = client.get("/staging?repo=my-repo")
assert response.status_code == 200

218
tests/api/test_webhooks.py Normal file
View File

@@ -0,0 +1,218 @@
"""Tests for webhook endpoint. Written FIRST (TDD RED phase)."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from release_agent.api.webhooks import (
_validate_webhook_secret,
router as webhook_router,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
VALID_SECRET = "super-secret-webhook-key"
_COMPLETED_PR_PAYLOAD = {
"subscription_id": "sub-1",
"event_type": "git.pullrequest.updated",
"resource": {
"repository": {
"id": "repo-1",
"name": "my-repo",
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
},
"pull_request_id": 42,
"title": "feat: add feature",
"source_ref_name": "refs/heads/feature/BILL-123-add-feature",
"target_ref_name": "refs/heads/main",
"status": "completed",
"closed_date": "2024-01-15T10:00:00Z",
},
}
_ACTIVE_PR_PAYLOAD = {
"subscription_id": "sub-2",
"event_type": "git.pullrequest.updated",
"resource": {
"repository": {
"id": "repo-1",
"name": "my-repo",
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
},
"pull_request_id": 43,
"title": "WIP: work in progress",
"source_ref_name": "refs/heads/feature/BILL-456",
"target_ref_name": "refs/heads/main",
"status": "active",
"closed_date": None,
},
}
def _make_test_app(webhook_secret: str = VALID_SECRET) -> FastAPI:
"""Return a FastAPI app with mocked state for webhook tests."""
app = FastAPI()
app.include_router(webhook_router)
mock_settings = MagicMock()
mock_settings.webhook_secret.get_secret_value.return_value = webhook_secret
mock_graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
mock_clients = MagicMock()
mock_pool = MagicMock()
# background_tasks set tracked on state
app.state.settings = mock_settings
app.state.graphs = mock_graphs
app.state.tool_clients = mock_clients
app.state.db_pool = mock_pool
app.state.background_tasks = set()
return app
# ---------------------------------------------------------------------------
# _validate_webhook_secret (unit tests, pure function)
# ---------------------------------------------------------------------------
class TestValidateWebhookSecret:
def test_valid_secret_returns_true(self) -> None:
assert _validate_webhook_secret("mysecret", "mysecret") is True
def test_wrong_secret_returns_false(self) -> None:
assert _validate_webhook_secret("wrong", "mysecret") is False
def test_empty_header_returns_false(self) -> None:
assert _validate_webhook_secret("", "mysecret") is False
def test_none_header_returns_false(self) -> None:
assert _validate_webhook_secret(None, "mysecret") is False # type: ignore[arg-type]
def test_uses_constant_time_comparison(self) -> None:
# Should not raise even for very different lengths
assert _validate_webhook_secret("a", "very-long-secret-value") is False
def test_empty_expected_rejects_all(self) -> None:
# Empty expected secret = auth misconfigured, reject everything
assert _validate_webhook_secret("", "") is False
assert _validate_webhook_secret("any-value", "") is False
assert _validate_webhook_secret(None, "") is False
# ---------------------------------------------------------------------------
# POST /webhooks/azdo — integration tests via TestClient
# ---------------------------------------------------------------------------
class TestWebhookEndpoint:
def test_valid_completed_pr_returns_202(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.webhooks.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app, raise_server_exceptions=True) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
headers={"X-Webhook-Secret": VALID_SECRET},
)
assert response.status_code == 202
data = response.json()
assert "thread_id" in data
assert "message" in data
def test_missing_secret_header_returns_401(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
)
assert response.status_code == 401
def test_wrong_secret_header_returns_401(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
headers={"X-Webhook-Secret": "wrong-secret"},
)
assert response.status_code == 401
def test_invalid_payload_returns_422(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json={"invalid": "payload"},
headers={"X-Webhook-Secret": VALID_SECRET},
)
assert response.status_code == 422
def test_active_pr_event_returns_200_ignored(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_ACTIVE_PR_PAYLOAD,
headers={"X-Webhook-Secret": VALID_SECRET},
)
assert response.status_code == 200
data = response.json()
assert "ignored" in data.get("message", "").lower() or "ignored" in str(data).lower()
def test_completed_pr_thread_id_is_string(self) -> None:
app = _make_test_app()
with patch(
"release_agent.api.webhooks.asyncio.create_task", return_value=MagicMock()
):
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
headers={"X-Webhook-Secret": VALID_SECRET},
)
assert response.status_code == 202
assert isinstance(response.json()["thread_id"], str)
def test_completed_pr_schedules_background_task(self) -> None:
app = _make_test_app()
task_mock = MagicMock()
with patch(
"release_agent.api.webhooks.asyncio.create_task", return_value=task_mock
) as mock_create:
with TestClient(app) as client:
client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
headers={"X-Webhook-Secret": VALID_SECRET},
)
mock_create.assert_called_once()
# ---------------------------------------------------------------------------
# Error response shape
# ---------------------------------------------------------------------------
class TestWebhookErrorShape:
def test_401_has_detail_field(self) -> None:
app = _make_test_app()
with TestClient(app) as client:
response = client.post(
"/webhooks/azdo",
json=_COMPLETED_PR_PAYLOAD,
)
assert response.status_code == 401
# FastAPI HTTPException returns {"detail": ...}
assert "detail" in response.json()

0
tests/graph/__init__.py Normal file
View File

44
tests/graph/conftest.py Normal file
View File

@@ -0,0 +1,44 @@
"""Shared fixtures for graph tests.
Provides build_mock_clients() to create ToolClients with AsyncMock fields
so individual node functions can be tested without compiling the full graph.
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from release_agent.graph.dependencies import ToolClients
def build_mock_clients() -> ToolClients:
"""Return a ToolClients instance whose fields are all AsyncMock/MagicMock."""
azdo = AsyncMock()
jira = AsyncMock()
slack = AsyncMock()
reviewer = AsyncMock()
return ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
def build_config(clients: ToolClients | None = None, staging_store=None) -> dict:
"""Return a LangGraph-style config dict with clients and staging_store."""
if clients is None:
clients = build_mock_clients()
return {
"configurable": {
"clients": clients,
"staging_store": staging_store,
}
}
@pytest.fixture()
def mock_clients() -> ToolClients:
"""Pytest fixture returning fresh mock ToolClients."""
return build_mock_clients()
@pytest.fixture()
def config(mock_clients: ToolClients):
"""Pytest fixture returning a config dict with mock clients."""
return build_config(mock_clients)

View File

@@ -0,0 +1,294 @@
"""Tests for graph/ci_nodes.py.
Written FIRST (TDD RED phase).
All external calls (azdo, slack, poll_until) are mocked.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from release_agent.graph.ci_nodes import notify_ci_result, poll_ci_build, trigger_ci_build
from release_agent.models.build import BuildStatus
from release_agent.models.pipeline import PipelineInfo
from tests.graph.conftest import build_config, build_mock_clients
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_pipeline(pipeline_id: int = 10, name: str = "CI-build") -> dict:
return {"id": pipeline_id, "name": name, "repo": "my-repo"}
# ---------------------------------------------------------------------------
# trigger_ci_build
# ---------------------------------------------------------------------------
class TestTriggerCiBuild:
"""Tests for trigger_ci_build node."""
async def test_triggers_pipeline_on_branch(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 555, "state": "inProgress"}
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
clients.azdo.trigger_pipeline.assert_called_once()
assert "ci_build_id" in result
assert result["ci_build_id"] == 555
async def test_returns_ci_build_id(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=20, name="build-and-test", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 999}
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v2.0.0"}
result = await trigger_ci_build(state, config)
assert result["ci_build_id"] == 999
async def test_appends_error_when_no_pipelines_found(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = []
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
assert "errors" in result
assert len(result["errors"]) >= 1
async def test_appends_error_on_trigger_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.side_effect = ServiceError(
service="azdo", status_code=500, detail="Internal error"
)
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
assert "errors" in result
async def test_uses_main_branch_when_no_version(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 1}
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await trigger_ci_build(state, config)
call_kwargs = clients.azdo.trigger_pipeline.call_args[1]
branch = call_kwargs.get("branch", "")
assert "main" in branch or "refs/heads" in branch
async def test_appends_message_on_success(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines.return_value = [
PipelineInfo(id=10, name="CI", repo="my-repo")
]
clients.azdo.trigger_pipeline.return_value = {"id": 123}
config = build_config(clients)
state = {"repo_name": "my-repo", "version": "v1.0.0"}
result = await trigger_ci_build(state, config)
assert "messages" in result
assert len(result["messages"]) >= 1
# ---------------------------------------------------------------------------
# poll_ci_build
# ---------------------------------------------------------------------------
class TestPollCiBuild:
"""Tests for poll_ci_build node."""
async def test_returns_ci_build_status_and_result_on_completion(self) -> None:
clients = build_mock_clients()
completed_status = BuildStatus(status="completed", result="succeeded", build_url="https://build/1")
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(completed_status, True),
):
result = await poll_ci_build(state, config)
assert result["ci_build_status"] == "completed"
assert result["ci_build_result"] == "succeeded"
async def test_returns_build_url(self) -> None:
clients = build_mock_clients()
completed_status = BuildStatus(
status="completed",
result="succeeded",
build_url="https://dev.azure.com/build/42",
)
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(completed_status, True),
):
result = await poll_ci_build(state, config)
assert result.get("ci_build_url") == "https://dev.azure.com/build/42"
async def test_appends_error_on_timeout(self) -> None:
clients = build_mock_clients()
running_status = BuildStatus(status="inProgress", result=None, build_url=None)
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(running_status, False),
):
result = await poll_ci_build(state, config)
assert "errors" in result
async def test_appends_error_when_build_id_missing(self) -> None:
clients = build_mock_clients()
config = build_config(clients)
state = {"repo_name": "my-repo"} # no ci_build_id
result = await poll_ci_build(state, config)
assert "errors" in result
async def test_passes_correct_build_id_to_poll_fn(self) -> None:
clients = build_mock_clients()
clients.azdo.get_build_status.return_value = BuildStatus(
status="completed", result="succeeded", build_url=None
)
config = build_config(clients)
state = {"ci_build_id": 77, "repo_name": "my-repo"}
async def fake_poll_until(*, poll_fn, is_done, interval_seconds, max_wait_seconds, sleep_fn=None):
result = await poll_fn()
return result, True
with patch("release_agent.graph.ci_nodes.poll_until", side_effect=fake_poll_until):
await poll_ci_build(state, config)
clients.azdo.get_build_status.assert_called_once_with(build_id=77)
async def test_result_none_when_poll_returns_none(self) -> None:
clients = build_mock_clients()
config = build_config(clients)
state = {"ci_build_id": 42, "repo_name": "my-repo"}
with patch(
"release_agent.graph.ci_nodes.poll_until",
return_value=(None, False),
):
result = await poll_ci_build(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# notify_ci_result
# ---------------------------------------------------------------------------
class TestNotifyCiResult:
"""Tests for notify_ci_result node."""
async def test_sends_notification_on_success(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_status": "completed",
"ci_build_result": "succeeded",
"ci_build_url": "https://build/99",
}
result = await notify_ci_result(state, config)
clients.slack.send_notification.assert_called_once()
assert "messages" in result
async def test_sends_notification_on_failure(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_status": "completed",
"ci_build_result": "failed",
"ci_build_url": None,
}
result = await notify_ci_result(state, config)
clients.slack.send_notification.assert_called_once()
async def test_handles_slack_error_gracefully(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.slack.send_notification.side_effect = ServiceError(
service="slack", status_code=500, detail="Slack error"
)
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_result": "succeeded",
"ci_build_url": None,
}
result = await notify_ci_result(state, config)
# Should not re-raise; should append error
assert "errors" in result
async def test_includes_repo_name_in_message(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "super-service",
"ci_build_result": "succeeded",
"ci_build_url": None,
}
await notify_ci_result(state, config)
call_kwargs = clients.slack.send_notification.call_args[1]
text_or_blocks = str(call_kwargs)
assert "super-service" in text_or_blocks
async def test_returns_empty_dict_when_state_has_no_data(self) -> None:
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {}
result = await notify_ci_result(state, config)
# Should not crash; may return messages or empty dict
assert isinstance(result, dict)

View File

@@ -0,0 +1,283 @@
"""Tests for graph/dependencies.py. Written FIRST (TDD RED phase).
Covers:
- ToolClients frozen dataclass
- StagingStore Protocol (structural check)
- JsonFileStagingStore file I/O operations
"""
import json
from datetime import date
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients
from release_agent.models.release import ArchivedRelease, StagingRelease
from release_agent.models.ticket import TicketEntry
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease:
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=[],
)
# ---------------------------------------------------------------------------
# ToolClients tests
# ---------------------------------------------------------------------------
class TestToolClients:
"""Tests for the ToolClients frozen dataclass."""
def test_can_be_constructed_with_all_fields(self) -> None:
azdo = AsyncMock()
jira = AsyncMock()
slack = AsyncMock()
reviewer = AsyncMock()
clients = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
assert clients.azdo is azdo
assert clients.jira is jira
assert clients.slack is slack
assert clients.reviewer is reviewer
def test_is_frozen_cannot_reassign_field(self) -> None:
clients = ToolClients(
azdo=AsyncMock(), jira=AsyncMock(), slack=AsyncMock(), reviewer=AsyncMock()
)
with pytest.raises((AttributeError, TypeError)):
clients.azdo = AsyncMock() # type: ignore[misc]
def test_fields_are_accessible_by_name(self) -> None:
azdo = object()
clients = ToolClients(
azdo=azdo, jira=object(), slack=object(), reviewer=object()
)
assert clients.azdo is azdo
def test_equality_for_same_instances(self) -> None:
azdo = AsyncMock()
jira = AsyncMock()
slack = AsyncMock()
reviewer = AsyncMock()
c1 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
c2 = ToolClients(azdo=azdo, jira=jira, slack=slack, reviewer=reviewer)
assert c1 == c2
# ---------------------------------------------------------------------------
# StagingStore Protocol structural tests
# ---------------------------------------------------------------------------
class TestStagingStoreProtocol:
"""Verify that the Protocol is structurally correct."""
def test_json_file_store_satisfies_protocol(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
# runtime_checkable would need @runtime_checkable; check duck-typing instead
assert hasattr(store, "load")
assert hasattr(store, "save")
assert hasattr(store, "archive")
assert hasattr(store, "list_versions")
def test_protocol_is_importable(self) -> None:
# Just import-level check
assert StagingStore is not None
# ---------------------------------------------------------------------------
# JsonFileStagingStore tests
# ---------------------------------------------------------------------------
class TestJsonFileStagingStore:
"""Tests for JsonFileStagingStore using tmp_path for file I/O."""
# ------------------------------------------------------------------
# load
# ------------------------------------------------------------------
async def test_load_returns_none_when_file_missing(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
result = await store.load("nonexistent-repo")
assert result is None
async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.0"
assert loaded.repo == "my-repo"
async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging().add_ticket(_make_ticket("BILL-10"))
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert len(loaded.tickets) == 1
assert loaded.tickets[0].id == "BILL-10"
async def test_load_is_read_only_does_not_mutate_stored(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded1 = await store.load("my-repo")
loaded2 = await store.load("my-repo")
assert loaded1 is not loaded2 # fresh objects each time
# ------------------------------------------------------------------
# save
# ------------------------------------------------------------------
async def test_save_creates_file_in_directory(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="api-service")
await store.save(staging)
expected_path = tmp_path / "api-service.json"
assert expected_path.exists()
async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging_v1 = _make_staging(version="v1.0.0")
staging_v2 = _make_staging(version="v1.0.1")
await store.save(staging_v1)
await store.save(staging_v2)
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.1"
async def test_save_writes_valid_json(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
raw = (tmp_path / "my-repo.json").read_text()
data = json.loads(raw)
assert data["version"] == "v1.0.0"
assert data["repo"] == "my-repo"
async def test_save_does_not_mutate_staging_release(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
original_tickets = list(staging.tickets)
await store.save(staging)
assert list(staging.tickets) == original_tickets
# ------------------------------------------------------------------
# archive
# ------------------------------------------------------------------
async def test_archive_removes_staging_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
assert await store.load("my-repo") is None
async def test_archive_creates_archive_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="my-repo", version="v1.0.0")
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
assert archive_path.exists()
async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
release_date = date(2025, 6, 1)
await store.archive(staging, release_date)
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
data = json.loads(archive_path.read_text())
assert data["released_at"] == "2025-06-01"
async def test_archive_without_prior_save_creates_archive(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
assert archive_path.exists()
# ------------------------------------------------------------------
# list_versions
# ------------------------------------------------------------------
async def test_list_versions_empty_directory(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
versions = await store.list_versions("my-repo")
assert versions == []
async def test_list_versions_returns_version_from_staging_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v2.1.0"))
versions = await store.list_versions("my-repo")
assert "v2.1.0" in versions
async def test_list_versions_includes_archived_versions(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(version="v1.5.0")
await store.save(staging)
await store.archive(staging, date(2025, 3, 1))
# Now save a new staging for the same repo
await store.save(_make_staging(version="v1.6.0"))
versions = await store.list_versions("my-repo")
assert "v1.5.0" in versions
assert "v1.6.0" in versions
async def test_list_versions_only_returns_versions_for_given_repo(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(repo="repo-a", version="v1.0.0"))
await store.save(_make_staging(repo="repo-b", version="v2.0.0"))
versions_a = await store.list_versions("repo-a")
assert "v1.0.0" in versions_a
# repo-b version should not appear in repo-a's list
assert "v2.0.0" not in versions_a
async def test_list_versions_no_duplicates(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v1.0.0"))
versions = await store.list_versions("my-repo")
assert len(versions) == len(set(versions))
async def test_list_versions_multiple_archives(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
for i in range(3):
staging = _make_staging(version=f"v1.0.{i}")
await store.archive(staging, date(2025, 1, i + 1))
versions = await store.list_versions("my-repo")
assert len(versions) == 3
assert "v1.0.0" in versions
assert "v1.0.1" in versions
assert "v1.0.2" in versions
# ------------------------------------------------------------------
# directory creation
# ------------------------------------------------------------------
def test_store_creates_directory_if_not_exists(self, tmp_path: Path) -> None:
new_dir = tmp_path / "staging_data"
assert not new_dir.exists()
JsonFileStagingStore(directory=new_dir)
assert new_dir.exists()

View File

@@ -0,0 +1,177 @@
"""Tests for async StagingStore protocol and async JsonFileStagingStore.
Phase 5 - Step 1: All StagingStore methods become async def.
Written FIRST (TDD RED phase).
"""
import json
from datetime import date
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore, StagingStore, ToolClients
from release_agent.models.release import StagingRelease
from release_agent.models.ticket import TicketEntry
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(repo: str = "my-repo", version: str = "v1.0.0") -> StagingRelease:
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=[],
)
# ---------------------------------------------------------------------------
# Protocol: all methods must be async
# ---------------------------------------------------------------------------
class TestStagingStoreProtocolIsAsync:
"""Verify that StagingStore protocol methods are async-compatible."""
def test_protocol_has_load_method(self) -> None:
assert hasattr(StagingStore, "load")
def test_protocol_has_save_method(self) -> None:
assert hasattr(StagingStore, "save")
def test_protocol_has_archive_method(self) -> None:
assert hasattr(StagingStore, "archive")
def test_protocol_has_list_versions_method(self) -> None:
assert hasattr(StagingStore, "list_versions")
# ---------------------------------------------------------------------------
# JsonFileStagingStore async interface
# ---------------------------------------------------------------------------
class TestJsonFileStagingStoreAsync:
"""Verify that JsonFileStagingStore methods are awaitable (async def)."""
async def test_load_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
result = await store.load("nonexistent-repo")
assert result is None
async def test_load_returns_staging_release_after_save(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.0"
assert loaded.repo == "my-repo"
async def test_load_returns_staging_with_tickets(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging().add_ticket(_make_ticket("BILL-10"))
await store.save(staging)
loaded = await store.load("my-repo")
assert loaded is not None
assert len(loaded.tickets) == 1
assert loaded.tickets[0].id == "BILL-10"
async def test_load_returns_fresh_objects(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
loaded1 = await store.load("my-repo")
loaded2 = await store.load("my-repo")
assert loaded1 is not loaded2
async def test_save_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="api-service")
await store.save(staging)
expected_path = tmp_path / "api-service.json"
assert expected_path.exists()
async def test_save_overwrites_existing_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v1.0.0"))
await store.save(_make_staging(version="v1.0.1"))
loaded = await store.load("my-repo")
assert loaded is not None
assert loaded.version == "v1.0.1"
async def test_save_writes_valid_json(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
raw = (tmp_path / "my-repo.json").read_text()
data = json.loads(raw)
assert data["version"] == "v1.0.0"
assert data["repo"] == "my-repo"
async def test_archive_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
assert await store.load("my-repo") is None
async def test_archive_creates_archive_file(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(repo="my-repo", version="v1.0.0")
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
assert archive_path.exists()
async def test_archive_file_contains_released_at(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await store.save(staging)
await store.archive(staging, date(2025, 6, 1))
archive_path = tmp_path / "my-repo_v1.0.0_2025-06-01.json"
data = json.loads(archive_path.read_text())
assert data["released_at"] == "2025-06-01"
async def test_list_versions_is_awaitable(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
versions = await store.list_versions("my-repo")
assert versions == []
async def test_list_versions_returns_staging_version(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(version="v2.1.0"))
versions = await store.list_versions("my-repo")
assert "v2.1.0" in versions
async def test_list_versions_includes_archived(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(version="v1.5.0")
await store.save(staging)
await store.archive(staging, date(2025, 3, 1))
await store.save(_make_staging(version="v1.6.0"))
versions = await store.list_versions("my-repo")
assert "v1.5.0" in versions
assert "v1.6.0" in versions
async def test_list_versions_only_for_given_repo(self, tmp_path: Path) -> None:
store = JsonFileStagingStore(directory=tmp_path)
await store.save(_make_staging(repo="repo-a", version="v1.0.0"))
await store.save(_make_staging(repo="repo-b", version="v2.0.0"))
versions_a = await store.list_versions("repo-a")
assert "v1.0.0" in versions_a
assert "v2.0.0" not in versions_a

View File

@@ -0,0 +1,53 @@
"""Tests for graph/full_cycle.py.
Tests that the full cycle graph composes pr_completed and release subgraphs
correctly, and that the routing conditional edge works as expected.
"""
from release_agent.graph.full_cycle import build_full_cycle_graph
from release_agent.graph.routing import should_continue_to_release
class TestBuildFullCycleGraph:
def test_returns_compiled_graph(self) -> None:
graph = build_full_cycle_graph()
assert graph is not None
def test_graph_can_be_built_multiple_times(self) -> None:
graph1 = build_full_cycle_graph()
graph2 = build_full_cycle_graph()
assert graph1 is not None
assert graph2 is not None
def test_graph_has_get_graph_method(self) -> None:
graph = build_full_cycle_graph()
assert hasattr(graph, "get_graph") or hasattr(graph, "nodes")
class TestFullCycleRouting:
"""Test that the routing function used by full_cycle correctly
determines whether to continue to the release subgraph."""
def test_continue_when_flag_true_and_no_errors(self) -> None:
state = {"continue_to_release": True, "errors": []}
assert should_continue_to_release(state) == "yes"
def test_stop_when_flag_false(self) -> None:
state = {"continue_to_release": False}
assert should_continue_to_release(state) == "no"
def test_stop_when_flag_missing(self) -> None:
state = {}
assert should_continue_to_release(state) == "no"
def test_stop_when_errors_present(self) -> None:
state = {"continue_to_release": True, "errors": ["some error"]}
assert should_continue_to_release(state) == "no"
def test_stop_when_flag_true_but_errors_present(self) -> None:
state = {"continue_to_release": True, "errors": ["critical failure"]}
assert should_continue_to_release(state) == "no"
def test_continue_when_errors_empty_list(self) -> None:
state = {"continue_to_release": True, "errors": []}
assert should_continue_to_release(state) == "yes"

356
tests/graph/test_polling.py Normal file
View File

@@ -0,0 +1,356 @@
"""Tests for graph/polling.py — poll_until async utility.
Written FIRST (TDD RED phase).
All tests inject a fake_sleep_fn that returns immediately to avoid real waits.
"""
import asyncio
from unittest.mock import AsyncMock, call
import pytest
from release_agent.graph.polling import poll_until
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _immediate_sleep(seconds: float) -> None:
"""Drop-in replacement for asyncio.sleep that returns immediately."""
return
# ---------------------------------------------------------------------------
# Success path tests
# ---------------------------------------------------------------------------
class TestPollUntilSuccess:
"""Tests for the happy path where poll_fn succeeds before timeout."""
async def test_returns_tuple_of_result_and_completed_true(self) -> None:
calls = iter(["running", "running", "completed"])
async def poll_fn():
return next(calls)
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "completed",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "completed"
assert completed is True
async def test_returns_immediately_when_already_done(self) -> None:
async def poll_fn():
return "completed"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "completed",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "completed"
assert completed is True
async def test_polls_multiple_times_before_done(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
return "done" if call_count >= 3 else "pending"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "done"
assert completed is True
assert call_count == 3
async def test_sleep_called_between_polls(self) -> None:
call_count = 0
sleep_calls: list[float] = []
async def poll_fn():
nonlocal call_count
call_count += 1
return "done" if call_count >= 2 else "pending"
async def tracking_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=15,
max_wait_seconds=60,
sleep_fn=tracking_sleep,
)
assert len(sleep_calls) >= 1
assert all(s == 15 for s in sleep_calls)
async def test_no_sleep_on_first_successful_poll(self) -> None:
sleep_calls: list[float] = []
async def poll_fn():
return "done"
async def tracking_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=10,
max_wait_seconds=60,
sleep_fn=tracking_sleep,
)
assert sleep_calls == []
async def test_works_with_dict_results(self) -> None:
responses = iter([
{"status": "inProgress"},
{"status": "completed", "result": "succeeded"},
])
async def poll_fn():
return next(responses)
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r["status"] == "completed",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result["result"] == "succeeded"
assert completed is True
# ---------------------------------------------------------------------------
# Timeout tests
# ---------------------------------------------------------------------------
class TestPollUntilTimeout:
"""Tests for timeout behavior."""
async def test_returns_last_result_and_completed_false_on_timeout(self) -> None:
async def poll_fn():
return "still_running"
# With interval=10, max_wait=5, it should time out after one poll
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=10,
max_wait_seconds=5,
sleep_fn=_immediate_sleep,
)
assert result == "still_running"
assert completed is False
async def test_at_least_one_poll_happens_before_timeout(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
return "running"
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=100,
max_wait_seconds=1,
sleep_fn=_immediate_sleep,
)
assert call_count >= 1
async def test_max_polls_bounded_by_max_wait_over_interval(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
return "running"
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: False,
interval_seconds=10,
max_wait_seconds=30,
sleep_fn=_immediate_sleep,
)
# With interval=10, max_wait=30: should poll at most ceil(30/10)+1 = 4 times
assert call_count <= 5
# ---------------------------------------------------------------------------
# Error handling tests
# ---------------------------------------------------------------------------
class TestPollUntilErrorHandling:
"""Tests for error/exception handling in poll_until."""
async def test_continues_after_transient_exception(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("Transient error")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert result == "done"
assert completed is True
async def test_aborts_after_three_consecutive_failures(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
raise RuntimeError("Persistent error")
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: True,
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
# Should abort after 3 consecutive failures
assert call_count == 3
assert completed is False
assert result is None
async def test_resets_consecutive_failure_count_on_success(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
# Fail twice, succeed once, fail twice, succeed (done)
if call_count in (1, 2):
raise RuntimeError("fail")
if call_count == 3:
return "running"
if call_count in (4, 5):
raise RuntimeError("fail again")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=120,
sleep_fn=_immediate_sleep,
)
assert completed is True
assert result == "done"
async def test_single_exception_does_not_abort(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValueError("one error")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert completed is True
assert result == "done"
async def test_two_consecutive_failures_do_not_abort(self) -> None:
call_count = 0
async def poll_fn():
nonlocal call_count
call_count += 1
if call_count <= 2:
raise ConnectionError("two errors")
return "done"
result, completed = await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
interval_seconds=1,
max_wait_seconds=60,
sleep_fn=_immediate_sleep,
)
assert completed is True
assert result == "done"
# ---------------------------------------------------------------------------
# Default parameter tests
# ---------------------------------------------------------------------------
class TestPollUntilDefaults:
"""Tests that default parameters match the spec."""
async def test_default_interval_is_30_seconds(self) -> None:
sleep_calls: list[float] = []
async def poll_fn():
return "done" if len(sleep_calls) >= 1 else "running"
async def tracking_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
await poll_until(
poll_fn=poll_fn,
is_done=lambda r: r == "done",
sleep_fn=tracking_sleep,
)
if sleep_calls:
assert sleep_calls[0] == 30
async def test_poll_fn_and_is_done_are_keyword_only(self) -> None:
"""poll_fn and is_done must be passed as keyword arguments."""
async def poll_fn():
return "done"
with pytest.raises(TypeError):
await poll_until(poll_fn, lambda r: r == "done") # type: ignore[call-arg]

View File

@@ -0,0 +1,414 @@
"""Tests for PostgresStagingStore.
Phase 5 - Step 2: PostgreSQL-backed StagingStore using async pool.
Written FIRST (TDD RED phase).
All tests use FakeAsyncPool — no real PostgreSQL required.
"""
import json
from datetime import date
from unittest.mock import AsyncMock, MagicMock
import pytest
from release_agent.graph.postgres_staging_store import PostgresStagingStore
from release_agent.models.release import ArchivedRelease, StagingRelease
from release_agent.models.ticket import TicketEntry
# ---------------------------------------------------------------------------
# Fake pool infrastructure
# ---------------------------------------------------------------------------
class FakeAsyncCursor:
"""Records SQL calls and returns configured results."""
def __init__(self) -> None:
self.executed: list[tuple[str, tuple]] = []
self._fetchone_result: tuple | None = None
self._fetchall_result: list[tuple] = []
def set_fetchone(self, row: tuple | None) -> None:
self._fetchone_result = row
def set_fetchall(self, rows: list[tuple]) -> None:
self._fetchall_result = rows
async def execute(self, sql: str, params: tuple = ()) -> None:
self.executed.append((sql, params))
async def fetchone(self) -> tuple | None:
return self._fetchone_result
async def fetchall(self) -> list[tuple]:
return self._fetchall_result
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeAsyncTransaction:
"""Fake async transaction context manager (no-op)."""
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeAsyncConnection:
"""Async context manager returning a FakeAsyncCursor."""
def __init__(self, cursor: FakeAsyncCursor) -> None:
self._cursor = cursor
def cursor(self):
return self._cursor
def transaction(self):
return FakeAsyncTransaction()
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeAsyncPool:
"""Records all SQL executed through it."""
def __init__(self, cursor: FakeAsyncCursor) -> None:
self._cursor = cursor
self._conn = FakeAsyncConnection(cursor)
def connection(self):
return self._conn
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(
repo: str = "my-repo",
version: str = "v1.0.0",
tickets: list | None = None,
) -> StagingRelease:
t = tickets if tickets is not None else []
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=t,
)
def _staging_row(staging: StagingRelease) -> tuple:
"""Return (repo, version, started_at, tickets_json) as DB would store it."""
return (
staging.repo,
staging.version,
staging.started_at.isoformat(),
json.dumps([t.model_dump(mode="json") for t in staging.tickets]),
)
# ---------------------------------------------------------------------------
# load()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreLoad:
async def test_load_returns_none_when_no_row(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
result = await store.load("nonexistent-repo")
assert result is None
async def test_load_returns_staging_release_when_row_exists(self) -> None:
staging = _make_staging(repo="api-service", version="v2.0.0")
cursor = FakeAsyncCursor()
cursor.set_fetchone((
staging.repo,
staging.version,
staging.started_at.isoformat(),
json.dumps([]),
))
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
result = await store.load("api-service")
assert result is not None
assert isinstance(result, StagingRelease)
assert result.repo == "api-service"
assert result.version == "v2.0.0"
async def test_load_returns_staging_with_tickets(self) -> None:
ticket = _make_ticket("BILL-42")
staging = _make_staging(tickets=[ticket])
cursor = FakeAsyncCursor()
cursor.set_fetchone((
staging.repo,
staging.version,
staging.started_at.isoformat(),
json.dumps([ticket.model_dump(mode="json")]),
))
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
result = await store.load("my-repo")
assert result is not None
assert len(result.tickets) == 1
assert result.tickets[0].id == "BILL-42"
async def test_load_executes_select_with_correct_repo(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
await store.load("target-repo")
assert len(cursor.executed) >= 1
sql, params = cursor.executed[-1]
assert "SELECT" in sql.upper()
assert "target-repo" in params
async def test_load_queries_staging_releases_table(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
await store.load("my-repo")
sql, _ = cursor.executed[-1]
assert "staging_releases" in sql
# ---------------------------------------------------------------------------
# save()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreSave:
async def test_save_executes_upsert(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
await store.save(staging)
assert len(cursor.executed) >= 1
sql, _ = cursor.executed[-1]
# Should be an INSERT ... ON CONFLICT ... or UPSERT
assert "INSERT" in sql.upper() or "UPSERT" in sql.upper()
async def test_save_passes_repo_to_upsert(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(repo="payment-service")
await store.save(staging)
_, params = cursor.executed[-1]
assert "payment-service" in params
async def test_save_passes_version_to_upsert(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(version="v3.1.0")
await store.save(staging)
_, params = cursor.executed[-1]
assert "v3.1.0" in params
async def test_save_targets_staging_releases_table(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
await store.save(staging)
sql, _ = cursor.executed[-1]
assert "staging_releases" in sql
async def test_save_serializes_tickets_as_json(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(tickets=[_make_ticket("ALLPOST-99")])
await store.save(staging)
_, params = cursor.executed[-1]
# tickets param should be a JSON string containing the ticket id
tickets_json = next(p for p in params if isinstance(p, str) and "ALLPOST-99" in p)
parsed = json.loads(tickets_json)
assert parsed[0]["id"] == "ALLPOST-99"
# ---------------------------------------------------------------------------
# archive()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreArchive:
async def test_archive_inserts_into_archived_releases(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
await store.archive(staging, date(2025, 6, 1))
sql_statements = [sql for sql, _ in cursor.executed]
assert any("archived_releases" in sql for sql in sql_statements)
async def test_archive_deletes_from_staging_releases(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(repo="my-repo")
await store.archive(staging, date(2025, 6, 1))
sql_statements = [sql for sql, _ in cursor.executed]
assert any("DELETE" in sql.upper() and "staging_releases" in sql for sql in sql_statements)
async def test_archive_passes_released_at_date(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging()
release_date = date(2025, 12, 31)
await store.archive(staging, release_date)
all_params = [params for _, params in cursor.executed]
all_values = [v for params in all_params for v in params]
assert "2025-12-31" in all_values or release_date.isoformat() in all_values
async def test_archive_passes_repo_to_delete(self) -> None:
cursor = FakeAsyncCursor()
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
staging = _make_staging(repo="payment-service")
await store.archive(staging, date(2025, 6, 1))
delete_calls = [(sql, params) for sql, params in cursor.executed if "DELETE" in sql.upper()]
assert len(delete_calls) >= 1
_, params = delete_calls[0]
assert "payment-service" in params
# ---------------------------------------------------------------------------
# list_versions()
# ---------------------------------------------------------------------------
class TestPostgresStagingStoreListVersions:
async def test_list_versions_returns_empty_when_no_data(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
cursor.set_fetchall([])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert versions == []
async def test_list_versions_includes_staging_version(self) -> None:
cursor = FakeAsyncCursor()
# fetchone returns staging row
cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]"))
# fetchall returns archived rows
cursor.set_fetchall([])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert "v1.0.0" in versions
async def test_list_versions_includes_archived_versions(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
cursor.set_fetchall([
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-06-01"),
("my-repo", "v1.1.0", "2025-02-01", "[]", "2025-07-01"),
])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert "v1.0.0" in versions
assert "v1.1.0" in versions
async def test_list_versions_combines_staging_and_archived(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(("my-repo", "v2.0.0", "2025-03-01", "[]"))
cursor.set_fetchall([
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"),
])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert "v2.0.0" in versions
assert "v1.0.0" in versions
async def test_list_versions_no_duplicates(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(("my-repo", "v1.0.0", "2025-01-01", "[]"))
cursor.set_fetchall([
("my-repo", "v1.0.0", "2025-01-01", "[]", "2025-02-01"),
])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
versions = await store.list_versions("my-repo")
assert len(versions) == len(set(versions))
async def test_list_versions_executes_queries_for_correct_repo(self) -> None:
cursor = FakeAsyncCursor()
cursor.set_fetchone(None)
cursor.set_fetchall([])
pool = FakeAsyncPool(cursor)
store = PostgresStagingStore(pool=pool)
await store.list_versions("target-repo")
all_params = [params for _, params in cursor.executed]
all_values = [v for params in all_params for v in params]
assert "target-repo" in all_values

View File

@@ -0,0 +1,956 @@
"""Tests for graph/pr_completed.py node functions. Written FIRST (TDD RED phase).
Each node is an async function (state, config) -> dict.
Tests call nodes directly with a state dict and config dict — no graph compilation.
"""
from datetime import date, datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore, ToolClients
from release_agent.graph.pr_completed import (
_post_review_to_pr,
add_jira_pr_link,
auto_create_ticket,
calculate_version,
evaluate_review,
fetch_pr_details,
interrupt_confirm_merge,
merge_pr_node,
move_jira_code_review,
move_jira_ready_for_stage,
notify_request_changes,
parse_webhook,
run_code_review,
update_staging,
build_pr_completed_graph,
)
from release_agent.models.review import ReviewIssue
from release_agent.models.jira import JiraIssue
from release_agent.models.pr import PRInfo
from release_agent.models.review import ReviewResult
from tests.graph.conftest import build_config, build_mock_clients
# ---------------------------------------------------------------------------
# Webhook payload fixtures
# ---------------------------------------------------------------------------
def _make_webhook_payload(
*,
repo_name: str = "my-repo",
pr_id: int = 42,
source_ref: str = "refs/heads/feature/ALLPOST-100_fix-bug",
target_ref: str = "refs/heads/main",
status: str = "completed",
title: str = "Fix: bug",
closed_date: str | None = "2025-01-15T10:00:00Z",
) -> dict:
# Uses snake_case keys to match WebhookPayload Pydantic model field names
return {
"subscription_id": "sub-1",
"event_type": "git.pullrequest.merged",
"resource": {
"repository": {
"id": "repo-id-1",
"name": repo_name,
"web_url": "https://dev.azure.com/org/proj/_git/my-repo",
},
"pull_request_id": pr_id,
"title": title,
"source_ref_name": source_ref,
"target_ref_name": target_ref,
"status": status,
"closed_date": closed_date,
},
}
def _make_pr_info(
*,
pr_id: str = "42",
repo_name: str = "my-repo",
branch: str = "refs/heads/feature/ALLPOST-100-fix-bug",
status: str = "completed",
) -> PRInfo:
return PRInfo(
pr_id=pr_id,
pr_url="https://dev.azure.com/org/proj/_git/my-repo/pullrequest/42",
repo_name=repo_name,
branch=branch,
pr_title="Fix: bug",
pr_status=status,
)
def _make_approve_review() -> dict:
return {
"verdict": "approve",
"summary": "Looks good",
"issues": [],
"has_blockers": False,
}
def _make_request_changes_review() -> dict:
return {
"verdict": "request_changes",
"summary": "Needs work",
"issues": [{"severity": "blocker", "description": "Missing tests"}],
"has_blockers": True,
}
# ---------------------------------------------------------------------------
# parse_webhook
# ---------------------------------------------------------------------------
class TestParseWebhook:
async def test_extracts_pr_info_from_payload(self) -> None:
state = {"webhook_payload": _make_webhook_payload()}
config = build_config()
result = await parse_webhook(state, config)
assert "pr_info" in result
pr = result["pr_info"]
assert pr["pr_id"] == "42"
assert pr["repo_name"] == "my-repo"
async def test_extracts_ticket_from_branch(self) -> None:
state = {"webhook_payload": _make_webhook_payload(
source_ref="refs/heads/feature/ALLPOST-100_fix-bug"
)}
config = build_config()
result = await parse_webhook(state, config)
assert result["ticket_id"] == "ALLPOST-100"
assert result["has_ticket"] is True
async def test_no_ticket_when_branch_has_none(self) -> None:
state = {"webhook_payload": _make_webhook_payload(
source_ref="refs/heads/bugfix/generic_fix"
)}
config = build_config()
result = await parse_webhook(state, config)
assert result["has_ticket"] is False
assert result["ticket_id"] is None
async def test_sets_repo_name(self) -> None:
state = {"webhook_payload": _make_webhook_payload(repo_name="backend-api")}
config = build_config()
result = await parse_webhook(state, config)
assert result["repo_name"] == "backend-api"
async def test_sets_pr_id_as_string(self) -> None:
state = {"webhook_payload": _make_webhook_payload(pr_id=99)}
config = build_config()
result = await parse_webhook(state, config)
assert result["pr_info"]["pr_id"] == "99"
async def test_invalid_payload_adds_error(self) -> None:
state = {"webhook_payload": {"bad": "data"}}
config = build_config()
result = await parse_webhook(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
# ---------------------------------------------------------------------------
# fetch_pr_details
# ---------------------------------------------------------------------------
class TestFetchPrDetails:
async def test_fetches_pr_and_sets_pr_already_merged_false(self) -> None:
clients = build_mock_clients()
pr = _make_pr_info(status="active")
clients.azdo.get_pr = AsyncMock(return_value=pr)
clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py")
config = build_config(clients)
state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "active"}}
result = await fetch_pr_details(state, config)
assert result["pr_already_merged"] is False
assert result["pr_diff"] == "edit: main.py"
async def test_sets_pr_already_merged_true_when_completed(self) -> None:
clients = build_mock_clients()
pr = _make_pr_info(status="completed")
clients.azdo.get_pr = AsyncMock(return_value=pr)
clients.azdo.get_pr_diff = AsyncMock(return_value="")
config = build_config(clients)
state = {"pr_id": "42", "pr_info": {"pr_id": "42", "pr_status": "completed"}}
result = await fetch_pr_details(state, config)
assert result["pr_already_merged"] is True
async def test_stores_last_merge_source_commit(self) -> None:
clients = build_mock_clients()
pr = _make_pr_info(status="active")
clients.azdo.get_pr = AsyncMock(return_value=pr)
clients.azdo.get_pr_diff = AsyncMock(return_value="edit: main.py")
config = build_config(clients)
state = {"pr_id": "42", "pr_info": {"pr_id": "42"}}
result = await fetch_pr_details(state, config)
# last_merge_source_commit may be None if pr doesn't have it, but key must be present
assert "last_merge_source_commit" in result
async def test_adds_error_on_service_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.get_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Server error"
))
config = build_config(clients)
state = {"pr_id": "42"}
result = await fetch_pr_details(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
# ---------------------------------------------------------------------------
# move_jira_code_review
# ---------------------------------------------------------------------------
class TestMoveJiraCodeReview:
async def test_transitions_ticket_when_has_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_code_review(state, config)
clients.jira.transition_issue.assert_called_once_with("ALLPOST-100", "code review")
assert result == {} or "messages" in result
async def test_skips_when_no_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
state = {"has_ticket": False, "ticket_id": None}
result = await move_jira_code_review(state, config)
clients.jira.transition_issue.assert_not_called()
async def test_appends_error_on_jira_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Jira down"
))
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_code_review(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
# ---------------------------------------------------------------------------
# run_code_review
# ---------------------------------------------------------------------------
class TestRunCodeReview:
async def test_calls_reviewer_with_diff(self) -> None:
clients = build_mock_clients()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
clients.reviewer.review_pr = AsyncMock(return_value=review)
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix: bug", "repo_name": "my-repo"},
}
result = await run_code_review(state, config)
clients.reviewer.review_pr.assert_called_once()
assert "review_result" in result
async def test_stores_review_result_as_dict(self) -> None:
clients = build_mock_clients()
review = ReviewResult(verdict="approve", summary="Clean code", issues=())
clients.reviewer.review_pr = AsyncMock(return_value=review)
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix", "repo_name": "repo"},
}
result = await run_code_review(state, config)
assert result["review_result"]["verdict"] == "approve"
async def test_adds_error_on_reviewer_failure(self) -> None:
clients = build_mock_clients()
clients.reviewer.review_pr = AsyncMock(side_effect=Exception("API error"))
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix", "repo_name": "repo"},
}
result = await run_code_review(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# _post_review_to_pr
# ---------------------------------------------------------------------------
class TestPostReviewToPr:
async def test_posts_summary_comment(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_comment.assert_called_once()
call_kwargs = clients.azdo.add_pr_comment.call_args
assert "APPROVE" in call_kwargs.kwargs["content"]
async def test_posts_inline_comment_for_issue_with_file_and_line(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issue = ReviewIssue(
severity="error", description="Null check missing",
file_path="src/Foo.cs", line_start=42, suggestion="Add null guard",
)
review = ReviewResult(verdict="request_changes", summary="Issues", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_inline_comment.assert_called_once()
call_kwargs = clients.azdo.add_pr_inline_comment.call_args.kwargs
assert call_kwargs["file_path"] == "src/Foo.cs"
assert call_kwargs["line_start"] == 42
assert "Null check missing" in call_kwargs["content"]
assert "Add null guard" in call_kwargs["content"]
async def test_skips_inline_for_issue_without_line(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issue = ReviewIssue(severity="warning", description="Style issue", file_path="src/Foo.cs")
review = ReviewResult(verdict="approve", summary="OK", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_inline_comment.assert_not_called()
async def test_skips_inline_for_issue_without_file(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issue = ReviewIssue(severity="info", description="General note", line_start=10)
review = ReviewResult(verdict="approve", summary="OK", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
clients.azdo.add_pr_inline_comment.assert_not_called()
async def test_inline_failure_does_not_prevent_summary(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock(side_effect=Exception("API error"))
issue = ReviewIssue(
severity="blocker", description="Critical", file_path="a.cs", line_start=1
)
review = ReviewResult(verdict="request_changes", summary="Bad", issues=(issue,))
await _post_review_to_pr(clients, "my-repo", 42, review)
# Summary should still be posted even though inline failed
clients.azdo.add_pr_comment.assert_called_once()
async def test_summary_failure_does_not_raise(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock(side_effect=Exception("Network error"))
clients.azdo.add_pr_inline_comment = AsyncMock()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
# Should not raise
await _post_review_to_pr(clients, "my-repo", 42, review)
async def test_summary_contains_issue_count(self) -> None:
clients = build_mock_clients()
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
issues = (
ReviewIssue(severity="warning", description="Issue 1"),
ReviewIssue(severity="error", description="Issue 2"),
)
review = ReviewResult(verdict="request_changes", summary="Problems", issues=issues)
await _post_review_to_pr(clients, "my-repo", 42, review)
content = clients.azdo.add_pr_comment.call_args.kwargs["content"]
assert "2 issue(s)" in content
async def test_run_code_review_calls_post_review(self) -> None:
"""Integration: run_code_review posts comments when pr_id and repo_name present."""
clients = build_mock_clients()
review = ReviewResult(verdict="approve", summary="LGTM", issues=())
clients.reviewer.review_pr = AsyncMock(return_value=review)
clients.azdo.add_pr_comment = AsyncMock()
clients.azdo.add_pr_inline_comment = AsyncMock()
config = build_config(clients)
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "my-repo"},
}
await run_code_review(state, config)
clients.azdo.add_pr_comment.assert_called_once()
# ---------------------------------------------------------------------------
# evaluate_review
# ---------------------------------------------------------------------------
class TestEvaluateReview:
async def test_sets_review_approved_true_for_approve_verdict(self) -> None:
config = build_config()
state = {"review_result": _make_approve_review()}
result = await evaluate_review(state, config)
assert result["review_approved"] is True
async def test_sets_review_approved_false_for_request_changes(self) -> None:
config = build_config()
state = {"review_result": _make_request_changes_review()}
result = await evaluate_review(state, config)
assert result["review_approved"] is False
async def test_sets_false_when_review_result_missing(self) -> None:
config = build_config()
state = {}
result = await evaluate_review(state, config)
assert result["review_approved"] is False
async def test_sets_false_when_has_blockers(self) -> None:
config = build_config()
state = {
"review_result": {
"verdict": "approve",
"summary": "Approve with blocker?",
"issues": [{"severity": "blocker", "description": "Problem"}],
"has_blockers": True,
}
}
result = await evaluate_review(state, config)
assert result["review_approved"] is False
# ---------------------------------------------------------------------------
# interrupt_confirm_merge
# ---------------------------------------------------------------------------
class TestInterruptConfirmMerge:
async def test_calls_interrupt_with_summary_string(self) -> None:
config = build_config()
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"},
"review_result": {"summary": "LGTM"},
}
with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_merge(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
assert len(call_arg) > 0
async def test_interrupt_value_contains_pr_info(self) -> None:
config = build_config()
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix: auth bug", "repo_name": "backend"},
"review_result": {"summary": "All good"},
}
with patch("release_agent.graph.pr_completed.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_merge(state, config)
call_arg = mock_interrupt.call_args[0][0]
assert "42" in call_arg or "Fix: auth bug" in call_arg or "backend" in call_arg
# ---------------------------------------------------------------------------
# merge_pr_node
# ---------------------------------------------------------------------------
class TestMergePrNode:
async def test_calls_azdo_merge_pr(self) -> None:
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42"},
"last_merge_source_commit": "abc123",
}
result = await merge_pr_node(state, config)
clients.azdo.merge_pr.assert_called_once_with(
pr_id=42, last_merge_source_commit="abc123"
)
async def test_returns_message_on_success(self) -> None:
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42"},
"last_merge_source_commit": "abc123",
}
result = await merge_pr_node(state, config)
assert "messages" in result
async def test_re_raises_on_service_error(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=409, detail="Conflict"
))
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42"},
"last_merge_source_commit": "abc123",
}
with pytest.raises(ServiceError):
await merge_pr_node(state, config)
# ---------------------------------------------------------------------------
# move_jira_ready_for_stage
# ---------------------------------------------------------------------------
class TestMoveJiraReadyForStage:
async def test_transitions_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_ready_for_stage(state, config)
clients.jira.transition_issue.assert_called_once_with(
"ALLPOST-100", "Ready for stage (2)"
)
async def test_skips_when_no_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock()
config = build_config(clients)
state = {"has_ticket": False}
await move_jira_ready_for_stage(state, config)
clients.jira.transition_issue.assert_not_called()
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"ticket_id": "ALLPOST-100", "has_ticket": True}
result = await move_jira_ready_for_stage(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# add_jira_pr_link
# ---------------------------------------------------------------------------
class TestAddJiraPrLink:
async def test_calls_add_remote_link(self) -> None:
clients = build_mock_clients()
clients.jira.add_remote_link = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"ticket_id": "ALLPOST-100",
"has_ticket": True,
"pr_info": {
"pr_id": "42",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
"pr_title": "Fix: bug",
},
}
result = await add_jira_pr_link(state, config)
clients.jira.add_remote_link.assert_called_once()
async def test_skips_when_no_ticket(self) -> None:
clients = build_mock_clients()
clients.jira.add_remote_link = AsyncMock()
config = build_config(clients)
state = {"has_ticket": False}
await add_jira_pr_link(state, config)
clients.jira.add_remote_link.assert_not_called()
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.add_remote_link = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Error"
))
config = build_config(clients)
state = {
"ticket_id": "ALLPOST-100",
"has_ticket": True,
"pr_info": {
"pr_id": "42",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
"pr_title": "Fix",
},
}
result = await add_jira_pr_link(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# calculate_version
# ---------------------------------------------------------------------------
class TestCalculateVersion:
async def test_returns_v1_0_0_for_empty_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await calculate_version(state, config)
assert result["version"] == "v1.0.0"
async def test_increments_patch_version(self, tmp_path) -> None:
from release_agent.models.release import StagingRelease
staging_store = JsonFileStagingStore(directory=tmp_path)
# Pre-populate with an existing version
staging = StagingRelease(
version="v1.0.5",
repo="my-repo",
started_at=date(2025, 1, 1),
tickets=[],
)
await staging_store.archive(staging, date(2025, 1, 10))
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await calculate_version(state, config)
assert result["version"] == "v1.0.6"
async def test_sets_version_in_state(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "new-repo"}
result = await calculate_version(state, config)
assert "version" in result
assert result["version"].startswith("v")
# ---------------------------------------------------------------------------
# update_staging
# ---------------------------------------------------------------------------
class TestUpdateStaging:
async def test_creates_new_staging_when_none_exists(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
# Jira get_issue returns a summary
from release_agent.models.jira import JiraIssue
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="ALLPOST-100", summary="Fix auth bug", status="Ready for stage (2)"
))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-100",
"has_ticket": True,
"pr_info": {
"pr_id": "42",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
"pr_title": "Fix: auth bug",
"branch": "feature/ALLPOST-100-fix",
},
}
result = await update_staging(state, config)
loaded = await staging_store.load("my-repo")
assert loaded is not None
assert loaded.has_ticket("ALLPOST-100")
async def test_appends_ticket_to_existing_staging(self, tmp_path) -> None:
from datetime import date
from release_agent.models.release import StagingRelease
from release_agent.models.jira import JiraIssue
staging_store = JsonFileStagingStore(directory=tmp_path)
existing = StagingRelease(
version="v1.0.0", repo="my-repo",
started_at=date(2025, 1, 1), tickets=[]
)
await staging_store.save(existing)
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="BILL-99", summary="New feature", status="Ready"
))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "BILL-99",
"has_ticket": True,
"pr_info": {
"pr_id": "55",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/55",
"pr_title": "Feat: new feature",
"branch": "feature/BILL-99-feat",
},
}
await update_staging(state, config)
loaded = await staging_store.load("my-repo")
assert loaded is not None
assert loaded.has_ticket("BILL-99")
async def test_skips_ticket_add_when_no_ticket(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"has_ticket": False,
}
await update_staging(state, config)
# No staging file should be created for ticket-less PR if no existing staging
# (or staging exists without new ticket added)
clients.jira.get_issue.assert_not_called()
async def test_returns_empty_dict_when_no_staging_store(self) -> None:
from release_agent.models.jira import JiraIssue
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="ALLPOST-1", summary="Fix", status="Ready"
))
config = build_config(clients, staging_store=None)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-1",
"has_ticket": True,
"pr_info": {
"pr_id": "1",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1",
"pr_title": "Fix",
"branch": "feature/ALLPOST-1",
},
}
result = await update_staging(state, config)
assert result == {}
async def test_uses_ticket_id_as_summary_on_jira_failure(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(side_effect=Exception("Jira unavailable"))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-99",
"has_ticket": True,
"pr_info": {
"pr_id": "5",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/5",
"pr_title": "Fix something",
"branch": "feature/ALLPOST-99_fix",
},
}
result = await update_staging(state, config)
loaded = await staging_store.load("my-repo")
assert loaded is not None
assert loaded.tickets[0].id == "ALLPOST-99"
assert loaded.tickets[0].summary == "ALLPOST-99"
async def test_sets_staging_dict_in_result(self, tmp_path) -> None:
from release_agent.models.jira import JiraIssue
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
clients.jira.get_issue = AsyncMock(return_value=JiraIssue(
key="ALLPOST-1", summary="S", status="Ready"
))
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"ticket_id": "ALLPOST-1",
"has_ticket": True,
"pr_info": {
"pr_id": "1",
"pr_url": "https://dev.azure.com/org/proj/_git/repo/pullrequest/1",
"pr_title": "Fix",
"branch": "feature/ALLPOST-1",
},
}
result = await update_staging(state, config)
assert "staging" in result
assert isinstance(result["staging"], dict)
# ---------------------------------------------------------------------------
# notify_request_changes
# ---------------------------------------------------------------------------
class TestNotifyRequestChanges:
async def test_calls_slack_send_approval_request(self) -> None:
clients = build_mock_clients()
clients.slack.send_approval_request = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix: bug", "repo_name": "my-repo"},
"review_result": {
"verdict": "request_changes",
"summary": "Too many issues",
"issues": [{"severity": "blocker", "description": "No tests"}],
},
}
result = await notify_request_changes(state, config)
clients.slack.send_approval_request.assert_called_once()
async def test_appends_error_on_slack_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.slack.send_approval_request = AsyncMock(side_effect=ServiceError(
service="slack", status_code=500, detail="Webhook error"
))
config = build_config(clients)
state = {
"pr_info": {"pr_id": "42", "pr_title": "Fix", "repo_name": "repo"},
"review_result": {"summary": "Issues found", "issues": []},
}
result = await notify_request_changes(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# build_pr_completed_graph
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# auto_create_ticket node
# ---------------------------------------------------------------------------
class TestAutoCreateTicket:
"""Tests for the auto_create_ticket node."""
def _make_config_with_jira_project(
self, jira_project: str = "ALLPOST"
):
clients = build_mock_clients()
clients.jira.create_issue = AsyncMock(return_value="ALLPOST-99")
clients.reviewer.generate_ticket_content = AsyncMock(
return_value=("My summary", "My description")
)
config = build_config(clients)
config["configurable"]["default_jira_project"] = jira_project
return config, clients
async def test_creates_jira_issue_and_returns_ticket_id(self) -> None:
config, clients = self._make_config_with_jira_project()
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"},
}
result = await auto_create_ticket(state, config)
assert result.get("ticket_id") == "ALLPOST-99"
async def test_sets_has_ticket_true(self) -> None:
config, clients = self._make_config_with_jira_project()
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix bug", "repo_name": "my-repo"},
}
result = await auto_create_ticket(state, config)
assert result.get("has_ticket") is True
async def test_calls_generate_ticket_content(self) -> None:
config, clients = self._make_config_with_jira_project()
state = {
"pr_diff": "edit: main.py",
"pr_info": {"pr_title": "Fix login", "repo_name": "auth-service"},
}
await auto_create_ticket(state, config)
clients.reviewer.generate_ticket_content.assert_awaited_once()
async def test_calls_create_issue_with_project_key(self) -> None:
config, clients = self._make_config_with_jira_project(jira_project="MYPROJ")
clients.jira.create_issue = AsyncMock(return_value="MYPROJ-5")
config["configurable"]["default_jira_project"] = "MYPROJ"
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
await auto_create_ticket(state, config)
call_kwargs = clients.jira.create_issue.call_args.kwargs
assert call_kwargs["project"] == "MYPROJ"
async def test_appends_message_on_success(self) -> None:
config, _ = self._make_config_with_jira_project()
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert "messages" in result
assert len(result["messages"]) > 0
async def test_appends_error_on_create_issue_failure(self) -> None:
config, clients = self._make_config_with_jira_project()
clients.jira.create_issue = AsyncMock(side_effect=Exception("Jira down"))
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert "errors" in result
assert len(result["errors"]) > 0
async def test_appends_error_on_generate_content_failure(self) -> None:
config, clients = self._make_config_with_jira_project()
clients.reviewer.generate_ticket_content = AsyncMock(side_effect=RuntimeError("CLI fail"))
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert "errors" in result
async def test_uses_default_project_from_config(self) -> None:
config, clients = self._make_config_with_jira_project(jira_project="TEAM")
clients.jira.create_issue = AsyncMock(return_value="TEAM-1")
state = {
"pr_diff": "d",
"pr_info": {"pr_title": "t", "repo_name": "r"},
}
result = await auto_create_ticket(state, config)
assert result["ticket_id"] == "TEAM-1"
# ---------------------------------------------------------------------------
# build_pr_completed_graph
# ---------------------------------------------------------------------------
class TestBuildPrCompletedGraph:
def test_returns_compiled_graph(self) -> None:
graph = build_pr_completed_graph()
assert graph is not None
def test_graph_has_nodes(self) -> None:
graph = build_pr_completed_graph()
# The compiled graph object should be truthy
assert graph is not None
def test_graph_includes_trigger_ci_build_node(self) -> None:
graph = build_pr_completed_graph()
# Graph nodes should include CI pipeline nodes
graph_nodes = graph.get_graph().nodes
assert "trigger_ci_build" in graph_nodes
def test_graph_includes_poll_ci_build_node(self) -> None:
graph = build_pr_completed_graph()
graph_nodes = graph.get_graph().nodes
assert "poll_ci_build" in graph_nodes
def test_graph_includes_notify_ci_result_node(self) -> None:
graph = build_pr_completed_graph()
graph_nodes = graph.get_graph().nodes
assert "notify_ci_result" in graph_nodes
def test_graph_includes_auto_create_ticket_node(self) -> None:
graph = build_pr_completed_graph()
graph_nodes = graph.get_graph().nodes
assert "auto_create_ticket" in graph_nodes

866
tests/graph/test_release.py Normal file
View File

@@ -0,0 +1,866 @@
"""Tests for graph/release.py node functions. Written FIRST (TDD RED phase).
Each node is an async function (state, config) -> dict.
Tests call nodes directly — no graph compilation required.
"""
from datetime import date
from unittest.mock import AsyncMock, patch
import pytest
from release_agent.graph.dependencies import JsonFileStagingStore
from release_agent.graph.release import (
approve_stage,
archive_release,
check_release_approvals,
create_release_pr,
interrupt_confirm_approve,
interrupt_confirm_merge_release,
interrupt_confirm_release,
interrupt_confirm_trigger,
list_pipelines,
load_staging,
merge_release_pr,
move_tickets_to_done,
send_slack_notification,
trigger_pipelines,
build_release_graph,
)
from release_agent.models.pipeline import PipelineInfo, ReleasePipelineStage
from release_agent.models.release import StagingRelease
from release_agent.models.ticket import TicketEntry
from tests.graph.conftest import build_config, build_mock_clients
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ticket(ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Fix something",
pr_id="42",
pr_url="https://dev.azure.com/org/proj/_git/repo/pullrequest/42",
pr_title="Fix: something",
branch=f"feature/{ticket_id}-fix",
merged_at=date(2025, 1, 15),
)
def _make_staging(
*,
repo: str = "my-repo",
version: str = "v1.0.0",
tickets: list | None = None,
) -> StagingRelease:
t = tickets if tickets is not None else [_make_ticket()]
return StagingRelease(
version=version,
repo=repo,
started_at=date(2025, 1, 1),
tickets=t,
)
def _staging_dict(staging: StagingRelease) -> dict:
return staging.model_dump(mode="json")
# ---------------------------------------------------------------------------
# load_staging
# ---------------------------------------------------------------------------
class TestLoadStaging:
async def test_loads_staging_from_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await load_staging(state, config)
assert "staging" in result
assert result["staging"]["version"] == "v1.0.0"
async def test_returns_none_when_no_staging(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "nonexistent"}
result = await load_staging(state, config)
assert result.get("staging") is None
async def test_staging_includes_tickets(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(tickets=[_make_ticket("BILL-10"), _make_ticket("BILL-11")])
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {"repo_name": "my-repo"}
result = await load_staging(state, config)
assert len(result["staging"]["tickets"]) == 2
# ---------------------------------------------------------------------------
# interrupt_confirm_release
# ---------------------------------------------------------------------------
class TestInterruptConfirmRelease:
async def test_calls_interrupt_with_staging_summary(self) -> None:
config = build_config()
staging = _make_staging()
state = {
"repo_name": "my-repo",
"staging": _staging_dict(staging),
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_release(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
async def test_interrupt_contains_version_and_repo(self) -> None:
config = build_config()
staging = _make_staging(version="v2.5.0", repo="backend")
state = {
"repo_name": "backend",
"staging": _staging_dict(staging),
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_release(state, config)
call_arg = mock_interrupt.call_args[0][0]
assert "v2.5.0" in call_arg or "backend" in call_arg
# ---------------------------------------------------------------------------
# create_release_pr
# ---------------------------------------------------------------------------
class TestCreateReleasePr:
async def test_calls_azdo_create_pr(self) -> None:
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(return_value={
"pullRequestId": 99,
"lastMergeSourceCommit": {"commitId": "deadbeef"},
})
config = build_config(clients)
staging = _make_staging(version="v1.2.0")
state = {
"repo_name": "my-repo",
"version": "v1.2.0",
"staging": _staging_dict(staging),
}
result = await create_release_pr(state, config)
clients.azdo.create_pr.assert_called_once()
call_kwargs = clients.azdo.create_pr.call_args.kwargs
assert call_kwargs["repo"] == "my-repo"
async def test_sets_release_pr_id(self) -> None:
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(return_value={
"pullRequestId": 77,
"lastMergeSourceCommit": {"commitId": "cafe1234"},
})
config = build_config(clients)
staging = _make_staging(version="v1.0.3")
state = {
"repo_name": "my-repo",
"version": "v1.0.3",
"staging": _staging_dict(staging),
}
result = await create_release_pr(state, config)
assert result["release_pr_id"] == "77"
async def test_sets_release_pr_commit(self) -> None:
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(return_value={
"pullRequestId": 77,
"lastMergeSourceCommit": {"commitId": "cafe1234"},
})
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
result = await create_release_pr(state, config)
assert result["release_pr_commit"] == "cafe1234"
async def test_re_raises_on_service_error(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.create_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=422, detail="Invalid branch"
))
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
with pytest.raises(ServiceError):
await create_release_pr(state, config)
# ---------------------------------------------------------------------------
# interrupt_confirm_merge_release
# ---------------------------------------------------------------------------
class TestInterruptConfirmMergeRelease:
async def test_calls_interrupt_with_pr_info(self) -> None:
config = build_config()
state = {
"release_pr_id": "99",
"version": "v1.0.0",
"repo_name": "my-repo",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_merge_release(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
assert len(call_arg) > 0
# ---------------------------------------------------------------------------
# merge_release_pr
# ---------------------------------------------------------------------------
class TestMergeReleasePr:
async def test_calls_azdo_merge_pr(self) -> None:
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(return_value=True)
config = build_config(clients)
state = {
"release_pr_id": "99",
"release_pr_commit": "abc123",
}
await merge_release_pr(state, config)
clients.azdo.merge_pr.assert_called_once_with(
pr_id=99, last_merge_source_commit="abc123"
)
async def test_re_raises_on_service_error(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.merge_pr = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=409, detail="Conflict"
))
config = build_config(clients)
state = {"release_pr_id": "99", "release_pr_commit": "abc"}
with pytest.raises(ServiceError):
await merge_release_pr(state, config)
# ---------------------------------------------------------------------------
# move_tickets_to_done
# ---------------------------------------------------------------------------
class TestMoveTicketsToDone:
async def test_transitions_all_tickets(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
staging = _make_staging(tickets=[_make_ticket("BILL-1"), _make_ticket("BILL-2")])
state = {"staging": _staging_dict(staging)}
await move_tickets_to_done(state, config)
assert clients.jira.transition_issue.call_count == 2
async def test_calls_transition_with_done_name(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(return_value=True)
config = build_config(clients)
staging = _make_staging(tickets=[_make_ticket("BILL-1")])
state = {"staging": _staging_dict(staging)}
await move_tickets_to_done(state, config)
call_args = clients.jira.transition_issue.call_args_list[0]
ticket_id, transition = call_args[0]
assert ticket_id == "BILL-1"
assert "done" in transition.lower() or "released" in transition.lower()
async def test_appends_error_on_jira_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock(side_effect=ServiceError(
service="jira", status_code=500, detail="Error"
))
config = build_config(clients)
staging = _make_staging(tickets=[_make_ticket()])
state = {"staging": _staging_dict(staging)}
result = await move_tickets_to_done(state, config)
assert "errors" in result
async def test_empty_tickets_no_calls(self) -> None:
clients = build_mock_clients()
clients.jira.transition_issue = AsyncMock()
config = build_config(clients)
staging = _make_staging(tickets=[])
state = {"staging": _staging_dict(staging)}
await move_tickets_to_done(state, config)
clients.jira.transition_issue.assert_not_called()
# ---------------------------------------------------------------------------
# send_slack_notification
# ---------------------------------------------------------------------------
class TestSendSlackNotification:
async def test_calls_slack_send_release_notification(self) -> None:
clients = build_mock_clients()
clients.slack.send_release_notification = AsyncMock(return_value=True)
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
result = await send_slack_notification(state, config)
clients.slack.send_release_notification.assert_called_once()
async def test_appends_error_on_slack_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.slack.send_release_notification = AsyncMock(side_effect=ServiceError(
service="slack", status_code=500, detail="Webhook error"
))
config = build_config(clients)
staging = _make_staging()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"staging": _staging_dict(staging),
}
result = await send_slack_notification(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# archive_release
# ---------------------------------------------------------------------------
class TestArchiveRelease:
async def test_archives_staging_to_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging()
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"staging": _staging_dict(staging),
}
await archive_release(state, config)
# Staging should be gone now
assert await staging_store.load("my-repo") is None
async def test_archive_file_created_in_store(self, tmp_path) -> None:
staging_store = JsonFileStagingStore(directory=tmp_path)
staging = _make_staging(version="v3.0.0")
await staging_store.save(staging)
clients = build_mock_clients()
config = build_config(clients, staging_store=staging_store)
state = {
"repo_name": "my-repo",
"staging": _staging_dict(staging),
}
await archive_release(state, config)
versions = await staging_store.list_versions("my-repo")
assert "v3.0.0" in versions
# ---------------------------------------------------------------------------
# list_pipelines
# ---------------------------------------------------------------------------
class TestListPipelines:
async def test_fetches_pipelines_from_azdo(self) -> None:
clients = build_mock_clients()
pipelines = [PipelineInfo(id=1, name="build", repo="my-repo")]
clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines)
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
clients.azdo.list_build_pipelines.assert_called_once_with(repo="my-repo")
assert "pipelines" in result
assert len(result["pipelines"]) == 1
async def test_stores_pipelines_as_list_of_dicts(self) -> None:
clients = build_mock_clients()
pipelines = [
PipelineInfo(id=1, name="build", repo="my-repo"),
PipelineInfo(id=2, name="deploy", repo="my-repo"),
]
clients.azdo.list_build_pipelines = AsyncMock(return_value=pipelines)
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
assert len(result["pipelines"]) == 2
assert result["pipelines"][0]["id"] == 1
async def test_empty_pipelines_stored_as_empty_list(self) -> None:
clients = build_mock_clients()
clients.azdo.list_build_pipelines = AsyncMock(return_value=[])
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
assert result["pipelines"] == []
async def test_appends_error_on_service_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.list_build_pipelines = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await list_pipelines(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# interrupt_confirm_trigger
# ---------------------------------------------------------------------------
class TestInterruptConfirmTrigger:
async def test_calls_interrupt_with_pipelines_summary(self) -> None:
config = build_config()
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}],
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_trigger(state, config)
mock_interrupt.assert_called_once()
call_arg = mock_interrupt.call_args[0][0]
assert isinstance(call_arg, str)
# ---------------------------------------------------------------------------
# trigger_pipelines
# ---------------------------------------------------------------------------
class TestTriggerPipelines:
async def test_triggers_each_pipeline(self) -> None:
clients = build_mock_clients()
clients.azdo.trigger_pipeline = AsyncMock(return_value={"id": 1001})
config = build_config(clients)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [
{"id": 1, "name": "build", "repo": "my-repo"},
{"id": 2, "name": "deploy", "repo": "my-repo"},
],
}
result = await trigger_pipelines(state, config)
assert clients.azdo.trigger_pipeline.call_count == 2
assert "triggered_builds" in result
assert len(result["triggered_builds"]) == 2
async def test_no_pipelines_no_calls(self) -> None:
clients = build_mock_clients()
clients.azdo.trigger_pipeline = AsyncMock()
config = build_config(clients)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [],
}
result = await trigger_pipelines(state, config)
clients.azdo.trigger_pipeline.assert_not_called()
assert result["triggered_builds"] == []
async def test_appends_error_on_trigger_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.trigger_pipeline = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {
"repo_name": "my-repo",
"version": "v1.0.0",
"pipelines": [{"id": 1, "name": "build", "repo": "my-repo"}],
}
result = await trigger_pipelines(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# check_release_approvals
# ---------------------------------------------------------------------------
class TestCheckReleaseApprovals:
async def test_fetches_pending_approvals_from_builds(self) -> None:
clients = build_mock_clients()
clients.azdo.get_build_status = AsyncMock(return_value="completed")
config = build_config(clients)
state = {
"triggered_builds": [{"id": 1001}],
}
result = await check_release_approvals(state, config)
assert "pending_approvals" in result
async def test_empty_builds_means_no_approvals(self) -> None:
clients = build_mock_clients()
config = build_config(clients)
state = {"triggered_builds": []}
result = await check_release_approvals(state, config)
assert result["pending_approvals"] == []
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.get_build_status = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"triggered_builds": [{"id": 1001}]}
result = await check_release_approvals(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# interrupt_confirm_approve
# ---------------------------------------------------------------------------
class TestInterruptConfirmApprove:
async def test_calls_interrupt_with_approvals_summary(self) -> None:
config = build_config()
state = {
"pending_approvals": [{"approval_id": "aaa", "stage_name": "Production"}],
"version": "v1.0.0",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_confirm_approve(state, config)
mock_interrupt.assert_called_once()
# ---------------------------------------------------------------------------
# approve_stage
# ---------------------------------------------------------------------------
class TestApproveStage:
async def test_approves_each_pending_approval(self) -> None:
clients = build_mock_clients()
clients.azdo.approve_release = AsyncMock(return_value={"status": "approved"})
config = build_config(clients)
state = {
"pending_approvals": [
{"approval_id": "aaa"},
{"approval_id": "bbb"},
],
}
result = await approve_stage(state, config)
assert clients.azdo.approve_release.call_count == 2
async def test_no_approvals_no_calls(self) -> None:
clients = build_mock_clients()
clients.azdo.approve_release = AsyncMock()
config = build_config(clients)
state = {"pending_approvals": []}
await approve_stage(state, config)
clients.azdo.approve_release.assert_not_called()
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
clients = build_mock_clients()
clients.azdo.approve_release = AsyncMock(side_effect=ServiceError(
service="azdo", status_code=500, detail="Error"
))
config = build_config(clients)
state = {"pending_approvals": [{"approval_id": "aaa"}]}
result = await approve_stage(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# build_release_graph
# ---------------------------------------------------------------------------
class TestBuildReleaseGraph:
def test_returns_compiled_graph(self) -> None:
graph = build_release_graph()
assert graph is not None
def test_graph_includes_trigger_ci_build_main_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "trigger_ci_build_main" in graph_nodes
def test_graph_includes_poll_ci_build_main_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "poll_ci_build_main" in graph_nodes
def test_graph_includes_wait_for_cd_release_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "wait_for_cd_release" in graph_nodes
def test_graph_includes_poll_release_approvals_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "poll_release_approvals" in graph_nodes
def test_graph_includes_interrupt_sandbox_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "interrupt_sandbox_approval" in graph_nodes
def test_graph_includes_interrupt_prod_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "interrupt_prod_approval" in graph_nodes
def test_graph_includes_execute_sandbox_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "execute_sandbox_approval" in graph_nodes
def test_graph_includes_execute_prod_approval_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "execute_prod_approval" in graph_nodes
def test_graph_includes_notify_ci_failure_node(self) -> None:
graph = build_release_graph()
graph_nodes = graph.get_graph().nodes
assert "notify_ci_failure" in graph_nodes
# ---------------------------------------------------------------------------
# New release graph node: wait_for_cd_release
# ---------------------------------------------------------------------------
class TestWaitForCdRelease:
"""Tests for wait_for_cd_release node."""
async def test_sets_release_id_when_found(self) -> None:
from release_agent.graph.release import wait_for_cd_release
clients = build_mock_clients()
clients.azdo.get_latest_release.return_value = {"id": 100, "name": "Release-100"}
config = build_config(clients)
state = {"release_definition_id": 5, "repo_name": "my-repo"}
result = await wait_for_cd_release(state, config)
assert "release_id" in result
assert result["release_id"] == 100
async def test_appends_error_when_no_release(self) -> None:
from release_agent.graph.release import wait_for_cd_release
clients = build_mock_clients()
clients.azdo.get_latest_release.return_value = {}
config = build_config(clients)
state = {"release_definition_id": 5, "repo_name": "my-repo"}
result = await wait_for_cd_release(state, config)
assert "errors" in result
async def test_works_without_release_definition_id(self) -> None:
from release_agent.graph.release import wait_for_cd_release
clients = build_mock_clients()
config = build_config(clients)
state = {"repo_name": "my-repo"}
result = await wait_for_cd_release(state, config)
assert isinstance(result, dict)
# ---------------------------------------------------------------------------
# New release graph node: poll_release_approvals
# ---------------------------------------------------------------------------
class TestPollReleaseApprovals:
"""Tests for poll_release_approvals node."""
async def test_sets_pending_approvals_from_azdo(self) -> None:
from release_agent.graph.release import poll_release_approvals
from release_agent.models.build import ApprovalRecord
clients = build_mock_clients()
clients.azdo.get_release_approvals.return_value = [
ApprovalRecord(approval_id="a1", stage_name="Sandbox", status="pending", release_id=10),
]
config = build_config(clients)
state = {"release_id": 10}
result = await poll_release_approvals(state, config)
assert "pending_approvals" in result
assert len(result["pending_approvals"]) == 1
async def test_returns_empty_list_when_no_approvals(self) -> None:
from release_agent.graph.release import poll_release_approvals
clients = build_mock_clients()
clients.azdo.get_release_approvals.return_value = []
config = build_config(clients)
state = {"release_id": 10}
result = await poll_release_approvals(state, config)
assert result.get("pending_approvals") == []
async def test_appends_error_on_failure(self) -> None:
from release_agent.exceptions import ServiceError
from release_agent.graph.release import poll_release_approvals
clients = build_mock_clients()
clients.azdo.get_release_approvals.side_effect = ServiceError(
service="azdo", status_code=500, detail="error"
)
config = build_config(clients)
state = {"release_id": 10}
result = await poll_release_approvals(state, config)
assert "errors" in result
# ---------------------------------------------------------------------------
# New release graph node: interrupt_sandbox_approval
# ---------------------------------------------------------------------------
class TestInterruptSandboxApproval:
async def test_calls_interrupt(self) -> None:
from release_agent.graph.release import interrupt_sandbox_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
"version": "v1.0.0",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_sandbox_approval(state, config)
mock_interrupt.assert_called_once()
async def test_sets_current_stage_to_sandbox_pending(self) -> None:
from release_agent.graph.release import interrupt_sandbox_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
}
with patch("release_agent.graph.release.interrupt", return_value="yes"):
result = await interrupt_sandbox_approval(state, config)
assert result.get("current_stage") == "sandbox_pending"
# ---------------------------------------------------------------------------
# New release graph node: interrupt_prod_approval
# ---------------------------------------------------------------------------
class TestInterruptProdApproval:
async def test_calls_interrupt(self) -> None:
from release_agent.graph.release import interrupt_prod_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
"version": "v1.0.0",
}
with patch("release_agent.graph.release.interrupt") as mock_interrupt:
mock_interrupt.return_value = "confirm"
await interrupt_prod_approval(state, config)
mock_interrupt.assert_called_once()
async def test_sets_current_stage_to_prod_pending(self) -> None:
from release_agent.graph.release import interrupt_prod_approval
config = build_config()
state = {
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
}
with patch("release_agent.graph.release.interrupt", return_value="yes"):
result = await interrupt_prod_approval(state, config)
assert result.get("current_stage") == "prod_pending"
# ---------------------------------------------------------------------------
# New release graph node: execute_sandbox_approval
# ---------------------------------------------------------------------------
class TestExecuteSandboxApproval:
async def test_approves_sandbox_approvals(self) -> None:
from release_agent.graph.release import execute_sandbox_approval
clients = build_mock_clients()
clients.azdo.approve_release.return_value = {"status": "approved"}
config = build_config(clients)
state = {
"pending_approvals": [{"approval_id": "sb1", "stage_name": "Sandbox"}],
}
result = await execute_sandbox_approval(state, config)
clients.azdo.approve_release.assert_called()
async def test_returns_empty_dict_on_success(self) -> None:
from release_agent.graph.release import execute_sandbox_approval
clients = build_mock_clients()
clients.azdo.approve_release.return_value = {"status": "approved"}
config = build_config(clients)
state = {"pending_approvals": [{"approval_id": "sb1"}]}
result = await execute_sandbox_approval(state, config)
assert "errors" not in result or result["errors"] == []
# ---------------------------------------------------------------------------
# New release graph node: execute_prod_approval
# ---------------------------------------------------------------------------
class TestExecuteProdApproval:
async def test_approves_prod_approvals(self) -> None:
from release_agent.graph.release import execute_prod_approval
clients = build_mock_clients()
clients.azdo.approve_release.return_value = {"status": "approved"}
config = build_config(clients)
state = {
"pending_approvals": [{"approval_id": "pd1", "stage_name": "Production"}],
}
result = await execute_prod_approval(state, config)
clients.azdo.approve_release.assert_called()
# ---------------------------------------------------------------------------
# New release graph node: notify_ci_failure
# ---------------------------------------------------------------------------
class TestNotifyCiFailure:
async def test_sends_slack_notification(self) -> None:
from release_agent.graph.release import notify_ci_failure
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {
"repo_name": "my-repo",
"ci_build_result": "failed",
"ci_build_url": "https://build/1",
}
result = await notify_ci_failure(state, config)
clients.slack.send_notification.assert_called_once()
async def test_appends_message_on_success(self) -> None:
from release_agent.graph.release import notify_ci_failure
clients = build_mock_clients()
clients.slack.send_notification.return_value = True
config = build_config(clients)
state = {"repo_name": "my-repo", "ci_build_result": "failed"}
result = await notify_ci_failure(state, config)
assert "messages" in result or isinstance(result, dict)

302
tests/graph/test_routing.py Normal file
View File

@@ -0,0 +1,302 @@
"""Tests for graph/routing.py. Written FIRST (TDD RED phase).
All routing functions are pure — they take a state dict and return a string.
Every branch is tested, including missing state fields (defaults to falsy).
"""
import pytest
from release_agent.graph.routing import (
has_pending_approvals,
has_pipelines,
has_ticket,
is_pr_already_merged,
is_review_approved,
route_after_fetch,
route_approval_stage,
route_ci_result,
should_continue_to_release,
)
# ---------------------------------------------------------------------------
# is_pr_already_merged
# ---------------------------------------------------------------------------
class TestIsPrAlreadyMerged:
def test_returns_merged_when_true(self) -> None:
state = {"pr_already_merged": True}
assert is_pr_already_merged(state) == "merged"
def test_returns_active_when_false(self) -> None:
state = {"pr_already_merged": False}
assert is_pr_already_merged(state) == "active"
def test_returns_active_when_field_missing(self) -> None:
state = {}
assert is_pr_already_merged(state) == "active"
def test_returns_active_when_none(self) -> None:
state = {"pr_already_merged": None}
assert is_pr_already_merged(state) == "active"
# ---------------------------------------------------------------------------
# is_review_approved
# ---------------------------------------------------------------------------
class TestIsReviewApproved:
def test_returns_approve_when_true(self) -> None:
state = {"review_approved": True}
assert is_review_approved(state) == "approve"
def test_returns_request_changes_when_false(self) -> None:
state = {"review_approved": False}
assert is_review_approved(state) == "request_changes"
def test_returns_request_changes_when_field_missing(self) -> None:
state = {}
assert is_review_approved(state) == "request_changes"
def test_returns_request_changes_when_none(self) -> None:
state = {"review_approved": None}
assert is_review_approved(state) == "request_changes"
# ---------------------------------------------------------------------------
# has_ticket
# ---------------------------------------------------------------------------
class TestHasTicket:
def test_returns_yes_when_true(self) -> None:
state = {"has_ticket": True}
assert has_ticket(state) == "yes"
def test_returns_no_when_false(self) -> None:
state = {"has_ticket": False}
assert has_ticket(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert has_ticket(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"has_ticket": None}
assert has_ticket(state) == "no"
# ---------------------------------------------------------------------------
# should_continue_to_release
# ---------------------------------------------------------------------------
class TestShouldContinueToRelease:
def test_returns_yes_when_true(self) -> None:
state = {"continue_to_release": True}
assert should_continue_to_release(state) == "yes"
def test_returns_no_when_false(self) -> None:
state = {"continue_to_release": False}
assert should_continue_to_release(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert should_continue_to_release(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"continue_to_release": None}
assert should_continue_to_release(state) == "no"
# ---------------------------------------------------------------------------
# has_pipelines
# ---------------------------------------------------------------------------
class TestHasPipelines:
def test_returns_yes_when_non_empty_list(self) -> None:
state = {"pipelines": [{"id": 1}]}
assert has_pipelines(state) == "yes"
def test_returns_no_when_empty_list(self) -> None:
state = {"pipelines": []}
assert has_pipelines(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert has_pipelines(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"pipelines": None}
assert has_pipelines(state) == "no"
def test_returns_yes_with_multiple_pipelines(self) -> None:
state = {"pipelines": [{"id": 1}, {"id": 2}]}
assert has_pipelines(state) == "yes"
# ---------------------------------------------------------------------------
# has_pending_approvals
# ---------------------------------------------------------------------------
class TestHasPendingApprovals:
def test_returns_yes_when_non_empty_list(self) -> None:
state = {"pending_approvals": [{"approval_id": "abc"}]}
assert has_pending_approvals(state) == "yes"
def test_returns_no_when_empty_list(self) -> None:
state = {"pending_approvals": []}
assert has_pending_approvals(state) == "no"
def test_returns_no_when_field_missing(self) -> None:
state = {}
assert has_pending_approvals(state) == "no"
def test_returns_no_when_none(self) -> None:
state = {"pending_approvals": None}
assert has_pending_approvals(state) == "no"
def test_returns_yes_with_multiple_approvals(self) -> None:
state = {"pending_approvals": [{"approval_id": "a"}, {"approval_id": "b"}]}
assert has_pending_approvals(state) == "yes"
# ---------------------------------------------------------------------------
# route_ci_result
# ---------------------------------------------------------------------------
class TestRouteCiResult:
"""Tests for route_ci_result routing function."""
def test_returns_ci_passed_when_succeeded(self) -> None:
state = {"ci_build_result": "succeeded"}
assert route_ci_result(state) == "ci_passed"
def test_returns_ci_failed_when_failed(self) -> None:
state = {"ci_build_result": "failed"}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_canceled(self) -> None:
state = {"ci_build_result": "canceled"}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_partially_succeeded(self) -> None:
state = {"ci_build_result": "partiallySucceeded"}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_field_missing(self) -> None:
state = {}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_none(self) -> None:
state = {"ci_build_result": None}
assert route_ci_result(state) == "ci_failed"
def test_returns_ci_failed_when_empty_string(self) -> None:
state = {"ci_build_result": ""}
assert route_ci_result(state) == "ci_failed"
def test_case_sensitive_succeeded(self) -> None:
# AzDo returns "succeeded" (lowercase)
state = {"ci_build_result": "succeeded"}
assert route_ci_result(state) == "ci_passed"
# ---------------------------------------------------------------------------
# route_approval_stage
# ---------------------------------------------------------------------------
class TestRouteApprovalStage:
"""Tests for route_approval_stage routing function."""
def test_returns_all_deployed_when_no_pending_approvals(self) -> None:
state = {"pending_approvals": []}
assert route_approval_stage(state) == "all_deployed"
def test_returns_all_deployed_when_field_missing(self) -> None:
state = {}
assert route_approval_stage(state) == "all_deployed"
def test_returns_all_deployed_when_none(self) -> None:
state = {"pending_approvals": None}
assert route_approval_stage(state) == "all_deployed"
def test_returns_sandbox_pending_when_sandbox_approval_exists(self) -> None:
state = {
"current_stage": "sandbox_pending",
"pending_approvals": [{"approval_id": "x", "stage_name": "Sandbox"}],
}
assert route_approval_stage(state) == "sandbox_pending"
def test_returns_prod_pending_when_prod_approval_exists(self) -> None:
state = {
"current_stage": "prod_pending",
"pending_approvals": [{"approval_id": "y", "stage_name": "Production"}],
}
assert route_approval_stage(state) == "prod_pending"
def test_uses_current_stage_field_when_present(self) -> None:
state = {
"current_stage": "sandbox_pending",
"pending_approvals": [{"approval_id": "z"}],
}
assert route_approval_stage(state) == "sandbox_pending"
def test_returns_all_deployed_when_no_current_stage_and_has_approvals(self) -> None:
# When current_stage is missing but approvals exist, stage is unknown
# so we treat as sandbox by default (first stage)
state = {
"pending_approvals": [{"approval_id": "a"}],
}
# Must return either sandbox_pending or prod_pending (not all_deployed)
result = route_approval_stage(state)
assert result in ("sandbox_pending", "prod_pending")
def test_sandbox_pending_from_current_stage(self) -> None:
state = {"current_stage": "sandbox_pending", "pending_approvals": [{"approval_id": "x"}]}
assert route_approval_stage(state) == "sandbox_pending"
def test_prod_pending_from_current_stage(self) -> None:
state = {"current_stage": "prod_pending", "pending_approvals": [{"approval_id": "x"}]}
assert route_approval_stage(state) == "prod_pending"
# ---------------------------------------------------------------------------
# route_after_fetch
# ---------------------------------------------------------------------------
class TestRouteAfterFetch:
"""Tests for route_after_fetch — 3-way routing replacing is_pr_already_merged."""
def test_returns_merged_when_pr_already_merged(self) -> None:
state = {"pr_already_merged": True}
assert route_after_fetch(state) == "merged"
def test_returns_active_with_ticket_when_active_and_has_ticket(self) -> None:
state = {"pr_already_merged": False, "has_ticket": True}
assert route_after_fetch(state) == "active_with_ticket"
def test_returns_active_no_ticket_when_active_and_no_ticket(self) -> None:
state = {"pr_already_merged": False, "has_ticket": False}
assert route_after_fetch(state) == "active_no_ticket"
def test_returns_active_no_ticket_when_has_ticket_missing(self) -> None:
state = {"pr_already_merged": False}
assert route_after_fetch(state) == "active_no_ticket"
def test_returns_active_no_ticket_when_has_ticket_none(self) -> None:
state = {"pr_already_merged": False, "has_ticket": None}
assert route_after_fetch(state) == "active_no_ticket"
def test_returns_active_no_ticket_when_all_fields_missing(self) -> None:
state = {}
assert route_after_fetch(state) == "active_no_ticket"
def test_merged_takes_precedence_over_has_ticket(self) -> None:
# Even if has_ticket is True, merged PR should route to "merged"
state = {"pr_already_merged": True, "has_ticket": True}
assert route_after_fetch(state) == "merged"
def test_returns_active_with_ticket_ignores_merged_false(self) -> None:
state = {"pr_already_merged": False, "has_ticket": True}
result = route_after_fetch(state)
assert result != "merged"
assert result == "active_with_ticket"

View File

View File

@@ -0,0 +1,332 @@
"""Tests for scripts/migrate_json_to_db.py.
Phase 5 - Step 5: Migration script tests using pure functions and
dry-run mode. No real database required.
Written FIRST (TDD RED phase).
"""
import json
from datetime import date
from pathlib import Path
import pytest
from scripts.migrate_json_to_db import (
collect_json_files,
parse_staging_json,
parse_archived_json,
build_staging_insert_sql,
build_archived_insert_sql,
is_archived_filename,
is_staging_filename,
MigrationRecord,
)
# ---------------------------------------------------------------------------
# Fixture data (mirrors real JSON structure from release-workflow/releases/)
# ---------------------------------------------------------------------------
STAGING_JSON = {
"version": "v1.0.0",
"repo": "Billo.Platform.Document",
"started_at": "2026-03-17",
"tickets": [
{
"id": "ALLPOST-4219",
"summary": "Test release bot",
"pr_id": "10460",
"pr_url": "https://dev.azure.com/billodev/Billo%20App%20Platform/_git/Billo.Platform.Document/pullrequest/10460",
"pr_title": "chore: trigger release bot test",
"branch": "feature_ALLPOST-4219_test_release_bot",
"merged_at": "2026-03-17",
}
],
}
PAYMENT_STAGING_JSON = {
"version": "v1.0.1",
"repo": "Billo.Platform.Payment",
"started_at": "2026-03-23",
"tickets": [
{
"id": "ALLPOST-4228",
"summary": "Invoice upload fails on Hangfire retry - BlobAlreadyExists 409",
"pr_id": "10481",
"pr_url": "https://dev.azure.com/billodev/Billo%20App%20Platform/_git/Billo.Platform.Payment/pullrequest/10481",
"pr_title": "Invoice upload fails on Hangfire retry - BlobAlreadyExists 409",
"branch": "bug/ALLPOST-4228_fix-invoice-upload-blob-already-exists",
"merged_at": "2026-03-23",
}
],
}
# Archived JSON has an additional released_at field
ARCHIVED_JSON = {
"version": "v1.0.0",
"repo": "Billo.Platform.Payment",
"started_at": "2026-01-01",
"tickets": [],
"released_at": "2026-01-15",
}
# ---------------------------------------------------------------------------
# is_staging_filename / is_archived_filename
# ---------------------------------------------------------------------------
class TestFileNameClassification:
def test_staging_filename_identified(self) -> None:
assert is_staging_filename("Billo.Platform.Document.json") is True
def test_archived_filename_identified(self) -> None:
assert is_archived_filename("Billo.Platform.Payment_v1.0.1_2026-03-23.json") is True
def test_staging_filename_not_archived(self) -> None:
assert is_archived_filename("Billo.Platform.Document.json") is False
def test_archived_filename_not_staging(self) -> None:
assert is_staging_filename("Billo.Platform.Payment_v1.0.1_2026-03-23.json") is False
def test_non_json_file_is_not_staging(self) -> None:
assert is_staging_filename("README.md") is False
def test_non_json_file_is_not_archived(self) -> None:
assert is_archived_filename("README.md") is False
# ---------------------------------------------------------------------------
# parse_staging_json
# ---------------------------------------------------------------------------
class TestParseStagingJson:
def test_parse_returns_migration_record(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert isinstance(record, MigrationRecord)
def test_parse_extracts_repo(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert record.repo == "Billo.Platform.Document"
def test_parse_extracts_version(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert record.version == "v1.0.0"
def test_parse_extracts_started_at(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert record.started_at == date(2026, 3, 17)
def test_parse_extracts_tickets(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert len(record.tickets) == 1
assert record.tickets[0]["id"] == "ALLPOST-4219"
def test_parse_staging_has_no_released_at(self) -> None:
record = parse_staging_json(STAGING_JSON)
assert record.released_at is None
def test_parse_staging_with_multiple_tickets(self) -> None:
data = {
**STAGING_JSON,
"tickets": [
{**STAGING_JSON["tickets"][0], "id": "ALLPOST-1"},
{**STAGING_JSON["tickets"][0], "id": "ALLPOST-2"},
],
}
record = parse_staging_json(data)
assert len(record.tickets) == 2
def test_parse_staging_with_empty_tickets(self) -> None:
data = {**STAGING_JSON, "tickets": []}
record = parse_staging_json(data)
assert record.tickets == []
# ---------------------------------------------------------------------------
# parse_archived_json
# ---------------------------------------------------------------------------
class TestParseArchivedJson:
def test_parse_returns_migration_record(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
assert isinstance(record, MigrationRecord)
def test_parse_extracts_released_at(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
assert record.released_at == date(2026, 1, 15)
def test_parse_extracts_repo(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
assert record.repo == "Billo.Platform.Payment"
def test_parse_extracts_version(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
assert record.version == "v1.0.0"
# ---------------------------------------------------------------------------
# build_staging_insert_sql
# ---------------------------------------------------------------------------
class TestBuildStagingInsertSql:
def test_returns_tuple_of_sql_and_params(self) -> None:
record = parse_staging_json(STAGING_JSON)
sql, params = build_staging_insert_sql(record)
assert isinstance(sql, str)
assert isinstance(params, tuple)
def test_sql_inserts_into_staging_releases(self) -> None:
record = parse_staging_json(STAGING_JSON)
sql, _ = build_staging_insert_sql(record)
assert "staging_releases" in sql
def test_sql_is_insert_statement(self) -> None:
record = parse_staging_json(STAGING_JSON)
sql, _ = build_staging_insert_sql(record)
assert "INSERT" in sql.upper()
def test_params_include_repo(self) -> None:
record = parse_staging_json(STAGING_JSON)
_, params = build_staging_insert_sql(record)
assert "Billo.Platform.Document" in params
def test_params_include_version(self) -> None:
record = parse_staging_json(STAGING_JSON)
_, params = build_staging_insert_sql(record)
assert "v1.0.0" in params
def test_params_include_started_at(self) -> None:
record = parse_staging_json(STAGING_JSON)
_, params = build_staging_insert_sql(record)
assert "2026-03-17" in params
def test_params_include_tickets_json(self) -> None:
record = parse_staging_json(STAGING_JSON)
_, params = build_staging_insert_sql(record)
# tickets should be serialized as JSON string
tickets_json = next(p for p in params if isinstance(p, str) and "ALLPOST-4219" in p)
parsed = json.loads(tickets_json)
assert parsed[0]["id"] == "ALLPOST-4219"
def test_sql_uses_on_conflict_do_nothing_or_update(self) -> None:
record = parse_staging_json(STAGING_JSON)
sql, _ = build_staging_insert_sql(record)
assert "ON CONFLICT" in sql.upper() or "INSERT" in sql.upper()
# ---------------------------------------------------------------------------
# build_archived_insert_sql
# ---------------------------------------------------------------------------
class TestBuildArchivedInsertSql:
def test_returns_tuple_of_sql_and_params(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
sql, params = build_archived_insert_sql(record)
assert isinstance(sql, str)
assert isinstance(params, tuple)
def test_sql_inserts_into_archived_releases(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
sql, _ = build_archived_insert_sql(record)
assert "archived_releases" in sql
def test_sql_is_insert_statement(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
sql, _ = build_archived_insert_sql(record)
assert "INSERT" in sql.upper()
def test_params_include_released_at(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
_, params = build_archived_insert_sql(record)
assert "2026-01-15" in params
def test_params_include_repo(self) -> None:
record = parse_archived_json(ARCHIVED_JSON)
_, params = build_archived_insert_sql(record)
assert "Billo.Platform.Payment" in params
# ---------------------------------------------------------------------------
# collect_json_files
# ---------------------------------------------------------------------------
class TestCollectJsonFiles:
def test_returns_empty_list_for_empty_directory(self, tmp_path: Path) -> None:
result = collect_json_files(tmp_path)
assert result == []
def test_finds_staging_json_files(self, tmp_path: Path) -> None:
(tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert len(result) == 1
def test_finds_archived_json_files(self, tmp_path: Path) -> None:
(tmp_path / "my-repo_v1.0.0_2025-06-01.json").write_text(
json.dumps(ARCHIVED_JSON)
)
result = collect_json_files(tmp_path)
assert len(result) == 1
def test_ignores_non_json_files(self, tmp_path: Path) -> None:
(tmp_path / "README.md").write_text("readme")
(tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert len(result) == 1
def test_collects_from_nested_directories(self, tmp_path: Path) -> None:
repo_dir = tmp_path / "Billo.Platform.Document"
repo_dir.mkdir()
(repo_dir / "v1.0.0.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert len(result) == 1
def test_returns_path_objects(self, tmp_path: Path) -> None:
(tmp_path / "my-repo.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert all(isinstance(p, Path) for p in result)
def test_collects_multiple_files(self, tmp_path: Path) -> None:
for i in range(3):
(tmp_path / f"repo-{i}.json").write_text(json.dumps(STAGING_JSON))
result = collect_json_files(tmp_path)
assert len(result) == 3
# ---------------------------------------------------------------------------
# Dry-run mode (integration of pure functions)
# ---------------------------------------------------------------------------
class TestDryRunMode:
def test_dry_run_collects_records_without_db_access(self, tmp_path: Path) -> None:
"""Dry run processes files and returns SQL/params without executing."""
repo_dir = tmp_path / "Billo.Platform.Document"
repo_dir.mkdir()
(repo_dir / "v1.0.0.json").write_text(json.dumps(STAGING_JSON))
files = collect_json_files(tmp_path)
assert len(files) == 1
# Parse and build SQL — no DB connection needed
record = parse_staging_json(json.loads(files[0].read_text()))
sql, params = build_staging_insert_sql(record)
assert "INSERT" in sql.upper()
assert "Billo.Platform.Document" in params
def test_payment_staging_file_parses_correctly(self, tmp_path: Path) -> None:
(tmp_path / "Billo.Platform.Payment.json").write_text(
json.dumps(PAYMENT_STAGING_JSON)
)
files = collect_json_files(tmp_path)
record = parse_staging_json(json.loads(files[0].read_text()))
assert record.repo == "Billo.Platform.Payment"
assert record.version == "v1.0.1"
assert len(record.tickets) == 1
assert record.tickets[0]["id"] == "ALLPOST-4228"
def test_archived_file_parses_correctly(self, tmp_path: Path) -> None:
(tmp_path / "my-repo_v1.0.0_2026-01-15.json").write_text(
json.dumps(ARCHIVED_JSON)
)
files = collect_json_files(tmp_path)
record = parse_archived_json(json.loads(files[0].read_text()))
assert record.released_at == date(2026, 1, 15)

View File

View File

@@ -0,0 +1,141 @@
"""Tests for services/pr_dedup.py. Written FIRST (TDD RED phase).
find_unprocessed_prs queries agent_threads to find which PRs have not yet
been processed (no existing thread for that repo+pr_id combination).
"""
import pytest
from release_agent.models.pr import PRInfo
from release_agent.services.pr_dedup import find_unprocessed_prs
# ---------------------------------------------------------------------------
# Helpers — fake async pool
# ---------------------------------------------------------------------------
def _make_pr(pr_id: str, repo_name: str = "my-repo") -> PRInfo:
return PRInfo(
pr_id=pr_id,
pr_url=f"https://dev.azure.com/org/proj/_git/{repo_name}/pullrequest/{pr_id}",
repo_name=repo_name,
branch="refs/heads/feature/ALLPOST-100-fix",
pr_title=f"PR {pr_id}",
pr_status="active",
)
def _make_pool(existing_rows: list[tuple[str, str]]):
"""Return a fake async connection pool.
existing_rows: list of (pr_id, repo_name) tuples representing already-processed PRs.
"""
class FakeCursor:
def __init__(self, rows):
self._rows = rows
async def execute(self, sql, params=None):
pass
async def fetchall(self):
return self._rows
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakeConn:
def __init__(self, rows):
self._rows = rows
def cursor(self):
return FakeCursor(self._rows)
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class FakePool:
def __init__(self, rows):
self._rows = rows
def connection(self):
return FakeConn(self._rows)
return FakePool(existing_rows)
# ---------------------------------------------------------------------------
# find_unprocessed_prs tests
# ---------------------------------------------------------------------------
class TestFindUnprocessedPrs:
async def test_returns_all_when_none_processed(self) -> None:
prs = [_make_pr("10"), _make_pr("20")]
pool = _make_pool([])
result = await find_unprocessed_prs(pool, prs)
assert len(result) == 2
async def test_returns_empty_when_all_processed(self) -> None:
prs = [_make_pr("10"), _make_pr("20")]
# existing rows: (pr_id, repo_name)
pool = _make_pool([("10", "my-repo"), ("20", "my-repo")])
result = await find_unprocessed_prs(pool, prs)
assert result == []
async def test_returns_only_unprocessed(self) -> None:
prs = [_make_pr("10"), _make_pr("20"), _make_pr("30")]
pool = _make_pool([("10", "my-repo")])
result = await find_unprocessed_prs(pool, prs)
pr_ids = [p.pr_id for p in result]
assert "10" not in pr_ids
assert "20" in pr_ids
assert "30" in pr_ids
async def test_empty_input_returns_empty(self) -> None:
pool = _make_pool([])
result = await find_unprocessed_prs(pool, [])
assert result == []
async def test_different_repos_not_confused(self) -> None:
pr_repo_a = _make_pr("10", repo_name="repo-a")
pr_repo_b = _make_pr("10", repo_name="repo-b")
# Only repo-a/10 is processed
pool = _make_pool([("10", "repo-a")])
result = await find_unprocessed_prs(pool, [pr_repo_a, pr_repo_b])
# repo-b/10 should still be returned (different repo)
assert len(result) == 1
assert result[0].repo_name == "repo-b"
async def test_returns_list_of_pr_info(self) -> None:
prs = [_make_pr("42")]
pool = _make_pool([])
result = await find_unprocessed_prs(pool, prs)
assert all(isinstance(p, PRInfo) for p in result)
async def test_preserves_pr_info_objects(self) -> None:
pr = _make_pr("77")
pool = _make_pool([])
result = await find_unprocessed_prs(pool, [pr])
assert result[0].pr_id == "77"
assert result[0].repo_name == "my-repo"

View File

@@ -0,0 +1,309 @@
"""Tests for services/pr_poller.py. Written FIRST (TDD RED phase).
Tests verify:
- _synthesize_webhook_payload produces a valid payload dict
- run_pr_poll_loop calls list_active_prs, dedup, then schedules graph for each unprocessed PR
- Fake sleep is injected to avoid real waits
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from release_agent.models.pr import PRInfo
from release_agent.services.pr_poller import _synthesize_webhook_payload, run_pr_poll_loop
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_pr(
pr_id: str = "10",
repo_name: str = "my-repo",
branch: str = "refs/heads/feature/ALLPOST-100-fix",
title: str = "Test PR",
status: str = "active",
) -> PRInfo:
return PRInfo(
pr_id=pr_id,
pr_url=f"https://dev.azure.com/org/proj/_git/{repo_name}/pullrequest/{pr_id}",
repo_name=repo_name,
branch=branch,
pr_title=title,
pr_status=status,
)
# ---------------------------------------------------------------------------
# _synthesize_webhook_payload tests
# ---------------------------------------------------------------------------
class TestSynthesizeWebhookPayload:
def test_returns_dict(self) -> None:
pr = _make_pr()
result = _synthesize_webhook_payload(pr)
assert isinstance(result, dict)
def test_has_resource_key(self) -> None:
pr = _make_pr()
result = _synthesize_webhook_payload(pr)
assert "resource" in result
def test_resource_contains_pull_request_id(self) -> None:
pr = _make_pr(pr_id="42")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["pull_request_id"] == 42
def test_resource_contains_repository_name(self) -> None:
pr = _make_pr(repo_name="backend-api")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["repository"]["name"] == "backend-api"
def test_resource_contains_title(self) -> None:
pr = _make_pr(title="My PR Title")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["title"] == "My PR Title"
def test_resource_contains_source_ref_name(self) -> None:
pr = _make_pr(branch="refs/heads/feature/ALLPOST-200-test")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["source_ref_name"] == "refs/heads/feature/ALLPOST-200-test"
def test_resource_status_is_active(self) -> None:
pr = _make_pr(status="active")
result = _synthesize_webhook_payload(pr)
assert result["resource"]["status"] == "active"
def test_event_type_is_pr_updated(self) -> None:
pr = _make_pr()
result = _synthesize_webhook_payload(pr)
assert "event_type" in result
def test_subscription_id_present(self) -> None:
pr = _make_pr()
result = _synthesize_webhook_payload(pr)
assert "subscription_id" in result
def test_different_prs_produce_different_payloads(self) -> None:
pr1 = _make_pr(pr_id="1")
pr2 = _make_pr(pr_id="2")
r1 = _synthesize_webhook_payload(pr1)
r2 = _synthesize_webhook_payload(pr2)
assert r1["resource"]["pull_request_id"] != r2["resource"]["pull_request_id"]
# ---------------------------------------------------------------------------
# run_pr_poll_loop tests
# ---------------------------------------------------------------------------
class TestRunPrPollLoop:
async def test_calls_list_active_prs_for_each_repo(self) -> None:
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[])
sleep_calls: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["repo-a", "repo-b"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
assert azdo.list_active_prs.call_count == 2
async def test_calls_find_unprocessed_prs(self) -> None:
pr = _make_pr(pr_id="10")
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[pr])
find_mock = AsyncMock(return_value=[])
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=find_mock):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["my-repo"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
find_mock.assert_called_once()
async def test_schedules_graph_for_each_unprocessed_pr(self) -> None:
pr1 = _make_pr(pr_id="10")
pr2 = _make_pr(pr_id="20")
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[pr1, pr2])
schedule_mock = MagicMock()
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
with patch(
"release_agent.services.pr_poller.find_unprocessed_prs",
new=AsyncMock(return_value=[pr1, pr2]),
):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["my-repo"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=schedule_mock,
sleep_fn=fake_sleep,
)
assert schedule_mock.call_count == 2
async def test_does_not_schedule_already_processed_prs(self) -> None:
pr = _make_pr(pr_id="10")
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[pr])
schedule_mock = MagicMock()
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
# All PRs already processed
with patch(
"release_agent.services.pr_poller.find_unprocessed_prs",
new=AsyncMock(return_value=[]),
):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["my-repo"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=schedule_mock,
sleep_fn=fake_sleep,
)
schedule_mock.assert_not_called()
async def test_sleeps_for_configured_interval(self) -> None:
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[])
sleep_calls: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["my-repo"],
target_branch="refs/heads/develop",
interval_seconds=123,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
assert sleep_calls[0] == 123
async def test_handles_empty_watched_repos(self) -> None:
azdo = AsyncMock()
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=[],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
azdo.list_active_prs.assert_not_called()
async def test_schedule_fn_receives_synthesized_payload(self) -> None:
pr = _make_pr(pr_id="55", repo_name="test-repo")
azdo = AsyncMock()
azdo.list_active_prs = AsyncMock(return_value=[pr])
schedule_calls: list[dict] = []
def schedule_mock(**kwargs) -> None:
schedule_calls.append(kwargs)
async def fake_sleep(seconds: float) -> None:
raise asyncio.CancelledError
with patch(
"release_agent.services.pr_poller.find_unprocessed_prs",
new=AsyncMock(return_value=[pr]),
):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["test-repo"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=schedule_mock,
sleep_fn=fake_sleep,
)
assert len(schedule_calls) == 1
initial_state = schedule_calls[0]["initial_state"]
assert initial_state["webhook_payload"]["resource"]["pull_request_id"] == 55
assert initial_state["pr_id"] == "55"
assert initial_state["repo_name"] == "test-repo"
async def test_continues_after_list_active_prs_error(self) -> None:
azdo = AsyncMock()
# First repo raises, second succeeds
azdo.list_active_prs = AsyncMock(side_effect=[Exception("API error"), []])
sleep_calls: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleep_calls.append(seconds)
raise asyncio.CancelledError
with patch("release_agent.services.pr_poller.find_unprocessed_prs", new=AsyncMock(return_value=[])):
with pytest.raises(asyncio.CancelledError):
await run_pr_poll_loop(
azdo_client=azdo,
db_pool=MagicMock(),
watched_repos=["repo-a", "repo-b"],
target_branch="refs/heads/develop",
interval_seconds=30,
schedule_fn=MagicMock(),
sleep_fn=fake_sleep,
)
# Should still sleep (loop iteration completed despite error)
assert len(sleep_calls) == 1

122
tests/test_branch_parser.py Normal file
View File

@@ -0,0 +1,122 @@
"""Tests for branch_parser module. Written FIRST (TDD RED phase)."""
from release_agent.branch_parser import parse_branch, strip_refs_prefix
class TestStripRefsPrefix:
"""Tests for strip_refs_prefix function."""
def test_strips_refs_heads_prefix(self) -> None:
assert strip_refs_prefix("refs/heads/fix/BILL-42_something") == "fix/BILL-42_something"
def test_strips_refs_heads_prefix_feature(self) -> None:
assert strip_refs_prefix("refs/heads/feature/ALLPOST-100_add-feature") == "feature/ALLPOST-100_add-feature"
def test_no_refs_prefix_unchanged(self) -> None:
assert strip_refs_prefix("bug/ALLPOST-4229_fix-review") == "bug/ALLPOST-4229_fix-review"
def test_main_unchanged(self) -> None:
assert strip_refs_prefix("main") == "main"
def test_develop_unchanged(self) -> None:
assert strip_refs_prefix("develop") == "develop"
def test_empty_string(self) -> None:
assert strip_refs_prefix("") == ""
def test_only_refs_heads(self) -> None:
assert strip_refs_prefix("refs/heads/") == ""
def test_refs_tags_not_stripped(self) -> None:
assert strip_refs_prefix("refs/tags/v1.0.0") == "refs/tags/v1.0.0"
class TestParseBranch:
"""Tests for parse_branch function."""
def test_bug_branch_with_ticket(self) -> None:
ticket_id, has_ticket = parse_branch("bug/ALLPOST-4229_fix-review")
assert ticket_id == "ALLPOST-4229"
assert has_ticket is True
def test_feature_branch_with_ticket(self) -> None:
ticket_id, has_ticket = parse_branch("feature/ALLPOST-100_add-feature")
assert ticket_id == "ALLPOST-100"
assert has_ticket is True
def test_refs_heads_fix_branch(self) -> None:
ticket_id, has_ticket = parse_branch("refs/heads/fix/BILL-42_something")
assert ticket_id == "BILL-42"
assert has_ticket is True
def test_feat_branch_short(self) -> None:
ticket_id, has_ticket = parse_branch("feat/MY-1_x")
assert ticket_id == "MY-1"
assert has_ticket is True
def test_chore_without_ticket(self) -> None:
ticket_id, has_ticket = parse_branch("chore/update-dependencies")
assert ticket_id is None
assert has_ticket is False
def test_main_branch(self) -> None:
ticket_id, has_ticket = parse_branch("main")
assert ticket_id is None
assert has_ticket is False
def test_develop_branch(self) -> None:
ticket_id, has_ticket = parse_branch("develop")
assert ticket_id is None
assert has_ticket is False
def test_release_branch(self) -> None:
ticket_id, has_ticket = parse_branch("release/v1.0.3")
assert ticket_id is None
assert has_ticket is False
def test_returns_tuple(self) -> None:
result = parse_branch("main")
assert isinstance(result, tuple)
assert len(result) == 2
def test_ticket_id_type_when_present(self) -> None:
ticket_id, has_ticket = parse_branch("bug/ALLPOST-4229_fix-review")
assert isinstance(ticket_id, str)
assert isinstance(has_ticket, bool)
def test_ticket_id_type_when_absent(self) -> None:
ticket_id, has_ticket = parse_branch("main")
assert ticket_id is None
assert isinstance(has_ticket, bool)
def test_fix_prefix(self) -> None:
ticket_id, has_ticket = parse_branch("fix/PROJ-999_some-fix")
assert ticket_id == "PROJ-999"
assert has_ticket is True
def test_refs_heads_feature_branch(self) -> None:
ticket_id, has_ticket = parse_branch("refs/heads/feature/ALLPOST-100_add-feature")
assert ticket_id == "ALLPOST-100"
assert has_ticket is True
def test_ticket_with_multiple_digits(self) -> None:
ticket_id, has_ticket = parse_branch("feature/ABC-12345_some-long-feature")
assert ticket_id == "ABC-12345"
assert has_ticket is True
def test_branch_without_underscore_separator(self) -> None:
# Branch has ticket pattern but no underscore - still detects ticket
ticket_id, has_ticket = parse_branch("feature/PROJ-100")
assert ticket_id == "PROJ-100"
assert has_ticket is True
def test_empty_string(self) -> None:
ticket_id, has_ticket = parse_branch("")
assert ticket_id is None
assert has_ticket is False
def test_ticket_with_numeric_project_prefix(self) -> None:
ticket_id, has_ticket = parse_branch("feature/AB2-100_feature")
assert ticket_id == "AB2-100"
assert has_ticket is True

450
tests/test_config.py Normal file
View File

@@ -0,0 +1,450 @@
"""Tests for config module. Written FIRST (TDD RED phase)."""
import os
from unittest.mock import patch
import pytest
from pydantic import SecretStr, ValidationError
from release_agent.config import Settings
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _base_env() -> dict[str, str]:
"""Return minimal valid environment variables."""
return {
"AZDO_ORGANIZATION": "my-org",
"AZDO_PROJECT": "my-project",
"AZDO_PAT": "super-secret-pat",
"ANTHROPIC_API_KEY": "sk-ant-key",
"POSTGRES_DSN": "postgresql://user:pass@localhost:5432/db",
"JIRA_EMAIL": "user@example.com",
"JIRA_API_TOKEN": "jira-token-abc",
"SLACK_WEBHOOK_URL": "https://hooks.slack.com/services/T000/B000/xxxx",
"WEBHOOK_SECRET": "test-webhook-secret",
}
# ---------------------------------------------------------------------------
# Settings tests
# ---------------------------------------------------------------------------
class TestSettings:
"""Tests for Settings config class."""
def test_loads_from_env_vars(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.azdo_organization == "my-org"
assert settings.azdo_project == "my-project"
def test_pat_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.azdo_pat, SecretStr)
def test_anthropic_key_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.anthropic_api_key, SecretStr)
def test_secret_str_not_leaked_in_repr(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
repr_str = repr(settings)
assert "super-secret-pat" not in repr_str
assert "sk-ant-key" not in repr_str
def test_secret_str_not_leaked_in_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
str_repr = str(settings)
assert "super-secret-pat" not in str_repr
def test_missing_required_azdo_org_raises(self) -> None:
env = _base_env()
del env["AZDO_ORGANIZATION"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_missing_required_azdo_project_raises(self) -> None:
env = _base_env()
del env["AZDO_PROJECT"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_missing_required_pat_raises(self) -> None:
env = _base_env()
del env["AZDO_PAT"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_missing_anthropic_key_is_optional(self) -> None:
env = _base_env()
del env["ANTHROPIC_API_KEY"]
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.anthropic_api_key.get_secret_value() == ""
def test_missing_postgres_dsn_raises(self) -> None:
env = _base_env()
del env["POSTGRES_DSN"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_azdo_base_url_computed(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
expected = "https://dev.azure.com/my-org"
assert settings.azdo_base_url == expected
def test_azdo_api_url_computed(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
expected = "https://dev.azure.com/my-org/my-project/_apis"
assert settings.azdo_api_url == expected
def test_default_port(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.port == 8000
def test_custom_port_from_env(self) -> None:
env = {**_base_env(), "PORT": "9000"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.port == 9000
def test_port_below_minimum_raises(self) -> None:
env = {**_base_env(), "PORT": "0"}
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_port_above_maximum_raises(self) -> None:
env = {**_base_env(), "PORT": "65536"}
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_port_minimum_valid(self) -> None:
env = {**_base_env(), "PORT": "1"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.port == 1
def test_port_maximum_valid(self) -> None:
env = {**_base_env(), "PORT": "65535"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.port == 65535
def test_get_pat_value(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.azdo_pat.get_secret_value() == "super-secret-pat"
def test_get_anthropic_key_value(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.anthropic_api_key.get_secret_value() == "sk-ant-key"
def test_postgres_dsn_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.postgres_dsn, SecretStr)
assert "localhost" in settings.postgres_dsn.get_secret_value()
def test_postgres_dsn_not_leaked_in_repr(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert "user:pass" not in repr(settings)
class TestSettingsPhase2:
"""Tests for Phase 2 settings fields."""
def test_jira_base_url_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.jira_base_url == "https://billolife.atlassian.net"
def test_jira_base_url_custom(self) -> None:
env = {**_base_env(), "JIRA_BASE_URL": "https://custom.atlassian.net"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.jira_base_url == "https://custom.atlassian.net"
def test_jira_email_required(self) -> None:
env = _base_env()
del env["JIRA_EMAIL"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_jira_email_stored(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.jira_email == "user@example.com"
def test_jira_api_token_required(self) -> None:
env = _base_env()
del env["JIRA_API_TOKEN"]
with patch.dict(os.environ, env, clear=True), pytest.raises(ValidationError):
Settings()
def test_jira_api_token_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.jira_api_token, SecretStr)
def test_jira_api_token_not_leaked_in_repr(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert "jira-token-abc" not in repr(settings)
def test_slack_webhook_url_optional_defaults_empty(self) -> None:
env = _base_env()
del env["SLACK_WEBHOOK_URL"]
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_webhook_url.get_secret_value() == ""
def test_slack_webhook_url_is_secret_str(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert isinstance(settings.slack_webhook_url, SecretStr)
def test_slack_webhook_url_not_leaked_in_repr(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert "xxxx" not in repr(settings)
def test_claude_review_model_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.claude_review_model == "claude-sonnet-4-20250514"
def test_claude_review_model_custom(self) -> None:
env = {**_base_env(), "CLAUDE_REVIEW_MODEL": "claude-opus-4-20250514"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.claude_review_model == "claude-opus-4-20250514"
def test_azdo_vsrm_api_url_computed(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
expected = "https://vsrm.dev.azure.com/my-org/my-project/_apis"
assert settings.azdo_vsrm_api_url == expected
class TestSettingsPhase4:
"""Tests for Phase 4 settings fields (webhook secret)."""
def test_webhook_secret_optional_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
# When not provided, defaults to empty string or None
assert settings.webhook_secret is not None or settings.webhook_secret == ""
def test_webhook_secret_custom_value(self) -> None:
env = {**_base_env(), "WEBHOOK_SECRET": "my-super-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.webhook_secret.get_secret_value() == "my-super-secret"
def test_webhook_secret_is_secret_str(self) -> None:
env = {**_base_env(), "WEBHOOK_SECRET": "secret-value"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert isinstance(settings.webhook_secret, SecretStr)
def test_webhook_secret_not_leaked_in_repr(self) -> None:
env = {**_base_env(), "WEBHOOK_SECRET": "super-private-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert "super-private-secret" not in repr(settings)
class TestSettingsPhase5:
"""Tests for Phase 5 settings fields (Slack Web API + CI polling)."""
def test_slack_webhook_url_optional_when_bot_token_provided(self) -> None:
env = {k: v for k, v in _base_env().items() if k != "SLACK_WEBHOOK_URL"}
env["SLACK_BOT_TOKEN"] = "xoxb-test-token"
env["SLACK_CHANNEL_ID"] = "C12345678"
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_bot_token is not None
assert settings.slack_bot_token.get_secret_value() == "xoxb-test-token"
def test_slack_bot_token_optional_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.slack_bot_token.get_secret_value() == ""
def test_slack_bot_token_is_secret_str(self) -> None:
env = {**_base_env(), "SLACK_BOT_TOKEN": "xoxb-abc-123"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert isinstance(settings.slack_bot_token, SecretStr)
def test_slack_bot_token_not_leaked_in_repr(self) -> None:
env = {**_base_env(), "SLACK_BOT_TOKEN": "xoxb-super-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert "xoxb-super-secret" not in repr(settings)
def test_slack_signing_secret_optional_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.slack_signing_secret.get_secret_value() == ""
def test_slack_signing_secret_custom_value(self) -> None:
env = {**_base_env(), "SLACK_SIGNING_SECRET": "signing-secret-xyz"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_signing_secret.get_secret_value() == "signing-secret-xyz"
def test_slack_signing_secret_is_secret_str(self) -> None:
env = {**_base_env(), "SLACK_SIGNING_SECRET": "some-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert isinstance(settings.slack_signing_secret, SecretStr)
def test_slack_signing_secret_not_leaked_in_repr(self) -> None:
env = {**_base_env(), "SLACK_SIGNING_SECRET": "private-signing-secret"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert "private-signing-secret" not in repr(settings)
def test_slack_channel_id_optional_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.slack_channel_id == ""
def test_slack_channel_id_custom_value(self) -> None:
env = {**_base_env(), "SLACK_CHANNEL_ID": "C0987654321"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_channel_id == "C0987654321"
def test_ci_poll_interval_seconds_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.ci_poll_interval_seconds == 30
def test_ci_poll_interval_seconds_custom(self) -> None:
env = {**_base_env(), "CI_POLL_INTERVAL_SECONDS": "60"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.ci_poll_interval_seconds == 60
def test_ci_poll_max_wait_seconds_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.ci_poll_max_wait_seconds == 1800
def test_ci_poll_max_wait_seconds_custom(self) -> None:
env = {**_base_env(), "CI_POLL_MAX_WAIT_SECONDS": "3600"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.ci_poll_max_wait_seconds == 3600
def test_slack_webhook_url_still_optional(self) -> None:
env = {k: v for k, v in _base_env().items() if k != "SLACK_WEBHOOK_URL"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.slack_webhook_url.get_secret_value() == ""
class TestSettingsPrPolling:
"""Tests for PR polling config fields (Step 1)."""
def test_watched_repos_defaults_empty(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.watched_repos == ""
def test_watched_repos_custom_value(self) -> None:
env = {**_base_env(), "WATCHED_REPOS": "repo-a,repo-b"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.watched_repos == "repo-a,repo-b"
def test_watched_repos_list_empty_when_blank(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.watched_repos_list == []
def test_watched_repos_list_splits_comma_separated(self) -> None:
env = {**_base_env(), "WATCHED_REPOS": "repo-a,repo-b,repo-c"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.watched_repos_list == ["repo-a", "repo-b", "repo-c"]
def test_watched_repos_list_strips_whitespace(self) -> None:
env = {**_base_env(), "WATCHED_REPOS": " repo-a , repo-b "}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.watched_repos_list == ["repo-a", "repo-b"]
def test_watched_repos_list_ignores_empty_entries(self) -> None:
env = {**_base_env(), "WATCHED_REPOS": "repo-a,,repo-b"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.watched_repos_list == ["repo-a", "repo-b"]
def test_pr_poll_interval_seconds_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.pr_poll_interval_seconds == 300
def test_pr_poll_interval_seconds_custom(self) -> None:
env = {**_base_env(), "PR_POLL_INTERVAL_SECONDS": "60"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.pr_poll_interval_seconds == 60
def test_pr_poll_target_branch_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.pr_poll_target_branch == "refs/heads/develop"
def test_pr_poll_target_branch_custom(self) -> None:
env = {**_base_env(), "PR_POLL_TARGET_BRANCH": "refs/heads/main"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.pr_poll_target_branch == "refs/heads/main"
def test_pr_poll_enabled_defaults_false(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.pr_poll_enabled is False
def test_pr_poll_enabled_true_from_env(self) -> None:
env = {**_base_env(), "PR_POLL_ENABLED": "true"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.pr_poll_enabled is True
def test_default_jira_project_default(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.default_jira_project == "ALLPOST"
def test_default_jira_project_custom(self) -> None:
env = {**_base_env(), "DEFAULT_JIRA_PROJECT": "MYPROJ"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.default_jira_project == "MYPROJ"
def test_auto_create_ticket_enabled_defaults_true(self) -> None:
with patch.dict(os.environ, _base_env(), clear=True):
settings = Settings()
assert settings.auto_create_ticket_enabled is True
def test_auto_create_ticket_enabled_false_from_env(self) -> None:
env = {**_base_env(), "AUTO_CREATE_TICKET_ENABLED": "false"}
with patch.dict(os.environ, env, clear=True):
settings = Settings()
assert settings.auto_create_ticket_enabled is False

198
tests/test_exceptions.py Normal file
View File

@@ -0,0 +1,198 @@
"""Tests for custom exception hierarchy. Written FIRST (TDD RED phase)."""
import pytest
from release_agent.exceptions import (
AuthenticationError,
NotFoundError,
RateLimitError,
ReleaseAgentError,
ServiceError,
ServiceUnavailableError,
)
class TestReleaseAgentError:
"""Tests for the base exception class."""
def test_is_exception(self) -> None:
err = ReleaseAgentError("something went wrong")
assert isinstance(err, Exception)
def test_message_stored(self) -> None:
err = ReleaseAgentError("something went wrong")
assert str(err) == "something went wrong"
def test_can_be_raised(self) -> None:
with pytest.raises(ReleaseAgentError):
raise ReleaseAgentError("boom")
class TestServiceError:
"""Tests for ServiceError with service name and status code."""
def test_is_release_agent_error(self) -> None:
err = ServiceError(service="jira", status_code=500, detail="Internal error")
assert isinstance(err, ReleaseAgentError)
def test_stores_service(self) -> None:
err = ServiceError(service="jira", status_code=500, detail="Internal error")
assert err.service == "jira"
def test_stores_status_code(self) -> None:
err = ServiceError(service="azdo", status_code=422, detail="Unprocessable")
assert err.status_code == 422
def test_stores_detail(self) -> None:
err = ServiceError(service="slack", status_code=400, detail="Bad payload")
assert err.detail == "Bad payload"
def test_str_includes_service_and_status(self) -> None:
err = ServiceError(service="jira", status_code=500, detail="Server error")
text = str(err)
assert "jira" in text
assert "500" in text
def test_can_be_raised(self) -> None:
with pytest.raises(ServiceError):
raise ServiceError(service="azdo", status_code=400, detail="bad request")
def test_detail_none_allowed(self) -> None:
err = ServiceError(service="jira", status_code=404, detail=None)
assert err.detail is None
class TestAuthenticationError:
"""Tests for AuthenticationError (401/403)."""
def test_is_service_error(self) -> None:
err = AuthenticationError(service="azdo")
assert isinstance(err, ServiceError)
def test_default_status_code_401(self) -> None:
err = AuthenticationError(service="azdo")
assert err.status_code == 401
def test_service_stored(self) -> None:
err = AuthenticationError(service="jira")
assert err.service == "jira"
def test_custom_status_code(self) -> None:
err = AuthenticationError(service="azdo", status_code=403)
assert err.status_code == 403
def test_str_contains_service(self) -> None:
err = AuthenticationError(service="slack")
assert "slack" in str(err)
def test_can_be_raised(self) -> None:
with pytest.raises(AuthenticationError):
raise AuthenticationError(service="azdo")
def test_caught_as_service_error(self) -> None:
with pytest.raises(ServiceError):
raise AuthenticationError(service="azdo")
class TestNotFoundError:
"""Tests for NotFoundError (404)."""
def test_is_service_error(self) -> None:
err = NotFoundError(service="azdo", detail="PR not found")
assert isinstance(err, ServiceError)
def test_status_code_is_404(self) -> None:
err = NotFoundError(service="jira", detail="Issue not found")
assert err.status_code == 404
def test_detail_stored(self) -> None:
err = NotFoundError(service="azdo", detail="PR 999 not found")
assert "PR 999" in err.detail
def test_can_be_raised(self) -> None:
with pytest.raises(NotFoundError):
raise NotFoundError(service="azdo", detail="not found")
def test_caught_as_release_agent_error(self) -> None:
with pytest.raises(ReleaseAgentError):
raise NotFoundError(service="jira", detail="issue missing")
class TestRateLimitError:
"""Tests for RateLimitError (429) with retry_after."""
def test_is_service_error(self) -> None:
err = RateLimitError(service="jira", retry_after=30)
assert isinstance(err, ServiceError)
def test_status_code_is_429(self) -> None:
err = RateLimitError(service="jira", retry_after=30)
assert err.status_code == 429
def test_stores_retry_after(self) -> None:
err = RateLimitError(service="slack", retry_after=60)
assert err.retry_after == 60
def test_retry_after_none_allowed(self) -> None:
err = RateLimitError(service="azdo", retry_after=None)
assert err.retry_after is None
def test_str_contains_service(self) -> None:
err = RateLimitError(service="jira", retry_after=5)
assert "jira" in str(err)
def test_can_be_raised(self) -> None:
with pytest.raises(RateLimitError):
raise RateLimitError(service="jira", retry_after=30)
class TestServiceUnavailableError:
"""Tests for ServiceUnavailableError (503)."""
def test_is_service_error(self) -> None:
err = ServiceUnavailableError(service="azdo")
assert isinstance(err, ServiceError)
def test_status_code_is_503(self) -> None:
err = ServiceUnavailableError(service="azdo")
assert err.status_code == 503
def test_service_stored(self) -> None:
err = ServiceUnavailableError(service="slack")
assert err.service == "slack"
def test_custom_detail(self) -> None:
err = ServiceUnavailableError(service="azdo", detail="Maintenance window")
assert "Maintenance" in err.detail
def test_can_be_raised(self) -> None:
with pytest.raises(ServiceUnavailableError):
raise ServiceUnavailableError(service="azdo")
def test_caught_as_service_error(self) -> None:
with pytest.raises(ServiceError):
raise ServiceUnavailableError(service="jira")
class TestExceptionHierarchyInheritance:
"""Tests verifying the full exception hierarchy is correct."""
def test_all_are_release_agent_errors(self) -> None:
errors = [
AuthenticationError(service="svc"),
NotFoundError(service="svc", detail="x"),
RateLimitError(service="svc", retry_after=1),
ServiceUnavailableError(service="svc"),
]
for err in errors:
assert isinstance(err, ReleaseAgentError), f"{type(err)} not ReleaseAgentError"
def test_all_are_service_errors(self) -> None:
errors = [
AuthenticationError(service="svc"),
NotFoundError(service="svc", detail="x"),
RateLimitError(service="svc", retry_after=1),
ServiceUnavailableError(service="svc"),
]
for err in errors:
assert isinstance(err, ServiceError), f"{type(err)} not ServiceError"

666
tests/test_main.py Normal file
View File

@@ -0,0 +1,666 @@
"""Tests for main FastAPI application. Written FIRST (TDD RED phase).
Heavy startup (PostgreSQL, httpx clients, graph compilation) is mocked.
Tests verify: routes registered, lifespan hooks, exception handlers.
"""
import asyncio
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
# ---------------------------------------------------------------------------
# Helpers / fixtures
# ---------------------------------------------------------------------------
def _make_mock_settings():
s = MagicMock()
s.webhook_secret.get_secret_value.return_value = "test-secret"
s.postgres_dsn.get_secret_value.return_value = "postgresql://u:p@localhost/db"
s.azdo_pat.get_secret_value.return_value = "pat"
s.anthropic_api_key.get_secret_value.return_value = "key"
s.jira_api_token.get_secret_value.return_value = "jira"
s.slack_webhook_url.get_secret_value.return_value = "https://hooks.slack.com/x"
s.slack_bot_token.get_secret_value.return_value = ""
s.slack_channel_id = ""
s.slack_signing_secret.get_secret_value.return_value = ""
s.port = 8000
s.pr_poll_enabled = False
s.pr_poll_interval_seconds = 300
s.pr_poll_target_branch = "refs/heads/develop"
s.watched_repos_list = []
s.default_jira_project = "ALLPOST"
return s
def _make_patched_app():
"""Return the FastAPI app with all heavy startup mocked."""
mock_settings = _make_mock_settings()
mock_pool = AsyncMock()
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
mock_pool.__aexit__ = AsyncMock(return_value=False)
mock_graphs = {
"pr_completed": MagicMock(),
"release": MagicMock(),
}
mock_clients = MagicMock()
mock_staging_store = MagicMock()
patches = [
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]),
patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=mock_clients),
patch("release_agent.main._create_staging_store", return_value=mock_staging_store),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
]
for p in patches:
p.start()
from release_agent.main import create_app
app = create_app()
for p in patches:
p.stop()
return app, mock_settings, mock_pool, mock_graphs
# ---------------------------------------------------------------------------
# Route registration tests
# ---------------------------------------------------------------------------
class TestRouteRegistration:
def test_webhook_route_registered(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/webhooks/azdo" in routes
def test_approvals_routes_registered(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/approvals/pending" in routes
assert "/approvals/{thread_id}" in routes
def test_status_routes_registered(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/status" in routes
assert "/staging" in routes
# ---------------------------------------------------------------------------
# schedule_graph / run_graph_in_background tests
# ---------------------------------------------------------------------------
class TestScheduleGraph:
def test_schedule_graph_returns_thread_id(self) -> None:
from release_agent.main import schedule_graph
mock_app = MagicMock()
mock_app.state.background_tasks = set()
mock_graph = MagicMock()
with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()):
thread_id = schedule_graph(
app=mock_app,
graph=mock_graph,
initial_state={"repo_name": "my-repo"},
thread_id=None,
)
assert isinstance(thread_id, str)
assert len(thread_id) > 0
def test_schedule_graph_uses_provided_thread_id(self) -> None:
from release_agent.main import schedule_graph
mock_app = MagicMock()
mock_app.state.background_tasks = set()
mock_graph = MagicMock()
with patch("release_agent.main.asyncio.create_task", return_value=MagicMock()):
thread_id = schedule_graph(
app=mock_app,
graph=mock_graph,
initial_state={},
thread_id="custom-thread-id",
)
assert thread_id == "custom-thread-id"
def test_schedule_graph_adds_task_to_background_tasks(self) -> None:
from release_agent.main import schedule_graph
mock_app = MagicMock()
mock_app.state.background_tasks = set()
mock_graph = MagicMock()
mock_task = MagicMock()
with patch("release_agent.main.asyncio.create_task", return_value=mock_task):
schedule_graph(
app=mock_app,
graph=mock_graph,
initial_state={},
thread_id=None,
)
assert mock_task in mock_app.state.background_tasks
def test_run_graph_in_background_is_coroutine(self) -> None:
from release_agent.main import run_graph_in_background
import inspect
assert inspect.iscoroutinefunction(run_graph_in_background)
# ---------------------------------------------------------------------------
# _ensure_db_schema tests
# ---------------------------------------------------------------------------
class TestEnsureDbSchema:
@pytest.mark.asyncio
async def test_ensure_db_schema_creates_table(self) -> None:
from release_agent.main import _ensure_db_schema
mock_pool = AsyncMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
# Phase 5: now executes multiple DDL statements (agent_threads +
# staging_releases + archived_releases), so called_once no longer holds.
assert mock_cursor.execute.call_count >= 1
all_sql = " ".join(
call.args[0] for call in mock_cursor.execute.call_args_list
)
assert "agent_threads" in all_sql
# ---------------------------------------------------------------------------
# _create_tool_clients tests
# ---------------------------------------------------------------------------
class TestCreateToolClients:
def test_create_tool_clients_returns_tool_clients_instance(self) -> None:
from release_agent.main import _create_tool_clients
from release_agent.graph.dependencies import ToolClients
mock_settings = _make_mock_settings()
with (
patch("release_agent.main.AzDoClient") as mock_azdo,
patch("release_agent.main.JiraClient") as mock_jira,
patch("release_agent.main.SlackClient") as mock_slack,
patch("release_agent.main.ClaudeReviewer") as mock_reviewer,
patch("release_agent.main.httpx.AsyncClient") as mock_httpx,
):
clients, http_clients = _create_tool_clients(mock_settings)
assert isinstance(clients, ToolClients)
# ---------------------------------------------------------------------------
# _create_staging_store tests
# ---------------------------------------------------------------------------
class TestCreateStagingStore:
def test_create_staging_store_returns_store(self) -> None:
from release_agent.main import _create_staging_store
from release_agent.graph.dependencies import JsonFileStagingStore
result = _create_staging_store()
assert isinstance(result, JsonFileStagingStore)
# ---------------------------------------------------------------------------
# Global exception handler tests
# ---------------------------------------------------------------------------
class TestExceptionHandlers:
def test_app_has_exception_handlers(self) -> None:
from release_agent.main import create_app
with (
patch("release_agent.main.Settings", return_value=_make_mock_settings()),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=AsyncMock()),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
# FastAPI stores exception handlers in exception_handlers attribute
assert hasattr(app, "exception_handlers")
# ---------------------------------------------------------------------------
# Lifespan tests
# ---------------------------------------------------------------------------
class TestGracefulShutdown:
@pytest.mark.asyncio
async def test_lifespan_cancels_timed_out_tasks(self) -> None:
"""Verify the lifespan waits for tasks and cancels timed-out ones."""
from release_agent.main import lifespan
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
from fastapi import FastAPI
app = FastAPI()
app.state.background_tasks = set()
mock_settings = _make_mock_settings()
mock_task = MagicMock()
mock_task.cancel = MagicMock()
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
patch(
"release_agent.main.asyncio.wait",
new_callable=AsyncMock,
return_value=(set(), {mock_task}),
),
):
ctx = lifespan(app)
await ctx.__aenter__()
# Add a fake task to background_tasks after startup
app.state.background_tasks.add(mock_task)
await ctx.__aexit__(None, None, None)
# The pending task should have been cancelled
mock_task.cancel.assert_called_once()
class TestLifespan:
def test_app_state_set_after_lifespan(self) -> None:
"""Verify app.state.graphs and app.state.settings are set during lifespan."""
from release_agent.main import create_app
mock_pool = AsyncMock()
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
mock_pool.__aexit__ = AsyncMock(return_value=False)
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.execute = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
mock_settings = _make_mock_settings()
mock_graphs = {"pr_completed": MagicMock(), "release": MagicMock()}
mock_clients = MagicMock()
mock_staging_store = MagicMock()
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=mock_graphs["pr_completed"]),
patch("release_agent.main.build_release_graph", return_value=mock_graphs["release"]),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(mock_clients, [])),
patch("release_agent.main._create_staging_store", return_value=mock_staging_store),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
):
app = create_app()
with TestClient(app) as client:
# App is started; state should be accessible
response = client.get("/status")
# We just verify no crash
assert response.status_code in (200, 500)
# ---------------------------------------------------------------------------
# Phase 5: Slack interactions route + new config tests
# ---------------------------------------------------------------------------
class TestPhase5Routes:
"""Tests for Phase 5 additions to main.py."""
def _make_patches(self):
mock_settings = _make_mock_settings()
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_pool.__aenter__ = AsyncMock(return_value=mock_pool)
mock_pool.__aexit__ = AsyncMock(return_value=False)
return [
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=MagicMock()),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
]
def test_slack_interactions_route_registered(self) -> None:
from release_agent.main import create_app
patches = self._make_patches()
for p in patches:
p.start()
try:
app = create_app()
finally:
for p in patches:
p.stop()
routes = {r.path for r in app.routes} # type: ignore[attr-defined]
assert "/slack/interactions" in routes
def test_create_tool_clients_uses_bot_token(self) -> None:
from release_agent.main import _create_tool_clients
mock_settings = _make_mock_settings()
mock_settings.slack_bot_token.get_secret_value.return_value = "xoxb-test"
mock_settings.slack_channel_id = "C12345"
mock_settings.slack_webhook_url.get_secret_value.return_value = ""
mock_settings.azdo_api_url = "https://dev.azure.com/org/proj/_apis"
mock_settings.azdo_vsrm_api_url = "https://vsrm.dev.azure.com/org/proj/_apis"
mock_settings.jira_base_url = "https://example.atlassian.net"
mock_settings.jira_email = "test@example.com"
# Should not raise
clients, http_clients = _create_tool_clients(mock_settings)
assert clients is not None
# Clean up
for hc in http_clients:
asyncio.get_event_loop().run_until_complete(hc.aclose())
class TestPhase5DbSchema:
"""Tests that _ensure_db_schema adds the slack_message_ts column."""
async def test_ensure_db_schema_executes_sql_statements(self) -> None:
from release_agent.main import _ensure_db_schema
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
executed_sql: list[str] = []
async def capture_execute(sql, *args):
executed_sql.append(sql.strip())
mock_cursor.execute = capture_execute
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
# Should have executed CREATE TABLE statements
assert len(executed_sql) >= 3
combined = " ".join(executed_sql)
assert "agent_threads" in combined
assert "staging_releases" in combined
async def test_ensure_db_schema_includes_slack_message_ts_column(self) -> None:
from release_agent.main import _ensure_db_schema
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
executed_sql: list[str] = []
async def capture_execute(sql, *args):
executed_sql.append(sql.strip())
mock_cursor.execute = capture_execute
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
combined = " ".join(executed_sql)
assert "slack_message_ts" in combined
# ---------------------------------------------------------------------------
# PR polling lifespan integration tests
# ---------------------------------------------------------------------------
class TestPrPollingLifespan:
"""Tests for PR polling startup in the lifespan handler."""
def _make_polling_settings(self, *, pr_poll_enabled: bool = True) -> MagicMock:
s = _make_mock_settings()
s.pr_poll_enabled = pr_poll_enabled
s.pr_poll_interval_seconds = 30
s.pr_poll_target_branch = "refs/heads/develop"
s.watched_repos_list = ["repo-a"]
s.default_jira_project = "ALLPOST"
return s
async def test_poll_loop_started_when_pr_poll_enabled(self) -> None:
"""When pr_poll_enabled=True, a background task for polling is created."""
from release_agent.main import create_app
mock_settings = self._make_polling_settings(pr_poll_enabled=True)
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_pool.connection = MagicMock()
poll_loop_started = []
async def fake_run_poll_loop(**kwargs):
poll_loop_started.append(True)
# Simulate an immediate cancellation to avoid infinite loop
raise asyncio.CancelledError
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop),
):
app = create_app()
async with app.router.lifespan_context(app):
# Give the event loop a chance to start background tasks
await asyncio.sleep(0)
assert len(poll_loop_started) > 0
async def test_poll_loop_not_started_when_pr_poll_disabled(self) -> None:
"""When pr_poll_enabled=False, no polling background task is created."""
from release_agent.main import create_app
mock_settings = self._make_polling_settings(pr_poll_enabled=False)
mock_pool = AsyncMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_pool.connection = MagicMock()
poll_loop_started = []
async def fake_run_poll_loop(**kwargs):
poll_loop_started.append(True)
with (
patch("release_agent.main.Settings", return_value=mock_settings),
patch("release_agent.main.build_pr_completed_graph", return_value=MagicMock()),
patch("release_agent.main.build_release_graph", return_value=MagicMock()),
patch("release_agent.main.AsyncConnectionPool", return_value=mock_pool),
patch("release_agent.main._create_tool_clients", return_value=(MagicMock(), [])),
patch("release_agent.main._create_staging_store", return_value=MagicMock()),
patch("release_agent.main._ensure_db_schema", new_callable=AsyncMock),
patch("release_agent.main.run_pr_poll_loop", new=fake_run_poll_loop),
):
app = create_app()
async with app.router.lifespan_context(app):
await asyncio.sleep(0)
assert len(poll_loop_started) == 0
# ---------------------------------------------------------------------------
# _run_graph default_jira_project injection tests
# ---------------------------------------------------------------------------
class TestRunGraphJiraProjectInjection:
"""Tests that _run_graph passes default_jira_project into the graph config."""
async def test_default_jira_project_passed_to_graph_config(self) -> None:
from release_agent.api.webhooks import _run_graph
captured_configs: list[dict] = []
mock_graph = MagicMock()
async def fake_ainvoke(state, config=None):
captured_configs.append(config or {})
return {}
mock_graph.ainvoke = fake_ainvoke
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"pr_id": "1", "repo_name": "r"},
thread_id="tid-1",
tool_clients=MagicMock(),
db_pool=mock_pool,
repos_base_dir="",
graph_name="pr_completed",
default_jira_project="MYPROJ",
)
assert len(captured_configs) == 1
configurable = captured_configs[0].get("configurable", {})
assert configurable.get("default_jira_project") == "MYPROJ"
async def test_default_jira_project_defaults_to_allpost(self) -> None:
from release_agent.api.webhooks import _run_graph
captured_configs: list[dict] = []
mock_graph = MagicMock()
async def fake_ainvoke(state, config=None):
captured_configs.append(config or {})
return {}
mock_graph.ainvoke = fake_ainvoke
mock_pool = MagicMock()
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool.connection = MagicMock(return_value=mock_conn)
await _run_graph(
graph=mock_graph,
initial_state={"pr_id": "1", "repo_name": "r"},
thread_id="tid-2",
tool_clients=MagicMock(),
db_pool=mock_pool,
)
configurable = captured_configs[0].get("configurable", {})
assert configurable.get("default_jira_project") == "ALLPOST"

147
tests/test_main_phase5.py Normal file
View File

@@ -0,0 +1,147 @@
"""Tests for main.py Phase 5 changes.
Phase 5 - Step 4: _ensure_db_schema creates staging/archived tables,
and lifespan uses PostgresStagingStore.
Written FIRST (TDD RED phase).
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# _ensure_db_schema includes staging DDL
# ---------------------------------------------------------------------------
class TestEnsureDbSchemaPhase5:
async def test_schema_creates_staging_releases_table(self) -> None:
from release_agent.main import _ensure_db_schema
executed_sqls: list[str] = []
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
async def capture_execute(sql: str, *args) -> None:
executed_sqls.append(sql)
mock_cursor.execute = capture_execute
mock_conn = AsyncMock()
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool = MagicMock()
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
all_sql = " ".join(executed_sqls)
assert "staging_releases" in all_sql
async def test_schema_creates_archived_releases_table(self) -> None:
from release_agent.main import _ensure_db_schema
executed_sqls: list[str] = []
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
async def capture_execute(sql: str, *args) -> None:
executed_sqls.append(sql)
mock_cursor.execute = capture_execute
mock_conn = AsyncMock()
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool = MagicMock()
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
all_sql = " ".join(executed_sqls)
assert "archived_releases" in all_sql
async def test_schema_still_creates_agent_threads_table(self) -> None:
from release_agent.main import _ensure_db_schema
executed_sqls: list[str] = []
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
async def capture_execute(sql: str, *args) -> None:
executed_sqls.append(sql)
mock_cursor.execute = capture_execute
mock_conn = AsyncMock()
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool = MagicMock()
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
all_sql = " ".join(executed_sqls)
assert "agent_threads" in all_sql
async def test_schema_uses_if_not_exists(self) -> None:
from release_agent.main import _ensure_db_schema
executed_sqls: list[str] = []
mock_cursor = AsyncMock()
mock_cursor.__aenter__ = AsyncMock(return_value=mock_cursor)
mock_cursor.__aexit__ = AsyncMock(return_value=False)
async def capture_execute(sql: str, *args) -> None:
executed_sqls.append(sql)
mock_cursor.execute = capture_execute
mock_conn = AsyncMock()
mock_conn.cursor = MagicMock(return_value=mock_cursor)
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock(return_value=False)
mock_pool = MagicMock()
mock_pool.connection = MagicMock(return_value=mock_conn)
await _ensure_db_schema(mock_pool)
all_sql = " ".join(executed_sqls)
assert "IF NOT EXISTS" in all_sql.upper()
# ---------------------------------------------------------------------------
# Lifespan: PostgresStagingStore wired in
# ---------------------------------------------------------------------------
class TestLifespanUsesPostgresStagingStore:
def test_lifespan_creates_postgres_staging_store(self) -> None:
"""When PostgresStagingStore is imported in main, it is used in lifespan."""
from release_agent.main import _create_staging_store
from release_agent.graph.postgres_staging_store import PostgresStagingStore
mock_pool = MagicMock()
result = _create_staging_store(pool=mock_pool)
assert isinstance(result, PostgresStagingStore)
def test_create_staging_store_without_pool_falls_back_to_json(self) -> None:
"""Without a pool, falls back to JsonFileStagingStore for local dev."""
from release_agent.main import _create_staging_store
from release_agent.graph.dependencies import JsonFileStagingStore
result = _create_staging_store(pool=None)
assert isinstance(result, JsonFileStagingStore)

635
tests/test_models.py Normal file
View File

@@ -0,0 +1,635 @@
"""Tests for Pydantic models. Written FIRST (TDD RED phase)."""
from datetime import date, datetime
import pytest
from pydantic import ValidationError
from release_agent.models.jira import JiraIssue, JiraTransition
from release_agent.models.pipeline import PipelineInfo, ReleasePipelineStage
from release_agent.models.pr import PRInfo
from release_agent.models.release import ArchivedRelease, StagingRelease
from release_agent.models.review import ReviewIssue, ReviewResult
from release_agent.models.ticket import TicketEntry
from release_agent.models.webhook import WebhookPayload, WebhookRepository, WebhookResource
# ---------------------------------------------------------------------------
# PRInfo tests
# ---------------------------------------------------------------------------
class TestPRInfo:
"""Tests for PRInfo model."""
def _make_pr(self, **kwargs) -> PRInfo:
defaults = {
"pr_id": "PR-1",
"pr_url": "https://dev.azure.com/org/project/_git/repo/pullrequest/1",
"repo_name": "my-repo",
"branch": "feature/ALLPOST-100_add-feature",
"pr_title": "Add new feature",
"pr_status": "active",
}
defaults.update(kwargs)
return PRInfo(**defaults)
def test_ticket_id_extracted_from_branch(self) -> None:
pr = self._make_pr(branch="feature/ALLPOST-100_add-feature")
assert pr.ticket_id == "ALLPOST-100"
assert pr.has_ticket is True
def test_branch_without_ticket(self) -> None:
pr = self._make_pr(branch="chore/update-dependencies")
assert pr.ticket_id is None
assert pr.has_ticket is False
def test_main_branch_no_ticket(self) -> None:
pr = self._make_pr(branch="main")
assert pr.ticket_id is None
assert pr.has_ticket is False
def test_refs_heads_branch_parsed(self) -> None:
pr = self._make_pr(branch="refs/heads/fix/BILL-42_fix-bug")
assert pr.ticket_id == "BILL-42"
assert pr.has_ticket is True
def test_pr_status_active(self) -> None:
pr = self._make_pr(pr_status="active")
assert pr.pr_status == "active"
def test_pr_status_completed(self) -> None:
pr = self._make_pr(pr_status="completed")
assert pr.pr_status == "completed"
def test_pr_status_abandoned(self) -> None:
pr = self._make_pr(pr_status="abandoned")
assert pr.pr_status == "abandoned"
def test_invalid_pr_status_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_pr(pr_status="unknown")
def test_model_is_frozen(self) -> None:
pr = self._make_pr()
with pytest.raises(ValidationError):
pr.pr_id = "modified" # type: ignore[misc]
def test_pr_url_is_valid_url(self) -> None:
pr = self._make_pr()
# HttpUrl should have been validated
assert "dev.azure.com" in str(pr.pr_url)
def test_invalid_url_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_pr(pr_url="not-a-url")
# ---------------------------------------------------------------------------
# TicketEntry tests
# ---------------------------------------------------------------------------
class TestTicketEntry:
"""Tests for TicketEntry model."""
def _make_ticket(self, **kwargs) -> TicketEntry:
defaults = {
"id": "ALLPOST-4229",
"summary": "Fix review bug",
"pr_id": "PR-42",
"pr_url": "https://dev.azure.com/org/project/_git/repo/pullrequest/42",
"pr_title": "Fix review",
"branch": "bug/ALLPOST-4229_fix-review",
"merged_at": date(2024, 1, 15),
}
defaults.update(kwargs)
return TicketEntry(**defaults)
def test_valid_ticket_entry(self) -> None:
ticket = self._make_ticket()
assert ticket.id == "ALLPOST-4229"
assert ticket.summary == "Fix review bug"
def test_valid_jira_id_format(self) -> None:
ticket = self._make_ticket(id="BILL-42")
assert ticket.id == "BILL-42"
def test_invalid_id_lowercase_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_ticket(id="allpost-4229")
def test_invalid_id_no_number_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_ticket(id="ALLPOST-")
def test_invalid_id_no_dash_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_ticket(id="ALLPOST4229")
def test_invalid_id_starts_with_number_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_ticket(id="4ALLPOST-4229")
def test_merged_at_is_date(self) -> None:
ticket = self._make_ticket()
assert isinstance(ticket.merged_at, date)
def test_model_is_frozen(self) -> None:
ticket = self._make_ticket()
with pytest.raises(ValidationError):
ticket.id = "OTHER-1" # type: ignore[misc]
def test_minimum_valid_id(self) -> None:
# Single uppercase letter prefix followed by dash and digits
ticket = self._make_ticket(id="A-1")
assert ticket.id == "A-1"
def test_numeric_in_project_key(self) -> None:
ticket = self._make_ticket(id="AB2-100")
assert ticket.id == "AB2-100"
# ---------------------------------------------------------------------------
# StagingRelease tests
# ---------------------------------------------------------------------------
class TestStagingRelease:
"""Tests for StagingRelease model."""
def _make_ticket(self, ticket_id: str = "ALLPOST-1") -> TicketEntry:
return TicketEntry(
id=ticket_id,
summary="Some ticket",
pr_id="PR-1",
pr_url="https://dev.azure.com/org/project/_git/repo/pullrequest/1",
pr_title="Some PR",
branch=f"feature/{ticket_id}_some-feature",
merged_at=date(2024, 1, 15),
)
def _make_release(self, **kwargs) -> StagingRelease:
defaults = {
"version": "v1.0.0",
"repo": "my-repo",
"started_at": date(2024, 1, 1),
"tickets": [],
}
defaults.update(kwargs)
return StagingRelease(**defaults)
def test_valid_release(self) -> None:
release = self._make_release()
assert release.version == "v1.0.0"
def test_version_must_match_pattern(self) -> None:
with pytest.raises(ValidationError):
self._make_release(version="1.0.0")
def test_version_missing_patch_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_release(version="v1.0")
def test_version_extra_segments_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_release(version="v1.0.0.1")
def test_version_letters_in_numbers_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_release(version="v1.a.0")
def test_add_ticket_returns_new_instance(self) -> None:
release = self._make_release()
ticket = self._make_ticket("ALLPOST-1")
new_release = release.add_ticket(ticket)
assert new_release is not release
def test_add_ticket_immutability(self) -> None:
release = self._make_release()
ticket = self._make_ticket("ALLPOST-1")
new_release = release.add_ticket(ticket)
assert len(release.tickets) == 0
assert len(new_release.tickets) == 1
def test_add_ticket_contains_ticket(self) -> None:
release = self._make_release()
ticket = self._make_ticket("ALLPOST-1")
new_release = release.add_ticket(ticket)
assert ticket in new_release.tickets
def test_has_ticket_true(self) -> None:
ticket = self._make_ticket("ALLPOST-1")
release = self._make_release(tickets=[ticket])
assert release.has_ticket("ALLPOST-1") is True
def test_has_ticket_false(self) -> None:
release = self._make_release()
assert release.has_ticket("ALLPOST-99") is False
def test_has_ticket_after_add(self) -> None:
release = self._make_release()
ticket = self._make_ticket("ALLPOST-5")
new_release = release.add_ticket(ticket)
assert new_release.has_ticket("ALLPOST-5") is True
def test_model_is_frozen(self) -> None:
release = self._make_release()
with pytest.raises(ValidationError):
release.version = "v2.0.0" # type: ignore[misc]
def test_multiple_tickets(self) -> None:
t1 = self._make_ticket("ALLPOST-1")
t2 = self._make_ticket("ALLPOST-2")
release = self._make_release(tickets=[t1, t2])
assert len(release.tickets) == 2
# ---------------------------------------------------------------------------
# ArchivedRelease tests
# ---------------------------------------------------------------------------
class TestArchivedRelease:
"""Tests for ArchivedRelease model."""
def _make_archived(self, **kwargs) -> ArchivedRelease:
defaults = {
"version": "v1.0.0",
"repo": "my-repo",
"started_at": date(2024, 1, 1),
"tickets": [],
"released_at": date(2024, 1, 10),
}
defaults.update(kwargs)
return ArchivedRelease(**defaults)
def test_valid_archived_release(self) -> None:
release = self._make_archived()
assert release.released_at == date(2024, 1, 10)
def test_released_at_same_as_started_at_is_valid(self) -> None:
release = self._make_archived(started_at=date(2024, 1, 1), released_at=date(2024, 1, 1))
assert release.released_at == release.started_at
def test_released_at_before_started_at_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_archived(
started_at=date(2024, 1, 10),
released_at=date(2024, 1, 1),
)
def test_model_is_frozen(self) -> None:
release = self._make_archived()
with pytest.raises(ValidationError):
release.released_at = date(2024, 12, 31) # type: ignore[misc]
def test_inherits_version_validation(self) -> None:
with pytest.raises(ValidationError):
self._make_archived(version="1.0.0")
# ---------------------------------------------------------------------------
# PipelineInfo tests
# ---------------------------------------------------------------------------
class TestPipelineInfo:
"""Tests for PipelineInfo model."""
def test_valid_pipeline_info(self) -> None:
pipeline = PipelineInfo(id=42, name="Release Pipeline", repo="my-repo")
assert pipeline.id == 42
assert pipeline.name == "Release Pipeline"
assert pipeline.repo == "my-repo"
def test_model_is_frozen(self) -> None:
pipeline = PipelineInfo(id=1, name="Test", repo="repo")
with pytest.raises(ValidationError):
pipeline.id = 2 # type: ignore[misc]
# ---------------------------------------------------------------------------
# ReleasePipelineStage tests
# ---------------------------------------------------------------------------
class TestReleasePipelineStage:
"""Tests for ReleasePipelineStage model."""
def test_valid_stage_without_approval(self) -> None:
stage = ReleasePipelineStage(
name="Build", rank=0, requires_approval=False, approval_id=None
)
assert stage.name == "Build"
assert stage.rank == 0
def test_valid_stage_with_approval(self) -> None:
stage = ReleasePipelineStage(
name="Production", rank=2, requires_approval=True, approval_id="approval-uuid-123"
)
assert stage.requires_approval is True
assert stage.approval_id == "approval-uuid-123"
def test_negative_rank_raises(self) -> None:
with pytest.raises(ValidationError):
ReleasePipelineStage(
name="Bad", rank=-1, requires_approval=False, approval_id=None
)
def test_requires_approval_false_with_approval_id_raises(self) -> None:
with pytest.raises(ValidationError):
ReleasePipelineStage(
name="Bad", rank=0, requires_approval=False, approval_id="some-id"
)
def test_requires_approval_true_without_approval_id_is_valid(self) -> None:
stage = ReleasePipelineStage(
name="Production", rank=2, requires_approval=True, approval_id=None
)
assert stage.requires_approval is True
assert stage.approval_id is None
def test_model_is_frozen(self) -> None:
stage = ReleasePipelineStage(name="Build", rank=0, requires_approval=False, approval_id=None)
with pytest.raises(ValidationError):
stage.name = "Changed" # type: ignore[misc]
# ---------------------------------------------------------------------------
# WebhookPayload tests
# ---------------------------------------------------------------------------
class TestWebhookPayload:
"""Tests for WebhookPayload and nested models."""
def _make_payload(self, **kwargs) -> WebhookPayload:
defaults = {
"subscription_id": "sub-123",
"event_type": "git.pullrequest.merged",
"resource": {
"repository": {
"id": "repo-uuid-456",
"name": "my-repo",
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
},
"pull_request_id": 42,
"title": "Fix the bug",
"source_ref_name": "refs/heads/bug/ALLPOST-4229_fix-review",
"target_ref_name": "refs/heads/main",
"status": "completed",
"closed_date": None,
},
}
defaults.update(kwargs)
return WebhookPayload(**defaults)
def test_valid_payload(self) -> None:
payload = self._make_payload()
assert payload.subscription_id == "sub-123"
assert payload.event_type == "git.pullrequest.merged"
def test_resource_parsed(self) -> None:
payload = self._make_payload()
assert isinstance(payload.resource, WebhookResource)
assert payload.resource.pull_request_id == 42
assert payload.resource.title == "Fix the bug"
def test_repository_parsed(self) -> None:
payload = self._make_payload()
repo = payload.resource.repository
assert isinstance(repo, WebhookRepository)
assert repo.name == "my-repo"
def test_repository_web_url(self) -> None:
payload = self._make_payload()
assert "dev.azure.com" in str(payload.resource.repository.web_url)
def test_closed_date_none(self) -> None:
payload = self._make_payload()
assert payload.resource.closed_date is None
def test_closed_date_populated(self) -> None:
payload_data = {
"subscription_id": "sub-123",
"event_type": "git.pullrequest.merged",
"resource": {
"repository": {
"id": "repo-uuid-456",
"name": "my-repo",
"web_url": "https://dev.azure.com/org/project/_git/my-repo",
},
"pull_request_id": 42,
"title": "Fix the bug",
"source_ref_name": "refs/heads/bug/ALLPOST-4229_fix-review",
"target_ref_name": "refs/heads/main",
"status": "completed",
"closed_date": "2024-01-15T10:30:00Z",
},
}
payload = WebhookPayload(**payload_data)
assert payload.resource.closed_date is not None
assert isinstance(payload.resource.closed_date, datetime)
def test_model_is_frozen(self) -> None:
payload = self._make_payload()
with pytest.raises(ValidationError):
payload.subscription_id = "changed" # type: ignore[misc]
def test_source_ref_name_preserved(self) -> None:
payload = self._make_payload()
assert payload.resource.source_ref_name == "refs/heads/bug/ALLPOST-4229_fix-review"
# ---------------------------------------------------------------------------
# ReviewIssue tests
# ---------------------------------------------------------------------------
class TestReviewIssue:
"""Tests for ReviewIssue model."""
def _make_issue(self, **kwargs) -> ReviewIssue:
defaults = {
"severity": "warning",
"description": "Variable name is not descriptive",
}
defaults.update(kwargs)
return ReviewIssue(**defaults)
def test_valid_warning_issue(self) -> None:
issue = self._make_issue(severity="warning", description="Unclear variable")
assert issue.severity == "warning"
assert issue.description == "Unclear variable"
def test_valid_error_issue(self) -> None:
issue = self._make_issue(severity="error", description="Null pointer risk")
assert issue.severity == "error"
def test_valid_info_issue(self) -> None:
issue = self._make_issue(severity="info", description="Minor style note")
assert issue.severity == "info"
def test_valid_blocker_issue(self) -> None:
issue = self._make_issue(severity="blocker", description="Security vulnerability")
assert issue.severity == "blocker"
def test_invalid_severity_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_issue(severity="critical")
def test_file_path_optional_none_by_default(self) -> None:
issue = self._make_issue()
assert issue.file_path is None
def test_file_path_can_be_set(self) -> None:
issue = self._make_issue(file_path="src/foo.py")
assert issue.file_path == "src/foo.py"
def test_suggestion_optional_none_by_default(self) -> None:
issue = self._make_issue()
assert issue.suggestion is None
def test_suggestion_can_be_set(self) -> None:
issue = self._make_issue(suggestion="Rename to `user_count`")
assert issue.suggestion == "Rename to `user_count`"
def test_model_is_frozen(self) -> None:
issue = self._make_issue()
with pytest.raises(ValidationError):
issue.severity = "error" # type: ignore[misc]
def test_description_required(self) -> None:
with pytest.raises(ValidationError):
ReviewIssue(severity="warning") # type: ignore[call-arg]
# ---------------------------------------------------------------------------
# ReviewResult tests
# ---------------------------------------------------------------------------
class TestReviewResult:
"""Tests for ReviewResult model."""
def _make_blocker_issue(self) -> ReviewIssue:
return ReviewIssue(severity="blocker", description="Must fix this")
def _make_warning_issue(self) -> ReviewIssue:
return ReviewIssue(severity="warning", description="Minor issue")
def _make_result(self, **kwargs) -> ReviewResult:
defaults = {
"verdict": "approve",
"summary": "Looks good overall",
"issues": [],
}
defaults.update(kwargs)
return ReviewResult(**defaults)
def test_valid_approve_verdict(self) -> None:
result = self._make_result(verdict="approve")
assert result.verdict == "approve"
def test_valid_request_changes_verdict(self) -> None:
result = self._make_result(verdict="request_changes")
assert result.verdict == "request_changes"
def test_invalid_verdict_raises(self) -> None:
with pytest.raises(ValidationError):
self._make_result(verdict="reject")
def test_summary_stored(self) -> None:
result = self._make_result(summary="Great PR")
assert result.summary == "Great PR"
def test_issues_empty_by_default(self) -> None:
result = self._make_result()
assert len(result.issues) == 0
def test_has_blockers_false_with_no_issues(self) -> None:
result = self._make_result(issues=[])
assert result.has_blockers is False
def test_has_blockers_false_with_only_warnings(self) -> None:
result = self._make_result(issues=[self._make_warning_issue()])
assert result.has_blockers is False
def test_has_blockers_true_with_blocker_issue(self) -> None:
result = self._make_result(issues=[self._make_blocker_issue()])
assert result.has_blockers is True
def test_has_blockers_true_mixed_issues(self) -> None:
result = self._make_result(
issues=[self._make_warning_issue(), self._make_blocker_issue()]
)
assert result.has_blockers is True
def test_model_is_frozen(self) -> None:
result = self._make_result()
with pytest.raises(ValidationError):
result.verdict = "request_changes" # type: ignore[misc]
def test_multiple_issues_stored(self) -> None:
issues = [self._make_warning_issue(), self._make_blocker_issue()]
result = self._make_result(issues=issues)
assert len(result.issues) == 2
def test_has_blockers_is_computed(self) -> None:
# Verify has_blockers cannot be set directly (it's computed)
result = self._make_result(issues=[self._make_blocker_issue()])
assert result.has_blockers is True
# ---------------------------------------------------------------------------
# JiraTransition tests
# ---------------------------------------------------------------------------
class TestJiraTransition:
"""Tests for JiraTransition model."""
def test_valid_transition(self) -> None:
transition = JiraTransition(id="11", name="To Do")
assert transition.id == "11"
assert transition.name == "To Do"
def test_model_is_frozen(self) -> None:
transition = JiraTransition(id="11", name="To Do")
with pytest.raises(ValidationError):
transition.id = "22" # type: ignore[misc]
def test_id_required(self) -> None:
with pytest.raises(ValidationError):
JiraTransition(name="To Do") # type: ignore[call-arg]
def test_name_required(self) -> None:
with pytest.raises(ValidationError):
JiraTransition(id="11") # type: ignore[call-arg]
# ---------------------------------------------------------------------------
# JiraIssue tests
# ---------------------------------------------------------------------------
class TestJiraIssue:
"""Tests for JiraIssue model."""
def test_valid_issue(self) -> None:
issue = JiraIssue(key="ALLPOST-100", summary="Fix the bug", status="In Progress")
assert issue.key == "ALLPOST-100"
assert issue.summary == "Fix the bug"
assert issue.status == "In Progress"
def test_model_is_frozen(self) -> None:
issue = JiraIssue(key="ALLPOST-100", summary="Fix the bug", status="In Progress")
with pytest.raises(ValidationError):
issue.key = "ALLPOST-200" # type: ignore[misc]
def test_key_required(self) -> None:
with pytest.raises(ValidationError):
JiraIssue(summary="Fix the bug", status="In Progress") # type: ignore[call-arg]
def test_summary_required(self) -> None:
with pytest.raises(ValidationError):
JiraIssue(key="ALLPOST-100", status="In Progress") # type: ignore[call-arg]
def test_status_required(self) -> None:
with pytest.raises(ValidationError):
JiraIssue(key="ALLPOST-100", summary="Fix the bug") # type: ignore[call-arg]
def test_various_statuses(self) -> None:
statuses = ["To Do", "In Progress", "Done", "Released"]
for status in statuses:
issue = JiraIssue(key="ALLPOST-1", summary="Test", status=status)
assert issue.status == status

148
tests/test_models_build.py Normal file
View File

@@ -0,0 +1,148 @@
"""Tests for models/build.py — BuildStatus and ApprovalRecord.
Written FIRST (TDD RED phase).
"""
import pytest
from dataclasses import FrozenInstanceError
from release_agent.models.build import ApprovalRecord, BuildStatus
# ---------------------------------------------------------------------------
# BuildStatus tests
# ---------------------------------------------------------------------------
class TestBuildStatus:
"""Tests for BuildStatus frozen dataclass."""
def test_can_be_created_with_all_fields(self) -> None:
bs = BuildStatus(
status="completed",
result="succeeded",
build_url="https://dev.azure.com/org/proj/_build/results?buildId=42",
)
assert bs.status == "completed"
assert bs.result == "succeeded"
assert bs.build_url == "https://dev.azure.com/org/proj/_build/results?buildId=42"
def test_result_can_be_none(self) -> None:
bs = BuildStatus(
status="inProgress",
result=None,
build_url="https://dev.azure.com/org/proj/_build/results?buildId=99",
)
assert bs.result is None
def test_build_url_can_be_none(self) -> None:
bs = BuildStatus(status="notStarted", result=None, build_url=None)
assert bs.build_url is None
def test_is_frozen_status(self) -> None:
bs = BuildStatus(status="completed", result="succeeded", build_url=None)
with pytest.raises((FrozenInstanceError, AttributeError)):
bs.status = "inProgress" # type: ignore[misc]
def test_is_frozen_result(self) -> None:
bs = BuildStatus(status="completed", result="succeeded", build_url=None)
with pytest.raises((FrozenInstanceError, AttributeError)):
bs.result = "failed" # type: ignore[misc]
def test_equality(self) -> None:
a = BuildStatus(status="completed", result="succeeded", build_url="http://x")
b = BuildStatus(status="completed", result="succeeded", build_url="http://x")
assert a == b
def test_inequality_on_status(self) -> None:
a = BuildStatus(status="completed", result="succeeded", build_url=None)
b = BuildStatus(status="inProgress", result="succeeded", build_url=None)
assert a != b
def test_inequality_on_result(self) -> None:
a = BuildStatus(status="completed", result="succeeded", build_url=None)
b = BuildStatus(status="completed", result="failed", build_url=None)
assert a != b
def test_repr_contains_status(self) -> None:
bs = BuildStatus(status="completed", result="succeeded", build_url=None)
assert "completed" in repr(bs)
def test_status_values_typical(self) -> None:
for s in ("notStarted", "inProgress", "completed", "cancelling"):
bs = BuildStatus(status=s, result=None, build_url=None)
assert bs.status == s
def test_result_values_typical(self) -> None:
for r in ("succeeded", "failed", "canceled", "partiallySucceeded"):
bs = BuildStatus(status="completed", result=r, build_url=None)
assert bs.result == r
# ---------------------------------------------------------------------------
# ApprovalRecord tests
# ---------------------------------------------------------------------------
class TestApprovalRecord:
"""Tests for ApprovalRecord frozen dataclass."""
def test_can_be_created_with_all_fields(self) -> None:
ar = ApprovalRecord(
approval_id="approval-abc-123",
stage_name="Sandbox",
status="pending",
release_id=42,
)
assert ar.approval_id == "approval-abc-123"
assert ar.stage_name == "Sandbox"
assert ar.status == "pending"
assert ar.release_id == 42
def test_is_frozen_approval_id(self) -> None:
ar = ApprovalRecord(
approval_id="abc",
stage_name="Sandbox",
status="pending",
release_id=1,
)
with pytest.raises((FrozenInstanceError, AttributeError)):
ar.approval_id = "xyz" # type: ignore[misc]
def test_is_frozen_stage_name(self) -> None:
ar = ApprovalRecord(
approval_id="abc",
stage_name="Sandbox",
status="pending",
release_id=1,
)
with pytest.raises((FrozenInstanceError, AttributeError)):
ar.stage_name = "Production" # type: ignore[misc]
def test_equality(self) -> None:
a = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1)
b = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1)
assert a == b
def test_inequality_on_approval_id(self) -> None:
a = ApprovalRecord(approval_id="x", stage_name="S", status="pending", release_id=1)
b = ApprovalRecord(approval_id="y", stage_name="S", status="pending", release_id=1)
assert a != b
def test_status_pending(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="pending", release_id=10)
assert ar.status == "pending"
def test_status_approved(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="approved", release_id=10)
assert ar.status == "approved"
def test_status_rejected(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="Stage", status="rejected", release_id=10)
assert ar.status == "rejected"
def test_repr_contains_stage_name(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="Production", status="pending", release_id=5)
assert "Production" in repr(ar)
def test_release_id_is_int(self) -> None:
ar = ApprovalRecord(approval_id="a", stage_name="S", status="pending", release_id=999)
assert isinstance(ar.release_id, int)

241
tests/test_state.py Normal file
View File

@@ -0,0 +1,241 @@
"""Tests for LangGraph state module. Written FIRST (TDD RED phase)."""
import json
from release_agent.state import ReleaseState, add_errors, add_messages
# ---------------------------------------------------------------------------
# ReleaseState tests
# ---------------------------------------------------------------------------
class TestReleaseState:
"""Tests for ReleaseState TypedDict."""
def test_empty_state_is_valid(self) -> None:
# total=False means all fields are optional
state: ReleaseState = {}
assert state == {}
def test_partial_state_with_repo(self) -> None:
state: ReleaseState = {"repo_name": "my-repo"}
assert state["repo_name"] == "my-repo"
def test_partial_state_with_messages(self) -> None:
state: ReleaseState = {"messages": ["Hello"]}
assert state["messages"] == ["Hello"]
def test_partial_state_with_errors(self) -> None:
state: ReleaseState = {"errors": ["Something went wrong"]}
assert state["errors"] == ["Something went wrong"]
def test_state_with_pr_id(self) -> None:
state: ReleaseState = {"pr_id": "PR-42"}
assert state["pr_id"] == "PR-42"
def test_state_with_ticket_id(self) -> None:
state: ReleaseState = {"ticket_id": "ALLPOST-100"}
assert state["ticket_id"] == "ALLPOST-100"
def test_state_with_version(self) -> None:
state: ReleaseState = {"version": "v1.0.1"}
assert state["version"] == "v1.0.1"
# Phase 3 new fields
def test_state_with_webhook_payload(self) -> None:
state: ReleaseState = {"webhook_payload": {"event_type": "git.pullrequest.merged"}}
assert state["webhook_payload"]["event_type"] == "git.pullrequest.merged"
def test_state_with_pr_info(self) -> None:
state: ReleaseState = {"pr_info": {"pr_id": "42", "repo_name": "my-repo"}}
assert state["pr_info"]["repo_name"] == "my-repo"
def test_state_with_pr_diff(self) -> None:
state: ReleaseState = {"pr_diff": "edit: src/main.py"}
assert state["pr_diff"] == "edit: src/main.py"
def test_state_with_last_merge_source_commit(self) -> None:
state: ReleaseState = {"last_merge_source_commit": "abc123"}
assert state["last_merge_source_commit"] == "abc123"
def test_state_with_ticket_summary(self) -> None:
state: ReleaseState = {"ticket_summary": "Fix login bug"}
assert state["ticket_summary"] == "Fix login bug"
def test_state_with_has_ticket(self) -> None:
state: ReleaseState = {"has_ticket": True}
assert state["has_ticket"] is True
def test_state_with_review_result(self) -> None:
state: ReleaseState = {"review_result": {"verdict": "approve", "summary": "LGTM"}}
assert state["review_result"]["verdict"] == "approve"
def test_state_with_review_approved(self) -> None:
state: ReleaseState = {"review_approved": True}
assert state["review_approved"] is True
def test_state_with_staging(self) -> None:
state: ReleaseState = {"staging": {"version": "v1.0.0", "tickets": []}}
assert state["staging"]["version"] == "v1.0.0"
def test_state_with_pr_already_merged(self) -> None:
state: ReleaseState = {"pr_already_merged": False}
assert state["pr_already_merged"] is False
def test_state_with_release_pr_id(self) -> None:
state: ReleaseState = {"release_pr_id": "123"}
assert state["release_pr_id"] == "123"
def test_state_with_release_pr_commit(self) -> None:
state: ReleaseState = {"release_pr_commit": "deadbeef"}
assert state["release_pr_commit"] == "deadbeef"
def test_state_with_pipelines(self) -> None:
state: ReleaseState = {"pipelines": [{"id": 1, "name": "build"}]}
assert len(state["pipelines"]) == 1
def test_state_with_triggered_builds(self) -> None:
state: ReleaseState = {"triggered_builds": [{"id": 99}]}
assert state["triggered_builds"][0]["id"] == 99
def test_state_with_pending_approvals(self) -> None:
state: ReleaseState = {"pending_approvals": [{"approval_id": "aaa"}]}
assert state["pending_approvals"][0]["approval_id"] == "aaa"
def test_state_with_continue_to_release(self) -> None:
state: ReleaseState = {"continue_to_release": True}
assert state["continue_to_release"] is True
# Phase 5: CI/CD and approval fields
def test_state_with_ci_build_id(self) -> None:
state: ReleaseState = {"ci_build_id": 12345}
assert state["ci_build_id"] == 12345
def test_state_with_ci_build_status(self) -> None:
state: ReleaseState = {"ci_build_status": "inProgress"}
assert state["ci_build_status"] == "inProgress"
def test_state_with_ci_build_result(self) -> None:
state: ReleaseState = {"ci_build_result": "succeeded"}
assert state["ci_build_result"] == "succeeded"
def test_state_with_ci_build_url(self) -> None:
state: ReleaseState = {"ci_build_url": "https://dev.azure.com/org/proj/_build/results?buildId=99"}
assert "buildId=99" in state["ci_build_url"]
def test_state_with_release_definition_id(self) -> None:
state: ReleaseState = {"release_definition_id": 7}
assert state["release_definition_id"] == 7
def test_state_with_release_id(self) -> None:
state: ReleaseState = {"release_id": 456}
assert state["release_id"] == 456
def test_state_with_current_stage(self) -> None:
state: ReleaseState = {"current_stage": "sandbox_pending"}
assert state["current_stage"] == "sandbox_pending"
def test_state_with_approval_message_ts(self) -> None:
state: ReleaseState = {"approval_message_ts": "1234567890.123456"}
assert state["approval_message_ts"] == "1234567890.123456"
def test_state_with_slack_message_ts(self) -> None:
state: ReleaseState = {"slack_message_ts": "9876543210.000001"}
assert state["slack_message_ts"] == "9876543210.000001"
def test_state_json_serializable_empty(self) -> None:
state: ReleaseState = {}
serialized = json.dumps(state)
assert json.loads(serialized) == {}
def test_state_json_serializable_with_strings(self) -> None:
state: ReleaseState = {
"repo_name": "my-repo",
"pr_id": "PR-1",
"ticket_id": "ALLPOST-1",
"version": "v1.0.0",
}
serialized = json.dumps(state)
loaded = json.loads(serialized)
assert loaded["repo_name"] == "my-repo"
assert loaded["pr_id"] == "PR-1"
def test_state_json_serializable_with_lists(self) -> None:
state: ReleaseState = {
"messages": ["msg1", "msg2"],
"errors": ["err1"],
}
serialized = json.dumps(state)
loaded = json.loads(serialized)
assert loaded["messages"] == ["msg1", "msg2"]
assert loaded["errors"] == ["err1"]
# ---------------------------------------------------------------------------
# Reducer tests
# ---------------------------------------------------------------------------
class TestAddMessages:
"""Tests for add_messages reducer."""
def test_accumulates_to_empty(self) -> None:
result = add_messages([], ["Hello"])
assert result == ["Hello"]
def test_accumulates_to_existing(self) -> None:
result = add_messages(["Hello"], ["World"])
assert result == ["Hello", "World"]
def test_accumulates_multiple(self) -> None:
result = add_messages(["A", "B"], ["C", "D"])
assert result == ["A", "B", "C", "D"]
def test_existing_unchanged(self) -> None:
existing = ["Hello"]
add_messages(existing, ["World"])
# Original should not be mutated
assert existing == ["Hello"]
def test_empty_new_messages(self) -> None:
result = add_messages(["Hello"], [])
assert result == ["Hello"]
def test_both_empty(self) -> None:
result = add_messages([], [])
assert result == []
def test_returns_new_list(self) -> None:
existing = ["Hello"]
new_msgs = ["World"]
result = add_messages(existing, new_msgs)
assert result is not existing
assert result is not new_msgs
class TestAddErrors:
"""Tests for add_errors reducer."""
def test_accumulates_to_empty(self) -> None:
result = add_errors([], ["Error occurred"])
assert result == ["Error occurred"]
def test_accumulates_to_existing(self) -> None:
result = add_errors(["First error"], ["Second error"])
assert result == ["First error", "Second error"]
def test_existing_unchanged(self) -> None:
existing = ["First error"]
add_errors(existing, ["Second error"])
assert existing == ["First error"]
def test_empty_new_errors(self) -> None:
result = add_errors(["Existing"], [])
assert result == ["Existing"]
def test_both_empty(self) -> None:
result = add_errors([], [])
assert result == []
def test_returns_new_list(self) -> None:
existing = ["Error"]
result = add_errors(existing, ["New error"])
assert result is not existing

124
tests/test_versioning.py Normal file
View File

@@ -0,0 +1,124 @@
"""Tests for versioning module. Written FIRST (TDD RED phase)."""
import pytest
from release_agent.versioning import (
calculate_next_version,
format_version,
parse_version,
)
class TestParseVersion:
"""Tests for parse_version function."""
def test_parse_with_v_prefix(self) -> None:
assert parse_version("v1.2.3") == (1, 2, 3)
def test_parse_without_v_prefix(self) -> None:
assert parse_version("1.2.3") == (1, 2, 3)
def test_parse_zeros(self) -> None:
assert parse_version("v0.0.0") == (0, 0, 0)
def test_parse_large_numbers(self) -> None:
assert parse_version("v10.20.300") == (10, 20, 300)
def test_parse_returns_tuple_of_ints(self) -> None:
result = parse_version("v1.2.3")
assert isinstance(result, tuple)
assert len(result) == 3
assert all(isinstance(x, int) for x in result)
def test_parse_invalid_raises_value_error(self) -> None:
with pytest.raises(ValueError):
parse_version("invalid")
def test_parse_partial_version_raises_value_error(self) -> None:
with pytest.raises(ValueError):
parse_version("v1.2")
def test_parse_non_numeric_raises_value_error(self) -> None:
with pytest.raises(ValueError):
parse_version("va.b.c")
class TestFormatVersion:
"""Tests for format_version function."""
def test_format_basic(self) -> None:
assert format_version(1, 0, 3) == "v1.0.3"
def test_format_zeros(self) -> None:
assert format_version(0, 0, 0) == "v0.0.0"
def test_format_large_numbers(self) -> None:
assert format_version(10, 20, 300) == "v10.20.300"
def test_format_returns_string(self) -> None:
result = format_version(1, 2, 3)
assert isinstance(result, str)
def test_format_starts_with_v(self) -> None:
result = format_version(1, 2, 3)
assert result.startswith("v")
class TestCalculateNextVersion:
"""Tests for calculate_next_version function."""
def test_empty_list_returns_v1_0_0(self) -> None:
assert calculate_next_version("my-repo", []) == "v1.0.0"
def test_single_version_increments_patch(self) -> None:
assert calculate_next_version("my-repo", ["v1.0.0"]) == "v1.0.1"
def test_multiple_versions_uses_highest(self) -> None:
assert calculate_next_version("my-repo", ["v1.0.3", "v1.0.1"]) == "v1.0.4"
def test_different_major_versions(self) -> None:
assert calculate_next_version("my-repo", ["v2.1.0", "v1.9.9"]) == "v2.1.1"
def test_skips_malformed_versions(self) -> None:
assert calculate_next_version("my-repo", ["invalid", "v1.0.0"]) == "v1.0.1"
def test_all_malformed_versions_returns_v1_0_0(self) -> None:
assert calculate_next_version("my-repo", ["invalid", "bad", "nope"]) == "v1.0.0"
def test_repo_name_does_not_affect_result(self) -> None:
result_a = calculate_next_version("repo-a", ["v1.0.0"])
result_b = calculate_next_version("repo-b", ["v1.0.0"])
assert result_a == result_b
def test_versions_out_of_order(self) -> None:
assert calculate_next_version("my-repo", ["v1.0.1", "v1.0.3", "v1.0.2"]) == "v1.0.4"
def test_patch_overflow_does_not_occur(self) -> None:
# Just increments patch - no overflow logic required
result = calculate_next_version("my-repo", ["v1.0.99"])
assert result == "v1.0.100"
def test_versions_without_v_prefix_skipped(self) -> None:
# Versions without 'v' prefix are treated as malformed per spec
result = calculate_next_version("my-repo", ["1.0.0", "v2.0.0"])
assert result == "v2.0.1"
def test_result_format_starts_with_v(self) -> None:
result = calculate_next_version("my-repo", ["v1.0.0"])
assert result.startswith("v")
def test_result_has_three_parts(self) -> None:
result = calculate_next_version("my-repo", ["v1.0.0"])
parts = result[1:].split(".")
assert len(parts) == 3
assert all(p.isdigit() for p in parts)
def test_v_prefix_with_nonnumeric_parts_skipped(self) -> None:
# Starts with 'v' but is malformed - should be skipped gracefully
result = calculate_next_version("my-repo", ["va.b.c", "v1.0.0"])
assert result == "v1.0.1"
def test_v_prefix_partial_version_skipped(self) -> None:
# Starts with 'v' but only has two parts - should be skipped
result = calculate_next_version("my-repo", ["v1.0", "v2.0.0"])
assert result == "v2.0.1"

0
tests/tools/__init__.py Normal file
View File

View File

@@ -0,0 +1,9 @@
{
"id": "approval-uuid-123",
"status": "approved",
"approver": {
"id": "user-uuid-456",
"displayName": "Release Bot"
},
"comments": "Approved via release agent"
}

View File

@@ -0,0 +1,9 @@
{
"id": 1001,
"buildNumber": "20240115.1",
"status": "completed",
"result": "succeeded",
"queueTime": "2024-01-15T10:00:00Z",
"startTime": "2024-01-15T10:01:00Z",
"finishTime": "2024-01-15T10:10:00Z"
}

View File

@@ -0,0 +1,8 @@
{
"pullRequestId": 99,
"title": "Release v1.2.0",
"status": "active",
"sourceRefName": "refs/heads/release/v1.2.0",
"targetRefName": "refs/heads/main",
"url": "https://dev.azure.com/my-org/my-project/_apis/git/repositories/my-repo/pullRequests/99"
}

View File

@@ -0,0 +1,8 @@
{
"pullRequestId": 42,
"status": "completed",
"title": "Fix the auth bug",
"completionOptions": {
"mergeStrategy": "squash"
}
}

View File

@@ -0,0 +1,15 @@
{
"value": [
{
"id": 10,
"name": "Release Pipeline",
"folder": "\\"
},
{
"id": 20,
"name": "Build Pipeline",
"folder": "\\"
}
],
"count": 2
}

View File

@@ -0,0 +1,16 @@
{
"pullRequestId": 42,
"title": "Fix the auth bug",
"status": "active",
"sourceRefName": "refs/heads/bug/ALLPOST-999_fix-auth",
"targetRefName": "refs/heads/main",
"url": "https://dev.azure.com/my-org/my-project/_apis/git/repositories/my-repo/pullRequests/42",
"repository": {
"id": "repo-uuid-123",
"name": "my-repo",
"remoteUrl": "https://dev.azure.com/my-org/my-project/_git/my-repo"
},
"lastMergeSourceCommit": {
"commitId": "abc123def456"
}
}

View File

@@ -0,0 +1,11 @@
diff --git a/src/auth.py b/src/auth.py
index 1234567..abcdefg 100644
--- a/src/auth.py
+++ b/src/auth.py
@@ -10,6 +10,10 @@ class AuthService:
def authenticate(self, token: str) -> bool:
- return token == "hardcoded"
+ return self._validate_token(token)
+
+ def _validate_token(self, token: str) -> bool:
+ return len(token) > 0 and token.startswith("Bearer ")

View File

@@ -0,0 +1,11 @@
{
"id": 1001,
"buildNumber": "20240115.1",
"status": "notStarted",
"queueTime": "2024-01-15T10:00:00Z",
"definition": {
"id": 10,
"name": "Release Pipeline"
},
"sourceBranch": "refs/heads/main"
}

View File

@@ -0,0 +1,10 @@
{
"id": "12345",
"key": "ALLPOST-100",
"fields": {
"summary": "Fix the authentication bug",
"status": {
"name": "In Progress"
}
}
}

View File

@@ -0,0 +1,20 @@
{
"transitions": [
{
"id": "11",
"name": "To Do"
},
{
"id": "21",
"name": "In Progress"
},
{
"id": "31",
"name": "Done"
},
{
"id": "41",
"name": "Released"
}
]
}

819
tests/tools/test_azdo.py Normal file
View File

@@ -0,0 +1,819 @@
"""Tests for AzDoClient. Written FIRST (TDD RED phase)."""
import json
from pathlib import Path
import httpx
import pytest
from release_agent.exceptions import AuthenticationError, NotFoundError, ServiceError
from release_agent.models.build import ApprovalRecord, BuildStatus
from release_agent.models.pipeline import PipelineInfo
from release_agent.models.pr import PRInfo
from release_agent.tools.azdo import AzDoClient
# ---------------------------------------------------------------------------
# Fixture helpers
# ---------------------------------------------------------------------------
FIXTURES = Path(__file__).parent / "fixtures"
def _load_json(name: str) -> dict:
return json.loads((FIXTURES / name).read_text())
def _load_text(name: str) -> str:
return (FIXTURES / name).read_text()
def _make_transport(routes: dict[tuple[str, str], tuple[int, bytes | str]]) -> httpx.MockTransport:
"""Build a MockTransport that dispatches based on (method, url_substring)."""
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
for (m, url_fragment), (status, body) in routes.items():
if m == method and url_fragment in url:
content = body if isinstance(body, bytes) else body.encode()
return httpx.Response(status_code=status, content=content)
return httpx.Response(status_code=404, content=b'{"message": "Not found"}')
return httpx.MockTransport(handler)
def _make_client(routes: dict) -> AzDoClient:
"""Create an AzDoClient with mocked HTTP transport."""
transport = _make_transport(routes)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
return AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
# ---------------------------------------------------------------------------
# AzDoClient construction tests
# ---------------------------------------------------------------------------
class TestAzDoClientConstruction:
"""Tests for AzDoClient initialization."""
def test_can_be_instantiated_with_injected_clients(self) -> None:
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/org/proj/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis",
pat="my-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
assert client is not None
async def test_context_manager_closes_clients(self) -> None:
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
async with AzDoClient(
base_url="https://dev.azure.com/org/proj/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis",
pat="my-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
) as client:
assert client is not None
# After context manager exits, clients should be closed
assert http_client.is_closed
assert vsrm_client.is_closed
# ---------------------------------------------------------------------------
# get_pr tests
# ---------------------------------------------------------------------------
class TestGetPr:
"""Tests for AzDoClient.get_pr."""
async def test_returns_pr_info(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
client = _make_client(routes)
result = await client.get_pr(42)
assert isinstance(result, PRInfo)
assert result.pr_id == "42"
async def test_pr_title_extracted(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
client = _make_client(routes)
result = await client.get_pr(42)
assert result.pr_title == "Fix the auth bug"
async def test_pr_branch_extracted(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
client = _make_client(routes)
result = await client.get_pr(42)
assert "ALLPOST-999" in result.branch or "bug" in result.branch
async def test_pr_status_extracted(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {("GET", "pullRequests/42"): (200, json.dumps(pr_data))}
client = _make_client(routes)
result = await client.get_pr(42)
assert result.pr_status == "active"
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "pullRequests/999"): (404, b'{"message": "PR not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_pr(999)
async def test_401_raises_authentication_error(self) -> None:
routes = {("GET", "pullRequests/42"): (401, b'{"message": "Unauthorized"}')}
client = _make_client(routes)
with pytest.raises(AuthenticationError):
await client.get_pr(42)
async def test_500_raises_service_error(self) -> None:
routes = {("GET", "pullRequests/42"): (500, b'{"message": "Internal error"}')}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.get_pr(42)
# ---------------------------------------------------------------------------
# get_pr_diff tests
# ---------------------------------------------------------------------------
class TestGetPrDiff:
"""Tests for AzDoClient.get_pr_diff."""
async def test_returns_diff_string(self) -> None:
pr_data = _load_json("azdo_pr.json")
routes = {
("GET", "pullRequests/42"): (200, json.dumps(pr_data)),
("GET", "diffs"): (200, json.dumps({
"changes": [
{
"item": {"path": "/src/auth.py"},
"changeType": "edit",
}
]
})),
}
client = _make_client(routes)
result = await client.get_pr_diff(42)
assert isinstance(result, str)
async def test_diff_includes_file_paths(self) -> None:
pr_data = _load_json("azdo_pr.json")
diff_data = {
"changes": [
{"item": {"path": "/src/auth.py"}, "changeType": "edit"},
{"item": {"path": "/src/util.py"}, "changeType": "add"},
]
}
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if "diffs" in url:
return httpx.Response(200, content=json.dumps(diff_data).encode())
if "pullRequests/42" in url:
return httpx.Response(200, content=json.dumps(pr_data).encode())
return httpx.Response(404, content=b"{}")
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
result = await client.get_pr_diff(42)
assert "/src/auth.py" in result
assert "/src/util.py" in result
async def test_empty_changes_returns_empty_string(self) -> None:
pr_data = _load_json("azdo_pr.json")
diff_data: dict = {"changes": []}
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if "diffs" in url:
return httpx.Response(200, content=json.dumps(diff_data).encode())
if "pullRequests/42" in url:
return httpx.Response(200, content=json.dumps(pr_data).encode())
return httpx.Response(404, content=b"{}")
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
result = await client.get_pr_diff(42)
assert result == ""
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "pullRequests/999"): (404, b'{"message": "PR not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_pr_diff(999)
async def test_pr_without_url_field_uses_remote_url(self) -> None:
"""When the API response omits the 'url' field, fallback URL is built."""
pr_data = {
"pullRequestId": 42,
"title": "Fix bug",
"status": "active",
"sourceRefName": "refs/heads/fix/ALLPOST-1_fix",
"repository": {
"id": "repo-uuid",
"name": "my-repo",
"remoteUrl": "https://dev.azure.com/org/proj/_git/my-repo",
},
# NOTE: 'url' field is intentionally omitted
}
routes = {
("GET", "pullRequests/42"): (200, json.dumps(pr_data)),
("GET", "diffs"): (200, json.dumps({"changes": []})),
}
client = _make_client(routes)
result = await client.get_pr(42)
assert "42" in str(result.pr_url)
# ---------------------------------------------------------------------------
# merge_pr tests
# ---------------------------------------------------------------------------
class TestMergePr:
"""Tests for AzDoClient.merge_pr."""
async def test_returns_true_on_success(self) -> None:
merge_data = _load_json("azdo_merge_pr.json")
routes = {("PATCH", "pullRequests/42"): (200, json.dumps(merge_data))}
client = _make_client(routes)
result = await client.merge_pr(pr_id=42, last_merge_source_commit="abc123def456")
assert result is True
async def test_404_raises_not_found(self) -> None:
routes = {("PATCH", "pullRequests/999"): (404, b'{"message": "Not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.merge_pr(pr_id=999, last_merge_source_commit="abc123")
async def test_409_raises_service_error(self) -> None:
routes = {("PATCH", "pullRequests/42"): (409, b'{"message": "Conflict"}')}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.merge_pr(pr_id=42, last_merge_source_commit="abc123")
# ---------------------------------------------------------------------------
# create_pr tests
# ---------------------------------------------------------------------------
class TestCreatePr:
"""Tests for AzDoClient.create_pr."""
async def test_returns_dict_with_pr_id(self) -> None:
create_data = _load_json("azdo_create_pr.json")
routes = {("POST", "pullRequests"): (201, json.dumps(create_data))}
client = _make_client(routes)
result = await client.create_pr(
repo="my-repo",
source="refs/heads/release/v1.2.0",
target="refs/heads/main",
title="Release v1.2.0",
description="Release notes",
)
assert isinstance(result, dict)
assert result["pullRequestId"] == 99
async def test_400_raises_service_error(self) -> None:
routes = {("POST", "pullRequests"): (400, b'{"message": "Bad request"}')}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.create_pr(
repo="my-repo",
source="refs/heads/release/v1.2.0",
target="refs/heads/main",
title="Release",
description="",
)
# ---------------------------------------------------------------------------
# list_build_pipelines tests
# ---------------------------------------------------------------------------
class TestListBuildPipelines:
"""Tests for AzDoClient.list_build_pipelines."""
async def test_returns_list_of_pipeline_info(self) -> None:
pipeline_data = _load_json("azdo_pipelines.json")
routes = {("GET", "pipelines"): (200, json.dumps(pipeline_data))}
client = _make_client(routes)
result = await client.list_build_pipelines(repo="my-repo")
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(p, PipelineInfo) for p in result)
async def test_pipeline_ids_extracted(self) -> None:
pipeline_data = _load_json("azdo_pipelines.json")
routes = {("GET", "pipelines"): (200, json.dumps(pipeline_data))}
client = _make_client(routes)
result = await client.list_build_pipelines(repo="my-repo")
ids = [p.id for p in result]
assert 10 in ids
assert 20 in ids
async def test_empty_list_on_no_pipelines(self) -> None:
routes = {("GET", "pipelines"): (200, json.dumps({"value": [], "count": 0}))}
client = _make_client(routes)
result = await client.list_build_pipelines(repo="my-repo")
assert result == []
# ---------------------------------------------------------------------------
# trigger_pipeline tests
# ---------------------------------------------------------------------------
class TestTriggerPipeline:
"""Tests for AzDoClient.trigger_pipeline."""
async def test_returns_dict_with_build_id(self) -> None:
trigger_data = _load_json("azdo_trigger_pipeline.json")
routes = {("POST", "pipelines/10/runs"): (200, json.dumps(trigger_data))}
client = _make_client(routes)
result = await client.trigger_pipeline(pipeline_id=10, branch="refs/heads/main")
assert isinstance(result, dict)
assert result["id"] == 1001
async def test_404_raises_not_found(self) -> None:
routes = {("POST", "pipelines/999/runs"): (404, b'{"message": "Pipeline not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.trigger_pipeline(pipeline_id=999, branch="refs/heads/main")
# ---------------------------------------------------------------------------
# get_build_status tests
# ---------------------------------------------------------------------------
class TestGetBuildStatus:
"""Tests for AzDoClient.get_build_status."""
async def test_returns_build_status_object(self) -> None:
build_data = _load_json("azdo_build_status.json")
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=1001)
assert isinstance(result, BuildStatus)
async def test_status_field_populated(self) -> None:
build_data = _load_json("azdo_build_status.json")
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=1001)
assert result.status == "completed"
async def test_result_field_populated(self) -> None:
build_data = _load_json("azdo_build_status.json")
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=1001)
assert result.result == "succeeded"
async def test_build_url_present(self) -> None:
build_data = _load_json("azdo_build_status.json")
routes = {("GET", "build/builds/1001"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=1001)
# build_url may be None if not in fixture, but field must exist
assert hasattr(result, "build_url")
async def test_result_none_when_not_completed(self) -> None:
build_data = {"id": 99, "status": "inProgress", "buildNumber": "20240101.1"}
routes = {("GET", "build/builds/99"): (200, json.dumps(build_data))}
client = _make_client(routes)
result = await client.get_build_status(build_id=99)
assert result.status == "inProgress"
assert result.result is None
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "build/builds/9999"): (404, b'{"message": "Build not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_build_status(build_id=9999)
# ---------------------------------------------------------------------------
# get_release_approvals tests
# ---------------------------------------------------------------------------
class TestGetReleaseApprovals:
"""Tests for AzDoClient.get_release_approvals."""
async def test_returns_list_of_approval_records(self) -> None:
approvals_data = {
"value": [
{
"id": 101,
"status": "pending",
"releaseEnvironment": {"name": "Sandbox", "release": {"id": 55}},
},
{
"id": 102,
"status": "approved",
"releaseEnvironment": {"name": "Production", "release": {"id": 55}},
},
],
"count": 2,
}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=55)
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(a, ApprovalRecord) for a in result)
async def test_approval_id_populated(self) -> None:
approvals_data = {
"value": [
{
"id": 201,
"status": "pending",
"releaseEnvironment": {"name": "Sandbox", "release": {"id": 10}},
}
],
"count": 1,
}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=10)
assert result[0].approval_id == "201"
async def test_stage_name_populated(self) -> None:
approvals_data = {
"value": [
{
"id": 300,
"status": "pending",
"releaseEnvironment": {"name": "Production", "release": {"id": 20}},
}
],
"count": 1,
}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=20)
assert result[0].stage_name == "Production"
async def test_release_id_populated(self) -> None:
approvals_data = {
"value": [
{
"id": 400,
"status": "pending",
"releaseEnvironment": {"name": "Stage", "release": {"id": 99}},
}
],
"count": 1,
}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=99)
assert result[0].release_id == 99
async def test_empty_list_when_no_approvals(self) -> None:
approvals_data = {"value": [], "count": 0}
routes = {("GET", "release/approvals"): (200, json.dumps(approvals_data))}
client = _make_client(routes)
result = await client.get_release_approvals(release_id=77)
assert result == []
async def test_filters_by_release_id_in_query(self) -> None:
captured_urls: list[str] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_urls.append(str(request.url))
return httpx.Response(200, content=b'{"value": [], "count": 0}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
await client.get_release_approvals(release_id=42)
assert any("approvals" in url for url in captured_urls)
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "release/approvals"): (404, b'{"message": "Not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_release_approvals(release_id=999)
# ---------------------------------------------------------------------------
# get_latest_release tests
# ---------------------------------------------------------------------------
class TestGetLatestRelease:
"""Tests for AzDoClient.get_latest_release."""
async def test_returns_dict(self) -> None:
release_data = {
"value": [{"id": 55, "name": "Release-55", "status": "active"}],
"count": 1,
}
routes = {("GET", "release/releases"): (200, json.dumps(release_data))}
client = _make_client(routes)
result = await client.get_latest_release(definition_id=7)
assert isinstance(result, dict)
assert result["id"] == 55
async def test_returns_empty_dict_when_no_releases(self) -> None:
release_data = {"value": [], "count": 0}
routes = {("GET", "release/releases"): (200, json.dumps(release_data))}
client = _make_client(routes)
result = await client.get_latest_release(definition_id=99)
assert result == {}
async def test_404_raises_not_found(self) -> None:
routes = {("GET", "release/releases"): (404, b'{"message": "Not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.get_latest_release(definition_id=999)
async def test_passes_definition_id_as_filter(self) -> None:
captured_urls: list[str] = []
def handler(request: httpx.Request) -> httpx.Response:
captured_urls.append(str(request.url))
return httpx.Response(200, content=b'{"value": [], "count": 0}')
transport = httpx.MockTransport(handler)
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/my-org/my-project/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/my-org/my-project/_apis",
pat="test-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
await client.get_latest_release(definition_id=13)
assert any("releases" in url for url in captured_urls)
# ---------------------------------------------------------------------------
# approve_release tests
# ---------------------------------------------------------------------------
class TestApproveRelease:
"""Tests for AzDoClient.approve_release."""
async def test_returns_dict_with_status(self) -> None:
approve_data = _load_json("azdo_approve_release.json")
routes = {("PATCH", "release/approvals"): (200, json.dumps(approve_data))}
client = _make_client(routes)
result = await client.approve_release(
approval_id="approval-uuid-123", comment="Approved"
)
assert isinstance(result, dict)
assert result["status"] == "approved"
async def test_404_raises_not_found(self) -> None:
routes = {("PATCH", "release/approvals"): (404, b'{"message": "Approval not found"}')}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.approve_release(approval_id="bad-id", comment="Approve")
# ---------------------------------------------------------------------------
# close() lifecycle tests
# ---------------------------------------------------------------------------
class TestAzDoClientLifecycle:
"""Tests for AzDoClient close() method."""
async def test_close_closes_both_clients(self) -> None:
transport = httpx.MockTransport(lambda r: httpx.Response(200, content=b"{}"))
http_client = httpx.AsyncClient(transport=transport)
vsrm_client = httpx.AsyncClient(transport=transport)
client = AzDoClient(
base_url="https://dev.azure.com/org/proj/_apis",
vsrm_base_url="https://vsrm.dev.azure.com/org/proj/_apis",
pat="my-pat",
http_client=http_client,
vsrm_http_client=vsrm_client,
)
await client.close()
assert http_client.is_closed
assert vsrm_client.is_closed
# ---------------------------------------------------------------------------
# list_active_prs tests
# ---------------------------------------------------------------------------
def _make_pr_list_response(prs: list[dict]) -> str:
return json.dumps({"value": prs, "count": len(prs)})
def _make_active_pr_item(
pr_id: int = 10,
title: str = "Test PR",
branch: str = "refs/heads/feature/ALLPOST-100_fix",
status: str = "active",
repo_name: str = "my-repo",
) -> dict:
return {
"pullRequestId": pr_id,
"title": title,
"status": status,
"sourceRefName": branch,
"targetRefName": "refs/heads/develop",
"url": f"https://dev.azure.com/org/proj/_apis/git/repositories/{repo_name}/pullRequests/{pr_id}",
"repository": {
"id": "repo-uuid",
"name": repo_name,
"remoteUrl": f"https://dev.azure.com/org/proj/_git/{repo_name}",
},
}
class TestListActivePrs:
"""Tests for AzDoClient.list_active_prs."""
async def test_returns_list_of_pr_info(self) -> None:
pr_item = _make_active_pr_item()
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response([pr_item]),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], PRInfo)
async def test_pr_id_extracted(self) -> None:
pr_item = _make_active_pr_item(pr_id=55)
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response([pr_item]),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert result[0].pr_id == "55"
async def test_pr_title_extracted(self) -> None:
pr_item = _make_active_pr_item(title="My Feature")
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response([pr_item]),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert result[0].pr_title == "My Feature"
async def test_empty_list_when_no_prs(self) -> None:
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response([]),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert result == []
async def test_multiple_prs_returned(self) -> None:
prs = [
_make_active_pr_item(pr_id=10, title="PR 10"),
_make_active_pr_item(pr_id=20, title="PR 20"),
]
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
200,
_make_pr_list_response(prs),
)
}
client = _make_client(routes)
result = await client.list_active_prs("my-repo", "refs/heads/develop")
assert len(result) == 2
assert {r.pr_id for r in result} == {"10", "20"}
async def test_404_raises_not_found(self) -> None:
routes = {
("GET", "git/repositories/missing-repo/pullRequests"): (
404,
b'{"message": "Repo not found"}',
)
}
client = _make_client(routes)
with pytest.raises(NotFoundError):
await client.list_active_prs("missing-repo", "refs/heads/develop")
async def test_401_raises_authentication_error(self) -> None:
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
401,
b'{"message": "Unauthorized"}',
)
}
client = _make_client(routes)
with pytest.raises(AuthenticationError):
await client.list_active_prs("my-repo", "refs/heads/develop")
async def test_500_raises_service_error(self) -> None:
routes = {
("GET", "git/repositories/my-repo/pullRequests"): (
500,
b'{"message": "Internal error"}',
)
}
client = _make_client(routes)
with pytest.raises(ServiceError):
await client.list_active_prs("my-repo", "refs/heads/develop")

View File

@@ -0,0 +1,454 @@
"""Tests for ClaudeReviewer using Claude Code CLI subprocess."""
import json
import pytest
from release_agent.models.review import ReviewResult
from release_agent.tools.claude_review import (
ClaudeReviewer,
_build_prompt,
_parse_cli_output,
_truncate_diff,
)
MAX_DIFF_CHARS = 100_000
# ---------------------------------------------------------------------------
# Helpers — fake subprocess runner
# ---------------------------------------------------------------------------
def _make_cli_output(
verdict: str = "approve",
summary: str = "LGTM",
issues: list | None = None,
) -> str:
"""Build a JSON string mimicking Claude Code CLI --output-format json."""
structured = {
"verdict": verdict,
"summary": summary,
"issues": issues or [],
}
return json.dumps({"result": "", "structured_output": structured})
def _make_subprocess_runner(
stdout: str = "",
stderr: str = "",
returncode: int = 0,
):
"""Return a fake run_subprocess callable that records calls."""
calls: list[dict] = []
async def fake_run(*, cmd, cwd, timeout):
calls.append({"cmd": cmd, "cwd": cwd, "timeout": timeout})
return (stdout, stderr, returncode)
return fake_run, calls
# ---------------------------------------------------------------------------
# _truncate_diff tests
# ---------------------------------------------------------------------------
class TestTruncateDiff:
def test_short_diff_not_truncated(self) -> None:
diff = "short diff"
assert _truncate_diff(diff) == diff
def test_exact_limit_not_truncated(self) -> None:
diff = "x" * MAX_DIFF_CHARS
assert _truncate_diff(diff) == diff
def test_over_limit_truncated(self) -> None:
diff = "x" * (MAX_DIFF_CHARS + 1000)
result = _truncate_diff(diff)
assert len(result) < len(diff)
assert "TRUNCATED" in result
# ---------------------------------------------------------------------------
# _build_prompt tests
# ---------------------------------------------------------------------------
class TestBuildPrompt:
def test_contains_pr_title(self) -> None:
prompt = _build_prompt(diff="d", pr_title="My Title", repo_name="repo")
assert "My Title" in prompt
def test_contains_repo_name(self) -> None:
prompt = _build_prompt(diff="d", pr_title="t", repo_name="my-repo")
assert "my-repo" in prompt
def test_contains_diff(self) -> None:
prompt = _build_prompt(diff="UNIQUE_DIFF", pr_title="t", repo_name="r")
assert "UNIQUE_DIFF" in prompt
# ---------------------------------------------------------------------------
# _parse_cli_output tests
# ---------------------------------------------------------------------------
class TestParseCliOutput:
def test_parses_structured_output(self) -> None:
stdout = _make_cli_output(verdict="approve", summary="Good")
result = _parse_cli_output(stdout)
assert isinstance(result, ReviewResult)
assert result.verdict == "approve"
assert result.summary == "Good"
def test_parses_request_changes(self) -> None:
stdout = _make_cli_output(
verdict="request_changes",
summary="Has issues",
issues=[{"severity": "blocker", "description": "SQL injection"}],
)
result = _parse_cli_output(stdout)
assert result.verdict == "request_changes"
assert len(result.issues) == 1
assert result.has_blockers is True
def test_parses_issues_with_optional_fields(self) -> None:
stdout = _make_cli_output(
verdict="request_changes",
summary="Issues found",
issues=[{
"severity": "warning",
"description": "Style issue",
"file_path": "src/foo.py",
"suggestion": "Fix it",
}],
)
result = _parse_cli_output(stdout)
assert result.issues[0].file_path == "src/foo.py"
assert result.issues[0].suggestion == "Fix it"
def test_empty_issues_no_blockers(self) -> None:
stdout = _make_cli_output(verdict="approve", summary="Clean", issues=[])
result = _parse_cli_output(stdout)
assert result.has_blockers is False
assert len(result.issues) == 0
def test_result_field_as_json_string(self) -> None:
"""When structured_output is absent, falls back to parsing result as JSON."""
inner = {"verdict": "approve", "summary": "OK", "issues": []}
stdout = json.dumps({"result": json.dumps(inner)})
result = _parse_cli_output(stdout)
assert result.verdict == "approve"
def test_invalid_json_raises(self) -> None:
with pytest.raises(ValueError, match="Failed to parse"):
_parse_cli_output("not json at all")
def test_missing_structured_output_and_result_raises(self) -> None:
with pytest.raises(ValueError, match="No structured_output"):
_parse_cli_output(json.dumps({"other": "data"}))
def test_non_dict_structured_output_raises(self) -> None:
stdout = json.dumps({"structured_output": ["not", "a", "dict"]})
with pytest.raises(ValueError, match="Expected dict"):
_parse_cli_output(stdout)
def test_result_is_non_json_string_raises(self) -> None:
stdout = json.dumps({"result": "just plain text, not json"})
with pytest.raises(ValueError, match="not valid JSON"):
_parse_cli_output(stdout)
# ---------------------------------------------------------------------------
# ClaudeReviewer construction tests
# ---------------------------------------------------------------------------
class TestClaudeReviewerConstruction:
def test_can_be_instantiated(self) -> None:
reviewer = ClaudeReviewer()
assert reviewer is not None
def test_custom_claude_cmd(self) -> None:
reviewer = ClaudeReviewer(claude_cmd="/usr/local/bin/claude")
assert reviewer._claude_cmd == "/usr/local/bin/claude"
def test_custom_timeout(self) -> None:
reviewer = ClaudeReviewer(timeout=60)
assert reviewer._timeout == 60
# ---------------------------------------------------------------------------
# review_pr tests
# ---------------------------------------------------------------------------
class TestReviewPr:
async def test_returns_review_result(self) -> None:
stdout = _make_cli_output(verdict="approve", summary="Looks good")
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
result = await reviewer.review_pr(
diff="diff --git a/foo.py ...",
pr_title="Fix bug",
repo_name="my-repo",
)
assert isinstance(result, ReviewResult)
assert result.verdict == "approve"
async def test_passes_cwd_to_subprocess(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(
diff="diff",
pr_title="PR",
repo_name="repo",
cwd="/path/to/worktree",
)
assert calls[0]["cwd"] == "/path/to/worktree"
async def test_cmd_includes_claude_p(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
assert cmd[0] == "claude"
assert "-p" in cmd
async def test_cmd_includes_output_format_json(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
idx = cmd.index("--output-format")
assert cmd[idx + 1] == "json"
async def test_cmd_includes_json_schema(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
assert "--json-schema" in cmd
async def test_cmd_includes_allowed_tools(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
idx = cmd.index("--allowedTools")
assert "Read" in cmd[idx + 1]
async def test_cmd_includes_system_prompt(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
assert "--system-prompt" in cmd
async def test_nonzero_exit_raises(self) -> None:
runner, _ = _make_subprocess_runner(
stdout="", stderr="error occurred", returncode=1
)
reviewer = ClaudeReviewer(run_subprocess=runner)
with pytest.raises(RuntimeError, match="exited with code 1"):
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
async def test_timeout_passed_to_subprocess(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner, timeout=120)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
assert calls[0]["timeout"] == 120
async def test_pr_title_in_prompt(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(
diff="d", pr_title="Specific Title", repo_name="r"
)
cmd = calls[0]["cmd"]
prompt = cmd[cmd.index("-p") + 1]
assert "Specific Title" in prompt
async def test_repo_name_in_prompt(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(
diff="d", pr_title="t", repo_name="special-repo"
)
cmd = calls[0]["cmd"]
prompt = cmd[cmd.index("-p") + 1]
assert "special-repo" in prompt
async def test_cwd_none_when_not_provided(self) -> None:
stdout = _make_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
assert calls[0]["cwd"] is None
async def test_request_changes_with_issues(self) -> None:
stdout = _make_cli_output(
verdict="request_changes",
summary="Problems found",
issues=[
{"severity": "blocker", "description": "Security flaw"},
{"severity": "warning", "description": "Missing docs"},
],
)
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
result = await reviewer.review_pr(diff="d", pr_title="t", repo_name="r")
assert result.verdict == "request_changes"
assert len(result.issues) == 2
assert result.has_blockers is True
# ---------------------------------------------------------------------------
# ClaudeReviewer.generate_ticket_content tests
# ---------------------------------------------------------------------------
def _make_ticket_cli_output(summary: str = "My summary", description: str = "My desc") -> str:
"""Build a JSON string mimicking Claude Code CLI output for ticket generation."""
structured = {"summary": summary, "description": description}
return json.dumps({"result": "", "structured_output": structured})
class TestGenerateTicketContent:
"""Tests for ClaudeReviewer.generate_ticket_content."""
async def test_returns_tuple_of_summary_and_description(self) -> None:
stdout = _make_ticket_cli_output(summary="Fix login bug", description="Detailed desc")
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
result = await reviewer.generate_ticket_content(
diff="edit: main.py", pr_title="Fix login", repo_name="backend"
)
assert isinstance(result, tuple)
assert len(result) == 2
async def test_returns_correct_summary(self) -> None:
stdout = _make_ticket_cli_output(summary="Implement OAuth2 login")
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
summary, _ = await reviewer.generate_ticket_content(
diff="d", pr_title="Add OAuth", repo_name="auth-service"
)
assert summary == "Implement OAuth2 login"
async def test_returns_correct_description(self) -> None:
stdout = _make_ticket_cli_output(description="This adds OAuth2 support for the login flow")
runner, _ = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
_, description = await reviewer.generate_ticket_content(
diff="d", pr_title="Add OAuth", repo_name="auth-service"
)
assert description == "This adds OAuth2 support for the login flow"
async def test_uses_json_schema_with_summary_and_description_fields(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
cmd = calls[0]["cmd"]
# Verify --json-schema flag was used
assert "--json-schema" in cmd
schema_idx = cmd.index("--json-schema")
schema_json = cmd[schema_idx + 1]
schema = json.loads(schema_json)
assert "summary" in schema["properties"]
assert "description" in schema["properties"]
async def test_passes_pr_title_in_prompt(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(
diff="d", pr_title="My Unique PR Title", repo_name="r"
)
cmd_str = " ".join(calls[0]["cmd"])
assert "My Unique PR Title" in cmd_str
async def test_passes_repo_name_in_prompt(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(
diff="d", pr_title="t", repo_name="my-special-repo"
)
cmd_str = " ".join(calls[0]["cmd"])
assert "my-special-repo" in cmd_str
async def test_passes_cwd_to_subprocess(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(
diff="d", pr_title="t", repo_name="r", cwd="/some/path"
)
assert calls[0]["cwd"] == "/some/path"
async def test_cwd_none_by_default(self) -> None:
stdout = _make_ticket_cli_output()
runner, calls = _make_subprocess_runner(stdout=stdout)
reviewer = ClaudeReviewer(run_subprocess=runner)
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
assert calls[0]["cwd"] is None
async def test_raises_on_nonzero_exit_code(self) -> None:
runner, _ = _make_subprocess_runner(stdout="", stderr="Error", returncode=1)
reviewer = ClaudeReviewer(run_subprocess=runner)
with pytest.raises(RuntimeError, match="Claude CLI"):
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")
async def test_raises_on_invalid_json_output(self) -> None:
runner, _ = _make_subprocess_runner(stdout="not json at all")
reviewer = ClaudeReviewer(run_subprocess=runner)
with pytest.raises((ValueError, Exception)):
await reviewer.generate_ticket_content(diff="d", pr_title="t", repo_name="r")

Some files were not shown because too many files have changed in this diff Show More