"""Tool code generator for classified OpenAPI endpoints. Generates Python source code for LangChain @tool functions and YAML agent configurations from classification results. """ from __future__ import annotations import re import yaml from app.openapi.models import ClassificationResult, EndpointInfo, GeneratedTool, ParameterInfo _INDENT = " " def generate_tool_code(classification: ClassificationResult, base_url: str) -> GeneratedTool: """Generate a LangChain @tool function for a classified endpoint. Returns a GeneratedTool with the function source code as a string. """ ep = classification.endpoint func_name = _to_snake_case(ep.operation_id) params = _collect_params(ep) sig = _build_signature(params, ep.request_body_schema) docstring = ep.summary or ep.description or ep.operation_id interrupt_comment = _interrupt_comment(classification) http_call = _build_http_call(ep, base_url, params) lines = [ "@tool", f"async def {func_name}({sig}) -> str:", f'{_INDENT}"""{docstring}"""', ] if interrupt_comment: lines.append(f"{_INDENT}{interrupt_comment}") lines += [ f"{_INDENT}async with httpx.AsyncClient() as client:", f"{_INDENT}{_INDENT}{http_call}", f"{_INDENT}{_INDENT}return response.text", ] code = "\n".join(lines) return GeneratedTool( function_name=func_name, endpoint=ep, classification=classification, code=code, ) def generate_agent_yaml( classifications: tuple[ClassificationResult, ...], base_url: str, ) -> str: """Generate an agents.yaml string from a set of classification results. Groups tools by agent_group, creating one agent entry per group. """ if not classifications: return yaml.dump({"agents": []}) groups: dict[str, dict] = {} for clf in classifications: group = clf.agent_group func_name = _to_snake_case(clf.endpoint.operation_id) if group not in groups: permission = "read" if clf.access_type == "read" else "write" groups[group] = { "name": group, "description": f"Agent for {group} operations", "permission": permission, "tools": [], } groups[group]["tools"].append(func_name) return yaml.dump({"agents": list(groups.values())}, sort_keys=False) # --- Private helpers --- def _collect_params(ep: EndpointInfo) -> list[ParameterInfo]: """Return path params first, then query params.""" path_params = [p for p in ep.parameters if p.location == "path"] other_params = [p for p in ep.parameters if p.location != "path"] return path_params + other_params def _build_signature(params: list[ParameterInfo], body_schema: dict | None) -> str: """Build a Python function signature string from parameters.""" parts: list[str] = [] for p in params: py_type = _schema_type_to_python(p.schema_type) if p.required: parts.append(f"{p.name}: {py_type}") else: parts.append(f"{p.name}: {py_type} | None = None") if body_schema: parts.append("body: dict | None = None") return ", ".join(parts) def _build_http_call( ep: EndpointInfo, base_url: str, params: list[ParameterInfo] ) -> str: """Build the httpx client call line.""" method = ep.method.lower() path = ep.path # Replace path parameters with f-string expressions for p in params: if p.location == "path": path = path.replace(f"{{{p.name}}}", f"{{{p.name}}}") url_expr = f'f"{base_url}{path}"' query_params = [p for p in params if p.location == "query"] extra_args = [] if query_params: qp_dict = "{" + ", ".join(f'"{p.name}": {p.name}' for p in query_params) + "}" extra_args.append(f"params={qp_dict}") if ep.request_body_schema and method in ("post", "put", "patch"): extra_args.append("json=body") args_str = ", ".join([url_expr] + extra_args) return f"response = await client.{method}({args_str})" def _interrupt_comment(classification: ClassificationResult) -> str: """Return a comment line if the endpoint requires interrupt/approval.""" if classification.needs_interrupt: return "# INTERRUPT: requires human approval before execution" return "" def _schema_type_to_python(schema_type: str) -> str: """Map OpenAPI schema type to Python type annotation.""" mapping = { "string": "str", "integer": "int", "number": "float", "boolean": "bool", "array": "list", "object": "dict", } return mapping.get(schema_type, "str") def _to_snake_case(name: str) -> str: """Convert operationId to a valid snake_case Python identifier.""" # Replace non-alphanumeric with underscore clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_") return clean.lower()