"""OpenAPI endpoint classifier. Classifies endpoints into read/write access types and identifies customer-identifying parameters. Provides a rule-based heuristic classifier and an LLM-backed classifier with heuristic fallback. """ from __future__ import annotations import json import re from typing import Protocol import structlog from app.openapi.models import ClassificationResult, EndpointInfo logger = structlog.get_logger() _WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) _INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) # Parameter names that identify the customer/order context _CUSTOMER_PARAM_PATTERNS = re.compile( r"(customer_id|order_id|user_id|account_id|client_id|subscriber_id)", re.IGNORECASE, ) class ClassifierProtocol(Protocol): """Protocol for endpoint classifiers.""" async def classify( self, endpoints: tuple[EndpointInfo, ...] ) -> tuple[ClassificationResult, ...]: ... class HeuristicClassifier: """Rule-based endpoint classifier. GET -> read, no interrupt. POST/PUT/PATCH/DELETE -> write, needs_interrupt=True. """ async def classify( self, endpoints: tuple[EndpointInfo, ...] ) -> tuple[ClassificationResult, ...]: """Classify endpoints using HTTP method heuristics.""" if not endpoints: return () return tuple(_classify_one(ep) for ep in endpoints) def _classify_one(ep: EndpointInfo) -> ClassificationResult: """Classify a single endpoint using heuristics.""" access_type = "write" if ep.method in _WRITE_METHODS else "read" needs_interrupt = ep.method in _INTERRUPT_METHODS customer_params = _detect_customer_params(ep) agent_group = "write_agent" if access_type == "write" else "read_agent" return ClassificationResult( endpoint=ep, access_type=access_type, customer_params=customer_params, agent_group=agent_group, confidence=0.7, needs_interrupt=needs_interrupt, ) def _detect_customer_params(ep: EndpointInfo) -> tuple[str, ...]: """Extract parameter names that identify the customer/order context.""" return tuple( p.name for p in ep.parameters if _CUSTOMER_PARAM_PATTERNS.fullmatch(p.name) ) class LLMClassifier: """LLM-backed endpoint classifier with heuristic fallback. Uses an LLM to classify endpoints with higher accuracy. Falls back to HeuristicClassifier on any LLM error. """ def __init__(self, llm: object) -> None: self._llm = llm self._fallback = HeuristicClassifier() async def classify( self, endpoints: tuple[EndpointInfo, ...] ) -> tuple[ClassificationResult, ...]: """Classify endpoints using LLM with heuristic fallback.""" if not endpoints: return () try: return await self._classify_with_llm(endpoints) except Exception: logger.warning( "LLM classification failed, falling back to heuristic", exc_info=True, ) return await self._fallback.classify(endpoints) async def _classify_with_llm( self, endpoints: tuple[EndpointInfo, ...] ) -> tuple[ClassificationResult, ...]: """Attempt LLM-based classification.""" prompt = _build_classification_prompt(endpoints) response = await self._llm.ainvoke(prompt) parsed = _parse_llm_response(response.content, endpoints) return parsed def _build_classification_prompt(endpoints: tuple[EndpointInfo, ...]) -> str: """Build a prompt for classifying endpoints.""" items = [] for i, ep in enumerate(endpoints): items.append( f"{i}: {ep.method} {ep.path} - {ep.summary or ep.operation_id}" ) endpoint_list = "\n".join(items) return ( "Classify each API endpoint as 'read' or 'write'. " "For each, determine if it needs human interrupt approval, " "identify customer-identifying parameters, and assign an agent_group.\n\n" f"Endpoints:\n{endpoint_list}\n\n" "Respond with a JSON array with one object per endpoint:\n" '[{"access_type": "read|write", "agent_group": "...", ' '"needs_interrupt": true|false, "customer_params": [...], "confidence": 0.0-1.0}]' ) def _parse_llm_response( content: str, endpoints: tuple[EndpointInfo, ...] ) -> tuple[ClassificationResult, ...]: """Parse LLM JSON response into ClassificationResult instances. Raises ValueError if the response cannot be parsed or is mismatched. """ # Extract JSON array from response match = re.search(r"\[.*\]", content, re.DOTALL) if not match: raise ValueError(f"No JSON array found in LLM response: {content!r}") items = json.loads(match.group()) if not isinstance(items, list) or len(items) != len(endpoints): raise ValueError( f"LLM returned {len(items)} items for {len(endpoints)} endpoints" ) results: list[ClassificationResult] = [] for ep, item in zip(endpoints, items, strict=True): raw_access = item.get("access_type", "read") access_type = raw_access if raw_access in {"read", "write"} else "read" confidence = max(0.0, min(1.0, float(item.get("confidence", 0.8)))) raw_group = str(item.get("agent_group", "support")) agent_group = raw_group if raw_group.strip() else "support" results.append( ClassificationResult( endpoint=ep, access_type=access_type, customer_params=tuple(item.get("customer_params", [])), agent_group=agent_group, confidence=confidence, needs_interrupt=bool(item.get("needs_interrupt", False)), ) ) return tuple(results)