"""SSRF protection module. Validates URLs before making external HTTP requests. Blocks private IPs, loopback addresses, link-local, and DNS rebinding attacks. """ from __future__ import annotations import ipaddress import socket from dataclasses import dataclass from urllib.parse import urlparse import httpx class SSRFError(Exception): """Raised when a URL fails SSRF validation.""" @dataclass(frozen=True) class SSRFPolicy: """Configuration for SSRF protection.""" allowed_schemes: frozenset[str] = frozenset({"http", "https"}) allowed_hosts: frozenset[str] | None = None # None = all public hosts allowed max_redirects: int = 5 timeout_seconds: float = 30.0 _BLOCKED_NETWORKS = [ ipaddress.ip_network("10.0.0.0/8"), ipaddress.ip_network("172.16.0.0/12"), ipaddress.ip_network("192.168.0.0/16"), ipaddress.ip_network("127.0.0.0/8"), ipaddress.ip_network("169.254.0.0/16"), ipaddress.ip_network("0.0.0.0/32"), # IPv6 ipaddress.ip_network("::1/128"), ipaddress.ip_network("fe80::/10"), ipaddress.ip_network("fc00::/7"), ipaddress.ip_network("::/128"), ] DEFAULT_POLICY = SSRFPolicy() def is_private_ip(ip_str: str) -> bool: """Check if an IP address is private/reserved.""" try: addr = ipaddress.ip_address(ip_str) except ValueError: return True # Invalid IP treated as blocked return any(addr in network for network in _BLOCKED_NETWORKS) def validate_url(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> str: """Validate a URL against SSRF policy. Returns the validated URL string. Raises SSRFError if the URL is blocked. """ parsed = urlparse(url) # Check scheme if parsed.scheme not in policy.allowed_schemes: raise SSRFError( f"URL scheme '{parsed.scheme}' is not allowed. " f"Allowed: {', '.join(sorted(policy.allowed_schemes))}" ) # Check hostname exists hostname = parsed.hostname if not hostname: raise SSRFError("URL has no hostname") # Check allowed hosts whitelist if policy.allowed_hosts is not None and hostname not in policy.allowed_hosts: raise SSRFError(f"Host '{hostname}' is not in the allowed hosts list") # DNS resolution -- resolve before making any request resolved_ips = resolve_hostname(hostname) if not resolved_ips: raise SSRFError(f"Could not resolve hostname '{hostname}'") # Check all resolved IPs against blocked networks for ip_str in resolved_ips: if is_private_ip(ip_str): raise SSRFError( f"Host '{hostname}' resolves to private/reserved IP {ip_str}. " "Request blocked for SSRF protection." ) return url def resolve_hostname(hostname: str) -> list[str]: """Resolve hostname to IP addresses via DNS.""" try: results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM) return list({result[4][0] for result in results}) except socket.gaierror: return [] async def safe_fetch( url: str, *, policy: SSRFPolicy = DEFAULT_POLICY, method: str = "GET", headers: dict[str, str] | None = None, ) -> httpx.Response: """Fetch a URL with SSRF protection. Validates the URL, resolves DNS, checks IPs, then makes the request. After receiving the response, verifies the actual connected IP to guard against DNS rebinding. """ validate_url(url, policy) # Make the request with redirect following disabled so we can check each hop async with httpx.AsyncClient( follow_redirects=False, timeout=httpx.Timeout(policy.timeout_seconds), ) as client: current_url = url for _redirect_count in range(policy.max_redirects + 1): response = await client.request( method, current_url, headers=headers, ) if response.is_redirect: redirect_url = str(response.next_request.url) if response.next_request else None if not redirect_url: raise SSRFError("Redirect with no target URL") # Validate the redirect target validate_url(redirect_url, policy) current_url = redirect_url continue return response raise SSRFError( f"Too many redirects (max {policy.max_redirects}). " "Possible redirect loop or evasion attempt." ) async def safe_fetch_text( url: str, *, policy: SSRFPolicy = DEFAULT_POLICY, headers: dict[str, str] | None = None, ) -> str: """Fetch a URL and return text content with SSRF protection.""" response = await safe_fetch(url, policy=policy, headers=headers) response.raise_for_status() return response.text