- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation - OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit - Structural spec validator (3.0.x/3.1.x) - Endpoint parser with $ref resolution, auto-generated operation IDs - Heuristic + LLM endpoint classifier with Protocol interface - Review API at /api/openapi (import, job status, classification CRUD, approve) - @tool code generator + Agent YAML generator - Import orchestrator (fetch -> validate -> parse -> classify pipeline) - 125 new tests, 322 total passing, 93.23% coverage
157 lines
4.9 KiB
Python
157 lines
4.9 KiB
Python
"""Tool code generator for classified OpenAPI endpoints.
|
|
|
|
Generates Python source code for LangChain @tool functions and
|
|
YAML agent configurations from classification results.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
|
|
import yaml
|
|
|
|
from app.openapi.models import ClassificationResult, EndpointInfo, GeneratedTool, ParameterInfo
|
|
|
|
_INDENT = " "
|
|
|
|
|
|
def generate_tool_code(classification: ClassificationResult, base_url: str) -> GeneratedTool:
|
|
"""Generate a LangChain @tool function for a classified endpoint.
|
|
|
|
Returns a GeneratedTool with the function source code as a string.
|
|
"""
|
|
ep = classification.endpoint
|
|
func_name = _to_snake_case(ep.operation_id)
|
|
params = _collect_params(ep)
|
|
sig = _build_signature(params, ep.request_body_schema)
|
|
docstring = ep.summary or ep.description or ep.operation_id
|
|
interrupt_comment = _interrupt_comment(classification)
|
|
http_call = _build_http_call(ep, base_url, params)
|
|
|
|
lines = [
|
|
"@tool",
|
|
f"async def {func_name}({sig}) -> str:",
|
|
f'{_INDENT}"""{docstring}"""',
|
|
]
|
|
if interrupt_comment:
|
|
lines.append(f"{_INDENT}{interrupt_comment}")
|
|
lines += [
|
|
f"{_INDENT}async with httpx.AsyncClient() as client:",
|
|
f"{_INDENT}{_INDENT}{http_call}",
|
|
f"{_INDENT}{_INDENT}return response.text",
|
|
]
|
|
|
|
code = "\n".join(lines)
|
|
return GeneratedTool(
|
|
function_name=func_name,
|
|
endpoint=ep,
|
|
classification=classification,
|
|
code=code,
|
|
)
|
|
|
|
|
|
def generate_agent_yaml(
|
|
classifications: tuple[ClassificationResult, ...],
|
|
base_url: str,
|
|
) -> str:
|
|
"""Generate an agents.yaml string from a set of classification results.
|
|
|
|
Groups tools by agent_group, creating one agent entry per group.
|
|
"""
|
|
if not classifications:
|
|
return yaml.dump({"agents": []})
|
|
|
|
groups: dict[str, dict] = {}
|
|
for clf in classifications:
|
|
group = clf.agent_group
|
|
func_name = _to_snake_case(clf.endpoint.operation_id)
|
|
if group not in groups:
|
|
permission = "read" if clf.access_type == "read" else "write"
|
|
groups[group] = {
|
|
"name": group,
|
|
"description": f"Agent for {group} operations",
|
|
"permission": permission,
|
|
"tools": [],
|
|
}
|
|
groups[group]["tools"].append(func_name)
|
|
|
|
return yaml.dump({"agents": list(groups.values())}, sort_keys=False)
|
|
|
|
|
|
# --- Private helpers ---
|
|
|
|
|
|
def _collect_params(ep: EndpointInfo) -> list[ParameterInfo]:
|
|
"""Return path params first, then query params."""
|
|
path_params = [p for p in ep.parameters if p.location == "path"]
|
|
other_params = [p for p in ep.parameters if p.location != "path"]
|
|
return path_params + other_params
|
|
|
|
|
|
def _build_signature(params: list[ParameterInfo], body_schema: dict | None) -> str:
|
|
"""Build a Python function signature string from parameters."""
|
|
parts: list[str] = []
|
|
for p in params:
|
|
py_type = _schema_type_to_python(p.schema_type)
|
|
if p.required:
|
|
parts.append(f"{p.name}: {py_type}")
|
|
else:
|
|
parts.append(f"{p.name}: {py_type} | None = None")
|
|
if body_schema:
|
|
parts.append("body: dict | None = None")
|
|
return ", ".join(parts)
|
|
|
|
|
|
def _build_http_call(
|
|
ep: EndpointInfo, base_url: str, params: list[ParameterInfo]
|
|
) -> str:
|
|
"""Build the httpx client call line."""
|
|
method = ep.method.lower()
|
|
path = ep.path
|
|
|
|
# Replace path parameters with f-string expressions
|
|
for p in params:
|
|
if p.location == "path":
|
|
path = path.replace(f"{{{p.name}}}", f"{{{p.name}}}")
|
|
|
|
url_expr = f'f"{base_url}{path}"'
|
|
|
|
query_params = [p for p in params if p.location == "query"]
|
|
extra_args = []
|
|
if query_params:
|
|
qp_dict = "{" + ", ".join(f'"{p.name}": {p.name}' for p in query_params) + "}"
|
|
extra_args.append(f"params={qp_dict}")
|
|
|
|
if ep.request_body_schema and method in ("post", "put", "patch"):
|
|
extra_args.append("json=body")
|
|
|
|
args_str = ", ".join([url_expr] + extra_args)
|
|
return f"response = await client.{method}({args_str})"
|
|
|
|
|
|
def _interrupt_comment(classification: ClassificationResult) -> str:
|
|
"""Return a comment line if the endpoint requires interrupt/approval."""
|
|
if classification.needs_interrupt:
|
|
return "# INTERRUPT: requires human approval before execution"
|
|
return ""
|
|
|
|
|
|
def _schema_type_to_python(schema_type: str) -> str:
|
|
"""Map OpenAPI schema type to Python type annotation."""
|
|
mapping = {
|
|
"string": "str",
|
|
"integer": "int",
|
|
"number": "float",
|
|
"boolean": "bool",
|
|
"array": "list",
|
|
"object": "dict",
|
|
}
|
|
return mapping.get(schema_type, "str")
|
|
|
|
|
|
def _to_snake_case(name: str) -> str:
|
|
"""Convert operationId to a valid snake_case Python identifier."""
|
|
# Replace non-alphanumeric with underscore
|
|
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
|
|
return clean.lower()
|