Files
smart-support/backend/tests/unit/openapi/test_classifier.py
Yaojia Wang a54eb224e0 feat: complete phase 3 -- OpenAPI auto-discovery, SSRF protection, tool generation
- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation
- OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit
- Structural spec validator (3.0.x/3.1.x)
- Endpoint parser with $ref resolution, auto-generated operation IDs
- Heuristic + LLM endpoint classifier with Protocol interface
- Review API at /api/openapi (import, job status, classification CRUD, approve)
- @tool code generator + Agent YAML generator
- Import orchestrator (fetch -> validate -> parse -> classify pipeline)
- 125 new tests, 322 total passing, 93.23% coverage
2026-03-31 00:10:44 +02:00

250 lines
8.8 KiB
Python

"""Tests for OpenAPI endpoint classifier module.
RED phase: written before implementation.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.openapi.models import EndpointInfo, ParameterInfo
pytestmark = pytest.mark.unit
def _make_endpoint(
path: str = "/items",
method: str = "GET",
operation_id: str = "list_items",
summary: str = "List items",
description: str = "",
parameters: tuple[ParameterInfo, ...] = (),
) -> EndpointInfo:
return EndpointInfo(
path=path,
method=method,
operation_id=operation_id,
summary=summary,
description=description,
parameters=parameters,
)
_ORDER_PARAM = ParameterInfo(
name="order_id", location="path", required=True, schema_type="string"
)
_CUSTOMER_PARAM = ParameterInfo(
name="customer_id", location="query", required=False, schema_type="string"
)
class TestHeuristicClassifier:
"""Tests for the rule-based HeuristicClassifier."""
async def test_get_classified_as_read(self) -> None:
"""GET endpoints are classified as read access."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="GET")
results = await clf.classify((ep,))
assert len(results) == 1
assert results[0].access_type == "read"
async def test_post_classified_as_write(self) -> None:
"""POST endpoints are classified as write access."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="POST")
results = await clf.classify((ep,))
assert results[0].access_type == "write"
async def test_post_needs_interrupt(self) -> None:
"""POST endpoints require interrupt/approval."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="POST")
results = await clf.classify((ep,))
assert results[0].needs_interrupt is True
async def test_put_classified_as_write(self) -> None:
"""PUT endpoints are classified as write access."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="PUT")
results = await clf.classify((ep,))
assert results[0].access_type == "write"
async def test_delete_classified_as_write_with_interrupt(self) -> None:
"""DELETE endpoints are classified as write and require interrupt."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="DELETE")
results = await clf.classify((ep,))
assert results[0].access_type == "write"
assert results[0].needs_interrupt is True
async def test_get_does_not_need_interrupt(self) -> None:
"""GET endpoints do not require interrupt."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="GET")
results = await clf.classify((ep,))
assert results[0].needs_interrupt is False
async def test_empty_endpoints_returns_empty_tuple(self) -> None:
"""Empty input yields empty output."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
results = await clf.classify(())
assert results == ()
async def test_customer_params_detected_order_id(self) -> None:
"""Parameters named order_id are recognized as customer params."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="GET", parameters=(_ORDER_PARAM,))
results = await clf.classify((ep,))
assert "order_id" in results[0].customer_params
async def test_customer_params_detected_customer_id(self) -> None:
"""Parameters named customer_id are recognized as customer params."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="GET", parameters=(_CUSTOMER_PARAM,))
results = await clf.classify((ep,))
assert "customer_id" in results[0].customer_params
async def test_result_is_tuple(self) -> None:
"""classify returns a tuple (immutable)."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint()
results = await clf.classify((ep,))
assert isinstance(results, tuple)
async def test_classification_has_confidence(self) -> None:
"""Heuristic results have a confidence value between 0 and 1."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint()
results = await clf.classify((ep,))
assert 0.0 <= results[0].confidence <= 1.0
async def test_patch_classified_as_write(self) -> None:
"""PATCH endpoints are classified as write access."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="PATCH")
results = await clf.classify((ep,))
assert results[0].access_type == "write"
class TestLLMClassifier:
"""Tests for the LLM-backed classifier."""
def _make_mock_llm(self, classifications: list[dict]) -> MagicMock:
"""Create a mock LLM that returns structured classification data."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = str(classifications)
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
return mock_llm
async def test_llm_classifier_classifies_endpoints(self) -> None:
"""LLM classifier returns ClassificationResult for each endpoint."""
from app.openapi.classifier import LLMClassifier
ep = _make_endpoint(method="GET")
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = (
'[{"access_type": "read", "agent_group": "support",'
' "needs_interrupt": false, "customer_params": [], "confidence": 0.9}]'
)
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
clf = LLMClassifier(llm=mock_llm)
results = await clf.classify((ep,))
assert len(results) == 1
assert results[0].access_type == "read"
async def test_llm_failure_falls_back_to_heuristic(self) -> None:
"""When LLM raises an exception, falls back to heuristic classifier."""
from app.openapi.classifier import LLMClassifier
ep = _make_endpoint(method="GET")
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
clf = LLMClassifier(llm=mock_llm)
results = await clf.classify((ep,))
# Falls back to heuristic: GET = read
assert len(results) == 1
assert results[0].access_type == "read"
async def test_llm_invalid_json_falls_back_to_heuristic(self) -> None:
"""When LLM returns unparseable output, falls back to heuristic."""
from app.openapi.classifier import LLMClassifier
ep = _make_endpoint(method="DELETE")
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = "this is not valid json at all"
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
clf = LLMClassifier(llm=mock_llm)
results = await clf.classify((ep,))
# Fallback: DELETE = write with interrupt
assert results[0].access_type == "write"
assert results[0].needs_interrupt is True
async def test_llm_empty_endpoints_returns_empty(self) -> None:
"""Empty input yields empty output without calling LLM."""
from app.openapi.classifier import LLMClassifier
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock()
clf = LLMClassifier(llm=mock_llm)
results = await clf.classify(())
assert results == ()
mock_llm.ainvoke.assert_not_called()
class TestClassifierProtocol:
"""Verify both classifiers conform to ClassifierProtocol."""
def test_heuristic_has_classify_method(self) -> None:
"""HeuristicClassifier exposes classify method."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
assert hasattr(clf, "classify")
assert callable(clf.classify)
def test_llm_has_classify_method(self) -> None:
"""LLMClassifier exposes classify method."""
from app.openapi.classifier import LLMClassifier
mock_llm = MagicMock()
clf = LLMClassifier(llm=mock_llm)
assert hasattr(clf, "classify")
assert callable(clf.classify)