Files
smart-support/backend/app/openapi/ssrf.py
Yaojia Wang a54eb224e0 feat: complete phase 3 -- OpenAPI auto-discovery, SSRF protection, tool generation
- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation
- OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit
- Structural spec validator (3.0.x/3.1.x)
- Endpoint parser with $ref resolution, auto-generated operation IDs
- Heuristic + LLM endpoint classifier with Protocol interface
- Review API at /api/openapi (import, job status, classification CRUD, approve)
- @tool code generator + Agent YAML generator
- Import orchestrator (fetch -> validate -> parse -> classify pipeline)
- 125 new tests, 322 total passing, 93.23% coverage
2026-03-31 00:10:44 +02:00

162 lines
4.8 KiB
Python

"""SSRF protection module.
Validates URLs before making external HTTP requests.
Blocks private IPs, loopback addresses, link-local, and DNS rebinding attacks.
"""
from __future__ import annotations
import ipaddress
import socket
from dataclasses import dataclass
from urllib.parse import urlparse
import httpx
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
@dataclass(frozen=True)
class SSRFPolicy:
"""Configuration for SSRF protection."""
allowed_schemes: frozenset[str] = frozenset({"http", "https"})
allowed_hosts: frozenset[str] | None = None # None = all public hosts allowed
max_redirects: int = 5
timeout_seconds: float = 30.0
_BLOCKED_NETWORKS = [
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("0.0.0.0/32"),
# IPv6
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fe80::/10"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("::/128"),
]
DEFAULT_POLICY = SSRFPolicy()
def is_private_ip(ip_str: str) -> bool:
"""Check if an IP address is private/reserved."""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
return True # Invalid IP treated as blocked
return any(addr in network for network in _BLOCKED_NETWORKS)
def validate_url(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> str:
"""Validate a URL against SSRF policy.
Returns the validated URL string.
Raises SSRFError if the URL is blocked.
"""
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in policy.allowed_schemes:
raise SSRFError(
f"URL scheme '{parsed.scheme}' is not allowed. "
f"Allowed: {', '.join(sorted(policy.allowed_schemes))}"
)
# Check hostname exists
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL has no hostname")
# Check allowed hosts whitelist
if policy.allowed_hosts is not None and hostname not in policy.allowed_hosts:
raise SSRFError(f"Host '{hostname}' is not in the allowed hosts list")
# DNS resolution -- resolve before making any request
resolved_ips = resolve_hostname(hostname)
if not resolved_ips:
raise SSRFError(f"Could not resolve hostname '{hostname}'")
# Check all resolved IPs against blocked networks
for ip_str in resolved_ips:
if is_private_ip(ip_str):
raise SSRFError(
f"Host '{hostname}' resolves to private/reserved IP {ip_str}. "
"Request blocked for SSRF protection."
)
return url
def resolve_hostname(hostname: str) -> list[str]:
"""Resolve hostname to IP addresses via DNS."""
try:
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
return list({result[4][0] for result in results})
except socket.gaierror:
return []
async def safe_fetch(
url: str,
*,
policy: SSRFPolicy = DEFAULT_POLICY,
method: str = "GET",
headers: dict[str, str] | None = None,
) -> httpx.Response:
"""Fetch a URL with SSRF protection.
Validates the URL, resolves DNS, checks IPs, then makes the request.
After receiving the response, verifies the actual connected IP
to guard against DNS rebinding.
"""
validate_url(url, policy)
# Make the request with redirect following disabled so we can check each hop
async with httpx.AsyncClient(
follow_redirects=False,
timeout=httpx.Timeout(policy.timeout_seconds),
) as client:
current_url = url
for _redirect_count in range(policy.max_redirects + 1):
response = await client.request(
method,
current_url,
headers=headers,
)
if response.is_redirect:
redirect_url = str(response.next_request.url) if response.next_request else None
if not redirect_url:
raise SSRFError("Redirect with no target URL")
# Validate the redirect target
validate_url(redirect_url, policy)
current_url = redirect_url
continue
return response
raise SSRFError(
f"Too many redirects (max {policy.max_redirects}). "
"Possible redirect loop or evasion attempt."
)
async def safe_fetch_text(
url: str,
*,
policy: SSRFPolicy = DEFAULT_POLICY,
headers: dict[str, str] | None = None,
) -> str:
"""Fetch a URL and return text content with SSRF protection."""
response = await safe_fetch(url, policy=policy, headers=headers)
response.raise_for_status()
return response.text