- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation - OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit - Structural spec validator (3.0.x/3.1.x) - Endpoint parser with $ref resolution, auto-generated operation IDs - Heuristic + LLM endpoint classifier with Protocol interface - Review API at /api/openapi (import, job status, classification CRUD, approve) - @tool code generator + Agent YAML generator - Import orchestrator (fetch -> validate -> parse -> classify pipeline) - 125 new tests, 322 total passing, 93.23% coverage
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
"""Import orchestrator for OpenAPI auto-discovery pipeline.
|
|
|
|
Orchestrates: fetch -> validate -> parse -> classify
|
|
Each stage updates the job status and calls the on_progress callback.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections.abc import Callable
|
|
|
|
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
|
|
from app.openapi.fetcher import fetch_spec
|
|
from app.openapi.models import ImportJob
|
|
from app.openapi.parser import parse_endpoints
|
|
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
|
|
from app.openapi.validator import validate_spec
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ProgressCallback = Callable[[str, ImportJob], None] | None
|
|
|
|
|
|
class ImportOrchestrator:
|
|
"""Orchestrates the full OpenAPI import pipeline.
|
|
|
|
Stages:
|
|
1. fetching -- download and parse spec from URL
|
|
2. validating -- check spec structure
|
|
3. parsing -- extract endpoint definitions
|
|
4. classifying -- classify endpoints for agent routing
|
|
5. done / failed
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
classifier: ClassifierProtocol | None = None,
|
|
policy: SSRFPolicy = DEFAULT_POLICY,
|
|
) -> None:
|
|
self._classifier = classifier or HeuristicClassifier()
|
|
self._policy = policy
|
|
|
|
async def start_import(
|
|
self,
|
|
url: str,
|
|
job_id: str,
|
|
on_progress: ProgressCallback,
|
|
) -> ImportJob:
|
|
"""Run the full import pipeline for a spec URL.
|
|
|
|
Returns an ImportJob reflecting final status (done or failed).
|
|
on_progress is called with (stage_name, current_job) at each stage.
|
|
Passing None for on_progress is safe.
|
|
"""
|
|
job = ImportJob(
|
|
job_id=job_id,
|
|
status="pending",
|
|
spec_url=url,
|
|
)
|
|
|
|
try:
|
|
# Stage 1: fetch
|
|
job = _update(job, status="fetching")
|
|
_notify(on_progress, "fetching", job)
|
|
spec_dict = await fetch_spec(url, self._policy)
|
|
|
|
# Stage 2: validate
|
|
job = _update(job, status="validating")
|
|
_notify(on_progress, "validating", job)
|
|
errors = validate_spec(spec_dict)
|
|
if errors:
|
|
raise ValueError(f"Invalid spec: {'; '.join(errors)}")
|
|
|
|
# Stage 3: parse
|
|
job = _update(job, status="parsing")
|
|
_notify(on_progress, "parsing", job)
|
|
endpoints = parse_endpoints(spec_dict)
|
|
|
|
# Stage 4: classify
|
|
job = _update(job, status="classifying", total_endpoints=len(endpoints))
|
|
_notify(on_progress, "classifying", job)
|
|
classifications = await self._classifier.classify(endpoints)
|
|
|
|
# Done
|
|
job = _update(
|
|
job,
|
|
status="done",
|
|
total_endpoints=len(endpoints),
|
|
classified_count=len(classifications),
|
|
)
|
|
_notify(on_progress, "done", job)
|
|
return job
|
|
|
|
except Exception as exc:
|
|
logger.exception("Import pipeline failed for job %s", job_id)
|
|
job = _update(job, status="failed", error_message=str(exc))
|
|
_notify(on_progress, "failed", job)
|
|
return job
|
|
|
|
|
|
def _update(job: ImportJob, **kwargs: object) -> ImportJob:
|
|
"""Return a new ImportJob with updated fields (immutable update)."""
|
|
return ImportJob(
|
|
job_id=job.job_id,
|
|
status=kwargs.get("status", job.status), # type: ignore[arg-type]
|
|
spec_url=job.spec_url,
|
|
total_endpoints=kwargs.get("total_endpoints", job.total_endpoints), # type: ignore[arg-type]
|
|
classified_count=kwargs.get("classified_count", job.classified_count), # type: ignore[arg-type]
|
|
error_message=kwargs.get("error_message", job.error_message), # type: ignore[arg-type]
|
|
)
|
|
|
|
|
|
def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None:
|
|
"""Call the progress callback if provided."""
|
|
if callback is not None:
|
|
callback(stage, job)
|