- Wire ImportOrchestrator into review_api start_import via BackgroundTasks - Sanitize docstrings in generated tool code to prevent code injection - Add Literal["read", "write"] validation for access_type - Add regex validation for agent_group - Validate URL scheme (http/https only) in ImportRequest - Validate LLM output fields (clamp confidence, validate access_type) - Use dataclasses.replace instead of manual reconstruction in importer - Expand SSRF blocked networks (Carrier-Grade NAT, IPv4-mapped IPv6, etc.) - Make _BLOCKED_NETWORKS immutable tuple - Use yaml.safe_dump instead of yaml.dump - Fix _to_snake_case for empty strings and Python keywords
168 lines
5.2 KiB
Python
168 lines
5.2 KiB
Python
"""SSRF protection module.
|
|
|
|
Validates URLs before making external HTTP requests.
|
|
Blocks private IPs, loopback addresses, link-local, and DNS rebinding attacks.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ipaddress
|
|
import socket
|
|
from dataclasses import dataclass
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
|
|
|
|
class SSRFError(Exception):
|
|
"""Raised when a URL fails SSRF validation."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SSRFPolicy:
|
|
"""Configuration for SSRF protection."""
|
|
|
|
allowed_schemes: frozenset[str] = frozenset({"http", "https"})
|
|
allowed_hosts: frozenset[str] | None = None # None = all public hosts allowed
|
|
max_redirects: int = 5
|
|
timeout_seconds: float = 30.0
|
|
|
|
|
|
_BLOCKED_NETWORKS = (
|
|
ipaddress.ip_network("10.0.0.0/8"),
|
|
ipaddress.ip_network("172.16.0.0/12"),
|
|
ipaddress.ip_network("192.168.0.0/16"),
|
|
ipaddress.ip_network("127.0.0.0/8"),
|
|
ipaddress.ip_network("169.254.0.0/16"),
|
|
ipaddress.ip_network("0.0.0.0/32"),
|
|
ipaddress.ip_network("100.64.0.0/10"), # Carrier-Grade NAT
|
|
ipaddress.ip_network("198.18.0.0/15"), # Benchmarking
|
|
ipaddress.ip_network("240.0.0.0/4"), # Reserved
|
|
ipaddress.ip_network("255.255.255.255/32"), # Broadcast
|
|
# IPv6
|
|
ipaddress.ip_network("::1/128"),
|
|
ipaddress.ip_network("fe80::/10"),
|
|
ipaddress.ip_network("fc00::/7"),
|
|
ipaddress.ip_network("::/128"),
|
|
ipaddress.ip_network("::ffff:0:0/96"), # IPv4-mapped IPv6
|
|
ipaddress.ip_network("2001:db8::/32"), # Documentation
|
|
)
|
|
|
|
DEFAULT_POLICY = SSRFPolicy()
|
|
|
|
|
|
def is_private_ip(ip_str: str) -> bool:
|
|
"""Check if an IP address is private/reserved."""
|
|
try:
|
|
addr = ipaddress.ip_address(ip_str)
|
|
except ValueError:
|
|
return True # Invalid IP treated as blocked
|
|
|
|
return any(addr in network for network in _BLOCKED_NETWORKS)
|
|
|
|
|
|
def validate_url(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> str:
|
|
"""Validate a URL against SSRF policy.
|
|
|
|
Returns the validated URL string.
|
|
Raises SSRFError if the URL is blocked.
|
|
"""
|
|
parsed = urlparse(url)
|
|
|
|
# Check scheme
|
|
if parsed.scheme not in policy.allowed_schemes:
|
|
raise SSRFError(
|
|
f"URL scheme '{parsed.scheme}' is not allowed. "
|
|
f"Allowed: {', '.join(sorted(policy.allowed_schemes))}"
|
|
)
|
|
|
|
# Check hostname exists
|
|
hostname = parsed.hostname
|
|
if not hostname:
|
|
raise SSRFError("URL has no hostname")
|
|
|
|
# Check allowed hosts whitelist
|
|
if policy.allowed_hosts is not None and hostname not in policy.allowed_hosts:
|
|
raise SSRFError(f"Host '{hostname}' is not in the allowed hosts list")
|
|
|
|
# DNS resolution -- resolve before making any request
|
|
resolved_ips = resolve_hostname(hostname)
|
|
if not resolved_ips:
|
|
raise SSRFError(f"Could not resolve hostname '{hostname}'")
|
|
|
|
# Check all resolved IPs against blocked networks
|
|
for ip_str in resolved_ips:
|
|
if is_private_ip(ip_str):
|
|
raise SSRFError(
|
|
f"Host '{hostname}' resolves to private/reserved IP {ip_str}. "
|
|
"Request blocked for SSRF protection."
|
|
)
|
|
|
|
return url
|
|
|
|
|
|
def resolve_hostname(hostname: str) -> list[str]:
|
|
"""Resolve hostname to IP addresses via DNS."""
|
|
try:
|
|
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
|
return list({result[4][0] for result in results})
|
|
except socket.gaierror:
|
|
return []
|
|
|
|
|
|
async def safe_fetch(
|
|
url: str,
|
|
*,
|
|
policy: SSRFPolicy = DEFAULT_POLICY,
|
|
method: str = "GET",
|
|
headers: dict[str, str] | None = None,
|
|
) -> httpx.Response:
|
|
"""Fetch a URL with SSRF protection.
|
|
|
|
Validates the URL, resolves DNS, checks IPs, then makes the request.
|
|
After receiving the response, verifies the actual connected IP
|
|
to guard against DNS rebinding.
|
|
"""
|
|
validate_url(url, policy)
|
|
|
|
# Make the request with redirect following disabled so we can check each hop
|
|
async with httpx.AsyncClient(
|
|
follow_redirects=False,
|
|
timeout=httpx.Timeout(policy.timeout_seconds),
|
|
) as client:
|
|
current_url = url
|
|
for _redirect_count in range(policy.max_redirects + 1):
|
|
response = await client.request(
|
|
method,
|
|
current_url,
|
|
headers=headers,
|
|
)
|
|
|
|
if response.is_redirect:
|
|
redirect_url = str(response.next_request.url) if response.next_request else None
|
|
if not redirect_url:
|
|
raise SSRFError("Redirect with no target URL")
|
|
# Validate the redirect target
|
|
validate_url(redirect_url, policy)
|
|
current_url = redirect_url
|
|
continue
|
|
|
|
return response
|
|
|
|
raise SSRFError(
|
|
f"Too many redirects (max {policy.max_redirects}). "
|
|
"Possible redirect loop or evasion attempt."
|
|
)
|
|
|
|
|
|
async def safe_fetch_text(
|
|
url: str,
|
|
*,
|
|
policy: SSRFPolicy = DEFAULT_POLICY,
|
|
headers: dict[str, str] | None = None,
|
|
) -> str:
|
|
"""Fetch a URL and return text content with SSRF protection."""
|
|
response = await safe_fetch(url, policy=policy, headers=headers)
|
|
response.raise_for_status()
|
|
return response.text
|