feat: complete phase 3 -- OpenAPI auto-discovery, SSRF protection, tool generation
- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation - OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit - Structural spec validator (3.0.x/3.1.x) - Endpoint parser with $ref resolution, auto-generated operation IDs - Heuristic + LLM endpoint classifier with Protocol interface - Review API at /api/openapi (import, job status, classification CRUD, approve) - @tool code generator + Agent YAML generator - Import orchestrator (fetch -> validate -> parse -> classify pipeline) - 125 new tests, 322 total passing, 93.23% coverage
This commit is contained in:
156
backend/app/openapi/generator.py
Normal file
156
backend/app/openapi/generator.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tool code generator for classified OpenAPI endpoints.
|
||||
|
||||
Generates Python source code for LangChain @tool functions and
|
||||
YAML agent configurations from classification results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
import yaml
|
||||
|
||||
from app.openapi.models import ClassificationResult, EndpointInfo, GeneratedTool, ParameterInfo
|
||||
|
||||
_INDENT = " "
|
||||
|
||||
|
||||
def generate_tool_code(classification: ClassificationResult, base_url: str) -> GeneratedTool:
|
||||
"""Generate a LangChain @tool function for a classified endpoint.
|
||||
|
||||
Returns a GeneratedTool with the function source code as a string.
|
||||
"""
|
||||
ep = classification.endpoint
|
||||
func_name = _to_snake_case(ep.operation_id)
|
||||
params = _collect_params(ep)
|
||||
sig = _build_signature(params, ep.request_body_schema)
|
||||
docstring = ep.summary or ep.description or ep.operation_id
|
||||
interrupt_comment = _interrupt_comment(classification)
|
||||
http_call = _build_http_call(ep, base_url, params)
|
||||
|
||||
lines = [
|
||||
"@tool",
|
||||
f"async def {func_name}({sig}) -> str:",
|
||||
f'{_INDENT}"""{docstring}"""',
|
||||
]
|
||||
if interrupt_comment:
|
||||
lines.append(f"{_INDENT}{interrupt_comment}")
|
||||
lines += [
|
||||
f"{_INDENT}async with httpx.AsyncClient() as client:",
|
||||
f"{_INDENT}{_INDENT}{http_call}",
|
||||
f"{_INDENT}{_INDENT}return response.text",
|
||||
]
|
||||
|
||||
code = "\n".join(lines)
|
||||
return GeneratedTool(
|
||||
function_name=func_name,
|
||||
endpoint=ep,
|
||||
classification=classification,
|
||||
code=code,
|
||||
)
|
||||
|
||||
|
||||
def generate_agent_yaml(
|
||||
classifications: tuple[ClassificationResult, ...],
|
||||
base_url: str,
|
||||
) -> str:
|
||||
"""Generate an agents.yaml string from a set of classification results.
|
||||
|
||||
Groups tools by agent_group, creating one agent entry per group.
|
||||
"""
|
||||
if not classifications:
|
||||
return yaml.dump({"agents": []})
|
||||
|
||||
groups: dict[str, dict] = {}
|
||||
for clf in classifications:
|
||||
group = clf.agent_group
|
||||
func_name = _to_snake_case(clf.endpoint.operation_id)
|
||||
if group not in groups:
|
||||
permission = "read" if clf.access_type == "read" else "write"
|
||||
groups[group] = {
|
||||
"name": group,
|
||||
"description": f"Agent for {group} operations",
|
||||
"permission": permission,
|
||||
"tools": [],
|
||||
}
|
||||
groups[group]["tools"].append(func_name)
|
||||
|
||||
return yaml.dump({"agents": list(groups.values())}, sort_keys=False)
|
||||
|
||||
|
||||
# --- Private helpers ---
|
||||
|
||||
|
||||
def _collect_params(ep: EndpointInfo) -> list[ParameterInfo]:
|
||||
"""Return path params first, then query params."""
|
||||
path_params = [p for p in ep.parameters if p.location == "path"]
|
||||
other_params = [p for p in ep.parameters if p.location != "path"]
|
||||
return path_params + other_params
|
||||
|
||||
|
||||
def _build_signature(params: list[ParameterInfo], body_schema: dict | None) -> str:
|
||||
"""Build a Python function signature string from parameters."""
|
||||
parts: list[str] = []
|
||||
for p in params:
|
||||
py_type = _schema_type_to_python(p.schema_type)
|
||||
if p.required:
|
||||
parts.append(f"{p.name}: {py_type}")
|
||||
else:
|
||||
parts.append(f"{p.name}: {py_type} | None = None")
|
||||
if body_schema:
|
||||
parts.append("body: dict | None = None")
|
||||
return ", ".join(parts)
|
||||
|
||||
|
||||
def _build_http_call(
|
||||
ep: EndpointInfo, base_url: str, params: list[ParameterInfo]
|
||||
) -> str:
|
||||
"""Build the httpx client call line."""
|
||||
method = ep.method.lower()
|
||||
path = ep.path
|
||||
|
||||
# Replace path parameters with f-string expressions
|
||||
for p in params:
|
||||
if p.location == "path":
|
||||
path = path.replace(f"{{{p.name}}}", f"{{{p.name}}}")
|
||||
|
||||
url_expr = f'f"{base_url}{path}"'
|
||||
|
||||
query_params = [p for p in params if p.location == "query"]
|
||||
extra_args = []
|
||||
if query_params:
|
||||
qp_dict = "{" + ", ".join(f'"{p.name}": {p.name}' for p in query_params) + "}"
|
||||
extra_args.append(f"params={qp_dict}")
|
||||
|
||||
if ep.request_body_schema and method in ("post", "put", "patch"):
|
||||
extra_args.append("json=body")
|
||||
|
||||
args_str = ", ".join([url_expr] + extra_args)
|
||||
return f"response = await client.{method}({args_str})"
|
||||
|
||||
|
||||
def _interrupt_comment(classification: ClassificationResult) -> str:
|
||||
"""Return a comment line if the endpoint requires interrupt/approval."""
|
||||
if classification.needs_interrupt:
|
||||
return "# INTERRUPT: requires human approval before execution"
|
||||
return ""
|
||||
|
||||
|
||||
def _schema_type_to_python(schema_type: str) -> str:
|
||||
"""Map OpenAPI schema type to Python type annotation."""
|
||||
mapping = {
|
||||
"string": "str",
|
||||
"integer": "int",
|
||||
"number": "float",
|
||||
"boolean": "bool",
|
||||
"array": "list",
|
||||
"object": "dict",
|
||||
}
|
||||
return mapping.get(schema_type, "str")
|
||||
|
||||
|
||||
def _to_snake_case(name: str) -> str:
|
||||
"""Convert operationId to a valid snake_case Python identifier."""
|
||||
# Replace non-alphanumeric with underscore
|
||||
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
|
||||
return clean.lower()
|
||||
Reference in New Issue
Block a user