- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation - OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit - Structural spec validator (3.0.x/3.1.x) - Endpoint parser with $ref resolution, auto-generated operation IDs - Heuristic + LLM endpoint classifier with Protocol interface - Review API at /api/openapi (import, job status, classification CRUD, approve) - @tool code generator + Agent YAML generator - Import orchestrator (fetch -> validate -> parse -> classify pipeline) - 125 new tests, 322 total passing, 93.23% coverage
162 lines
4.8 KiB
Python
162 lines
4.8 KiB
Python
"""SSRF protection module.
|
|
|
|
Validates URLs before making external HTTP requests.
|
|
Blocks private IPs, loopback addresses, link-local, and DNS rebinding attacks.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ipaddress
|
|
import socket
|
|
from dataclasses import dataclass
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
|
|
|
|
class SSRFError(Exception):
|
|
"""Raised when a URL fails SSRF validation."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SSRFPolicy:
|
|
"""Configuration for SSRF protection."""
|
|
|
|
allowed_schemes: frozenset[str] = frozenset({"http", "https"})
|
|
allowed_hosts: frozenset[str] | None = None # None = all public hosts allowed
|
|
max_redirects: int = 5
|
|
timeout_seconds: float = 30.0
|
|
|
|
|
|
_BLOCKED_NETWORKS = [
|
|
ipaddress.ip_network("10.0.0.0/8"),
|
|
ipaddress.ip_network("172.16.0.0/12"),
|
|
ipaddress.ip_network("192.168.0.0/16"),
|
|
ipaddress.ip_network("127.0.0.0/8"),
|
|
ipaddress.ip_network("169.254.0.0/16"),
|
|
ipaddress.ip_network("0.0.0.0/32"),
|
|
# IPv6
|
|
ipaddress.ip_network("::1/128"),
|
|
ipaddress.ip_network("fe80::/10"),
|
|
ipaddress.ip_network("fc00::/7"),
|
|
ipaddress.ip_network("::/128"),
|
|
]
|
|
|
|
DEFAULT_POLICY = SSRFPolicy()
|
|
|
|
|
|
def is_private_ip(ip_str: str) -> bool:
|
|
"""Check if an IP address is private/reserved."""
|
|
try:
|
|
addr = ipaddress.ip_address(ip_str)
|
|
except ValueError:
|
|
return True # Invalid IP treated as blocked
|
|
|
|
return any(addr in network for network in _BLOCKED_NETWORKS)
|
|
|
|
|
|
def validate_url(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> str:
|
|
"""Validate a URL against SSRF policy.
|
|
|
|
Returns the validated URL string.
|
|
Raises SSRFError if the URL is blocked.
|
|
"""
|
|
parsed = urlparse(url)
|
|
|
|
# Check scheme
|
|
if parsed.scheme not in policy.allowed_schemes:
|
|
raise SSRFError(
|
|
f"URL scheme '{parsed.scheme}' is not allowed. "
|
|
f"Allowed: {', '.join(sorted(policy.allowed_schemes))}"
|
|
)
|
|
|
|
# Check hostname exists
|
|
hostname = parsed.hostname
|
|
if not hostname:
|
|
raise SSRFError("URL has no hostname")
|
|
|
|
# Check allowed hosts whitelist
|
|
if policy.allowed_hosts is not None and hostname not in policy.allowed_hosts:
|
|
raise SSRFError(f"Host '{hostname}' is not in the allowed hosts list")
|
|
|
|
# DNS resolution -- resolve before making any request
|
|
resolved_ips = resolve_hostname(hostname)
|
|
if not resolved_ips:
|
|
raise SSRFError(f"Could not resolve hostname '{hostname}'")
|
|
|
|
# Check all resolved IPs against blocked networks
|
|
for ip_str in resolved_ips:
|
|
if is_private_ip(ip_str):
|
|
raise SSRFError(
|
|
f"Host '{hostname}' resolves to private/reserved IP {ip_str}. "
|
|
"Request blocked for SSRF protection."
|
|
)
|
|
|
|
return url
|
|
|
|
|
|
def resolve_hostname(hostname: str) -> list[str]:
|
|
"""Resolve hostname to IP addresses via DNS."""
|
|
try:
|
|
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
|
return list({result[4][0] for result in results})
|
|
except socket.gaierror:
|
|
return []
|
|
|
|
|
|
async def safe_fetch(
|
|
url: str,
|
|
*,
|
|
policy: SSRFPolicy = DEFAULT_POLICY,
|
|
method: str = "GET",
|
|
headers: dict[str, str] | None = None,
|
|
) -> httpx.Response:
|
|
"""Fetch a URL with SSRF protection.
|
|
|
|
Validates the URL, resolves DNS, checks IPs, then makes the request.
|
|
After receiving the response, verifies the actual connected IP
|
|
to guard against DNS rebinding.
|
|
"""
|
|
validate_url(url, policy)
|
|
|
|
# Make the request with redirect following disabled so we can check each hop
|
|
async with httpx.AsyncClient(
|
|
follow_redirects=False,
|
|
timeout=httpx.Timeout(policy.timeout_seconds),
|
|
) as client:
|
|
current_url = url
|
|
for _redirect_count in range(policy.max_redirects + 1):
|
|
response = await client.request(
|
|
method,
|
|
current_url,
|
|
headers=headers,
|
|
)
|
|
|
|
if response.is_redirect:
|
|
redirect_url = str(response.next_request.url) if response.next_request else None
|
|
if not redirect_url:
|
|
raise SSRFError("Redirect with no target URL")
|
|
# Validate the redirect target
|
|
validate_url(redirect_url, policy)
|
|
current_url = redirect_url
|
|
continue
|
|
|
|
return response
|
|
|
|
raise SSRFError(
|
|
f"Too many redirects (max {policy.max_redirects}). "
|
|
"Possible redirect loop or evasion attempt."
|
|
)
|
|
|
|
|
|
async def safe_fetch_text(
|
|
url: str,
|
|
*,
|
|
policy: SSRFPolicy = DEFAULT_POLICY,
|
|
headers: dict[str, str] | None = None,
|
|
) -> str:
|
|
"""Fetch a URL and return text content with SSRF protection."""
|
|
response = await safe_fetch(url, policy=policy, headers=headers)
|
|
response.raise_for_status()
|
|
return response.text
|