feat: complete phase 3 -- OpenAPI auto-discovery, SSRF protection, tool generation
- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation - OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit - Structural spec validator (3.0.x/3.1.x) - Endpoint parser with $ref resolution, auto-generated operation IDs - Heuristic + LLM endpoint classifier with Protocol interface - Review API at /api/openapi (import, job status, classification CRUD, approve) - @tool code generator + Agent YAML generator - Import orchestrator (fetch -> validate -> parse -> classify pipeline) - 125 new tests, 322 total passing, 93.23% coverage
This commit is contained in:
203
backend/tests/integration/test_import_pipeline.py
Normal file
203
backend/tests/integration/test_import_pipeline.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Integration tests for the OpenAPI import pipeline orchestrator.
|
||||
|
||||
Tests the full pipeline: fetch -> validate -> parse -> classify.
|
||||
Uses mocked HTTP and mocked LLM classifier.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.openapi.models import ImportJob
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_VALID_SPEC_JSON = """{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/orders": {
|
||||
"get": {
|
||||
"operationId": "list_orders",
|
||||
"summary": "List orders",
|
||||
"description": "Returns all orders",
|
||||
"responses": {"200": {"description": "OK"}}
|
||||
}
|
||||
},
|
||||
"/orders/{id}": {
|
||||
"delete": {
|
||||
"operationId": "delete_order",
|
||||
"summary": "Delete order",
|
||||
"description": "Deletes an order",
|
||||
"parameters": [
|
||||
{"name": "id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {"204": {"description": "Deleted"}}
|
||||
}
|
||||
}
|
||||
}
|
||||
}"""
|
||||
|
||||
_INVALID_SPEC_JSON = '{"not": "a valid openapi spec"}'
|
||||
|
||||
_PUBLIC_IP = "93.184.216.34"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_classifier():
|
||||
"""A mock classifier that classifies using heuristics."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
return HeuristicClassifier()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(mock_classifier):
|
||||
"""Create an ImportOrchestrator with the mock classifier."""
|
||||
from app.openapi.importer import ImportOrchestrator
|
||||
|
||||
return ImportOrchestrator(classifier=mock_classifier)
|
||||
|
||||
|
||||
class TestImportOrchestratorSuccess:
|
||||
"""Tests for successful import pipeline flows."""
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_full_pipeline_succeeds(self, orchestrator, httpx_mock) -> None:
|
||||
"""Full pipeline with valid spec and mocked HTTP succeeds."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/spec.json",
|
||||
text=_VALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
job = await orchestrator.start_import(
|
||||
url="https://example.com/api/spec.json",
|
||||
job_id="test-job-1",
|
||||
on_progress=None,
|
||||
)
|
||||
assert isinstance(job, ImportJob)
|
||||
assert job.status == "done"
|
||||
assert job.job_id == "test-job-1"
|
||||
assert job.total_endpoints == 2
|
||||
assert job.classified_count == 2
|
||||
assert job.error_message is None
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_progress_callback_called_at_stages(self, orchestrator, httpx_mock) -> None:
|
||||
"""on_progress callback is called at each pipeline stage."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/spec.json",
|
||||
text=_VALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
stages_seen: list[str] = []
|
||||
|
||||
def on_progress(stage: str, job: ImportJob) -> None:
|
||||
stages_seen.append(stage)
|
||||
|
||||
await orchestrator.start_import(
|
||||
url="https://example.com/api/spec.json",
|
||||
job_id="test-job-2",
|
||||
on_progress=on_progress,
|
||||
)
|
||||
assert "fetching" in stages_seen
|
||||
assert "validating" in stages_seen
|
||||
assert "parsing" in stages_seen
|
||||
assert "classifying" in stages_seen
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_none_progress_callback_does_not_raise(
|
||||
self, orchestrator, httpx_mock
|
||||
) -> None:
|
||||
"""Passing None as on_progress does not raise."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/spec.json",
|
||||
text=_VALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
job = await orchestrator.start_import(
|
||||
url="https://example.com/api/spec.json",
|
||||
job_id="test-job-3",
|
||||
on_progress=None,
|
||||
)
|
||||
assert job.status == "done"
|
||||
|
||||
|
||||
class TestImportOrchestratorFailures:
|
||||
"""Tests for error handling in the import pipeline."""
|
||||
|
||||
async def test_fetch_failure_sets_failed_status(self, orchestrator) -> None:
|
||||
"""When HTTP fetch fails, job status is 'failed'."""
|
||||
with patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]):
|
||||
job = await orchestrator.start_import(
|
||||
url="http://internal.corp/spec.json",
|
||||
job_id="test-job-fail-1",
|
||||
on_progress=None,
|
||||
)
|
||||
assert job.status == "failed"
|
||||
assert job.error_message is not None
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_validation_failure_sets_failed_status(
|
||||
self, orchestrator, httpx_mock
|
||||
) -> None:
|
||||
"""When spec validation fails, job status is 'failed'."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/bad.json",
|
||||
text=_INVALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
job = await orchestrator.start_import(
|
||||
url="https://example.com/api/bad.json",
|
||||
job_id="test-job-fail-2",
|
||||
on_progress=None,
|
||||
)
|
||||
assert job.status == "failed"
|
||||
assert job.error_message is not None
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_error_message_is_descriptive(self, orchestrator, httpx_mock) -> None:
|
||||
"""Error message contains useful context."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/bad.json",
|
||||
text=_INVALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
job = await orchestrator.start_import(
|
||||
url="https://example.com/api/bad.json",
|
||||
job_id="test-job-fail-3",
|
||||
on_progress=None,
|
||||
)
|
||||
assert len(job.error_message) > 0
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_failed_status_progress_called_with_failed(
|
||||
self, orchestrator, httpx_mock
|
||||
) -> None:
|
||||
"""When pipeline fails, on_progress is called with 'failed' stage."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/bad.json",
|
||||
text=_INVALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
stages_seen: list[str] = []
|
||||
|
||||
def on_progress(stage: str, job: ImportJob) -> None:
|
||||
stages_seen.append(stage)
|
||||
|
||||
await orchestrator.start_import(
|
||||
url="https://example.com/api/bad.json",
|
||||
job_id="test-job-fail-4",
|
||||
on_progress=on_progress,
|
||||
)
|
||||
assert "failed" in stages_seen
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_public_dns():
|
||||
with patch("app.openapi.ssrf.resolve_hostname", return_value=[_PUBLIC_IP]):
|
||||
yield
|
||||
0
backend/tests/unit/openapi/__init__.py
Normal file
0
backend/tests/unit/openapi/__init__.py
Normal file
249
backend/tests/unit/openapi/test_classifier.py
Normal file
249
backend/tests/unit/openapi/test_classifier.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Tests for OpenAPI endpoint classifier module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.openapi.models import EndpointInfo, ParameterInfo
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_endpoint(
|
||||
path: str = "/items",
|
||||
method: str = "GET",
|
||||
operation_id: str = "list_items",
|
||||
summary: str = "List items",
|
||||
description: str = "",
|
||||
parameters: tuple[ParameterInfo, ...] = (),
|
||||
) -> EndpointInfo:
|
||||
return EndpointInfo(
|
||||
path=path,
|
||||
method=method,
|
||||
operation_id=operation_id,
|
||||
summary=summary,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
|
||||
_ORDER_PARAM = ParameterInfo(
|
||||
name="order_id", location="path", required=True, schema_type="string"
|
||||
)
|
||||
_CUSTOMER_PARAM = ParameterInfo(
|
||||
name="customer_id", location="query", required=False, schema_type="string"
|
||||
)
|
||||
|
||||
|
||||
class TestHeuristicClassifier:
|
||||
"""Tests for the rule-based HeuristicClassifier."""
|
||||
|
||||
async def test_get_classified_as_read(self) -> None:
|
||||
"""GET endpoints are classified as read access."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="GET")
|
||||
results = await clf.classify((ep,))
|
||||
assert len(results) == 1
|
||||
assert results[0].access_type == "read"
|
||||
|
||||
async def test_post_classified_as_write(self) -> None:
|
||||
"""POST endpoints are classified as write access."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="POST")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].access_type == "write"
|
||||
|
||||
async def test_post_needs_interrupt(self) -> None:
|
||||
"""POST endpoints require interrupt/approval."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="POST")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].needs_interrupt is True
|
||||
|
||||
async def test_put_classified_as_write(self) -> None:
|
||||
"""PUT endpoints are classified as write access."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="PUT")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].access_type == "write"
|
||||
|
||||
async def test_delete_classified_as_write_with_interrupt(self) -> None:
|
||||
"""DELETE endpoints are classified as write and require interrupt."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="DELETE")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].access_type == "write"
|
||||
assert results[0].needs_interrupt is True
|
||||
|
||||
async def test_get_does_not_need_interrupt(self) -> None:
|
||||
"""GET endpoints do not require interrupt."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="GET")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].needs_interrupt is False
|
||||
|
||||
async def test_empty_endpoints_returns_empty_tuple(self) -> None:
|
||||
"""Empty input yields empty output."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
results = await clf.classify(())
|
||||
assert results == ()
|
||||
|
||||
async def test_customer_params_detected_order_id(self) -> None:
|
||||
"""Parameters named order_id are recognized as customer params."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="GET", parameters=(_ORDER_PARAM,))
|
||||
results = await clf.classify((ep,))
|
||||
assert "order_id" in results[0].customer_params
|
||||
|
||||
async def test_customer_params_detected_customer_id(self) -> None:
|
||||
"""Parameters named customer_id are recognized as customer params."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="GET", parameters=(_CUSTOMER_PARAM,))
|
||||
results = await clf.classify((ep,))
|
||||
assert "customer_id" in results[0].customer_params
|
||||
|
||||
async def test_result_is_tuple(self) -> None:
|
||||
"""classify returns a tuple (immutable)."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint()
|
||||
results = await clf.classify((ep,))
|
||||
assert isinstance(results, tuple)
|
||||
|
||||
async def test_classification_has_confidence(self) -> None:
|
||||
"""Heuristic results have a confidence value between 0 and 1."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint()
|
||||
results = await clf.classify((ep,))
|
||||
assert 0.0 <= results[0].confidence <= 1.0
|
||||
|
||||
async def test_patch_classified_as_write(self) -> None:
|
||||
"""PATCH endpoints are classified as write access."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="PATCH")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].access_type == "write"
|
||||
|
||||
|
||||
class TestLLMClassifier:
|
||||
"""Tests for the LLM-backed classifier."""
|
||||
|
||||
def _make_mock_llm(self, classifications: list[dict]) -> MagicMock:
|
||||
"""Create a mock LLM that returns structured classification data."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = str(classifications)
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
return mock_llm
|
||||
|
||||
async def test_llm_classifier_classifies_endpoints(self) -> None:
|
||||
"""LLM classifier returns ClassificationResult for each endpoint."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
ep = _make_endpoint(method="GET")
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = (
|
||||
'[{"access_type": "read", "agent_group": "support",'
|
||||
' "needs_interrupt": false, "customer_params": [], "confidence": 0.9}]'
|
||||
)
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
results = await clf.classify((ep,))
|
||||
assert len(results) == 1
|
||||
assert results[0].access_type == "read"
|
||||
|
||||
async def test_llm_failure_falls_back_to_heuristic(self) -> None:
|
||||
"""When LLM raises an exception, falls back to heuristic classifier."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
ep = _make_endpoint(method="GET")
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
results = await clf.classify((ep,))
|
||||
# Falls back to heuristic: GET = read
|
||||
assert len(results) == 1
|
||||
assert results[0].access_type == "read"
|
||||
|
||||
async def test_llm_invalid_json_falls_back_to_heuristic(self) -> None:
|
||||
"""When LLM returns unparseable output, falls back to heuristic."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
ep = _make_endpoint(method="DELETE")
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "this is not valid json at all"
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
results = await clf.classify((ep,))
|
||||
# Fallback: DELETE = write with interrupt
|
||||
assert results[0].access_type == "write"
|
||||
assert results[0].needs_interrupt is True
|
||||
|
||||
async def test_llm_empty_endpoints_returns_empty(self) -> None:
|
||||
"""Empty input yields empty output without calling LLM."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock()
|
||||
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
results = await clf.classify(())
|
||||
assert results == ()
|
||||
mock_llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
class TestClassifierProtocol:
|
||||
"""Verify both classifiers conform to ClassifierProtocol."""
|
||||
|
||||
def test_heuristic_has_classify_method(self) -> None:
|
||||
"""HeuristicClassifier exposes classify method."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
assert hasattr(clf, "classify")
|
||||
assert callable(clf.classify)
|
||||
|
||||
def test_llm_has_classify_method(self) -> None:
|
||||
"""LLMClassifier exposes classify method."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
mock_llm = MagicMock()
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
assert hasattr(clf, "classify")
|
||||
assert callable(clf.classify)
|
||||
120
backend/tests/unit/openapi/test_fetcher.py
Normal file
120
backend/tests/unit/openapi/test_fetcher.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Tests for OpenAPI spec fetcher module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.openapi.ssrf import SSRFError
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_SAMPLE_JSON = '{"openapi": "3.0.0", "info": {"title": "Test", "version": "1.0"}, "paths": {}}'
|
||||
_SAMPLE_YAML = "openapi: '3.0.0'\ninfo:\n title: Test\n version: '1.0'\npaths: {}\n"
|
||||
_PUBLIC_IP = "93.184.216.34"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_public_dns():
|
||||
with patch("app.openapi.ssrf.resolve_hostname", return_value=[_PUBLIC_IP]):
|
||||
yield
|
||||
|
||||
|
||||
class TestFetchSpec:
|
||||
"""Tests for fetch_spec function."""
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_fetch_json_spec_succeeds(self, httpx_mock) -> None:
|
||||
"""Fetch a JSON spec and return parsed dict."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/spec.json",
|
||||
text=_SAMPLE_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
result = await fetch_spec("https://example.com/spec.json")
|
||||
assert isinstance(result, dict)
|
||||
assert result["openapi"] == "3.0.0"
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_fetch_yaml_spec_succeeds(self, httpx_mock) -> None:
|
||||
"""Fetch a YAML spec and return parsed dict."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/spec.yaml",
|
||||
text=_SAMPLE_YAML,
|
||||
headers={"content-type": "application/x-yaml"},
|
||||
)
|
||||
result = await fetch_spec("https://example.com/spec.yaml")
|
||||
assert isinstance(result, dict)
|
||||
assert result["openapi"] == "3.0.0"
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_fetch_yaml_by_url_extension(self, httpx_mock) -> None:
|
||||
"""Auto-detect YAML format from .yaml URL extension."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api.yaml",
|
||||
text=_SAMPLE_YAML,
|
||||
headers={"content-type": "text/plain"},
|
||||
)
|
||||
result = await fetch_spec("https://example.com/api.yaml")
|
||||
assert isinstance(result, dict)
|
||||
assert result["openapi"] == "3.0.0"
|
||||
|
||||
async def test_ssrf_blocked_url_raises(self) -> None:
|
||||
"""SSRF-blocked URL raises SSRFError."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
with (
|
||||
patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]),
|
||||
pytest.raises(SSRFError),
|
||||
):
|
||||
await fetch_spec("http://internal.corp/spec.json")
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_oversized_response_raises(self, httpx_mock) -> None:
|
||||
"""Response exceeding 10MB raises ValueError."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
big_content = "x" * (10 * 1024 * 1024 + 1)
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/huge.json",
|
||||
text=big_content,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
with pytest.raises(ValueError, match="too large"):
|
||||
await fetch_spec("https://example.com/huge.json")
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_invalid_json_raises(self, httpx_mock) -> None:
|
||||
"""Non-parseable JSON raises ValueError."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/bad.json",
|
||||
text="not valid json {{{",
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
with pytest.raises(ValueError, match="[Pp]arse|[Ii]nvalid|[Dd]ecode"):
|
||||
await fetch_spec("https://example.com/bad.json")
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_invalid_yaml_raises(self, httpx_mock) -> None:
|
||||
"""Non-parseable YAML raises ValueError."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/bad.yaml",
|
||||
text=": invalid: yaml: {\n",
|
||||
headers={"content-type": "application/x-yaml"},
|
||||
)
|
||||
with pytest.raises(ValueError, match="[Pp]arse|[Ii]nvalid|[Yy]AML"):
|
||||
await fetch_spec("https://example.com/bad.yaml")
|
||||
258
backend/tests/unit/openapi/test_generator.py
Normal file
258
backend/tests/unit/openapi/test_generator.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Tests for OpenAPI tool generator module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.openapi.models import ClassificationResult, EndpointInfo, ParameterInfo
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_BASE_URL = "https://api.example.com"
|
||||
|
||||
|
||||
def _make_endpoint(
|
||||
path: str = "/items",
|
||||
method: str = "GET",
|
||||
operation_id: str = "list_items",
|
||||
summary: str = "List items",
|
||||
description: str = "Returns all items",
|
||||
parameters: tuple[ParameterInfo, ...] = (),
|
||||
request_body_schema: dict | None = None,
|
||||
) -> EndpointInfo:
|
||||
return EndpointInfo(
|
||||
path=path,
|
||||
method=method,
|
||||
operation_id=operation_id,
|
||||
summary=summary,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
request_body_schema=request_body_schema,
|
||||
)
|
||||
|
||||
|
||||
def _make_classification(
|
||||
endpoint: EndpointInfo,
|
||||
access_type: str = "read",
|
||||
needs_interrupt: bool = False,
|
||||
agent_group: str = "read_agent",
|
||||
) -> ClassificationResult:
|
||||
return ClassificationResult(
|
||||
endpoint=endpoint,
|
||||
access_type=access_type,
|
||||
customer_params=(),
|
||||
agent_group=agent_group,
|
||||
confidence=0.9,
|
||||
needs_interrupt=needs_interrupt,
|
||||
)
|
||||
|
||||
|
||||
_PATH_PARAM = ParameterInfo(
|
||||
name="item_id", location="path", required=True, schema_type="string"
|
||||
)
|
||||
_QUERY_PARAM = ParameterInfo(
|
||||
name="filter", location="query", required=False, schema_type="string"
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateToolCode:
|
||||
"""Tests for generate_tool_code function."""
|
||||
|
||||
def test_generate_tool_for_get_endpoint(self) -> None:
|
||||
"""Generated tool for GET endpoint is a GeneratedTool with non-empty code."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(method="GET")
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert tool.function_name == "list_items"
|
||||
assert tool.code != ""
|
||||
assert "@tool" in tool.code
|
||||
|
||||
def test_generate_tool_contains_function_name(self) -> None:
|
||||
"""Generated code contains the function name."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(operation_id="get_order", method="GET")
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert "get_order" in tool.code
|
||||
|
||||
def test_generate_tool_contains_base_url(self) -> None:
|
||||
"""Generated code contains the base URL."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert _BASE_URL in tool.code
|
||||
|
||||
def test_generate_tool_contains_http_method(self) -> None:
|
||||
"""Generated code uses the correct HTTP method."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(method="POST")
|
||||
clf = _make_classification(ep, access_type="write")
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert "post" in tool.code.lower()
|
||||
|
||||
def test_generate_tool_for_post_with_body(self) -> None:
|
||||
"""Generated tool for POST includes body parameter."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(
|
||||
method="POST",
|
||||
request_body_schema={"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
)
|
||||
clf = _make_classification(ep, access_type="write")
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert tool.code != ""
|
||||
assert "POST" in tool.code or "post" in tool.code
|
||||
|
||||
def test_generate_tool_with_path_params(self) -> None:
|
||||
"""Generated tool includes path parameter in function signature."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(
|
||||
path="/items/{item_id}",
|
||||
operation_id="get_item",
|
||||
parameters=(_PATH_PARAM,),
|
||||
)
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert "item_id" in tool.code
|
||||
|
||||
def test_write_tool_includes_interrupt_marker(self) -> None:
|
||||
"""Write tools that need interrupt include a marker comment."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(method="DELETE", operation_id="delete_item")
|
||||
clf = _make_classification(ep, access_type="write", needs_interrupt=True)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert "interrupt" in tool.code.lower() or "approval" in tool.code.lower()
|
||||
|
||||
def test_generated_code_is_executable(self) -> None:
|
||||
"""Generated code can be exec'd without syntax errors."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(
|
||||
path="/items/{item_id}",
|
||||
operation_id="fetch_item",
|
||||
parameters=(_PATH_PARAM,),
|
||||
)
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
# Must be valid Python syntax
|
||||
compile(tool.code, "<generated>", "exec")
|
||||
|
||||
def test_generated_tool_code_exec_imports(self) -> None:
|
||||
"""Generated code exec'd with required imports does not raise."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
namespace: dict = {}
|
||||
try:
|
||||
import httpx
|
||||
from langchain_core.tools import tool as lc_tool
|
||||
|
||||
namespace = {"httpx": httpx, "tool": lc_tool}
|
||||
exec(tool.code, namespace) # noqa: S102
|
||||
except ImportError:
|
||||
pytest.skip("langchain_core not available for exec test")
|
||||
|
||||
def test_returns_generated_tool_instance(self) -> None:
|
||||
"""generate_tool_code returns a GeneratedTool instance."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
from app.openapi.models import GeneratedTool
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert isinstance(tool, GeneratedTool)
|
||||
|
||||
def test_generated_tool_is_frozen(self) -> None:
|
||||
"""GeneratedTool instance is immutable."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
tool.code = "new code" # type: ignore[misc]
|
||||
|
||||
|
||||
class TestGenerateAgentYaml:
|
||||
"""Tests for generate_agent_yaml function."""
|
||||
|
||||
def test_generate_yaml_is_valid_string(self) -> None:
|
||||
"""generate_agent_yaml returns a non-empty string."""
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
def test_generated_yaml_is_parseable(self) -> None:
|
||||
"""Output can be parsed as YAML."""
|
||||
import yaml
|
||||
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||
|
||||
parsed = yaml.safe_load(result)
|
||||
assert isinstance(parsed, dict)
|
||||
|
||||
def test_generated_yaml_contains_agents_key(self) -> None:
|
||||
"""Generated YAML has an 'agents' key matching AgentConfig format."""
|
||||
import yaml
|
||||
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||
|
||||
parsed = yaml.safe_load(result)
|
||||
assert "agents" in parsed
|
||||
|
||||
def test_generated_yaml_contains_tool_name(self) -> None:
|
||||
"""Generated YAML references the tool function name."""
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
ep = _make_endpoint(operation_id="list_orders")
|
||||
clf = _make_classification(ep)
|
||||
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||
|
||||
assert "list_orders" in result
|
||||
|
||||
def test_empty_classifications_returns_empty_agents(self) -> None:
|
||||
"""No classifications yields YAML with empty agents list."""
|
||||
import yaml
|
||||
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
result = generate_agent_yaml((), _BASE_URL)
|
||||
parsed = yaml.safe_load(result)
|
||||
assert parsed.get("agents") == [] or parsed.get("agents") is None
|
||||
290
backend/tests/unit/openapi/test_parser.py
Normal file
290
backend/tests/unit/openapi/test_parser.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Tests for OpenAPI endpoint parser module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_MINIMAL_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"paths": {},
|
||||
}
|
||||
|
||||
_GET_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Orders API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/orders/{order_id}": {
|
||||
"get": {
|
||||
"operationId": "get_order",
|
||||
"summary": "Get an order",
|
||||
"description": "Retrieves a single order by ID",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "order_id",
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
"description": "The order identifier",
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Order found",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_POST_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Orders API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/orders": {
|
||||
"post": {
|
||||
"operationId": "create_order",
|
||||
"summary": "Create an order",
|
||||
"description": "Creates a new order",
|
||||
"requestBody": {
|
||||
"required": True,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item": {"type": "string"},
|
||||
"quantity": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
"responses": {"201": {"description": "Created"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_MULTI_PARAM_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Items API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/items/{item_id}": {
|
||||
"get": {
|
||||
"operationId": "get_item",
|
||||
"summary": "Get item",
|
||||
"description": "",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "item_id",
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "integer"},
|
||||
},
|
||||
{
|
||||
"name": "include_details",
|
||||
"in": "query",
|
||||
"required": False,
|
||||
"schema": {"type": "boolean"},
|
||||
},
|
||||
],
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_REF_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Ref API", "version": "1.0.0"},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Item": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
"paths": {
|
||||
"/items": {
|
||||
"get": {
|
||||
"operationId": "list_items",
|
||||
"summary": "List items",
|
||||
"description": "",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {"$ref": "#/components/schemas/Item"}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_MULTI_ENDPOINT_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Multi API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/users": {
|
||||
"get": {
|
||||
"operationId": "list_users",
|
||||
"summary": "List users",
|
||||
"description": "",
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
},
|
||||
"post": {
|
||||
"operationId": "create_user",
|
||||
"summary": "Create user",
|
||||
"description": "",
|
||||
"responses": {"201": {"description": "Created"}},
|
||||
},
|
||||
},
|
||||
"/users/{id}": {
|
||||
"delete": {
|
||||
"operationId": "delete_user",
|
||||
"summary": "Delete user",
|
||||
"description": "",
|
||||
"parameters": [
|
||||
{"name": "id", "in": "path", "required": True, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {"204": {"description": "Deleted"}},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestParseEndpoints:
|
||||
"""Tests for parse_endpoints function."""
|
||||
|
||||
def test_empty_paths_returns_empty_tuple(self) -> None:
|
||||
"""Spec with no paths yields no endpoints."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_MINIMAL_SPEC)
|
||||
assert result == ()
|
||||
|
||||
def test_parse_get_endpoint(self) -> None:
|
||||
"""Parse a GET endpoint with path parameter."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_GET_SPEC)
|
||||
assert len(result) == 1
|
||||
ep = result[0]
|
||||
assert ep.path == "/orders/{order_id}"
|
||||
assert ep.method == "GET"
|
||||
assert ep.operation_id == "get_order"
|
||||
assert ep.summary == "Get an order"
|
||||
|
||||
def test_parse_get_endpoint_parameters(self) -> None:
|
||||
"""Path parameters are extracted correctly."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_GET_SPEC)
|
||||
ep = result[0]
|
||||
assert len(ep.parameters) == 1
|
||||
param = ep.parameters[0]
|
||||
assert param.name == "order_id"
|
||||
assert param.location == "path"
|
||||
assert param.required is True
|
||||
assert param.schema_type == "string"
|
||||
|
||||
def test_parse_post_with_request_body(self) -> None:
|
||||
"""POST endpoint with request body is extracted."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_POST_SPEC)
|
||||
assert len(result) == 1
|
||||
ep = result[0]
|
||||
assert ep.method == "POST"
|
||||
assert ep.request_body_schema is not None
|
||||
assert ep.request_body_schema["type"] == "object"
|
||||
|
||||
def test_parse_path_and_query_params(self) -> None:
|
||||
"""Both path and query parameters are extracted."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_MULTI_PARAM_SPEC)
|
||||
ep = result[0]
|
||||
locations = {p.location for p in ep.parameters}
|
||||
assert "path" in locations
|
||||
assert "query" in locations
|
||||
|
||||
def test_autogenerate_operation_id_when_missing(self) -> None:
|
||||
"""Auto-generate operation_id when not provided in spec."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test", "version": "1.0"},
|
||||
"paths": {
|
||||
"/things/{id}": {
|
||||
"get": {
|
||||
"summary": "Get thing",
|
||||
"description": "",
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
result = parse_endpoints(spec)
|
||||
ep = result[0]
|
||||
assert ep.operation_id != ""
|
||||
assert len(ep.operation_id) > 0
|
||||
|
||||
def test_multiple_endpoints_extracted(self) -> None:
|
||||
"""Multiple path+method combinations are all extracted."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_MULTI_ENDPOINT_SPEC)
|
||||
assert len(result) == 3
|
||||
methods = {ep.method for ep in result}
|
||||
assert "GET" in methods
|
||||
assert "POST" in methods
|
||||
assert "DELETE" in methods
|
||||
|
||||
def test_ref_in_response_schema_resolved(self) -> None:
|
||||
"""$ref in response schema is resolved to the target schema."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_REF_SPEC)
|
||||
ep = result[0]
|
||||
assert ep.response_schema is not None
|
||||
# Resolved ref should contain the properties
|
||||
assert "properties" in ep.response_schema or "$ref" not in ep.response_schema
|
||||
|
||||
def test_result_is_tuple(self) -> None:
|
||||
"""parse_endpoints returns a tuple (immutable)."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_GET_SPEC)
|
||||
assert isinstance(result, tuple)
|
||||
|
||||
def test_endpoint_info_is_frozen(self) -> None:
|
||||
"""EndpointInfo instances are frozen/immutable."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_GET_SPEC)
|
||||
ep = result[0]
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
ep.method = "POST" # type: ignore[misc]
|
||||
198
backend/tests/unit/openapi/test_review_api.py
Normal file
198
backend/tests/unit/openapi/test_review_api.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Tests for OpenAPI review API endpoints.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_SAMPLE_URL = "https://example.com/api/spec.json"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create TestClient for the review API app."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.openapi.review_api import router
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job_id(client):
|
||||
"""Create a job and return its ID."""
|
||||
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
||||
assert response.status_code == 202
|
||||
return response.json()["job_id"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job_with_classifications(client, job_id):
|
||||
"""Return job_id for a job that has mock classifications injected."""
|
||||
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||
from app.openapi.review_api import _job_store
|
||||
|
||||
ep = EndpointInfo(
|
||||
path="/orders",
|
||||
method="GET",
|
||||
operation_id="list_orders",
|
||||
summary="List orders",
|
||||
description="",
|
||||
)
|
||||
clf = ClassificationResult(
|
||||
endpoint=ep,
|
||||
access_type="read",
|
||||
customer_params=(),
|
||||
agent_group="read_agent",
|
||||
confidence=0.9,
|
||||
needs_interrupt=False,
|
||||
)
|
||||
# Inject classifications directly into the store
|
||||
job = _job_store[job_id]
|
||||
_job_store[job_id] = {**job, "classifications": [clf]}
|
||||
return job_id
|
||||
|
||||
|
||||
class TestImportEndpoint:
|
||||
"""Tests for POST /api/openapi/import."""
|
||||
|
||||
def test_post_import_returns_job_id(self, client) -> None:
|
||||
"""POST /import returns 202 with a job_id."""
|
||||
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
assert len(data["job_id"]) > 0
|
||||
|
||||
def test_post_import_empty_url_returns_422(self, client) -> None:
|
||||
"""POST /import with empty URL returns 422 validation error."""
|
||||
response = client.post("/api/openapi/import", json={"url": ""})
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_post_import_missing_url_returns_422(self, client) -> None:
|
||||
"""POST /import with missing URL field returns 422."""
|
||||
response = client.post("/api/openapi/import", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_post_import_returns_pending_status(self, client) -> None:
|
||||
"""Newly created job has pending status."""
|
||||
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
||||
data = response.json()
|
||||
assert data["status"] == "pending"
|
||||
|
||||
def test_post_import_returns_spec_url(self, client) -> None:
|
||||
"""Response includes the original spec URL."""
|
||||
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
||||
data = response.json()
|
||||
assert data["spec_url"] == _SAMPLE_URL
|
||||
|
||||
|
||||
class TestGetJobEndpoint:
|
||||
"""Tests for GET /api/openapi/jobs/{job_id}."""
|
||||
|
||||
def test_get_job_returns_status(self, client, job_id) -> None:
|
||||
"""GET /jobs/{id} returns job status."""
|
||||
response = client.get(f"/api/openapi/jobs/{job_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert "job_id" in data
|
||||
|
||||
def test_get_unknown_job_returns_404(self, client) -> None:
|
||||
"""GET /jobs/nonexistent returns 404."""
|
||||
response = client.get("/api/openapi/jobs/nonexistent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_job_includes_spec_url(self, client, job_id) -> None:
|
||||
"""Job response includes the spec URL."""
|
||||
response = client.get(f"/api/openapi/jobs/{job_id}")
|
||||
data = response.json()
|
||||
assert data["spec_url"] == _SAMPLE_URL
|
||||
|
||||
|
||||
class TestGetClassificationsEndpoint:
|
||||
"""Tests for GET /api/openapi/jobs/{job_id}/classifications."""
|
||||
|
||||
def test_get_classifications_returns_list(self, client, job_with_classifications) -> None:
|
||||
"""GET /classifications returns a list."""
|
||||
response = client.get(
|
||||
f"/api/openapi/jobs/{job_with_classifications}/classifications"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
|
||||
def test_get_classifications_unknown_job_returns_404(self, client) -> None:
|
||||
"""GET /classifications for unknown job returns 404."""
|
||||
response = client.get("/api/openapi/jobs/unknown/classifications")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_classification_has_expected_fields(self, client, job_with_classifications) -> None:
|
||||
"""Each classification item has access_type and endpoint fields."""
|
||||
response = client.get(
|
||||
f"/api/openapi/jobs/{job_with_classifications}/classifications"
|
||||
)
|
||||
item = response.json()[0]
|
||||
assert "access_type" in item
|
||||
assert "endpoint" in item
|
||||
assert "needs_interrupt" in item
|
||||
|
||||
|
||||
class TestUpdateClassificationEndpoint:
|
||||
"""Tests for PUT /api/openapi/jobs/{job_id}/classifications/{idx}."""
|
||||
|
||||
def test_update_classification_succeeds(self, client, job_with_classifications) -> None:
|
||||
"""PUT /classifications/0 updates the classification."""
|
||||
response = client.put(
|
||||
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_update_unknown_job_returns_404(self, client) -> None:
|
||||
"""PUT /classifications/0 for unknown job returns 404."""
|
||||
response = client.put(
|
||||
"/api/openapi/jobs/unknown/classifications/0",
|
||||
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
|
||||
"""PUT /classifications/999 returns 404 for out-of-range index."""
|
||||
response = client.put(
|
||||
f"/api/openapi/jobs/{job_with_classifications}/classifications/999",
|
||||
json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestApproveEndpoint:
|
||||
"""Tests for POST /api/openapi/jobs/{job_id}/approve."""
|
||||
|
||||
def test_approve_job_succeeds(self, client, job_with_classifications) -> None:
|
||||
"""POST /approve transitions job to approved status."""
|
||||
response = client.post(
|
||||
f"/api/openapi/jobs/{job_with_classifications}/approve"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_approve_unknown_job_returns_404(self, client) -> None:
|
||||
"""POST /approve for unknown job returns 404."""
|
||||
response = client.post("/api/openapi/jobs/unknown/approve")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_approve_returns_job_status(self, client, job_with_classifications) -> None:
|
||||
"""POST /approve returns updated job status."""
|
||||
response = client.post(
|
||||
f"/api/openapi/jobs/{job_with_classifications}/approve"
|
||||
)
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
93
backend/tests/unit/openapi/test_validator.py
Normal file
93
backend/tests/unit/openapi/test_validator.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Tests for OpenAPI spec validator module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_VALID_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/items": {
|
||||
"get": {
|
||||
"summary": "List items",
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestValidateSpec:
|
||||
"""Tests for validate_spec function."""
|
||||
|
||||
def test_valid_minimal_spec_passes(self) -> None:
|
||||
"""A valid minimal spec returns empty error list."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
errors = validate_spec(_VALID_SPEC)
|
||||
assert errors == []
|
||||
|
||||
def test_missing_openapi_key_returns_error(self) -> None:
|
||||
"""Missing 'openapi' field returns an error."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
spec = {k: v for k, v in _VALID_SPEC.items() if k != "openapi"}
|
||||
errors = validate_spec(spec)
|
||||
assert len(errors) > 0
|
||||
assert any("openapi" in e.lower() for e in errors)
|
||||
|
||||
def test_missing_info_returns_error(self) -> None:
|
||||
"""Missing 'info' field returns an error."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
spec = {k: v for k, v in _VALID_SPEC.items() if k != "info"}
|
||||
errors = validate_spec(spec)
|
||||
assert len(errors) > 0
|
||||
assert any("info" in e.lower() for e in errors)
|
||||
|
||||
def test_missing_paths_returns_error(self) -> None:
|
||||
"""Missing 'paths' field returns an error."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
spec = {k: v for k, v in _VALID_SPEC.items() if k != "paths"}
|
||||
errors = validate_spec(spec)
|
||||
assert len(errors) > 0
|
||||
assert any("paths" in e.lower() for e in errors)
|
||||
|
||||
def test_non_dict_input_returns_error(self) -> None:
|
||||
"""Non-dict input returns an error without raising."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
errors = validate_spec("not a dict") # type: ignore[arg-type]
|
||||
assert len(errors) > 0
|
||||
|
||||
def test_empty_dict_returns_multiple_errors(self) -> None:
|
||||
"""Empty dict returns errors for all required fields."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
errors = validate_spec({})
|
||||
# Should have at least one error for each required field
|
||||
assert len(errors) >= 3
|
||||
|
||||
def test_invalid_openapi_version_returns_error(self) -> None:
|
||||
"""Unsupported openapi version string returns an error."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
spec = {**_VALID_SPEC, "openapi": "1.0.0"}
|
||||
errors = validate_spec(spec)
|
||||
assert len(errors) > 0
|
||||
|
||||
def test_errors_are_descriptive_strings(self) -> None:
|
||||
"""All returned errors are non-empty strings."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
errors = validate_spec({})
|
||||
for e in errors:
|
||||
assert isinstance(e, str)
|
||||
assert len(e) > 0
|
||||
@@ -13,7 +13,7 @@ class TestMainModule:
|
||||
assert app.title == "Smart Support"
|
||||
|
||||
def test_app_version(self) -> None:
|
||||
assert app.version == "0.2.0"
|
||||
assert app.version == "0.3.0"
|
||||
|
||||
def test_agents_yaml_path_exists(self) -> None:
|
||||
assert AGENTS_YAML.name == "agents.yaml"
|
||||
|
||||
236
backend/tests/unit/test_ssrf.py
Normal file
236
backend/tests/unit/test_ssrf.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Tests for SSRF protection module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.openapi.ssrf import (
|
||||
SSRFError,
|
||||
SSRFPolicy,
|
||||
is_private_ip,
|
||||
safe_fetch,
|
||||
safe_fetch_text,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# --- is_private_ip ---
|
||||
|
||||
|
||||
class TestIsPrivateIP:
|
||||
"""Tests for private IP detection."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
"10.0.0.1",
|
||||
"10.255.255.255",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
"192.168.0.1",
|
||||
"192.168.1.100",
|
||||
"127.0.0.1",
|
||||
"127.0.0.2",
|
||||
"169.254.1.1",
|
||||
"169.254.169.254", # AWS metadata
|
||||
"0.0.0.0",
|
||||
"::1",
|
||||
"fe80::1",
|
||||
"fc00::1",
|
||||
],
|
||||
)
|
||||
def test_private_ips_detected(self, ip: str) -> None:
|
||||
assert is_private_ip(ip) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"203.0.113.1",
|
||||
"93.184.216.34",
|
||||
"2001:4860:4860::8888",
|
||||
],
|
||||
)
|
||||
def test_public_ips_allowed(self, ip: str) -> None:
|
||||
assert is_private_ip(ip) is False
|
||||
|
||||
def test_invalid_ip_treated_as_blocked(self) -> None:
|
||||
assert is_private_ip("not-an-ip") is True
|
||||
|
||||
def test_empty_string_blocked(self) -> None:
|
||||
assert is_private_ip("") is True
|
||||
|
||||
|
||||
# --- validate_url ---
|
||||
|
||||
|
||||
class TestValidateURL:
|
||||
"""Tests for URL validation."""
|
||||
|
||||
def _mock_resolve(self, ips: list[str]):
|
||||
return patch("app.openapi.ssrf.resolve_hostname", return_value=ips)
|
||||
|
||||
def test_valid_https_url(self) -> None:
|
||||
with self._mock_resolve(["93.184.216.34"]):
|
||||
result = validate_url("https://example.com/api/v1/spec.json")
|
||||
assert result == "https://example.com/api/v1/spec.json"
|
||||
|
||||
def test_valid_http_url(self) -> None:
|
||||
with self._mock_resolve(["93.184.216.34"]):
|
||||
result = validate_url("http://example.com/spec.yaml")
|
||||
assert result == "http://example.com/spec.yaml"
|
||||
|
||||
def test_rejects_ftp_scheme(self) -> None:
|
||||
with pytest.raises(SSRFError, match="scheme.*not allowed"):
|
||||
validate_url("ftp://example.com/spec")
|
||||
|
||||
def test_rejects_file_scheme(self) -> None:
|
||||
with pytest.raises(SSRFError, match="scheme.*not allowed"):
|
||||
validate_url("file:///etc/passwd")
|
||||
|
||||
def test_rejects_no_hostname(self) -> None:
|
||||
with pytest.raises(SSRFError, match="no hostname"):
|
||||
validate_url("https://")
|
||||
|
||||
def test_rejects_private_ip_literal(self) -> None:
|
||||
with (
|
||||
self._mock_resolve(["127.0.0.1"]),
|
||||
pytest.raises(SSRFError, match="private/reserved IP"),
|
||||
):
|
||||
validate_url("http://127.0.0.1/api")
|
||||
|
||||
def test_rejects_localhost(self) -> None:
|
||||
with (
|
||||
self._mock_resolve(["127.0.0.1"]),
|
||||
pytest.raises(SSRFError, match="private/reserved IP"),
|
||||
):
|
||||
validate_url("http://localhost/api")
|
||||
|
||||
def test_rejects_10_network(self) -> None:
|
||||
with (
|
||||
self._mock_resolve(["10.0.0.5"]),
|
||||
pytest.raises(SSRFError, match="private/reserved IP"),
|
||||
):
|
||||
validate_url("http://internal.corp/api")
|
||||
|
||||
def test_rejects_172_16_network(self) -> None:
|
||||
with (
|
||||
self._mock_resolve(["172.16.0.1"]),
|
||||
pytest.raises(SSRFError, match="private/reserved IP"),
|
||||
):
|
||||
validate_url("http://internal.corp/api")
|
||||
|
||||
def test_rejects_192_168_network(self) -> None:
|
||||
with (
|
||||
self._mock_resolve(["192.168.1.1"]),
|
||||
pytest.raises(SSRFError, match="private/reserved IP"),
|
||||
):
|
||||
validate_url("http://internal.corp/api")
|
||||
|
||||
def test_rejects_metadata_ip(self) -> None:
|
||||
"""Block cloud metadata endpoint (169.254.169.254)."""
|
||||
with (
|
||||
self._mock_resolve(["169.254.169.254"]),
|
||||
pytest.raises(SSRFError, match="private/reserved IP"),
|
||||
):
|
||||
validate_url("http://169.254.169.254/latest/meta-data/")
|
||||
|
||||
def test_rejects_ipv6_loopback(self) -> None:
|
||||
with (
|
||||
self._mock_resolve(["::1"]),
|
||||
pytest.raises(SSRFError, match="private/reserved IP"),
|
||||
):
|
||||
validate_url("http://[::1]/api")
|
||||
|
||||
def test_rejects_unresolvable_host(self) -> None:
|
||||
with self._mock_resolve([]), pytest.raises(SSRFError, match="Could not resolve"):
|
||||
validate_url("http://nonexistent.invalid/api")
|
||||
|
||||
def test_allowed_hosts_whitelist(self) -> None:
|
||||
policy = SSRFPolicy(allowed_hosts=frozenset({"api.example.com"}))
|
||||
with self._mock_resolve(["93.184.216.34"]):
|
||||
validate_url("https://api.example.com/spec", policy=policy)
|
||||
|
||||
def test_allowed_hosts_rejects_unlisted(self) -> None:
|
||||
policy = SSRFPolicy(allowed_hosts=frozenset({"api.example.com"}))
|
||||
with pytest.raises(SSRFError, match="not in the allowed hosts"):
|
||||
validate_url("https://evil.com/spec", policy=policy)
|
||||
|
||||
def test_dns_rebinding_detection(self) -> None:
|
||||
"""A hostname that resolves to both public and private IPs should be blocked."""
|
||||
with (
|
||||
self._mock_resolve(["93.184.216.34", "127.0.0.1"]),
|
||||
pytest.raises(SSRFError, match="private/reserved IP"),
|
||||
):
|
||||
validate_url("http://evil-rebind.com/api")
|
||||
|
||||
|
||||
# --- safe_fetch ---
|
||||
|
||||
|
||||
class TestSafeFetch:
|
||||
"""Tests for safe HTTP fetching."""
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_public_dns(self):
|
||||
with patch("app.openapi.ssrf.resolve_hostname", return_value=["93.184.216.34"]):
|
||||
yield
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_fetch_success(self, httpx_mock) -> None:
|
||||
httpx_mock.add_response(url="https://example.com/spec.json", text='{"openapi":"3.0.0"}')
|
||||
response = await safe_fetch("https://example.com/spec.json")
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_fetch_text_success(self, httpx_mock) -> None:
|
||||
httpx_mock.add_response(url="https://example.com/spec.json", text='{"openapi":"3.0.0"}')
|
||||
text = await safe_fetch_text("https://example.com/spec.json")
|
||||
assert "openapi" in text
|
||||
|
||||
async def test_fetch_blocks_private_ip(self) -> None:
|
||||
with (
|
||||
patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]),
|
||||
pytest.raises(SSRFError, match="private/reserved"),
|
||||
):
|
||||
await safe_fetch("http://internal.corp/api")
|
||||
|
||||
async def test_redirect_to_private_ip_blocked(self, httpx_mock) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/spec.json",
|
||||
status_code=302,
|
||||
headers={"Location": "http://evil-redirect.com/steal"},
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _resolve_side_effect(hostname: str) -> list[str]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if hostname == "evil-redirect.com":
|
||||
return ["127.0.0.1"]
|
||||
return ["93.184.216.34"]
|
||||
|
||||
with (
|
||||
patch("app.openapi.ssrf.resolve_hostname", side_effect=_resolve_side_effect),
|
||||
pytest.raises(SSRFError, match="private/reserved"),
|
||||
):
|
||||
await safe_fetch("https://example.com/spec.json")
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_too_many_redirects(self, httpx_mock) -> None:
|
||||
# Create a redirect chain longer than max_redirects
|
||||
policy = SSRFPolicy(max_redirects=2)
|
||||
for i in range(3):
|
||||
httpx_mock.add_response(
|
||||
url=f"https://example.com/r{i}",
|
||||
status_code=302,
|
||||
headers={"Location": f"https://example.com/r{i + 1}"},
|
||||
)
|
||||
with pytest.raises(SSRFError, match="Too many redirects"):
|
||||
await safe_fetch("https://example.com/r0", policy=policy)
|
||||
Reference in New Issue
Block a user