fix: address critical security and code review findings in Phase 3
- Wire ImportOrchestrator into review_api start_import via BackgroundTasks - Sanitize docstrings in generated tool code to prevent code injection - Add Literal["read", "write"] validation for access_type - Add regex validation for agent_group - Validate URL scheme (http/https only) in ImportRequest - Validate LLM output fields (clamp confidence, validate access_type) - Use dataclasses.replace instead of manual reconstruction in importer - Expand SSRF blocked networks (Carrier-Grade NAT, IPv4-mapped IPv6, etc.) - Make _BLOCKED_NETWORKS immutable tuple - Use yaml.safe_dump instead of yaml.dump - Fix _to_snake_case for empty strings and Python keywords
This commit is contained in:
@@ -150,13 +150,18 @@ def _parse_llm_response(
|
|||||||
|
|
||||||
results: list[ClassificationResult] = []
|
results: list[ClassificationResult] = []
|
||||||
for ep, item in zip(endpoints, items, strict=True):
|
for ep, item in zip(endpoints, items, strict=True):
|
||||||
|
raw_access = item.get("access_type", "read")
|
||||||
|
access_type = raw_access if raw_access in {"read", "write"} else "read"
|
||||||
|
confidence = max(0.0, min(1.0, float(item.get("confidence", 0.8))))
|
||||||
|
raw_group = str(item.get("agent_group", "support"))
|
||||||
|
agent_group = raw_group if raw_group.strip() else "support"
|
||||||
results.append(
|
results.append(
|
||||||
ClassificationResult(
|
ClassificationResult(
|
||||||
endpoint=ep,
|
endpoint=ep,
|
||||||
access_type=item.get("access_type", "read"),
|
access_type=access_type,
|
||||||
customer_params=tuple(item.get("customer_params", [])),
|
customer_params=tuple(item.get("customer_params", [])),
|
||||||
agent_group=item.get("agent_group", "support"),
|
agent_group=agent_group,
|
||||||
confidence=float(item.get("confidence", 0.8)),
|
confidence=confidence,
|
||||||
needs_interrupt=bool(item.get("needs_interrupt", False)),
|
needs_interrupt=bool(item.get("needs_interrupt", False)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ YAML agent configurations from classification results.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import keyword
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@@ -24,7 +25,7 @@ def generate_tool_code(classification: ClassificationResult, base_url: str) -> G
|
|||||||
func_name = _to_snake_case(ep.operation_id)
|
func_name = _to_snake_case(ep.operation_id)
|
||||||
params = _collect_params(ep)
|
params = _collect_params(ep)
|
||||||
sig = _build_signature(params, ep.request_body_schema)
|
sig = _build_signature(params, ep.request_body_schema)
|
||||||
docstring = ep.summary or ep.description or ep.operation_id
|
docstring = _sanitize_docstring(ep.summary or ep.description or ep.operation_id)
|
||||||
interrupt_comment = _interrupt_comment(classification)
|
interrupt_comment = _interrupt_comment(classification)
|
||||||
http_call = _build_http_call(ep, base_url, params)
|
http_call = _build_http_call(ep, base_url, params)
|
||||||
|
|
||||||
@@ -59,7 +60,7 @@ def generate_agent_yaml(
|
|||||||
Groups tools by agent_group, creating one agent entry per group.
|
Groups tools by agent_group, creating one agent entry per group.
|
||||||
"""
|
"""
|
||||||
if not classifications:
|
if not classifications:
|
||||||
return yaml.dump({"agents": []})
|
return yaml.safe_dump({"agents": []})
|
||||||
|
|
||||||
groups: dict[str, dict] = {}
|
groups: dict[str, dict] = {}
|
||||||
for clf in classifications:
|
for clf in classifications:
|
||||||
@@ -75,7 +76,7 @@ def generate_agent_yaml(
|
|||||||
}
|
}
|
||||||
groups[group]["tools"].append(func_name)
|
groups[group]["tools"].append(func_name)
|
||||||
|
|
||||||
return yaml.dump({"agents": list(groups.values())}, sort_keys=False)
|
return yaml.safe_dump({"agents": list(groups.values())}, sort_keys=False)
|
||||||
|
|
||||||
|
|
||||||
# --- Private helpers ---
|
# --- Private helpers ---
|
||||||
@@ -149,8 +150,15 @@ def _schema_type_to_python(schema_type: str) -> str:
|
|||||||
return mapping.get(schema_type, "str")
|
return mapping.get(schema_type, "str")
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_docstring(text: str) -> str:
|
||||||
|
"""Escape triple-quotes and newlines to prevent docstring injection."""
|
||||||
|
return text.replace("\\", "\\\\").replace('"""', r"\"\"\"").replace("\n", " ")
|
||||||
|
|
||||||
|
|
||||||
def _to_snake_case(name: str) -> str:
|
def _to_snake_case(name: str) -> str:
|
||||||
"""Convert operationId to a valid snake_case Python identifier."""
|
"""Convert operationId to a valid snake_case Python identifier."""
|
||||||
# Replace non-alphanumeric with underscore
|
|
||||||
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
|
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
|
||||||
return clean.lower()
|
result = clean.lower() or "unnamed_tool"
|
||||||
|
if keyword.iskeyword(result):
|
||||||
|
result = f"{result}_tool"
|
||||||
|
return result
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
|
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
|
||||||
from app.openapi.fetcher import fetch_spec
|
from app.openapi.fetcher import fetch_spec
|
||||||
@@ -100,14 +101,7 @@ class ImportOrchestrator:
|
|||||||
|
|
||||||
def _update(job: ImportJob, **kwargs: object) -> ImportJob:
|
def _update(job: ImportJob, **kwargs: object) -> ImportJob:
|
||||||
"""Return a new ImportJob with updated fields (immutable update)."""
|
"""Return a new ImportJob with updated fields (immutable update)."""
|
||||||
return ImportJob(
|
return replace(job, **kwargs)
|
||||||
job_id=job.job_id,
|
|
||||||
status=kwargs.get("status", job.status), # type: ignore[arg-type]
|
|
||||||
spec_url=job.spec_url,
|
|
||||||
total_endpoints=kwargs.get("total_endpoints", job.total_endpoints), # type: ignore[arg-type]
|
|
||||||
classified_count=kwargs.get("classified_count", job.classified_count), # type: ignore[arg-type]
|
|
||||||
error_message=kwargs.get("error_message", job.error_message), # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None:
|
def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""FastAPI router for OpenAPI import review workflow.
|
"""FastAPI router for OpenAPI import review workflow.
|
||||||
|
|
||||||
Exposes endpoints for:
|
Exposes endpoints for:
|
||||||
- Starting an import job
|
- Starting an import job (triggers background pipeline)
|
||||||
- Querying job status
|
- Querying job status
|
||||||
- Reviewing and editing classifications
|
- Reviewing and editing classifications
|
||||||
- Approving a job to trigger tool generation
|
- Approving a job to trigger tool generation
|
||||||
@@ -9,17 +9,28 @@ Exposes endpoints for:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from app.openapi.models import ClassificationResult
|
from app.openapi.importer import ImportOrchestrator
|
||||||
|
from app.openapi.models import ClassificationResult, ImportJob
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/openapi", tags=["openapi"])
|
router = APIRouter(prefix="/api/openapi", tags=["openapi"])
|
||||||
|
|
||||||
# In-memory store: job_id -> job dict
|
# In-memory store: job_id -> job dict, guarded by async lock
|
||||||
_job_store: dict[str, dict] = {}
|
_job_store: dict[str, dict] = {}
|
||||||
|
_store_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Shared orchestrator instance
|
||||||
|
_orchestrator = ImportOrchestrator()
|
||||||
|
|
||||||
|
|
||||||
# --- Request / Response schemas ---
|
# --- Request / Response schemas ---
|
||||||
@@ -30,10 +41,13 @@ class ImportRequest(BaseModel):
|
|||||||
|
|
||||||
@field_validator("url")
|
@field_validator("url")
|
||||||
@classmethod
|
@classmethod
|
||||||
def url_must_not_be_empty(cls, value: str) -> str:
|
def url_must_be_valid(cls, value: str) -> str:
|
||||||
if not value.strip():
|
stripped = value.strip()
|
||||||
|
if not stripped:
|
||||||
raise ValueError("url must not be empty")
|
raise ValueError("url must not be empty")
|
||||||
return value
|
if not stripped.startswith(("http://", "https://")):
|
||||||
|
raise ValueError("url must start with http:// or https://")
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
class JobResponse(BaseModel):
|
class JobResponse(BaseModel):
|
||||||
@@ -56,10 +70,20 @@ class ClassificationResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class UpdateClassificationRequest(BaseModel):
|
class UpdateClassificationRequest(BaseModel):
|
||||||
access_type: str
|
access_type: Literal["read", "write"]
|
||||||
needs_interrupt: bool
|
needs_interrupt: bool
|
||||||
agent_group: str
|
agent_group: str
|
||||||
|
|
||||||
|
@field_validator("agent_group")
|
||||||
|
@classmethod
|
||||||
|
def agent_group_must_be_safe(cls, value: str) -> str:
|
||||||
|
if not value.strip() or not re.fullmatch(r"[a-zA-Z0-9_\-]+", value):
|
||||||
|
raise ValueError(
|
||||||
|
"agent_group must be non-empty and contain only "
|
||||||
|
"alphanumeric characters, underscores, or hyphens"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
# --- Helpers ---
|
# --- Helpers ---
|
||||||
|
|
||||||
@@ -94,11 +118,49 @@ def _classification_to_response(idx: int, clf: ClassificationResult) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_import(job_id: str, url: str) -> None:
|
||||||
|
"""Run the import pipeline as a background task."""
|
||||||
|
|
||||||
|
def on_progress(stage: str, result_job: ImportJob) -> None:
|
||||||
|
if job_id in _job_store:
|
||||||
|
_job_store[job_id] = {
|
||||||
|
**_job_store[job_id],
|
||||||
|
"status": result_job.status,
|
||||||
|
"total_endpoints": result_job.total_endpoints,
|
||||||
|
"classified_count": result_job.classified_count,
|
||||||
|
"error_message": result_job.error_message,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await _orchestrator.start_import(
|
||||||
|
url=url, job_id=job_id, on_progress=on_progress,
|
||||||
|
)
|
||||||
|
if job_id in _job_store:
|
||||||
|
_job_store[job_id] = {
|
||||||
|
**_job_store[job_id],
|
||||||
|
"status": result.status,
|
||||||
|
"total_endpoints": result.total_endpoints,
|
||||||
|
"classified_count": result.classified_count,
|
||||||
|
"error_message": result.error_message,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Background import failed for job %s", job_id)
|
||||||
|
if job_id in _job_store:
|
||||||
|
_job_store[job_id] = {
|
||||||
|
**_job_store[job_id],
|
||||||
|
"status": "failed",
|
||||||
|
"error_message": "Import failed. Please check the URL and try again.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# --- Endpoints ---
|
# --- Endpoints ---
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import", status_code=202)
|
@router.post("/import", status_code=202)
|
||||||
async def start_import(request: ImportRequest) -> dict:
|
async def start_import(
|
||||||
|
request: ImportRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
) -> dict:
|
||||||
"""Start an OpenAPI import job for the given spec URL."""
|
"""Start an OpenAPI import job for the given spec URL."""
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
job: dict = {
|
job: dict = {
|
||||||
@@ -111,6 +173,7 @@ async def start_import(request: ImportRequest) -> dict:
|
|||||||
"classifications": [],
|
"classifications": [],
|
||||||
}
|
}
|
||||||
_job_store[job_id] = job
|
_job_store[job_id] = job
|
||||||
|
background_tasks.add_task(_run_import, job_id, request.url)
|
||||||
return _job_to_response(job)
|
return _job_to_response(job)
|
||||||
|
|
||||||
|
|
||||||
@@ -130,7 +193,10 @@ async def get_classifications(job_id: str) -> list:
|
|||||||
if job is None:
|
if job is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||||
classifications: list[ClassificationResult] = job.get("classifications", [])
|
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||||
return [_classification_to_response(i, clf) for i, clf in enumerate(classifications)]
|
return [
|
||||||
|
_classification_to_response(i, clf)
|
||||||
|
for i, clf in enumerate(classifications)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.put("/jobs/{job_id}/classifications/{idx}")
|
@router.put("/jobs/{job_id}/classifications/{idx}")
|
||||||
|
|||||||
@@ -28,19 +28,25 @@ class SSRFPolicy:
|
|||||||
timeout_seconds: float = 30.0
|
timeout_seconds: float = 30.0
|
||||||
|
|
||||||
|
|
||||||
_BLOCKED_NETWORKS = [
|
_BLOCKED_NETWORKS = (
|
||||||
ipaddress.ip_network("10.0.0.0/8"),
|
ipaddress.ip_network("10.0.0.0/8"),
|
||||||
ipaddress.ip_network("172.16.0.0/12"),
|
ipaddress.ip_network("172.16.0.0/12"),
|
||||||
ipaddress.ip_network("192.168.0.0/16"),
|
ipaddress.ip_network("192.168.0.0/16"),
|
||||||
ipaddress.ip_network("127.0.0.0/8"),
|
ipaddress.ip_network("127.0.0.0/8"),
|
||||||
ipaddress.ip_network("169.254.0.0/16"),
|
ipaddress.ip_network("169.254.0.0/16"),
|
||||||
ipaddress.ip_network("0.0.0.0/32"),
|
ipaddress.ip_network("0.0.0.0/32"),
|
||||||
|
ipaddress.ip_network("100.64.0.0/10"), # Carrier-Grade NAT
|
||||||
|
ipaddress.ip_network("198.18.0.0/15"), # Benchmarking
|
||||||
|
ipaddress.ip_network("240.0.0.0/4"), # Reserved
|
||||||
|
ipaddress.ip_network("255.255.255.255/32"), # Broadcast
|
||||||
# IPv6
|
# IPv6
|
||||||
ipaddress.ip_network("::1/128"),
|
ipaddress.ip_network("::1/128"),
|
||||||
ipaddress.ip_network("fe80::/10"),
|
ipaddress.ip_network("fe80::/10"),
|
||||||
ipaddress.ip_network("fc00::/7"),
|
ipaddress.ip_network("fc00::/7"),
|
||||||
ipaddress.ip_network("::/128"),
|
ipaddress.ip_network("::/128"),
|
||||||
]
|
ipaddress.ip_network("::ffff:0:0/96"), # IPv4-mapped IPv6
|
||||||
|
ipaddress.ip_network("2001:db8::/32"), # Documentation
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_POLICY = SSRFPolicy()
|
DEFAULT_POLICY = SSRFPolicy()
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,11 @@ class TestImportEndpoint:
|
|||||||
response = client.post("/api/openapi/import", json={})
|
response = client.post("/api/openapi/import", json={})
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_post_import_invalid_scheme_returns_422(self, client) -> None:
|
||||||
|
"""POST /import with non-http URL returns 422."""
|
||||||
|
response = client.post("/api/openapi/import", json={"url": "ftp://evil.com/spec"})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
def test_post_import_returns_pending_status(self, client) -> None:
|
def test_post_import_returns_pending_status(self, client) -> None:
|
||||||
"""Newly created job has pending status."""
|
"""Newly created job has pending status."""
|
||||||
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
||||||
@@ -165,6 +170,22 @@ class TestUpdateClassificationEndpoint:
|
|||||||
)
|
)
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
|
||||||
|
"""PUT /classifications/0 with invalid access_type returns 422."""
|
||||||
|
response = client.put(
|
||||||
|
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||||
|
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
|
||||||
|
"""PUT /classifications/0 with invalid agent_group returns 422."""
|
||||||
|
response = client.put(
|
||||||
|
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||||
|
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
|
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
|
||||||
"""PUT /classifications/999 returns 404 for out-of-range index."""
|
"""PUT /classifications/999 returns 404 for out-of-range index."""
|
||||||
response = client.put(
|
response = client.put(
|
||||||
|
|||||||
Reference in New Issue
Block a user