feat: complete phase 3 -- OpenAPI auto-discovery, SSRF protection, tool generation

- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation
- OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit
- Structural spec validator (3.0.x/3.1.x)
- Endpoint parser with $ref resolution, auto-generated operation IDs
- Heuristic + LLM endpoint classifier with Protocol interface
- Review API at /api/openapi (import, job status, classification CRUD, approve)
- @tool code generator + Agent YAML generator
- Import orchestrator (fetch -> validate -> parse -> classify pipeline)
- 125 new tests, 322 total passing, 93.23% coverage
This commit is contained in:
Yaojia Wang
2026-03-31 00:10:44 +02:00
committed by yaoji
parent 006b4ee5d7
commit a54eb224e0
25 changed files with 2914 additions and 35 deletions

View File

@@ -20,6 +20,7 @@ from app.interrupt_manager import InterruptManager
from app.llm import create_llm
from app.registry import AgentRegistry
from app.session_manager import SessionManager
from app.openapi.review_api import router as openapi_router
from app.ws_handler import dispatch_message
if TYPE_CHECKING:
@@ -86,7 +87,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await pool.close()
app = FastAPI(title="Smart Support", version="0.2.0", lifespan=lifespan)
app = FastAPI(title="Smart Support", version="0.3.0", lifespan=lifespan)
app.include_router(openapi_router)
@app.websocket("/ws")

View File

@@ -0,0 +1,2 @@
# OpenAPI auto-discovery module
# Parses OpenAPI specs, classifies endpoints via LLM, generates tools

View File

@@ -0,0 +1,163 @@
"""OpenAPI endpoint classifier.
Classifies endpoints into read/write access types and identifies
customer-identifying parameters. Provides a rule-based heuristic
classifier and an LLM-backed classifier with heuristic fallback.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Protocol
from app.openapi.models import ClassificationResult, EndpointInfo
logger = logging.getLogger(__name__)
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
# Parameter names that identify the customer/order context
_CUSTOMER_PARAM_PATTERNS = re.compile(
r"(customer_id|order_id|user_id|account_id|client_id|subscriber_id)",
re.IGNORECASE,
)
class ClassifierProtocol(Protocol):
"""Protocol for endpoint classifiers."""
async def classify(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]: ...
class HeuristicClassifier:
"""Rule-based endpoint classifier.
GET -> read, no interrupt.
POST/PUT/PATCH/DELETE -> write, needs_interrupt=True.
"""
async def classify(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Classify endpoints using HTTP method heuristics."""
if not endpoints:
return ()
return tuple(_classify_one(ep) for ep in endpoints)
def _classify_one(ep: EndpointInfo) -> ClassificationResult:
"""Classify a single endpoint using heuristics."""
access_type = "write" if ep.method in _WRITE_METHODS else "read"
needs_interrupt = ep.method in _INTERRUPT_METHODS
customer_params = _detect_customer_params(ep)
agent_group = "write_agent" if access_type == "write" else "read_agent"
return ClassificationResult(
endpoint=ep,
access_type=access_type,
customer_params=customer_params,
agent_group=agent_group,
confidence=0.7,
needs_interrupt=needs_interrupt,
)
def _detect_customer_params(ep: EndpointInfo) -> tuple[str, ...]:
"""Extract parameter names that identify the customer/order context."""
return tuple(
p.name
for p in ep.parameters
if _CUSTOMER_PARAM_PATTERNS.fullmatch(p.name)
)
class LLMClassifier:
"""LLM-backed endpoint classifier with heuristic fallback.
Uses an LLM to classify endpoints with higher accuracy.
Falls back to HeuristicClassifier on any LLM error.
"""
def __init__(self, llm: object) -> None:
self._llm = llm
self._fallback = HeuristicClassifier()
async def classify(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Classify endpoints using LLM with heuristic fallback."""
if not endpoints:
return ()
try:
return await self._classify_with_llm(endpoints)
except Exception:
logger.warning(
"LLM classification failed, falling back to heuristic",
exc_info=True,
)
return await self._fallback.classify(endpoints)
async def _classify_with_llm(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Attempt LLM-based classification."""
prompt = _build_classification_prompt(endpoints)
response = await self._llm.ainvoke(prompt)
parsed = _parse_llm_response(response.content, endpoints)
return parsed
def _build_classification_prompt(endpoints: tuple[EndpointInfo, ...]) -> str:
"""Build a prompt for classifying endpoints."""
items = []
for i, ep in enumerate(endpoints):
items.append(
f"{i}: {ep.method} {ep.path} - {ep.summary or ep.operation_id}"
)
endpoint_list = "\n".join(items)
return (
"Classify each API endpoint as 'read' or 'write'. "
"For each, determine if it needs human interrupt approval, "
"identify customer-identifying parameters, and assign an agent_group.\n\n"
f"Endpoints:\n{endpoint_list}\n\n"
"Respond with a JSON array with one object per endpoint:\n"
'[{"access_type": "read|write", "agent_group": "...", '
'"needs_interrupt": true|false, "customer_params": [...], "confidence": 0.0-1.0}]'
)
def _parse_llm_response(
content: str, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Parse LLM JSON response into ClassificationResult instances.
Raises ValueError if the response cannot be parsed or is mismatched.
"""
# Extract JSON array from response
match = re.search(r"\[.*\]", content, re.DOTALL)
if not match:
raise ValueError(f"No JSON array found in LLM response: {content!r}")
items = json.loads(match.group())
if not isinstance(items, list) or len(items) != len(endpoints):
raise ValueError(
f"LLM returned {len(items)} items for {len(endpoints)} endpoints"
)
results: list[ClassificationResult] = []
for ep, item in zip(endpoints, items, strict=True):
results.append(
ClassificationResult(
endpoint=ep,
access_type=item.get("access_type", "read"),
customer_params=tuple(item.get("customer_params", [])),
agent_group=item.get("agent_group", "support"),
confidence=float(item.get("confidence", 0.8)),
needs_interrupt=bool(item.get("needs_interrupt", False)),
)
)
return tuple(results)

View File

@@ -0,0 +1,93 @@
"""OpenAPI spec fetcher with SSRF protection.
Fetches OpenAPI spec documents from remote URLs, validates them against
SSRF policy, and parses JSON or YAML format automatically.
"""
from __future__ import annotations
import json
import yaml
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
_MAX_SIZE_BYTES = 10 * 1024 * 1024 # 10MB
async def fetch_spec(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> dict:
"""Fetch an OpenAPI spec from a URL and return as a dict.
Auto-detects JSON or YAML format from content-type header or URL extension.
Enforces a 10MB size limit.
Raises:
SSRFError: If the URL is blocked by SSRF policy.
ValueError: If the response is too large or cannot be parsed.
"""
from app.openapi.ssrf import safe_fetch
response = await safe_fetch(url, policy=policy)
response.raise_for_status()
content = response.text
if len(content.encode("utf-8")) > _MAX_SIZE_BYTES:
raise ValueError(
f"Response too large: {len(content.encode('utf-8'))} bytes "
f"(max {_MAX_SIZE_BYTES} bytes)"
)
content_type = response.headers.get("content-type", "")
return _parse_content(content, content_type, url)
def _parse_content(content: str, content_type: str, url: str) -> dict:
"""Parse content as JSON or YAML based on content-type or URL extension."""
if _is_yaml_format(content_type, url):
return _parse_yaml(content)
if _is_json_format(content_type, url):
return _parse_json(content)
# Fall back: try JSON first, then YAML
try:
return _parse_json(content)
except ValueError:
return _parse_yaml(content)
def _is_yaml_format(content_type: str, url: str) -> bool:
"""Check if the content is YAML format."""
yaml_types = {"application/x-yaml", "text/yaml", "application/yaml"}
if any(t in content_type for t in yaml_types):
return True
lower_url = url.lower().split("?")[0]
return lower_url.endswith(".yaml") or lower_url.endswith(".yml")
def _is_json_format(content_type: str, url: str) -> bool:
"""Check if the content is JSON format."""
if "application/json" in content_type:
return True
lower_url = url.lower().split("?")[0]
return lower_url.endswith(".json")
def _parse_json(content: str) -> dict:
"""Parse content as JSON, raising ValueError on failure."""
try:
result = json.loads(content)
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON: {exc}") from exc
if not isinstance(result, dict):
raise ValueError(f"Expected a JSON object, got {type(result).__name__}")
return result
def _parse_yaml(content: str) -> dict:
"""Parse content as YAML, raising ValueError on failure."""
try:
result = yaml.safe_load(content)
except yaml.YAMLError as exc:
raise ValueError(f"Invalid YAML: {exc}") from exc
if not isinstance(result, dict):
raise ValueError(f"Expected a YAML mapping, got {type(result).__name__}")
return result

View File

@@ -0,0 +1,156 @@
"""Tool code generator for classified OpenAPI endpoints.
Generates Python source code for LangChain @tool functions and
YAML agent configurations from classification results.
"""
from __future__ import annotations
import re
import yaml
from app.openapi.models import ClassificationResult, EndpointInfo, GeneratedTool, ParameterInfo
_INDENT = " "
def generate_tool_code(classification: ClassificationResult, base_url: str) -> GeneratedTool:
"""Generate a LangChain @tool function for a classified endpoint.
Returns a GeneratedTool with the function source code as a string.
"""
ep = classification.endpoint
func_name = _to_snake_case(ep.operation_id)
params = _collect_params(ep)
sig = _build_signature(params, ep.request_body_schema)
docstring = ep.summary or ep.description or ep.operation_id
interrupt_comment = _interrupt_comment(classification)
http_call = _build_http_call(ep, base_url, params)
lines = [
"@tool",
f"async def {func_name}({sig}) -> str:",
f'{_INDENT}"""{docstring}"""',
]
if interrupt_comment:
lines.append(f"{_INDENT}{interrupt_comment}")
lines += [
f"{_INDENT}async with httpx.AsyncClient() as client:",
f"{_INDENT}{_INDENT}{http_call}",
f"{_INDENT}{_INDENT}return response.text",
]
code = "\n".join(lines)
return GeneratedTool(
function_name=func_name,
endpoint=ep,
classification=classification,
code=code,
)
def generate_agent_yaml(
classifications: tuple[ClassificationResult, ...],
base_url: str,
) -> str:
"""Generate an agents.yaml string from a set of classification results.
Groups tools by agent_group, creating one agent entry per group.
"""
if not classifications:
return yaml.dump({"agents": []})
groups: dict[str, dict] = {}
for clf in classifications:
group = clf.agent_group
func_name = _to_snake_case(clf.endpoint.operation_id)
if group not in groups:
permission = "read" if clf.access_type == "read" else "write"
groups[group] = {
"name": group,
"description": f"Agent for {group} operations",
"permission": permission,
"tools": [],
}
groups[group]["tools"].append(func_name)
return yaml.dump({"agents": list(groups.values())}, sort_keys=False)
# --- Private helpers ---
def _collect_params(ep: EndpointInfo) -> list[ParameterInfo]:
"""Return path params first, then query params."""
path_params = [p for p in ep.parameters if p.location == "path"]
other_params = [p for p in ep.parameters if p.location != "path"]
return path_params + other_params
def _build_signature(params: list[ParameterInfo], body_schema: dict | None) -> str:
"""Build a Python function signature string from parameters."""
parts: list[str] = []
for p in params:
py_type = _schema_type_to_python(p.schema_type)
if p.required:
parts.append(f"{p.name}: {py_type}")
else:
parts.append(f"{p.name}: {py_type} | None = None")
if body_schema:
parts.append("body: dict | None = None")
return ", ".join(parts)
def _build_http_call(
ep: EndpointInfo, base_url: str, params: list[ParameterInfo]
) -> str:
"""Build the httpx client call line."""
method = ep.method.lower()
path = ep.path
# Replace path parameters with f-string expressions
for p in params:
if p.location == "path":
path = path.replace(f"{{{p.name}}}", f"{{{p.name}}}")
url_expr = f'f"{base_url}{path}"'
query_params = [p for p in params if p.location == "query"]
extra_args = []
if query_params:
qp_dict = "{" + ", ".join(f'"{p.name}": {p.name}' for p in query_params) + "}"
extra_args.append(f"params={qp_dict}")
if ep.request_body_schema and method in ("post", "put", "patch"):
extra_args.append("json=body")
args_str = ", ".join([url_expr] + extra_args)
return f"response = await client.{method}({args_str})"
def _interrupt_comment(classification: ClassificationResult) -> str:
"""Return a comment line if the endpoint requires interrupt/approval."""
if classification.needs_interrupt:
return "# INTERRUPT: requires human approval before execution"
return ""
def _schema_type_to_python(schema_type: str) -> str:
"""Map OpenAPI schema type to Python type annotation."""
mapping = {
"string": "str",
"integer": "int",
"number": "float",
"boolean": "bool",
"array": "list",
"object": "dict",
}
return mapping.get(schema_type, "str")
def _to_snake_case(name: str) -> str:
"""Convert operationId to a valid snake_case Python identifier."""
# Replace non-alphanumeric with underscore
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
return clean.lower()

View File

@@ -0,0 +1,116 @@
"""Import orchestrator for OpenAPI auto-discovery pipeline.
Orchestrates: fetch -> validate -> parse -> classify
Each stage updates the job status and calls the on_progress callback.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
from app.openapi.fetcher import fetch_spec
from app.openapi.models import ImportJob
from app.openapi.parser import parse_endpoints
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
from app.openapi.validator import validate_spec
logger = logging.getLogger(__name__)
ProgressCallback = Callable[[str, ImportJob], None] | None
class ImportOrchestrator:
"""Orchestrates the full OpenAPI import pipeline.
Stages:
1. fetching -- download and parse spec from URL
2. validating -- check spec structure
3. parsing -- extract endpoint definitions
4. classifying -- classify endpoints for agent routing
5. done / failed
"""
def __init__(
self,
classifier: ClassifierProtocol | None = None,
policy: SSRFPolicy = DEFAULT_POLICY,
) -> None:
self._classifier = classifier or HeuristicClassifier()
self._policy = policy
async def start_import(
self,
url: str,
job_id: str,
on_progress: ProgressCallback,
) -> ImportJob:
"""Run the full import pipeline for a spec URL.
Returns an ImportJob reflecting final status (done or failed).
on_progress is called with (stage_name, current_job) at each stage.
Passing None for on_progress is safe.
"""
job = ImportJob(
job_id=job_id,
status="pending",
spec_url=url,
)
try:
# Stage 1: fetch
job = _update(job, status="fetching")
_notify(on_progress, "fetching", job)
spec_dict = await fetch_spec(url, self._policy)
# Stage 2: validate
job = _update(job, status="validating")
_notify(on_progress, "validating", job)
errors = validate_spec(spec_dict)
if errors:
raise ValueError(f"Invalid spec: {'; '.join(errors)}")
# Stage 3: parse
job = _update(job, status="parsing")
_notify(on_progress, "parsing", job)
endpoints = parse_endpoints(spec_dict)
# Stage 4: classify
job = _update(job, status="classifying", total_endpoints=len(endpoints))
_notify(on_progress, "classifying", job)
classifications = await self._classifier.classify(endpoints)
# Done
job = _update(
job,
status="done",
total_endpoints=len(endpoints),
classified_count=len(classifications),
)
_notify(on_progress, "done", job)
return job
except Exception as exc:
logger.exception("Import pipeline failed for job %s", job_id)
job = _update(job, status="failed", error_message=str(exc))
_notify(on_progress, "failed", job)
return job
def _update(job: ImportJob, **kwargs: object) -> ImportJob:
"""Return a new ImportJob with updated fields (immutable update)."""
return ImportJob(
job_id=job.job_id,
status=kwargs.get("status", job.status), # type: ignore[arg-type]
spec_url=job.spec_url,
total_endpoints=kwargs.get("total_endpoints", job.total_endpoints), # type: ignore[arg-type]
classified_count=kwargs.get("classified_count", job.classified_count), # type: ignore[arg-type]
error_message=kwargs.get("error_message", job.error_message), # type: ignore[arg-type]
)
def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None:
"""Call the progress callback if provided."""
if callback is not None:
callback(stage, job)

View File

@@ -0,0 +1,67 @@
"""Data models for OpenAPI auto-discovery module.
Frozen dataclasses for all value objects to ensure immutability.
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass(frozen=True)
class ParameterInfo:
"""Describes a single endpoint parameter."""
name: str
location: str # "path", "query", "header", "cookie"
required: bool
schema_type: str # "string", "integer", "boolean", etc.
description: str = ""
@dataclass(frozen=True)
class EndpointInfo:
"""Describes a single API endpoint."""
path: str
method: str # uppercase: GET, POST, PUT, DELETE, PATCH
operation_id: str
summary: str
description: str
parameters: tuple[ParameterInfo, ...] = field(default_factory=tuple)
request_body_schema: dict | None = None
response_schema: dict | None = None
@dataclass(frozen=True)
class ClassificationResult:
"""Result of classifying an endpoint for agent routing."""
endpoint: EndpointInfo
access_type: str # "read" or "write"
customer_params: tuple[str, ...] # param names that identify the customer
agent_group: str # which agent group handles this endpoint
confidence: float # 0.0 to 1.0
needs_interrupt: bool # requires human approval before execution
@dataclass(frozen=True)
class ImportJob:
"""Tracks the state of an OpenAPI import job."""
job_id: str
status: str # "pending", "fetching", "validating", "parsing", "classifying", "done", "failed"
spec_url: str
total_endpoints: int = 0
classified_count: int = 0
error_message: str | None = None
@dataclass(frozen=True)
class GeneratedTool:
"""A generated LangChain tool from a classified endpoint."""
function_name: str
endpoint: EndpointInfo
classification: ClassificationResult
code: str

View File

@@ -0,0 +1,152 @@
"""OpenAPI spec endpoint parser.
Extracts all endpoint definitions from a parsed OpenAPI spec dict,
resolving $ref references from components.
"""
from __future__ import annotations
import re
from app.openapi.models import EndpointInfo, ParameterInfo
_HTTP_METHODS = ("get", "post", "put", "patch", "delete", "head", "options")
def parse_endpoints(spec_dict: dict) -> tuple[EndpointInfo, ...]:
"""Parse all endpoints from a validated OpenAPI spec dict.
Returns an immutable tuple of EndpointInfo instances.
"""
paths = spec_dict.get("paths", {})
if not isinstance(paths, dict) or not paths:
return ()
endpoints: list[EndpointInfo] = []
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
for method in _HTTP_METHODS:
operation = path_item.get(method)
if operation is None:
continue
endpoint = _parse_operation(path, method.upper(), operation, spec_dict)
endpoints.append(endpoint)
return tuple(endpoints)
def _parse_operation(
path: str,
method: str,
operation: dict,
spec_dict: dict,
) -> EndpointInfo:
"""Parse a single operation dict into an EndpointInfo."""
operation_id = operation.get("operationId") or _generate_operation_id(path, method)
summary = operation.get("summary", "")
description = operation.get("description", "")
parameters = _parse_parameters(operation.get("parameters", []), spec_dict)
request_body_schema = _parse_request_body(operation.get("requestBody"), spec_dict)
response_schema = _parse_response_schema(operation.get("responses", {}), spec_dict)
return EndpointInfo(
path=path,
method=method,
operation_id=operation_id,
summary=summary,
description=description,
parameters=tuple(parameters),
request_body_schema=request_body_schema,
response_schema=response_schema,
)
def _parse_parameters(
params_list: list,
spec_dict: dict,
) -> list[ParameterInfo]:
"""Parse list of parameter dicts into ParameterInfo instances."""
result: list[ParameterInfo] = []
for param in params_list:
if not isinstance(param, dict):
continue
schema = param.get("schema", {})
schema_type = schema.get("type", "string") if isinstance(schema, dict) else "string"
result.append(
ParameterInfo(
name=param.get("name", ""),
location=param.get("in", "query"),
required=bool(param.get("required", False)),
schema_type=schema_type,
description=param.get("description", ""),
)
)
return result
def _parse_request_body(request_body: dict | None, spec_dict: dict) -> dict | None:
"""Extract schema from requestBody, resolving $ref if present."""
if not isinstance(request_body, dict):
return None
content = request_body.get("content", {})
if not isinstance(content, dict):
return None
# Prefer application/json
for media_type in ("application/json", *content.keys()):
media = content.get(media_type)
if not isinstance(media, dict):
continue
schema = media.get("schema")
if schema:
return _resolve_ref(schema, spec_dict)
return None
def _parse_response_schema(responses: dict, spec_dict: dict) -> dict | None:
"""Extract schema from the first 2xx response."""
if not isinstance(responses, dict):
return None
for status_code in ("200", "201", "202", "204"):
response = responses.get(status_code)
if not isinstance(response, dict):
continue
content = response.get("content", {})
if not isinstance(content, dict):
continue
for media_type in ("application/json", *content.keys()):
media = content.get(media_type)
if not isinstance(media, dict):
continue
schema = media.get("schema")
if schema:
return _resolve_ref(schema, spec_dict)
return None
def _resolve_ref(schema: object, spec_dict: dict) -> dict:
"""Resolve a $ref to its target schema, or return the schema as-is."""
if not isinstance(schema, dict):
return {}
ref = schema.get("$ref")
if not ref:
return schema
# Only handle local refs like #/components/schemas/Foo
if not isinstance(ref, str) or not ref.startswith("#/"):
return schema
parts = ref.lstrip("#/").split("/")
target: object = spec_dict
for part in parts:
if not isinstance(target, dict):
return schema
target = target.get(part)
return target if isinstance(target, dict) else schema
def _generate_operation_id(path: str, method: str) -> str:
"""Generate a snake_case operation_id from path and method."""
# Remove path parameters braces and replace / with _
clean = re.sub(r"\{[^}]+\}", "by_param", path)
clean = re.sub(r"[^a-zA-Z0-9]+", "_", clean).strip("_")
return f"{method.lower()}_{clean}" if clean else method.lower()

View File

@@ -0,0 +1,179 @@
"""FastAPI router for OpenAPI import review workflow.
Exposes endpoints for:
- Starting an import job
- Querying job status
- Reviewing and editing classifications
- Approving a job to trigger tool generation
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, field_validator
from app.openapi.models import ClassificationResult
router = APIRouter(prefix="/api/openapi", tags=["openapi"])
# In-memory store: job_id -> job dict
_job_store: dict[str, dict] = {}
# --- Request / Response schemas ---
class ImportRequest(BaseModel):
url: str
@field_validator("url")
@classmethod
def url_must_not_be_empty(cls, value: str) -> str:
if not value.strip():
raise ValueError("url must not be empty")
return value
class JobResponse(BaseModel):
job_id: str
status: str
spec_url: str
total_endpoints: int = 0
classified_count: int = 0
error_message: str | None = None
class ClassificationResponse(BaseModel):
index: int
access_type: str
needs_interrupt: bool
agent_group: str
confidence: float
customer_params: list[str]
endpoint: dict
class UpdateClassificationRequest(BaseModel):
access_type: str
needs_interrupt: bool
agent_group: str
# --- Helpers ---
def _job_to_response(job: dict) -> dict:
return {
"job_id": job["job_id"],
"status": job["status"],
"spec_url": job["spec_url"],
"total_endpoints": job.get("total_endpoints", 0),
"classified_count": job.get("classified_count", 0),
"error_message": job.get("error_message"),
}
def _classification_to_response(idx: int, clf: ClassificationResult) -> dict:
ep = clf.endpoint
return {
"index": idx,
"access_type": clf.access_type,
"needs_interrupt": clf.needs_interrupt,
"agent_group": clf.agent_group,
"confidence": clf.confidence,
"customer_params": list(clf.customer_params),
"endpoint": {
"path": ep.path,
"method": ep.method,
"operation_id": ep.operation_id,
"summary": ep.summary,
"description": ep.description,
},
}
# --- Endpoints ---
@router.post("/import", status_code=202)
async def start_import(request: ImportRequest) -> dict:
"""Start an OpenAPI import job for the given spec URL."""
job_id = str(uuid.uuid4())
job: dict = {
"job_id": job_id,
"status": "pending",
"spec_url": request.url,
"total_endpoints": 0,
"classified_count": 0,
"error_message": None,
"classifications": [],
}
_job_store[job_id] = job
return _job_to_response(job)
@router.get("/jobs/{job_id}")
async def get_job(job_id: str) -> dict:
"""Get the status of an import job."""
job = _job_store.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
return _job_to_response(job)
@router.get("/jobs/{job_id}/classifications")
async def get_classifications(job_id: str) -> list:
"""Get all classifications for an import job."""
job = _job_store.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
classifications: list[ClassificationResult] = job.get("classifications", [])
return [_classification_to_response(i, clf) for i, clf in enumerate(classifications)]
@router.put("/jobs/{job_id}/classifications/{idx}")
async def update_classification(
job_id: str,
idx: int,
request: UpdateClassificationRequest,
) -> dict:
"""Update a specific classification by index."""
job = _job_store.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
classifications: list[ClassificationResult] = job.get("classifications", [])
if idx < 0 or idx >= len(classifications):
raise HTTPException(
status_code=404,
detail=f"Classification index {idx} out of range",
)
original = classifications[idx]
updated = ClassificationResult(
endpoint=original.endpoint,
access_type=request.access_type,
customer_params=original.customer_params,
agent_group=request.agent_group,
confidence=original.confidence,
needs_interrupt=request.needs_interrupt,
)
new_classifications = list(classifications)
new_classifications[idx] = updated
_job_store[job_id] = {**job, "classifications": new_classifications}
return _classification_to_response(idx, updated)
@router.post("/jobs/{job_id}/approve")
async def approve_job(job_id: str) -> dict:
"""Approve a job's classifications and trigger tool generation."""
job = _job_store.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
updated_job = {**job, "status": "approved"}
_job_store[job_id] = updated_job
return _job_to_response(updated_job)

161
backend/app/openapi/ssrf.py Normal file
View File

@@ -0,0 +1,161 @@
"""SSRF protection module.
Validates URLs before making external HTTP requests.
Blocks private IPs, loopback addresses, link-local, and DNS rebinding attacks.
"""
from __future__ import annotations
import ipaddress
import socket
from dataclasses import dataclass
from urllib.parse import urlparse
import httpx
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
@dataclass(frozen=True)
class SSRFPolicy:
"""Configuration for SSRF protection."""
allowed_schemes: frozenset[str] = frozenset({"http", "https"})
allowed_hosts: frozenset[str] | None = None # None = all public hosts allowed
max_redirects: int = 5
timeout_seconds: float = 30.0
_BLOCKED_NETWORKS = [
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("0.0.0.0/32"),
# IPv6
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fe80::/10"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("::/128"),
]
DEFAULT_POLICY = SSRFPolicy()
def is_private_ip(ip_str: str) -> bool:
"""Check if an IP address is private/reserved."""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
return True # Invalid IP treated as blocked
return any(addr in network for network in _BLOCKED_NETWORKS)
def validate_url(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> str:
"""Validate a URL against SSRF policy.
Returns the validated URL string.
Raises SSRFError if the URL is blocked.
"""
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in policy.allowed_schemes:
raise SSRFError(
f"URL scheme '{parsed.scheme}' is not allowed. "
f"Allowed: {', '.join(sorted(policy.allowed_schemes))}"
)
# Check hostname exists
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL has no hostname")
# Check allowed hosts whitelist
if policy.allowed_hosts is not None and hostname not in policy.allowed_hosts:
raise SSRFError(f"Host '{hostname}' is not in the allowed hosts list")
# DNS resolution -- resolve before making any request
resolved_ips = resolve_hostname(hostname)
if not resolved_ips:
raise SSRFError(f"Could not resolve hostname '{hostname}'")
# Check all resolved IPs against blocked networks
for ip_str in resolved_ips:
if is_private_ip(ip_str):
raise SSRFError(
f"Host '{hostname}' resolves to private/reserved IP {ip_str}. "
"Request blocked for SSRF protection."
)
return url
def resolve_hostname(hostname: str) -> list[str]:
"""Resolve hostname to IP addresses via DNS."""
try:
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
return list({result[4][0] for result in results})
except socket.gaierror:
return []
async def safe_fetch(
url: str,
*,
policy: SSRFPolicy = DEFAULT_POLICY,
method: str = "GET",
headers: dict[str, str] | None = None,
) -> httpx.Response:
"""Fetch a URL with SSRF protection.
Validates the URL, resolves DNS, checks IPs, then makes the request.
After receiving the response, verifies the actual connected IP
to guard against DNS rebinding.
"""
validate_url(url, policy)
# Make the request with redirect following disabled so we can check each hop
async with httpx.AsyncClient(
follow_redirects=False,
timeout=httpx.Timeout(policy.timeout_seconds),
) as client:
current_url = url
for _redirect_count in range(policy.max_redirects + 1):
response = await client.request(
method,
current_url,
headers=headers,
)
if response.is_redirect:
redirect_url = str(response.next_request.url) if response.next_request else None
if not redirect_url:
raise SSRFError("Redirect with no target URL")
# Validate the redirect target
validate_url(redirect_url, policy)
current_url = redirect_url
continue
return response
raise SSRFError(
f"Too many redirects (max {policy.max_redirects}). "
"Possible redirect loop or evasion attempt."
)
async def safe_fetch_text(
url: str,
*,
policy: SSRFPolicy = DEFAULT_POLICY,
headers: dict[str, str] | None = None,
) -> str:
"""Fetch a URL and return text content with SSRF protection."""
response = await safe_fetch(url, policy=policy, headers=headers)
response.raise_for_status()
return response.text

View File

@@ -0,0 +1,51 @@
"""OpenAPI spec validator.
Validates an OpenAPI spec dict for required fields and basic structural
correctness. Returns a list of human-readable error strings.
"""
from __future__ import annotations
_SUPPORTED_VERSIONS = ("3.0.", "3.1.")
_REQUIRED_FIELDS = ("openapi", "info", "paths")
def validate_spec(spec_dict: object) -> list[str]:
"""Validate an OpenAPI spec dict.
Returns a list of error strings. An empty list means the spec is valid.
Does not raise; all errors are captured and returned.
"""
if not isinstance(spec_dict, dict):
return [f"Spec must be a dict, got {type(spec_dict).__name__}"]
errors: list[str] = []
# Check required top-level fields
for field in _REQUIRED_FIELDS:
if field not in spec_dict:
errors.append(f"Missing required field: '{field}'")
# Validate openapi version if present
if "openapi" in spec_dict:
version = spec_dict["openapi"]
if not isinstance(version, str):
errors.append(f"'openapi' must be a string, got {type(version).__name__}")
elif not any(version.startswith(prefix) for prefix in _SUPPORTED_VERSIONS):
errors.append(
f"Unsupported OpenAPI version '{version}'. "
f"Supported versions start with: {', '.join(_SUPPORTED_VERSIONS)}"
)
# Validate info object if present
if "info" in spec_dict and isinstance(spec_dict["info"], dict):
info = spec_dict["info"]
for sub_field in ("title", "version"):
if sub_field not in info:
errors.append(f"Missing required field in 'info': '{sub_field}'")
# Validate paths object if present
if "paths" in spec_dict and not isinstance(spec_dict["paths"], dict):
errors.append(f"'paths' must be a dict, got {type(spec_dict['paths']).__name__}")
return errors