"""Tests for SSRF protection module.""" from __future__ import annotations from unittest.mock import patch import pytest from app.openapi.ssrf import ( SSRFError, SSRFPolicy, is_private_ip, safe_fetch, safe_fetch_text, validate_url, ) pytestmark = pytest.mark.unit # --- is_private_ip --- class TestIsPrivateIP: """Tests for private IP detection.""" @pytest.mark.parametrize( "ip", [ "10.0.0.1", "10.255.255.255", "172.16.0.1", "172.31.255.255", "192.168.0.1", "192.168.1.100", "127.0.0.1", "127.0.0.2", "169.254.1.1", "169.254.169.254", # AWS metadata "0.0.0.0", "::1", "fe80::1", "fc00::1", ], ) def test_private_ips_detected(self, ip: str) -> None: assert is_private_ip(ip) is True @pytest.mark.parametrize( "ip", [ "8.8.8.8", "1.1.1.1", "203.0.113.1", "93.184.216.34", "2001:4860:4860::8888", ], ) def test_public_ips_allowed(self, ip: str) -> None: assert is_private_ip(ip) is False def test_invalid_ip_treated_as_blocked(self) -> None: assert is_private_ip("not-an-ip") is True def test_empty_string_blocked(self) -> None: assert is_private_ip("") is True # --- validate_url --- class TestValidateURL: """Tests for URL validation.""" def _mock_resolve(self, ips: list[str]): return patch("app.openapi.ssrf.resolve_hostname", return_value=ips) def test_valid_https_url(self) -> None: with self._mock_resolve(["93.184.216.34"]): result = validate_url("https://example.com/api/v1/spec.json") assert result == "https://example.com/api/v1/spec.json" def test_valid_http_url(self) -> None: with self._mock_resolve(["93.184.216.34"]): result = validate_url("http://example.com/spec.yaml") assert result == "http://example.com/spec.yaml" def test_rejects_ftp_scheme(self) -> None: with pytest.raises(SSRFError, match="scheme.*not allowed"): validate_url("ftp://example.com/spec") def test_rejects_file_scheme(self) -> None: with pytest.raises(SSRFError, match="scheme.*not allowed"): validate_url("file:///etc/passwd") def test_rejects_no_hostname(self) -> None: with pytest.raises(SSRFError, match="no hostname"): validate_url("https://") def test_rejects_private_ip_literal(self) -> None: with ( self._mock_resolve(["127.0.0.1"]), pytest.raises(SSRFError, match="private/reserved IP"), ): validate_url("http://127.0.0.1/api") def test_rejects_localhost(self) -> None: with ( self._mock_resolve(["127.0.0.1"]), pytest.raises(SSRFError, match="private/reserved IP"), ): validate_url("http://localhost/api") def test_rejects_10_network(self) -> None: with ( self._mock_resolve(["10.0.0.5"]), pytest.raises(SSRFError, match="private/reserved IP"), ): validate_url("http://internal.corp/api") def test_rejects_172_16_network(self) -> None: with ( self._mock_resolve(["172.16.0.1"]), pytest.raises(SSRFError, match="private/reserved IP"), ): validate_url("http://internal.corp/api") def test_rejects_192_168_network(self) -> None: with ( self._mock_resolve(["192.168.1.1"]), pytest.raises(SSRFError, match="private/reserved IP"), ): validate_url("http://internal.corp/api") def test_rejects_metadata_ip(self) -> None: """Block cloud metadata endpoint (169.254.169.254).""" with ( self._mock_resolve(["169.254.169.254"]), pytest.raises(SSRFError, match="private/reserved IP"), ): validate_url("http://169.254.169.254/latest/meta-data/") def test_rejects_ipv6_loopback(self) -> None: with ( self._mock_resolve(["::1"]), pytest.raises(SSRFError, match="private/reserved IP"), ): validate_url("http://[::1]/api") def test_rejects_unresolvable_host(self) -> None: with self._mock_resolve([]), pytest.raises(SSRFError, match="Could not resolve"): validate_url("http://nonexistent.invalid/api") def test_allowed_hosts_whitelist(self) -> None: policy = SSRFPolicy(allowed_hosts=frozenset({"api.example.com"})) with self._mock_resolve(["93.184.216.34"]): validate_url("https://api.example.com/spec", policy=policy) def test_allowed_hosts_rejects_unlisted(self) -> None: policy = SSRFPolicy(allowed_hosts=frozenset({"api.example.com"})) with pytest.raises(SSRFError, match="not in the allowed hosts"): validate_url("https://evil.com/spec", policy=policy) def test_dns_rebinding_detection(self) -> None: """A hostname that resolves to both public and private IPs should be blocked.""" with ( self._mock_resolve(["93.184.216.34", "127.0.0.1"]), pytest.raises(SSRFError, match="private/reserved IP"), ): validate_url("http://evil-rebind.com/api") # --- safe_fetch --- class TestSafeFetch: """Tests for safe HTTP fetching.""" @pytest.fixture def _mock_public_dns(self): with patch("app.openapi.ssrf.resolve_hostname", return_value=["93.184.216.34"]): yield @pytest.mark.usefixtures("_mock_public_dns") async def test_fetch_success(self, httpx_mock) -> None: httpx_mock.add_response(url="https://example.com/spec.json", text='{"openapi":"3.0.0"}') response = await safe_fetch("https://example.com/spec.json") assert response.status_code == 200 @pytest.mark.usefixtures("_mock_public_dns") async def test_fetch_text_success(self, httpx_mock) -> None: httpx_mock.add_response(url="https://example.com/spec.json", text='{"openapi":"3.0.0"}') text = await safe_fetch_text("https://example.com/spec.json") assert "openapi" in text async def test_fetch_blocks_private_ip(self) -> None: with ( patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]), pytest.raises(SSRFError, match="private/reserved"), ): await safe_fetch("http://internal.corp/api") async def test_redirect_to_private_ip_blocked(self, httpx_mock) -> None: httpx_mock.add_response( url="https://example.com/spec.json", status_code=302, headers={"Location": "http://evil-redirect.com/steal"}, ) call_count = 0 def _resolve_side_effect(hostname: str) -> list[str]: nonlocal call_count call_count += 1 if hostname == "evil-redirect.com": return ["127.0.0.1"] return ["93.184.216.34"] with ( patch("app.openapi.ssrf.resolve_hostname", side_effect=_resolve_side_effect), pytest.raises(SSRFError, match="private/reserved"), ): await safe_fetch("https://example.com/spec.json") @pytest.mark.usefixtures("_mock_public_dns") async def test_too_many_redirects(self, httpx_mock) -> None: # Create a redirect chain longer than max_redirects policy = SSRFPolicy(max_redirects=2) for i in range(3): httpx_mock.add_response( url=f"https://example.com/r{i}", status_code=302, headers={"Location": f"https://example.com/r{i + 1}"}, ) with pytest.raises(SSRFError, match="Too many redirects"): await safe_fetch("https://example.com/r0", policy=policy)