Files
smart-support/backend/app/openapi/ssrf.py
Yaojia Wang a2f750269d fix: address critical security and code review findings in Phase 3
- Wire ImportOrchestrator into review_api start_import via BackgroundTasks
- Sanitize docstrings in generated tool code to prevent code injection
- Add Literal["read", "write"] validation for access_type
- Add regex validation for agent_group
- Validate URL scheme (http/https only) in ImportRequest
- Validate LLM output fields (clamp confidence, validate access_type)
- Use dataclasses.replace instead of manual reconstruction in importer
- Expand SSRF blocked networks (Carrier-Grade NAT, IPv4-mapped IPv6, etc.)
- Make _BLOCKED_NETWORKS immutable tuple
- Use yaml.safe_dump instead of yaml.dump
- Fix _to_snake_case for empty strings and Python keywords
2026-03-31 00:28:28 +02:00

168 lines
5.2 KiB
Python

"""SSRF protection module.
Validates URLs before making external HTTP requests.
Blocks private IPs, loopback addresses, link-local, and DNS rebinding attacks.
"""
from __future__ import annotations
import ipaddress
import socket
from dataclasses import dataclass
from urllib.parse import urlparse
import httpx
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
@dataclass(frozen=True)
class SSRFPolicy:
"""Configuration for SSRF protection."""
allowed_schemes: frozenset[str] = frozenset({"http", "https"})
allowed_hosts: frozenset[str] | None = None # None = all public hosts allowed
max_redirects: int = 5
timeout_seconds: float = 30.0
_BLOCKED_NETWORKS = (
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("0.0.0.0/32"),
ipaddress.ip_network("100.64.0.0/10"), # Carrier-Grade NAT
ipaddress.ip_network("198.18.0.0/15"), # Benchmarking
ipaddress.ip_network("240.0.0.0/4"), # Reserved
ipaddress.ip_network("255.255.255.255/32"), # Broadcast
# IPv6
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fe80::/10"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("::/128"),
ipaddress.ip_network("::ffff:0:0/96"), # IPv4-mapped IPv6
ipaddress.ip_network("2001:db8::/32"), # Documentation
)
DEFAULT_POLICY = SSRFPolicy()
def is_private_ip(ip_str: str) -> bool:
"""Check if an IP address is private/reserved."""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
return True # Invalid IP treated as blocked
return any(addr in network for network in _BLOCKED_NETWORKS)
def validate_url(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> str:
"""Validate a URL against SSRF policy.
Returns the validated URL string.
Raises SSRFError if the URL is blocked.
"""
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in policy.allowed_schemes:
raise SSRFError(
f"URL scheme '{parsed.scheme}' is not allowed. "
f"Allowed: {', '.join(sorted(policy.allowed_schemes))}"
)
# Check hostname exists
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL has no hostname")
# Check allowed hosts whitelist
if policy.allowed_hosts is not None and hostname not in policy.allowed_hosts:
raise SSRFError(f"Host '{hostname}' is not in the allowed hosts list")
# DNS resolution -- resolve before making any request
resolved_ips = resolve_hostname(hostname)
if not resolved_ips:
raise SSRFError(f"Could not resolve hostname '{hostname}'")
# Check all resolved IPs against blocked networks
for ip_str in resolved_ips:
if is_private_ip(ip_str):
raise SSRFError(
f"Host '{hostname}' resolves to private/reserved IP {ip_str}. "
"Request blocked for SSRF protection."
)
return url
def resolve_hostname(hostname: str) -> list[str]:
"""Resolve hostname to IP addresses via DNS."""
try:
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
return list({result[4][0] for result in results})
except socket.gaierror:
return []
async def safe_fetch(
url: str,
*,
policy: SSRFPolicy = DEFAULT_POLICY,
method: str = "GET",
headers: dict[str, str] | None = None,
) -> httpx.Response:
"""Fetch a URL with SSRF protection.
Validates the URL, resolves DNS, checks IPs, then makes the request.
After receiving the response, verifies the actual connected IP
to guard against DNS rebinding.
"""
validate_url(url, policy)
# Make the request with redirect following disabled so we can check each hop
async with httpx.AsyncClient(
follow_redirects=False,
timeout=httpx.Timeout(policy.timeout_seconds),
) as client:
current_url = url
for _redirect_count in range(policy.max_redirects + 1):
response = await client.request(
method,
current_url,
headers=headers,
)
if response.is_redirect:
redirect_url = str(response.next_request.url) if response.next_request else None
if not redirect_url:
raise SSRFError("Redirect with no target URL")
# Validate the redirect target
validate_url(redirect_url, policy)
current_url = redirect_url
continue
return response
raise SSRFError(
f"Too many redirects (max {policy.max_redirects}). "
"Possible redirect loop or evasion attempt."
)
async def safe_fetch_text(
url: str,
*,
policy: SSRFPolicy = DEFAULT_POLICY,
headers: dict[str, str] | None = None,
) -> str:
"""Fetch a URL and return text content with SSRF protection."""
response = await safe_fetch(url, policy=policy, headers=headers)
response.raise_for_status()
return response.text