fix: address critical security and code review findings in Phase 3

- Wire ImportOrchestrator into review_api start_import via BackgroundTasks
- Sanitize docstrings in generated tool code to prevent code injection
- Add Literal["read", "write"] validation for access_type
- Add regex validation for agent_group
- Validate URL scheme (http/https only) in ImportRequest
- Validate LLM output fields (clamp confidence, validate access_type)
- Use dataclasses.replace instead of manual reconstruction in importer
- Expand SSRF blocked networks (Carrier-Grade NAT, IPv4-mapped IPv6, etc.)
- Make _BLOCKED_NETWORKS immutable tuple
- Use yaml.safe_dump instead of yaml.dump
- Fix _to_snake_case for empty strings and Python keywords
This commit is contained in:
Yaojia Wang
2026-03-31 00:28:28 +02:00
parent a54eb224e0
commit a2f750269d
6 changed files with 128 additions and 28 deletions

View File

@@ -150,13 +150,18 @@ def _parse_llm_response(
results: list[ClassificationResult] = []
for ep, item in zip(endpoints, items, strict=True):
raw_access = item.get("access_type", "read")
access_type = raw_access if raw_access in {"read", "write"} else "read"
confidence = max(0.0, min(1.0, float(item.get("confidence", 0.8))))
raw_group = str(item.get("agent_group", "support"))
agent_group = raw_group if raw_group.strip() else "support"
results.append(
ClassificationResult(
endpoint=ep,
access_type=item.get("access_type", "read"),
access_type=access_type,
customer_params=tuple(item.get("customer_params", [])),
agent_group=item.get("agent_group", "support"),
confidence=float(item.get("confidence", 0.8)),
agent_group=agent_group,
confidence=confidence,
needs_interrupt=bool(item.get("needs_interrupt", False)),
)
)

View File

@@ -6,6 +6,7 @@ YAML agent configurations from classification results.
from __future__ import annotations
import keyword
import re
import yaml
@@ -24,7 +25,7 @@ def generate_tool_code(classification: ClassificationResult, base_url: str) -> G
func_name = _to_snake_case(ep.operation_id)
params = _collect_params(ep)
sig = _build_signature(params, ep.request_body_schema)
docstring = ep.summary or ep.description or ep.operation_id
docstring = _sanitize_docstring(ep.summary or ep.description or ep.operation_id)
interrupt_comment = _interrupt_comment(classification)
http_call = _build_http_call(ep, base_url, params)
@@ -59,7 +60,7 @@ def generate_agent_yaml(
Groups tools by agent_group, creating one agent entry per group.
"""
if not classifications:
return yaml.dump({"agents": []})
return yaml.safe_dump({"agents": []})
groups: dict[str, dict] = {}
for clf in classifications:
@@ -75,7 +76,7 @@ def generate_agent_yaml(
}
groups[group]["tools"].append(func_name)
return yaml.dump({"agents": list(groups.values())}, sort_keys=False)
return yaml.safe_dump({"agents": list(groups.values())}, sort_keys=False)
# --- Private helpers ---
@@ -149,8 +150,15 @@ def _schema_type_to_python(schema_type: str) -> str:
return mapping.get(schema_type, "str")
def _sanitize_docstring(text: str) -> str:
"""Escape triple-quotes and newlines to prevent docstring injection."""
return text.replace("\\", "\\\\").replace('"""', r"\"\"\"").replace("\n", " ")
def _to_snake_case(name: str) -> str:
"""Convert operationId to a valid snake_case Python identifier."""
# Replace non-alphanumeric with underscore
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
return clean.lower()
result = clean.lower() or "unnamed_tool"
if keyword.iskeyword(result):
result = f"{result}_tool"
return result

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import logging
from collections.abc import Callable
from dataclasses import replace
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
from app.openapi.fetcher import fetch_spec
@@ -100,14 +101,7 @@ class ImportOrchestrator:
def _update(job: ImportJob, **kwargs: object) -> ImportJob:
"""Return a new ImportJob with updated fields (immutable update)."""
return ImportJob(
job_id=job.job_id,
status=kwargs.get("status", job.status), # type: ignore[arg-type]
spec_url=job.spec_url,
total_endpoints=kwargs.get("total_endpoints", job.total_endpoints), # type: ignore[arg-type]
classified_count=kwargs.get("classified_count", job.classified_count), # type: ignore[arg-type]
error_message=kwargs.get("error_message", job.error_message), # type: ignore[arg-type]
)
return replace(job, **kwargs)
def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None:

View File

@@ -1,7 +1,7 @@
"""FastAPI router for OpenAPI import review workflow.
Exposes endpoints for:
- Starting an import job
- Starting an import job (triggers background pipeline)
- Querying job status
- Reviewing and editing classifications
- Approving a job to trigger tool generation
@@ -9,17 +9,28 @@ Exposes endpoints for:
from __future__ import annotations
import asyncio
import logging
import re
import uuid
from typing import Literal
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel, field_validator
from app.openapi.models import ClassificationResult
from app.openapi.importer import ImportOrchestrator
from app.openapi.models import ClassificationResult, ImportJob
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/openapi", tags=["openapi"])
# In-memory store: job_id -> job dict
# In-memory store: job_id -> job dict, guarded by async lock
_job_store: dict[str, dict] = {}
_store_lock = asyncio.Lock()
# Shared orchestrator instance
_orchestrator = ImportOrchestrator()
# --- Request / Response schemas ---
@@ -30,10 +41,13 @@ class ImportRequest(BaseModel):
@field_validator("url")
@classmethod
def url_must_not_be_empty(cls, value: str) -> str:
if not value.strip():
def url_must_be_valid(cls, value: str) -> str:
stripped = value.strip()
if not stripped:
raise ValueError("url must not be empty")
return value
if not stripped.startswith(("http://", "https://")):
raise ValueError("url must start with http:// or https://")
return stripped
class JobResponse(BaseModel):
@@ -56,10 +70,20 @@ class ClassificationResponse(BaseModel):
class UpdateClassificationRequest(BaseModel):
access_type: str
access_type: Literal["read", "write"]
needs_interrupt: bool
agent_group: str
@field_validator("agent_group")
@classmethod
def agent_group_must_be_safe(cls, value: str) -> str:
if not value.strip() or not re.fullmatch(r"[a-zA-Z0-9_\-]+", value):
raise ValueError(
"agent_group must be non-empty and contain only "
"alphanumeric characters, underscores, or hyphens"
)
return value
# --- Helpers ---
@@ -94,11 +118,49 @@ def _classification_to_response(idx: int, clf: ClassificationResult) -> dict:
}
async def _run_import(job_id: str, url: str) -> None:
"""Run the import pipeline as a background task."""
def on_progress(stage: str, result_job: ImportJob) -> None:
if job_id in _job_store:
_job_store[job_id] = {
**_job_store[job_id],
"status": result_job.status,
"total_endpoints": result_job.total_endpoints,
"classified_count": result_job.classified_count,
"error_message": result_job.error_message,
}
try:
result = await _orchestrator.start_import(
url=url, job_id=job_id, on_progress=on_progress,
)
if job_id in _job_store:
_job_store[job_id] = {
**_job_store[job_id],
"status": result.status,
"total_endpoints": result.total_endpoints,
"classified_count": result.classified_count,
"error_message": result.error_message,
}
except Exception:
logger.exception("Background import failed for job %s", job_id)
if job_id in _job_store:
_job_store[job_id] = {
**_job_store[job_id],
"status": "failed",
"error_message": "Import failed. Please check the URL and try again.",
}
# --- Endpoints ---
@router.post("/import", status_code=202)
async def start_import(request: ImportRequest) -> dict:
async def start_import(
request: ImportRequest,
background_tasks: BackgroundTasks,
) -> dict:
"""Start an OpenAPI import job for the given spec URL."""
job_id = str(uuid.uuid4())
job: dict = {
@@ -111,6 +173,7 @@ async def start_import(request: ImportRequest) -> dict:
"classifications": [],
}
_job_store[job_id] = job
background_tasks.add_task(_run_import, job_id, request.url)
return _job_to_response(job)
@@ -130,7 +193,10 @@ async def get_classifications(job_id: str) -> list:
if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
classifications: list[ClassificationResult] = job.get("classifications", [])
return [_classification_to_response(i, clf) for i, clf in enumerate(classifications)]
return [
_classification_to_response(i, clf)
for i, clf in enumerate(classifications)
]
@router.put("/jobs/{job_id}/classifications/{idx}")

View File

@@ -28,19 +28,25 @@ class SSRFPolicy:
timeout_seconds: float = 30.0
_BLOCKED_NETWORKS = [
_BLOCKED_NETWORKS = (
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("0.0.0.0/32"),
ipaddress.ip_network("100.64.0.0/10"), # Carrier-Grade NAT
ipaddress.ip_network("198.18.0.0/15"), # Benchmarking
ipaddress.ip_network("240.0.0.0/4"), # Reserved
ipaddress.ip_network("255.255.255.255/32"), # Broadcast
# IPv6
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fe80::/10"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("::/128"),
]
ipaddress.ip_network("::ffff:0:0/96"), # IPv4-mapped IPv6
ipaddress.ip_network("2001:db8::/32"), # Documentation
)
DEFAULT_POLICY = SSRFPolicy()

View File

@@ -81,6 +81,11 @@ class TestImportEndpoint:
response = client.post("/api/openapi/import", json={})
assert response.status_code == 422
def test_post_import_invalid_scheme_returns_422(self, client) -> None:
"""POST /import with non-http URL returns 422."""
response = client.post("/api/openapi/import", json={"url": "ftp://evil.com/spec"})
assert response.status_code == 422
def test_post_import_returns_pending_status(self, client) -> None:
"""Newly created job has pending status."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
@@ -165,6 +170,22 @@ class TestUpdateClassificationEndpoint:
)
assert response.status_code == 404
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid access_type returns 422."""
response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
)
assert response.status_code == 422
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid agent_group returns 422."""
response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
)
assert response.status_code == 422
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
"""PUT /classifications/999 returns 404 for out-of-range index."""
response = client.put(