Files
smart-support/backend/app/openapi/classifier.py
Yaojia Wang a54eb224e0 feat: complete phase 3 -- OpenAPI auto-discovery, SSRF protection, tool generation
- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation
- OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit
- Structural spec validator (3.0.x/3.1.x)
- Endpoint parser with $ref resolution, auto-generated operation IDs
- Heuristic + LLM endpoint classifier with Protocol interface
- Review API at /api/openapi (import, job status, classification CRUD, approve)
- @tool code generator + Agent YAML generator
- Import orchestrator (fetch -> validate -> parse -> classify pipeline)
- 125 new tests, 322 total passing, 93.23% coverage
2026-03-31 00:10:44 +02:00

164 lines
5.5 KiB
Python

"""OpenAPI endpoint classifier.
Classifies endpoints into read/write access types and identifies
customer-identifying parameters. Provides a rule-based heuristic
classifier and an LLM-backed classifier with heuristic fallback.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Protocol
from app.openapi.models import ClassificationResult, EndpointInfo
logger = logging.getLogger(__name__)
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
# Parameter names that identify the customer/order context
_CUSTOMER_PARAM_PATTERNS = re.compile(
r"(customer_id|order_id|user_id|account_id|client_id|subscriber_id)",
re.IGNORECASE,
)
class ClassifierProtocol(Protocol):
"""Protocol for endpoint classifiers."""
async def classify(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]: ...
class HeuristicClassifier:
"""Rule-based endpoint classifier.
GET -> read, no interrupt.
POST/PUT/PATCH/DELETE -> write, needs_interrupt=True.
"""
async def classify(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Classify endpoints using HTTP method heuristics."""
if not endpoints:
return ()
return tuple(_classify_one(ep) for ep in endpoints)
def _classify_one(ep: EndpointInfo) -> ClassificationResult:
"""Classify a single endpoint using heuristics."""
access_type = "write" if ep.method in _WRITE_METHODS else "read"
needs_interrupt = ep.method in _INTERRUPT_METHODS
customer_params = _detect_customer_params(ep)
agent_group = "write_agent" if access_type == "write" else "read_agent"
return ClassificationResult(
endpoint=ep,
access_type=access_type,
customer_params=customer_params,
agent_group=agent_group,
confidence=0.7,
needs_interrupt=needs_interrupt,
)
def _detect_customer_params(ep: EndpointInfo) -> tuple[str, ...]:
"""Extract parameter names that identify the customer/order context."""
return tuple(
p.name
for p in ep.parameters
if _CUSTOMER_PARAM_PATTERNS.fullmatch(p.name)
)
class LLMClassifier:
"""LLM-backed endpoint classifier with heuristic fallback.
Uses an LLM to classify endpoints with higher accuracy.
Falls back to HeuristicClassifier on any LLM error.
"""
def __init__(self, llm: object) -> None:
self._llm = llm
self._fallback = HeuristicClassifier()
async def classify(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Classify endpoints using LLM with heuristic fallback."""
if not endpoints:
return ()
try:
return await self._classify_with_llm(endpoints)
except Exception:
logger.warning(
"LLM classification failed, falling back to heuristic",
exc_info=True,
)
return await self._fallback.classify(endpoints)
async def _classify_with_llm(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Attempt LLM-based classification."""
prompt = _build_classification_prompt(endpoints)
response = await self._llm.ainvoke(prompt)
parsed = _parse_llm_response(response.content, endpoints)
return parsed
def _build_classification_prompt(endpoints: tuple[EndpointInfo, ...]) -> str:
"""Build a prompt for classifying endpoints."""
items = []
for i, ep in enumerate(endpoints):
items.append(
f"{i}: {ep.method} {ep.path} - {ep.summary or ep.operation_id}"
)
endpoint_list = "\n".join(items)
return (
"Classify each API endpoint as 'read' or 'write'. "
"For each, determine if it needs human interrupt approval, "
"identify customer-identifying parameters, and assign an agent_group.\n\n"
f"Endpoints:\n{endpoint_list}\n\n"
"Respond with a JSON array with one object per endpoint:\n"
'[{"access_type": "read|write", "agent_group": "...", '
'"needs_interrupt": true|false, "customer_params": [...], "confidence": 0.0-1.0}]'
)
def _parse_llm_response(
content: str, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Parse LLM JSON response into ClassificationResult instances.
Raises ValueError if the response cannot be parsed or is mismatched.
"""
# Extract JSON array from response
match = re.search(r"\[.*\]", content, re.DOTALL)
if not match:
raise ValueError(f"No JSON array found in LLM response: {content!r}")
items = json.loads(match.group())
if not isinstance(items, list) or len(items) != len(endpoints):
raise ValueError(
f"LLM returned {len(items)} items for {len(endpoints)} endpoints"
)
results: list[ClassificationResult] = []
for ep, item in zip(endpoints, items, strict=True):
results.append(
ClassificationResult(
endpoint=ep,
access_type=item.get("access_type", "read"),
customer_params=tuple(item.get("customer_params", [])),
agent_group=item.get("agent_group", "support"),
confidence=float(item.get("confidence", 0.8)),
needs_interrupt=bool(item.get("needs_interrupt", False)),
)
)
return tuple(results)