Files
smart-support/backend/app/openapi/generator.py
Yaojia Wang a2f750269d fix: address critical security and code review findings in Phase 3
- Wire ImportOrchestrator into review_api start_import via BackgroundTasks
- Sanitize docstrings in generated tool code to prevent code injection
- Add Literal["read", "write"] validation for access_type
- Add regex validation for agent_group
- Validate URL scheme (http/https only) in ImportRequest
- Validate LLM output fields (clamp confidence, validate access_type)
- Use dataclasses.replace instead of manual reconstruction in importer
- Expand SSRF blocked networks (Carrier-Grade NAT, IPv4-mapped IPv6, etc.)
- Make _BLOCKED_NETWORKS immutable tuple
- Use yaml.safe_dump instead of yaml.dump
- Fix _to_snake_case for empty strings and Python keywords
2026-03-31 00:28:28 +02:00

165 lines
5.2 KiB
Python

"""Tool code generator for classified OpenAPI endpoints.
Generates Python source code for LangChain @tool functions and
YAML agent configurations from classification results.
"""
from __future__ import annotations
import keyword
import re
import yaml
from app.openapi.models import ClassificationResult, EndpointInfo, GeneratedTool, ParameterInfo
_INDENT = " "
def generate_tool_code(classification: ClassificationResult, base_url: str) -> GeneratedTool:
"""Generate a LangChain @tool function for a classified endpoint.
Returns a GeneratedTool with the function source code as a string.
"""
ep = classification.endpoint
func_name = _to_snake_case(ep.operation_id)
params = _collect_params(ep)
sig = _build_signature(params, ep.request_body_schema)
docstring = _sanitize_docstring(ep.summary or ep.description or ep.operation_id)
interrupt_comment = _interrupt_comment(classification)
http_call = _build_http_call(ep, base_url, params)
lines = [
"@tool",
f"async def {func_name}({sig}) -> str:",
f'{_INDENT}"""{docstring}"""',
]
if interrupt_comment:
lines.append(f"{_INDENT}{interrupt_comment}")
lines += [
f"{_INDENT}async with httpx.AsyncClient() as client:",
f"{_INDENT}{_INDENT}{http_call}",
f"{_INDENT}{_INDENT}return response.text",
]
code = "\n".join(lines)
return GeneratedTool(
function_name=func_name,
endpoint=ep,
classification=classification,
code=code,
)
def generate_agent_yaml(
classifications: tuple[ClassificationResult, ...],
base_url: str,
) -> str:
"""Generate an agents.yaml string from a set of classification results.
Groups tools by agent_group, creating one agent entry per group.
"""
if not classifications:
return yaml.safe_dump({"agents": []})
groups: dict[str, dict] = {}
for clf in classifications:
group = clf.agent_group
func_name = _to_snake_case(clf.endpoint.operation_id)
if group not in groups:
permission = "read" if clf.access_type == "read" else "write"
groups[group] = {
"name": group,
"description": f"Agent for {group} operations",
"permission": permission,
"tools": [],
}
groups[group]["tools"].append(func_name)
return yaml.safe_dump({"agents": list(groups.values())}, sort_keys=False)
# --- Private helpers ---
def _collect_params(ep: EndpointInfo) -> list[ParameterInfo]:
"""Return path params first, then query params."""
path_params = [p for p in ep.parameters if p.location == "path"]
other_params = [p for p in ep.parameters if p.location != "path"]
return path_params + other_params
def _build_signature(params: list[ParameterInfo], body_schema: dict | None) -> str:
"""Build a Python function signature string from parameters."""
parts: list[str] = []
for p in params:
py_type = _schema_type_to_python(p.schema_type)
if p.required:
parts.append(f"{p.name}: {py_type}")
else:
parts.append(f"{p.name}: {py_type} | None = None")
if body_schema:
parts.append("body: dict | None = None")
return ", ".join(parts)
def _build_http_call(
ep: EndpointInfo, base_url: str, params: list[ParameterInfo]
) -> str:
"""Build the httpx client call line."""
method = ep.method.lower()
path = ep.path
# Replace path parameters with f-string expressions
for p in params:
if p.location == "path":
path = path.replace(f"{{{p.name}}}", f"{{{p.name}}}")
url_expr = f'f"{base_url}{path}"'
query_params = [p for p in params if p.location == "query"]
extra_args = []
if query_params:
qp_dict = "{" + ", ".join(f'"{p.name}": {p.name}' for p in query_params) + "}"
extra_args.append(f"params={qp_dict}")
if ep.request_body_schema and method in ("post", "put", "patch"):
extra_args.append("json=body")
args_str = ", ".join([url_expr] + extra_args)
return f"response = await client.{method}({args_str})"
def _interrupt_comment(classification: ClassificationResult) -> str:
"""Return a comment line if the endpoint requires interrupt/approval."""
if classification.needs_interrupt:
return "# INTERRUPT: requires human approval before execution"
return ""
def _schema_type_to_python(schema_type: str) -> str:
"""Map OpenAPI schema type to Python type annotation."""
mapping = {
"string": "str",
"integer": "int",
"number": "float",
"boolean": "bool",
"array": "list",
"object": "dict",
}
return mapping.get(schema_type, "str")
def _sanitize_docstring(text: str) -> str:
"""Escape triple-quotes and newlines to prevent docstring injection."""
return text.replace("\\", "\\\\").replace('"""', r"\"\"\"").replace("\n", " ")
def _to_snake_case(name: str) -> str:
"""Convert operationId to a valid snake_case Python identifier."""
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
result = clean.lower() or "unnamed_tool"
if keyword.iskeyword(result):
result = f"{result}_tool"
return result