Files
smart-support/backend/tests/unit/test_ssrf.py
Yaojia Wang a54eb224e0 feat: complete phase 3 -- OpenAPI auto-discovery, SSRF protection, tool generation
- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation
- OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit
- Structural spec validator (3.0.x/3.1.x)
- Endpoint parser with $ref resolution, auto-generated operation IDs
- Heuristic + LLM endpoint classifier with Protocol interface
- Review API at /api/openapi (import, job status, classification CRUD, approve)
- @tool code generator + Agent YAML generator
- Import orchestrator (fetch -> validate -> parse -> classify pipeline)
- 125 new tests, 322 total passing, 93.23% coverage
2026-03-31 00:10:44 +02:00

237 lines
7.8 KiB
Python

"""Tests for SSRF protection module."""
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.openapi.ssrf import (
SSRFError,
SSRFPolicy,
is_private_ip,
safe_fetch,
safe_fetch_text,
validate_url,
)
pytestmark = pytest.mark.unit
# --- is_private_ip ---
class TestIsPrivateIP:
"""Tests for private IP detection."""
@pytest.mark.parametrize(
"ip",
[
"10.0.0.1",
"10.255.255.255",
"172.16.0.1",
"172.31.255.255",
"192.168.0.1",
"192.168.1.100",
"127.0.0.1",
"127.0.0.2",
"169.254.1.1",
"169.254.169.254", # AWS metadata
"0.0.0.0",
"::1",
"fe80::1",
"fc00::1",
],
)
def test_private_ips_detected(self, ip: str) -> None:
assert is_private_ip(ip) is True
@pytest.mark.parametrize(
"ip",
[
"8.8.8.8",
"1.1.1.1",
"203.0.113.1",
"93.184.216.34",
"2001:4860:4860::8888",
],
)
def test_public_ips_allowed(self, ip: str) -> None:
assert is_private_ip(ip) is False
def test_invalid_ip_treated_as_blocked(self) -> None:
assert is_private_ip("not-an-ip") is True
def test_empty_string_blocked(self) -> None:
assert is_private_ip("") is True
# --- validate_url ---
class TestValidateURL:
"""Tests for URL validation."""
def _mock_resolve(self, ips: list[str]):
return patch("app.openapi.ssrf.resolve_hostname", return_value=ips)
def test_valid_https_url(self) -> None:
with self._mock_resolve(["93.184.216.34"]):
result = validate_url("https://example.com/api/v1/spec.json")
assert result == "https://example.com/api/v1/spec.json"
def test_valid_http_url(self) -> None:
with self._mock_resolve(["93.184.216.34"]):
result = validate_url("http://example.com/spec.yaml")
assert result == "http://example.com/spec.yaml"
def test_rejects_ftp_scheme(self) -> None:
with pytest.raises(SSRFError, match="scheme.*not allowed"):
validate_url("ftp://example.com/spec")
def test_rejects_file_scheme(self) -> None:
with pytest.raises(SSRFError, match="scheme.*not allowed"):
validate_url("file:///etc/passwd")
def test_rejects_no_hostname(self) -> None:
with pytest.raises(SSRFError, match="no hostname"):
validate_url("https://")
def test_rejects_private_ip_literal(self) -> None:
with (
self._mock_resolve(["127.0.0.1"]),
pytest.raises(SSRFError, match="private/reserved IP"),
):
validate_url("http://127.0.0.1/api")
def test_rejects_localhost(self) -> None:
with (
self._mock_resolve(["127.0.0.1"]),
pytest.raises(SSRFError, match="private/reserved IP"),
):
validate_url("http://localhost/api")
def test_rejects_10_network(self) -> None:
with (
self._mock_resolve(["10.0.0.5"]),
pytest.raises(SSRFError, match="private/reserved IP"),
):
validate_url("http://internal.corp/api")
def test_rejects_172_16_network(self) -> None:
with (
self._mock_resolve(["172.16.0.1"]),
pytest.raises(SSRFError, match="private/reserved IP"),
):
validate_url("http://internal.corp/api")
def test_rejects_192_168_network(self) -> None:
with (
self._mock_resolve(["192.168.1.1"]),
pytest.raises(SSRFError, match="private/reserved IP"),
):
validate_url("http://internal.corp/api")
def test_rejects_metadata_ip(self) -> None:
"""Block cloud metadata endpoint (169.254.169.254)."""
with (
self._mock_resolve(["169.254.169.254"]),
pytest.raises(SSRFError, match="private/reserved IP"),
):
validate_url("http://169.254.169.254/latest/meta-data/")
def test_rejects_ipv6_loopback(self) -> None:
with (
self._mock_resolve(["::1"]),
pytest.raises(SSRFError, match="private/reserved IP"),
):
validate_url("http://[::1]/api")
def test_rejects_unresolvable_host(self) -> None:
with self._mock_resolve([]), pytest.raises(SSRFError, match="Could not resolve"):
validate_url("http://nonexistent.invalid/api")
def test_allowed_hosts_whitelist(self) -> None:
policy = SSRFPolicy(allowed_hosts=frozenset({"api.example.com"}))
with self._mock_resolve(["93.184.216.34"]):
validate_url("https://api.example.com/spec", policy=policy)
def test_allowed_hosts_rejects_unlisted(self) -> None:
policy = SSRFPolicy(allowed_hosts=frozenset({"api.example.com"}))
with pytest.raises(SSRFError, match="not in the allowed hosts"):
validate_url("https://evil.com/spec", policy=policy)
def test_dns_rebinding_detection(self) -> None:
"""A hostname that resolves to both public and private IPs should be blocked."""
with (
self._mock_resolve(["93.184.216.34", "127.0.0.1"]),
pytest.raises(SSRFError, match="private/reserved IP"),
):
validate_url("http://evil-rebind.com/api")
# --- safe_fetch ---
class TestSafeFetch:
"""Tests for safe HTTP fetching."""
@pytest.fixture
def _mock_public_dns(self):
with patch("app.openapi.ssrf.resolve_hostname", return_value=["93.184.216.34"]):
yield
@pytest.mark.usefixtures("_mock_public_dns")
async def test_fetch_success(self, httpx_mock) -> None:
httpx_mock.add_response(url="https://example.com/spec.json", text='{"openapi":"3.0.0"}')
response = await safe_fetch("https://example.com/spec.json")
assert response.status_code == 200
@pytest.mark.usefixtures("_mock_public_dns")
async def test_fetch_text_success(self, httpx_mock) -> None:
httpx_mock.add_response(url="https://example.com/spec.json", text='{"openapi":"3.0.0"}')
text = await safe_fetch_text("https://example.com/spec.json")
assert "openapi" in text
async def test_fetch_blocks_private_ip(self) -> None:
with (
patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]),
pytest.raises(SSRFError, match="private/reserved"),
):
await safe_fetch("http://internal.corp/api")
async def test_redirect_to_private_ip_blocked(self, httpx_mock) -> None:
httpx_mock.add_response(
url="https://example.com/spec.json",
status_code=302,
headers={"Location": "http://evil-redirect.com/steal"},
)
call_count = 0
def _resolve_side_effect(hostname: str) -> list[str]:
nonlocal call_count
call_count += 1
if hostname == "evil-redirect.com":
return ["127.0.0.1"]
return ["93.184.216.34"]
with (
patch("app.openapi.ssrf.resolve_hostname", side_effect=_resolve_side_effect),
pytest.raises(SSRFError, match="private/reserved"),
):
await safe_fetch("https://example.com/spec.json")
@pytest.mark.usefixtures("_mock_public_dns")
async def test_too_many_redirects(self, httpx_mock) -> None:
# Create a redirect chain longer than max_redirects
policy = SSRFPolicy(max_redirects=2)
for i in range(3):
httpx_mock.add_response(
url=f"https://example.com/r{i}",
status_code=302,
headers={"Location": f"https://example.com/r{i + 1}"},
)
with pytest.raises(SSRFError, match="Too many redirects"):
await safe_fetch("https://example.com/r0", policy=policy)