Address all architecture review findings: P0 fixes: - Add API key authentication for admin endpoints (analytics, replay, openapi) and WebSocket connections via ADMIN_API_KEY env var - Add PostgreSQL-backed PgSessionManager and PgInterruptManager for multi-worker production deployments (in-memory defaults preserved) P1 fixes: - Implement actual tool generation in OpenAPI approve_job endpoint using generate_tool_code() and generate_agent_yaml() - Add missing clarification, interrupt_expired, and tool_result message handlers in frontend ChatPage P2 fixes: - Replace monkey-patching on CompiledStateGraph with typed GraphContext - Replace 9-param dispatch_message with WebSocketContext dataclass - Extract duplicate _envelope() into shared app/api_utils.py - Replace mutable module-level counter with crypto.randomUUID() - Remove hardcoded mock data from ReviewPage, use api.ts wrappers - Remove `as any` type escape from ReplayPage All 516 tests passing, 0 TypeScript errors.
283 lines
8.4 KiB
Python
283 lines
8.4 KiB
Python
"""FastAPI router for OpenAPI import review workflow.
|
|
|
|
Exposes endpoints for:
|
|
- Starting an import job (triggers background pipeline)
|
|
- Querying job status
|
|
- Reviewing and editing classifications
|
|
- Approving a job to trigger tool generation
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
import uuid
|
|
from typing import Literal
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
|
from pydantic import BaseModel, field_validator
|
|
|
|
from app.auth import require_admin_api_key
|
|
from app.openapi.generator import generate_agent_yaml, generate_tool_code
|
|
from app.openapi.importer import ImportOrchestrator
|
|
from app.openapi.models import ClassificationResult, ImportJob
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/api/openapi",
|
|
tags=["openapi"],
|
|
dependencies=[Depends(require_admin_api_key)],
|
|
)
|
|
|
|
# In-memory store: job_id -> job dict, guarded by async lock
|
|
_job_store: dict[str, dict] = {}
|
|
_store_lock = asyncio.Lock()
|
|
|
|
# Shared orchestrator instance
|
|
_orchestrator = ImportOrchestrator()
|
|
|
|
|
|
# --- Request / Response schemas ---
|
|
|
|
|
|
class ImportRequest(BaseModel):
|
|
url: str
|
|
|
|
@field_validator("url")
|
|
@classmethod
|
|
def url_must_be_valid(cls, value: str) -> str:
|
|
stripped = value.strip()
|
|
if not stripped:
|
|
raise ValueError("url must not be empty")
|
|
if not stripped.startswith(("http://", "https://")):
|
|
raise ValueError("url must start with http:// or https://")
|
|
return stripped
|
|
|
|
|
|
class JobResponse(BaseModel):
|
|
job_id: str
|
|
status: str
|
|
spec_url: str
|
|
total_endpoints: int = 0
|
|
classified_count: int = 0
|
|
error_message: str | None = None
|
|
|
|
|
|
class ClassificationResponse(BaseModel):
|
|
index: int
|
|
access_type: str
|
|
needs_interrupt: bool
|
|
agent_group: str
|
|
confidence: float
|
|
customer_params: list[str]
|
|
endpoint: dict
|
|
|
|
|
|
class UpdateClassificationRequest(BaseModel):
|
|
access_type: Literal["read", "write"]
|
|
needs_interrupt: bool
|
|
agent_group: str
|
|
|
|
@field_validator("agent_group")
|
|
@classmethod
|
|
def agent_group_must_be_safe(cls, value: str) -> str:
|
|
if not value.strip() or not re.fullmatch(r"[a-zA-Z0-9_\-]+", value):
|
|
raise ValueError(
|
|
"agent_group must be non-empty and contain only "
|
|
"alphanumeric characters, underscores, or hyphens"
|
|
)
|
|
return value
|
|
|
|
|
|
# --- Helpers ---
|
|
|
|
|
|
def _job_to_response(job: dict) -> dict:
|
|
return {
|
|
"job_id": job["job_id"],
|
|
"status": job["status"],
|
|
"spec_url": job["spec_url"],
|
|
"total_endpoints": job.get("total_endpoints", 0),
|
|
"classified_count": job.get("classified_count", 0),
|
|
"error_message": job.get("error_message"),
|
|
}
|
|
|
|
|
|
def _classification_to_response(idx: int, clf: ClassificationResult) -> dict:
|
|
ep = clf.endpoint
|
|
return {
|
|
"index": idx,
|
|
"access_type": clf.access_type,
|
|
"needs_interrupt": clf.needs_interrupt,
|
|
"agent_group": clf.agent_group,
|
|
"confidence": clf.confidence,
|
|
"customer_params": list(clf.customer_params),
|
|
"endpoint": {
|
|
"path": ep.path,
|
|
"method": ep.method,
|
|
"operation_id": ep.operation_id,
|
|
"summary": ep.summary,
|
|
"description": ep.description,
|
|
},
|
|
}
|
|
|
|
|
|
async def _run_import(job_id: str, url: str) -> None:
|
|
"""Run the import pipeline as a background task."""
|
|
|
|
def on_progress(stage: str, result_job: ImportJob) -> None:
|
|
if job_id in _job_store:
|
|
_job_store[job_id] = {
|
|
**_job_store[job_id],
|
|
"status": result_job.status,
|
|
"total_endpoints": result_job.total_endpoints,
|
|
"classified_count": result_job.classified_count,
|
|
"error_message": result_job.error_message,
|
|
}
|
|
|
|
try:
|
|
result = await _orchestrator.start_import(
|
|
url=url, job_id=job_id, on_progress=on_progress,
|
|
)
|
|
if job_id in _job_store:
|
|
_job_store[job_id] = {
|
|
**_job_store[job_id],
|
|
"status": result.status,
|
|
"total_endpoints": result.total_endpoints,
|
|
"classified_count": result.classified_count,
|
|
"error_message": result.error_message,
|
|
}
|
|
except Exception:
|
|
logger.exception("Background import failed for job %s", job_id)
|
|
if job_id in _job_store:
|
|
_job_store[job_id] = {
|
|
**_job_store[job_id],
|
|
"status": "failed",
|
|
"error_message": "Import failed. Please check the URL and try again.",
|
|
}
|
|
|
|
|
|
# --- Endpoints ---
|
|
|
|
|
|
@router.post("/import", status_code=202)
|
|
async def start_import(
|
|
request: ImportRequest,
|
|
background_tasks: BackgroundTasks,
|
|
) -> dict:
|
|
"""Start an OpenAPI import job for the given spec URL."""
|
|
job_id = str(uuid.uuid4())
|
|
job: dict = {
|
|
"job_id": job_id,
|
|
"status": "pending",
|
|
"spec_url": request.url,
|
|
"total_endpoints": 0,
|
|
"classified_count": 0,
|
|
"error_message": None,
|
|
"classifications": [],
|
|
}
|
|
_job_store[job_id] = job
|
|
background_tasks.add_task(_run_import, job_id, request.url)
|
|
return _job_to_response(job)
|
|
|
|
|
|
@router.get("/jobs/{job_id}")
|
|
async def get_job(job_id: str) -> dict:
|
|
"""Get the status of an import job."""
|
|
job = _job_store.get(job_id)
|
|
if job is None:
|
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
|
return _job_to_response(job)
|
|
|
|
|
|
@router.get("/jobs/{job_id}/classifications")
|
|
async def get_classifications(job_id: str) -> list:
|
|
"""Get all classifications for an import job."""
|
|
job = _job_store.get(job_id)
|
|
if job is None:
|
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
|
classifications: list[ClassificationResult] = job.get("classifications", [])
|
|
return [
|
|
_classification_to_response(i, clf)
|
|
for i, clf in enumerate(classifications)
|
|
]
|
|
|
|
|
|
@router.put("/jobs/{job_id}/classifications/{idx}")
|
|
async def update_classification(
|
|
job_id: str,
|
|
idx: int,
|
|
request: UpdateClassificationRequest,
|
|
) -> dict:
|
|
"""Update a specific classification by index."""
|
|
job = _job_store.get(job_id)
|
|
if job is None:
|
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
|
|
|
classifications: list[ClassificationResult] = job.get("classifications", [])
|
|
if idx < 0 or idx >= len(classifications):
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Classification index {idx} out of range",
|
|
)
|
|
|
|
original = classifications[idx]
|
|
updated = ClassificationResult(
|
|
endpoint=original.endpoint,
|
|
access_type=request.access_type,
|
|
customer_params=original.customer_params,
|
|
agent_group=request.agent_group,
|
|
confidence=original.confidence,
|
|
needs_interrupt=request.needs_interrupt,
|
|
)
|
|
new_classifications = list(classifications)
|
|
new_classifications[idx] = updated
|
|
_job_store[job_id] = {**job, "classifications": new_classifications}
|
|
|
|
return _classification_to_response(idx, updated)
|
|
|
|
|
|
@router.post("/jobs/{job_id}/approve")
|
|
async def approve_job(job_id: str) -> dict:
|
|
"""Approve a job's classifications and trigger tool generation.
|
|
|
|
Generates Python tool code for each classified endpoint and
|
|
produces an agent YAML configuration snippet.
|
|
"""
|
|
job = _job_store.get(job_id)
|
|
if job is None:
|
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
|
|
|
classifications: list[ClassificationResult] = job.get("classifications", [])
|
|
if not classifications:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="No classifications to approve. Import must complete first.",
|
|
)
|
|
|
|
base_url = job["spec_url"].rsplit("/", 1)[0]
|
|
generated_tools = []
|
|
for clf in classifications:
|
|
tool = generate_tool_code(clf, base_url)
|
|
generated_tools.append({
|
|
"function_name": tool.function_name,
|
|
"agent_group": clf.agent_group,
|
|
"code": tool.code,
|
|
})
|
|
|
|
agent_yaml = generate_agent_yaml(tuple(classifications), base_url)
|
|
|
|
updated_job = {
|
|
**job,
|
|
"status": "approved",
|
|
"generated_tools": generated_tools,
|
|
"agent_yaml": agent_yaml,
|
|
}
|
|
_job_store[job_id] = updated_job
|
|
|
|
response = _job_to_response(updated_job)
|
|
response["generated_tools_count"] = len(generated_tools)
|
|
return response
|