fix: address critical security and code review findings in Phase 3

- Wire ImportOrchestrator into review_api start_import via BackgroundTasks
- Sanitize docstrings in generated tool code to prevent code injection
- Add Literal["read", "write"] validation for access_type
- Add regex validation for agent_group
- Validate URL scheme (http/https only) in ImportRequest
- Validate LLM output fields (clamp confidence, validate access_type)
- Use dataclasses.replace instead of manual reconstruction in importer
- Expand SSRF blocked networks (Carrier-Grade NAT, IPv4-mapped IPv6, etc.)
- Make _BLOCKED_NETWORKS immutable tuple
- Use yaml.safe_dump instead of yaml.dump
- Fix _to_snake_case for empty strings and Python keywords
This commit is contained in:
Yaojia Wang
2026-03-31 00:28:28 +02:00
parent a54eb224e0
commit a2f750269d
6 changed files with 128 additions and 28 deletions

View File

@@ -150,13 +150,18 @@ def _parse_llm_response(
results: list[ClassificationResult] = [] results: list[ClassificationResult] = []
for ep, item in zip(endpoints, items, strict=True): for ep, item in zip(endpoints, items, strict=True):
raw_access = item.get("access_type", "read")
access_type = raw_access if raw_access in {"read", "write"} else "read"
confidence = max(0.0, min(1.0, float(item.get("confidence", 0.8))))
raw_group = str(item.get("agent_group", "support"))
agent_group = raw_group if raw_group.strip() else "support"
results.append( results.append(
ClassificationResult( ClassificationResult(
endpoint=ep, endpoint=ep,
access_type=item.get("access_type", "read"), access_type=access_type,
customer_params=tuple(item.get("customer_params", [])), customer_params=tuple(item.get("customer_params", [])),
agent_group=item.get("agent_group", "support"), agent_group=agent_group,
confidence=float(item.get("confidence", 0.8)), confidence=confidence,
needs_interrupt=bool(item.get("needs_interrupt", False)), needs_interrupt=bool(item.get("needs_interrupt", False)),
) )
) )

View File

@@ -6,6 +6,7 @@ YAML agent configurations from classification results.
from __future__ import annotations from __future__ import annotations
import keyword
import re import re
import yaml import yaml
@@ -24,7 +25,7 @@ def generate_tool_code(classification: ClassificationResult, base_url: str) -> G
func_name = _to_snake_case(ep.operation_id) func_name = _to_snake_case(ep.operation_id)
params = _collect_params(ep) params = _collect_params(ep)
sig = _build_signature(params, ep.request_body_schema) sig = _build_signature(params, ep.request_body_schema)
docstring = ep.summary or ep.description or ep.operation_id docstring = _sanitize_docstring(ep.summary or ep.description or ep.operation_id)
interrupt_comment = _interrupt_comment(classification) interrupt_comment = _interrupt_comment(classification)
http_call = _build_http_call(ep, base_url, params) http_call = _build_http_call(ep, base_url, params)
@@ -59,7 +60,7 @@ def generate_agent_yaml(
Groups tools by agent_group, creating one agent entry per group. Groups tools by agent_group, creating one agent entry per group.
""" """
if not classifications: if not classifications:
return yaml.dump({"agents": []}) return yaml.safe_dump({"agents": []})
groups: dict[str, dict] = {} groups: dict[str, dict] = {}
for clf in classifications: for clf in classifications:
@@ -75,7 +76,7 @@ def generate_agent_yaml(
} }
groups[group]["tools"].append(func_name) groups[group]["tools"].append(func_name)
return yaml.dump({"agents": list(groups.values())}, sort_keys=False) return yaml.safe_dump({"agents": list(groups.values())}, sort_keys=False)
# --- Private helpers --- # --- Private helpers ---
@@ -149,8 +150,15 @@ def _schema_type_to_python(schema_type: str) -> str:
return mapping.get(schema_type, "str") return mapping.get(schema_type, "str")
def _sanitize_docstring(text: str) -> str:
"""Escape triple-quotes and newlines to prevent docstring injection."""
return text.replace("\\", "\\\\").replace('"""', r"\"\"\"").replace("\n", " ")
def _to_snake_case(name: str) -> str: def _to_snake_case(name: str) -> str:
"""Convert operationId to a valid snake_case Python identifier.""" """Convert operationId to a valid snake_case Python identifier."""
# Replace non-alphanumeric with underscore
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_") clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
return clean.lower() result = clean.lower() or "unnamed_tool"
if keyword.iskeyword(result):
result = f"{result}_tool"
return result

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from dataclasses import replace
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
from app.openapi.fetcher import fetch_spec from app.openapi.fetcher import fetch_spec
@@ -100,14 +101,7 @@ class ImportOrchestrator:
def _update(job: ImportJob, **kwargs: object) -> ImportJob: def _update(job: ImportJob, **kwargs: object) -> ImportJob:
"""Return a new ImportJob with updated fields (immutable update).""" """Return a new ImportJob with updated fields (immutable update)."""
return ImportJob( return replace(job, **kwargs)
job_id=job.job_id,
status=kwargs.get("status", job.status), # type: ignore[arg-type]
spec_url=job.spec_url,
total_endpoints=kwargs.get("total_endpoints", job.total_endpoints), # type: ignore[arg-type]
classified_count=kwargs.get("classified_count", job.classified_count), # type: ignore[arg-type]
error_message=kwargs.get("error_message", job.error_message), # type: ignore[arg-type]
)
def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None: def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None:

View File

@@ -1,7 +1,7 @@
"""FastAPI router for OpenAPI import review workflow. """FastAPI router for OpenAPI import review workflow.
Exposes endpoints for: Exposes endpoints for:
- Starting an import job - Starting an import job (triggers background pipeline)
- Querying job status - Querying job status
- Reviewing and editing classifications - Reviewing and editing classifications
- Approving a job to trigger tool generation - Approving a job to trigger tool generation
@@ -9,17 +9,28 @@ Exposes endpoints for:
from __future__ import annotations from __future__ import annotations
import asyncio
import logging
import re
import uuid import uuid
from typing import Literal
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from app.openapi.models import ClassificationResult from app.openapi.importer import ImportOrchestrator
from app.openapi.models import ClassificationResult, ImportJob
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/openapi", tags=["openapi"]) router = APIRouter(prefix="/api/openapi", tags=["openapi"])
# In-memory store: job_id -> job dict # In-memory store: job_id -> job dict, guarded by async lock
_job_store: dict[str, dict] = {} _job_store: dict[str, dict] = {}
_store_lock = asyncio.Lock()
# Shared orchestrator instance
_orchestrator = ImportOrchestrator()
# --- Request / Response schemas --- # --- Request / Response schemas ---
@@ -30,10 +41,13 @@ class ImportRequest(BaseModel):
@field_validator("url") @field_validator("url")
@classmethod @classmethod
def url_must_not_be_empty(cls, value: str) -> str: def url_must_be_valid(cls, value: str) -> str:
if not value.strip(): stripped = value.strip()
if not stripped:
raise ValueError("url must not be empty") raise ValueError("url must not be empty")
return value if not stripped.startswith(("http://", "https://")):
raise ValueError("url must start with http:// or https://")
return stripped
class JobResponse(BaseModel): class JobResponse(BaseModel):
@@ -56,10 +70,20 @@ class ClassificationResponse(BaseModel):
class UpdateClassificationRequest(BaseModel): class UpdateClassificationRequest(BaseModel):
access_type: str access_type: Literal["read", "write"]
needs_interrupt: bool needs_interrupt: bool
agent_group: str agent_group: str
@field_validator("agent_group")
@classmethod
def agent_group_must_be_safe(cls, value: str) -> str:
if not value.strip() or not re.fullmatch(r"[a-zA-Z0-9_\-]+", value):
raise ValueError(
"agent_group must be non-empty and contain only "
"alphanumeric characters, underscores, or hyphens"
)
return value
# --- Helpers --- # --- Helpers ---
@@ -94,11 +118,49 @@ def _classification_to_response(idx: int, clf: ClassificationResult) -> dict:
} }
async def _run_import(job_id: str, url: str) -> None:
"""Run the import pipeline as a background task."""
def on_progress(stage: str, result_job: ImportJob) -> None:
if job_id in _job_store:
_job_store[job_id] = {
**_job_store[job_id],
"status": result_job.status,
"total_endpoints": result_job.total_endpoints,
"classified_count": result_job.classified_count,
"error_message": result_job.error_message,
}
try:
result = await _orchestrator.start_import(
url=url, job_id=job_id, on_progress=on_progress,
)
if job_id in _job_store:
_job_store[job_id] = {
**_job_store[job_id],
"status": result.status,
"total_endpoints": result.total_endpoints,
"classified_count": result.classified_count,
"error_message": result.error_message,
}
except Exception:
logger.exception("Background import failed for job %s", job_id)
if job_id in _job_store:
_job_store[job_id] = {
**_job_store[job_id],
"status": "failed",
"error_message": "Import failed. Please check the URL and try again.",
}
# --- Endpoints --- # --- Endpoints ---
@router.post("/import", status_code=202) @router.post("/import", status_code=202)
async def start_import(request: ImportRequest) -> dict: async def start_import(
request: ImportRequest,
background_tasks: BackgroundTasks,
) -> dict:
"""Start an OpenAPI import job for the given spec URL.""" """Start an OpenAPI import job for the given spec URL."""
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
job: dict = { job: dict = {
@@ -111,6 +173,7 @@ async def start_import(request: ImportRequest) -> dict:
"classifications": [], "classifications": [],
} }
_job_store[job_id] = job _job_store[job_id] = job
background_tasks.add_task(_run_import, job_id, request.url)
return _job_to_response(job) return _job_to_response(job)
@@ -130,7 +193,10 @@ async def get_classifications(job_id: str) -> list:
if job is None: if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
classifications: list[ClassificationResult] = job.get("classifications", []) classifications: list[ClassificationResult] = job.get("classifications", [])
return [_classification_to_response(i, clf) for i, clf in enumerate(classifications)] return [
_classification_to_response(i, clf)
for i, clf in enumerate(classifications)
]
@router.put("/jobs/{job_id}/classifications/{idx}") @router.put("/jobs/{job_id}/classifications/{idx}")

View File

@@ -28,19 +28,25 @@ class SSRFPolicy:
timeout_seconds: float = 30.0 timeout_seconds: float = 30.0
_BLOCKED_NETWORKS = [ _BLOCKED_NETWORKS = (
ipaddress.ip_network("10.0.0.0/8"), ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"), ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"), ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("127.0.0.0/8"), ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"), ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("0.0.0.0/32"), ipaddress.ip_network("0.0.0.0/32"),
ipaddress.ip_network("100.64.0.0/10"), # Carrier-Grade NAT
ipaddress.ip_network("198.18.0.0/15"), # Benchmarking
ipaddress.ip_network("240.0.0.0/4"), # Reserved
ipaddress.ip_network("255.255.255.255/32"), # Broadcast
# IPv6 # IPv6
ipaddress.ip_network("::1/128"), ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fe80::/10"), ipaddress.ip_network("fe80::/10"),
ipaddress.ip_network("fc00::/7"), ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("::/128"), ipaddress.ip_network("::/128"),
] ipaddress.ip_network("::ffff:0:0/96"), # IPv4-mapped IPv6
ipaddress.ip_network("2001:db8::/32"), # Documentation
)
DEFAULT_POLICY = SSRFPolicy() DEFAULT_POLICY = SSRFPolicy()

View File

@@ -81,6 +81,11 @@ class TestImportEndpoint:
response = client.post("/api/openapi/import", json={}) response = client.post("/api/openapi/import", json={})
assert response.status_code == 422 assert response.status_code == 422
def test_post_import_invalid_scheme_returns_422(self, client) -> None:
"""POST /import with non-http URL returns 422."""
response = client.post("/api/openapi/import", json={"url": "ftp://evil.com/spec"})
assert response.status_code == 422
def test_post_import_returns_pending_status(self, client) -> None: def test_post_import_returns_pending_status(self, client) -> None:
"""Newly created job has pending status.""" """Newly created job has pending status."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
@@ -165,6 +170,22 @@ class TestUpdateClassificationEndpoint:
) )
assert response.status_code == 404 assert response.status_code == 404
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid access_type returns 422."""
response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
)
assert response.status_code == 422
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid agent_group returns 422."""
response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
)
assert response.status_code == 422
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None: def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
"""PUT /classifications/999 returns 404 for out-of-range index.""" """PUT /classifications/999 returns 404 for out-of-range index."""
response = client.put( response = client.put(