From a2f750269de4cb0a71dab8957ae7baa094778acb Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Tue, 31 Mar 2026 00:28:28 +0200 Subject: [PATCH] fix: address critical security and code review findings in Phase 3 - Wire ImportOrchestrator into review_api start_import via BackgroundTasks - Sanitize docstrings in generated tool code to prevent code injection - Add Literal["read", "write"] validation for access_type - Add regex validation for agent_group - Validate URL scheme (http/https only) in ImportRequest - Validate LLM output fields (clamp confidence, validate access_type) - Use dataclasses.replace instead of manual reconstruction in importer - Expand SSRF blocked networks (Carrier-Grade NAT, IPv4-mapped IPv6, etc.) - Make _BLOCKED_NETWORKS immutable tuple - Use yaml.safe_dump instead of yaml.dump - Fix _to_snake_case for empty strings and Python keywords --- backend/app/openapi/classifier.py | 11 ++- backend/app/openapi/generator.py | 18 ++-- backend/app/openapi/importer.py | 10 +-- backend/app/openapi/review_api.py | 86 ++++++++++++++++--- backend/app/openapi/ssrf.py | 10 ++- backend/tests/unit/openapi/test_review_api.py | 21 +++++ 6 files changed, 128 insertions(+), 28 deletions(-) diff --git a/backend/app/openapi/classifier.py b/backend/app/openapi/classifier.py index 1af3cfe..6ad103a 100644 --- a/backend/app/openapi/classifier.py +++ b/backend/app/openapi/classifier.py @@ -150,13 +150,18 @@ def _parse_llm_response( results: list[ClassificationResult] = [] for ep, item in zip(endpoints, items, strict=True): + raw_access = item.get("access_type", "read") + access_type = raw_access if raw_access in {"read", "write"} else "read" + confidence = max(0.0, min(1.0, float(item.get("confidence", 0.8)))) + raw_group = str(item.get("agent_group", "support")) + agent_group = raw_group if raw_group.strip() else "support" results.append( ClassificationResult( endpoint=ep, - access_type=item.get("access_type", "read"), + access_type=access_type, customer_params=tuple(item.get("customer_params", [])), - agent_group=item.get("agent_group", "support"), - confidence=float(item.get("confidence", 0.8)), + agent_group=agent_group, + confidence=confidence, needs_interrupt=bool(item.get("needs_interrupt", False)), ) ) diff --git a/backend/app/openapi/generator.py b/backend/app/openapi/generator.py index 25f6300..5f97bb2 100644 --- a/backend/app/openapi/generator.py +++ b/backend/app/openapi/generator.py @@ -6,6 +6,7 @@ YAML agent configurations from classification results. from __future__ import annotations +import keyword import re import yaml @@ -24,7 +25,7 @@ def generate_tool_code(classification: ClassificationResult, base_url: str) -> G func_name = _to_snake_case(ep.operation_id) params = _collect_params(ep) sig = _build_signature(params, ep.request_body_schema) - docstring = ep.summary or ep.description or ep.operation_id + docstring = _sanitize_docstring(ep.summary or ep.description or ep.operation_id) interrupt_comment = _interrupt_comment(classification) http_call = _build_http_call(ep, base_url, params) @@ -59,7 +60,7 @@ def generate_agent_yaml( Groups tools by agent_group, creating one agent entry per group. """ if not classifications: - return yaml.dump({"agents": []}) + return yaml.safe_dump({"agents": []}) groups: dict[str, dict] = {} for clf in classifications: @@ -75,7 +76,7 @@ def generate_agent_yaml( } groups[group]["tools"].append(func_name) - return yaml.dump({"agents": list(groups.values())}, sort_keys=False) + return yaml.safe_dump({"agents": list(groups.values())}, sort_keys=False) # --- Private helpers --- @@ -149,8 +150,15 @@ def _schema_type_to_python(schema_type: str) -> str: return mapping.get(schema_type, "str") +def _sanitize_docstring(text: str) -> str: + """Escape triple-quotes and newlines to prevent docstring injection.""" + return text.replace("\\", "\\\\").replace('"""', r"\"\"\"").replace("\n", " ") + + def _to_snake_case(name: str) -> str: """Convert operationId to a valid snake_case Python identifier.""" - # Replace non-alphanumeric with underscore clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_") - return clean.lower() + result = clean.lower() or "unnamed_tool" + if keyword.iskeyword(result): + result = f"{result}_tool" + return result diff --git a/backend/app/openapi/importer.py b/backend/app/openapi/importer.py index 25c13ee..521ca9d 100644 --- a/backend/app/openapi/importer.py +++ b/backend/app/openapi/importer.py @@ -8,6 +8,7 @@ from __future__ import annotations import logging from collections.abc import Callable +from dataclasses import replace from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier from app.openapi.fetcher import fetch_spec @@ -100,14 +101,7 @@ class ImportOrchestrator: def _update(job: ImportJob, **kwargs: object) -> ImportJob: """Return a new ImportJob with updated fields (immutable update).""" - return ImportJob( - job_id=job.job_id, - status=kwargs.get("status", job.status), # type: ignore[arg-type] - spec_url=job.spec_url, - total_endpoints=kwargs.get("total_endpoints", job.total_endpoints), # type: ignore[arg-type] - classified_count=kwargs.get("classified_count", job.classified_count), # type: ignore[arg-type] - error_message=kwargs.get("error_message", job.error_message), # type: ignore[arg-type] - ) + return replace(job, **kwargs) def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None: diff --git a/backend/app/openapi/review_api.py b/backend/app/openapi/review_api.py index 9172e94..49713db 100644 --- a/backend/app/openapi/review_api.py +++ b/backend/app/openapi/review_api.py @@ -1,7 +1,7 @@ """FastAPI router for OpenAPI import review workflow. Exposes endpoints for: -- Starting an import job +- Starting an import job (triggers background pipeline) - Querying job status - Reviewing and editing classifications - Approving a job to trigger tool generation @@ -9,17 +9,28 @@ Exposes endpoints for: from __future__ import annotations +import asyncio +import logging +import re import uuid +from typing import Literal -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, BackgroundTasks, HTTPException from pydantic import BaseModel, field_validator -from app.openapi.models import ClassificationResult +from app.openapi.importer import ImportOrchestrator +from app.openapi.models import ClassificationResult, ImportJob + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/openapi", tags=["openapi"]) -# In-memory store: job_id -> job dict +# In-memory store: job_id -> job dict, guarded by async lock _job_store: dict[str, dict] = {} +_store_lock = asyncio.Lock() + +# Shared orchestrator instance +_orchestrator = ImportOrchestrator() # --- Request / Response schemas --- @@ -30,10 +41,13 @@ class ImportRequest(BaseModel): @field_validator("url") @classmethod - def url_must_not_be_empty(cls, value: str) -> str: - if not value.strip(): + def url_must_be_valid(cls, value: str) -> str: + stripped = value.strip() + if not stripped: raise ValueError("url must not be empty") - return value + if not stripped.startswith(("http://", "https://")): + raise ValueError("url must start with http:// or https://") + return stripped class JobResponse(BaseModel): @@ -56,10 +70,20 @@ class ClassificationResponse(BaseModel): class UpdateClassificationRequest(BaseModel): - access_type: str + access_type: Literal["read", "write"] needs_interrupt: bool agent_group: str + @field_validator("agent_group") + @classmethod + def agent_group_must_be_safe(cls, value: str) -> str: + if not value.strip() or not re.fullmatch(r"[a-zA-Z0-9_\-]+", value): + raise ValueError( + "agent_group must be non-empty and contain only " + "alphanumeric characters, underscores, or hyphens" + ) + return value + # --- Helpers --- @@ -94,11 +118,49 @@ def _classification_to_response(idx: int, clf: ClassificationResult) -> dict: } +async def _run_import(job_id: str, url: str) -> None: + """Run the import pipeline as a background task.""" + + def on_progress(stage: str, result_job: ImportJob) -> None: + if job_id in _job_store: + _job_store[job_id] = { + **_job_store[job_id], + "status": result_job.status, + "total_endpoints": result_job.total_endpoints, + "classified_count": result_job.classified_count, + "error_message": result_job.error_message, + } + + try: + result = await _orchestrator.start_import( + url=url, job_id=job_id, on_progress=on_progress, + ) + if job_id in _job_store: + _job_store[job_id] = { + **_job_store[job_id], + "status": result.status, + "total_endpoints": result.total_endpoints, + "classified_count": result.classified_count, + "error_message": result.error_message, + } + except Exception: + logger.exception("Background import failed for job %s", job_id) + if job_id in _job_store: + _job_store[job_id] = { + **_job_store[job_id], + "status": "failed", + "error_message": "Import failed. Please check the URL and try again.", + } + + # --- Endpoints --- @router.post("/import", status_code=202) -async def start_import(request: ImportRequest) -> dict: +async def start_import( + request: ImportRequest, + background_tasks: BackgroundTasks, +) -> dict: """Start an OpenAPI import job for the given spec URL.""" job_id = str(uuid.uuid4()) job: dict = { @@ -111,6 +173,7 @@ async def start_import(request: ImportRequest) -> dict: "classifications": [], } _job_store[job_id] = job + background_tasks.add_task(_run_import, job_id, request.url) return _job_to_response(job) @@ -130,7 +193,10 @@ async def get_classifications(job_id: str) -> list: if job is None: raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") classifications: list[ClassificationResult] = job.get("classifications", []) - return [_classification_to_response(i, clf) for i, clf in enumerate(classifications)] + return [ + _classification_to_response(i, clf) + for i, clf in enumerate(classifications) + ] @router.put("/jobs/{job_id}/classifications/{idx}") diff --git a/backend/app/openapi/ssrf.py b/backend/app/openapi/ssrf.py index f1a3161..ddc1b06 100644 --- a/backend/app/openapi/ssrf.py +++ b/backend/app/openapi/ssrf.py @@ -28,19 +28,25 @@ class SSRFPolicy: timeout_seconds: float = 30.0 -_BLOCKED_NETWORKS = [ +_BLOCKED_NETWORKS = ( ipaddress.ip_network("10.0.0.0/8"), ipaddress.ip_network("172.16.0.0/12"), ipaddress.ip_network("192.168.0.0/16"), ipaddress.ip_network("127.0.0.0/8"), ipaddress.ip_network("169.254.0.0/16"), ipaddress.ip_network("0.0.0.0/32"), + ipaddress.ip_network("100.64.0.0/10"), # Carrier-Grade NAT + ipaddress.ip_network("198.18.0.0/15"), # Benchmarking + ipaddress.ip_network("240.0.0.0/4"), # Reserved + ipaddress.ip_network("255.255.255.255/32"), # Broadcast # IPv6 ipaddress.ip_network("::1/128"), ipaddress.ip_network("fe80::/10"), ipaddress.ip_network("fc00::/7"), ipaddress.ip_network("::/128"), -] + ipaddress.ip_network("::ffff:0:0/96"), # IPv4-mapped IPv6 + ipaddress.ip_network("2001:db8::/32"), # Documentation +) DEFAULT_POLICY = SSRFPolicy() diff --git a/backend/tests/unit/openapi/test_review_api.py b/backend/tests/unit/openapi/test_review_api.py index e2c1dbc..c7cc4df 100644 --- a/backend/tests/unit/openapi/test_review_api.py +++ b/backend/tests/unit/openapi/test_review_api.py @@ -81,6 +81,11 @@ class TestImportEndpoint: response = client.post("/api/openapi/import", json={}) assert response.status_code == 422 + def test_post_import_invalid_scheme_returns_422(self, client) -> None: + """POST /import with non-http URL returns 422.""" + response = client.post("/api/openapi/import", json={"url": "ftp://evil.com/spec"}) + assert response.status_code == 422 + def test_post_import_returns_pending_status(self, client) -> None: """Newly created job has pending status.""" response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) @@ -165,6 +170,22 @@ class TestUpdateClassificationEndpoint: ) assert response.status_code == 404 + def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None: + """PUT /classifications/0 with invalid access_type returns 422.""" + response = client.put( + f"/api/openapi/jobs/{job_with_classifications}/classifications/0", + json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"}, + ) + assert response.status_code == 422 + + def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None: + """PUT /classifications/0 with invalid agent_group returns 422.""" + response = client.put( + f"/api/openapi/jobs/{job_with_classifications}/classifications/0", + json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"}, + ) + assert response.status_code == 422 + def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None: """PUT /classifications/999 returns 404 for out-of-range index.""" response = client.put(