fix: address critical security and code review findings in Phase 3
- Wire ImportOrchestrator into review_api start_import via BackgroundTasks - Sanitize docstrings in generated tool code to prevent code injection - Add Literal["read", "write"] validation for access_type - Add regex validation for agent_group - Validate URL scheme (http/https only) in ImportRequest - Validate LLM output fields (clamp confidence, validate access_type) - Use dataclasses.replace instead of manual reconstruction in importer - Expand SSRF blocked networks (Carrier-Grade NAT, IPv4-mapped IPv6, etc.) - Make _BLOCKED_NETWORKS immutable tuple - Use yaml.safe_dump instead of yaml.dump - Fix _to_snake_case for empty strings and Python keywords
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""FastAPI router for OpenAPI import review workflow.
|
||||
|
||||
Exposes endpoints for:
|
||||
- Starting an import job
|
||||
- Starting an import job (triggers background pipeline)
|
||||
- Querying job status
|
||||
- Reviewing and editing classifications
|
||||
- Approving a job to trigger tool generation
|
||||
@@ -9,17 +9,28 @@ Exposes endpoints for:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from app.openapi.models import ClassificationResult
|
||||
from app.openapi.importer import ImportOrchestrator
|
||||
from app.openapi.models import ClassificationResult, ImportJob
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/openapi", tags=["openapi"])
|
||||
|
||||
# In-memory store: job_id -> job dict
|
||||
# In-memory store: job_id -> job dict, guarded by async lock
|
||||
_job_store: dict[str, dict] = {}
|
||||
_store_lock = asyncio.Lock()
|
||||
|
||||
# Shared orchestrator instance
|
||||
_orchestrator = ImportOrchestrator()
|
||||
|
||||
|
||||
# --- Request / Response schemas ---
|
||||
@@ -30,10 +41,13 @@ class ImportRequest(BaseModel):
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_must_not_be_empty(cls, value: str) -> str:
|
||||
if not value.strip():
|
||||
def url_must_be_valid(cls, value: str) -> str:
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
raise ValueError("url must not be empty")
|
||||
return value
|
||||
if not stripped.startswith(("http://", "https://")):
|
||||
raise ValueError("url must start with http:// or https://")
|
||||
return stripped
|
||||
|
||||
|
||||
class JobResponse(BaseModel):
|
||||
@@ -56,10 +70,20 @@ class ClassificationResponse(BaseModel):
|
||||
|
||||
|
||||
class UpdateClassificationRequest(BaseModel):
|
||||
access_type: str
|
||||
access_type: Literal["read", "write"]
|
||||
needs_interrupt: bool
|
||||
agent_group: str
|
||||
|
||||
@field_validator("agent_group")
|
||||
@classmethod
|
||||
def agent_group_must_be_safe(cls, value: str) -> str:
|
||||
if not value.strip() or not re.fullmatch(r"[a-zA-Z0-9_\-]+", value):
|
||||
raise ValueError(
|
||||
"agent_group must be non-empty and contain only "
|
||||
"alphanumeric characters, underscores, or hyphens"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
@@ -94,11 +118,49 @@ def _classification_to_response(idx: int, clf: ClassificationResult) -> dict:
|
||||
}
|
||||
|
||||
|
||||
async def _run_import(job_id: str, url: str) -> None:
|
||||
"""Run the import pipeline as a background task."""
|
||||
|
||||
def on_progress(stage: str, result_job: ImportJob) -> None:
|
||||
if job_id in _job_store:
|
||||
_job_store[job_id] = {
|
||||
**_job_store[job_id],
|
||||
"status": result_job.status,
|
||||
"total_endpoints": result_job.total_endpoints,
|
||||
"classified_count": result_job.classified_count,
|
||||
"error_message": result_job.error_message,
|
||||
}
|
||||
|
||||
try:
|
||||
result = await _orchestrator.start_import(
|
||||
url=url, job_id=job_id, on_progress=on_progress,
|
||||
)
|
||||
if job_id in _job_store:
|
||||
_job_store[job_id] = {
|
||||
**_job_store[job_id],
|
||||
"status": result.status,
|
||||
"total_endpoints": result.total_endpoints,
|
||||
"classified_count": result.classified_count,
|
||||
"error_message": result.error_message,
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Background import failed for job %s", job_id)
|
||||
if job_id in _job_store:
|
||||
_job_store[job_id] = {
|
||||
**_job_store[job_id],
|
||||
"status": "failed",
|
||||
"error_message": "Import failed. Please check the URL and try again.",
|
||||
}
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
|
||||
@router.post("/import", status_code=202)
|
||||
async def start_import(request: ImportRequest) -> dict:
|
||||
async def start_import(
|
||||
request: ImportRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> dict:
|
||||
"""Start an OpenAPI import job for the given spec URL."""
|
||||
job_id = str(uuid.uuid4())
|
||||
job: dict = {
|
||||
@@ -111,6 +173,7 @@ async def start_import(request: ImportRequest) -> dict:
|
||||
"classifications": [],
|
||||
}
|
||||
_job_store[job_id] = job
|
||||
background_tasks.add_task(_run_import, job_id, request.url)
|
||||
return _job_to_response(job)
|
||||
|
||||
|
||||
@@ -130,7 +193,10 @@ async def get_classifications(job_id: str) -> list:
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||
return [_classification_to_response(i, clf) for i, clf in enumerate(classifications)]
|
||||
return [
|
||||
_classification_to_response(i, clf)
|
||||
for i, clf in enumerate(classifications)
|
||||
]
|
||||
|
||||
|
||||
@router.put("/jobs/{job_id}/classifications/{idx}")
|
||||
|
||||
Reference in New Issue
Block a user