- API versioning: all REST endpoints prefixed with /api/v1/ - Structured logging: replaced stdlib logging with structlog (console/JSON modes) - Alembic migrations: versioned DB schema with initial migration - Error standardization: global exception handlers for consistent envelope format - Interrupt cleanup: asyncio background task for expired interrupt removal - Integration tests: +30 tests (analytics, replay, openapi, error, session APIs) - Frontend tests: +57 tests (all components, pages, useWebSocket hook) - Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
170 lines
5.7 KiB
Python
170 lines
5.7 KiB
Python
"""OpenAPI endpoint classifier.
|
|
|
|
Classifies endpoints into read/write access types and identifies
|
|
customer-identifying parameters. Provides a rule-based heuristic
|
|
classifier and an LLM-backed classifier with heuristic fallback.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from typing import Protocol
|
|
|
|
import structlog
|
|
|
|
from app.openapi.models import ClassificationResult, EndpointInfo
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
|
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
|
|
|
# Parameter names that identify the customer/order context
|
|
_CUSTOMER_PARAM_PATTERNS = re.compile(
|
|
r"(customer_id|order_id|user_id|account_id|client_id|subscriber_id)",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
class ClassifierProtocol(Protocol):
|
|
"""Protocol for endpoint classifiers."""
|
|
|
|
async def classify(
|
|
self, endpoints: tuple[EndpointInfo, ...]
|
|
) -> tuple[ClassificationResult, ...]: ...
|
|
|
|
|
|
class HeuristicClassifier:
|
|
"""Rule-based endpoint classifier.
|
|
|
|
GET -> read, no interrupt.
|
|
POST/PUT/PATCH/DELETE -> write, needs_interrupt=True.
|
|
"""
|
|
|
|
async def classify(
|
|
self, endpoints: tuple[EndpointInfo, ...]
|
|
) -> tuple[ClassificationResult, ...]:
|
|
"""Classify endpoints using HTTP method heuristics."""
|
|
if not endpoints:
|
|
return ()
|
|
return tuple(_classify_one(ep) for ep in endpoints)
|
|
|
|
|
|
def _classify_one(ep: EndpointInfo) -> ClassificationResult:
|
|
"""Classify a single endpoint using heuristics."""
|
|
access_type = "write" if ep.method in _WRITE_METHODS else "read"
|
|
needs_interrupt = ep.method in _INTERRUPT_METHODS
|
|
customer_params = _detect_customer_params(ep)
|
|
agent_group = "write_agent" if access_type == "write" else "read_agent"
|
|
return ClassificationResult(
|
|
endpoint=ep,
|
|
access_type=access_type,
|
|
customer_params=customer_params,
|
|
agent_group=agent_group,
|
|
confidence=0.7,
|
|
needs_interrupt=needs_interrupt,
|
|
)
|
|
|
|
|
|
def _detect_customer_params(ep: EndpointInfo) -> tuple[str, ...]:
|
|
"""Extract parameter names that identify the customer/order context."""
|
|
return tuple(
|
|
p.name
|
|
for p in ep.parameters
|
|
if _CUSTOMER_PARAM_PATTERNS.fullmatch(p.name)
|
|
)
|
|
|
|
|
|
class LLMClassifier:
|
|
"""LLM-backed endpoint classifier with heuristic fallback.
|
|
|
|
Uses an LLM to classify endpoints with higher accuracy.
|
|
Falls back to HeuristicClassifier on any LLM error.
|
|
"""
|
|
|
|
def __init__(self, llm: object) -> None:
|
|
self._llm = llm
|
|
self._fallback = HeuristicClassifier()
|
|
|
|
async def classify(
|
|
self, endpoints: tuple[EndpointInfo, ...]
|
|
) -> tuple[ClassificationResult, ...]:
|
|
"""Classify endpoints using LLM with heuristic fallback."""
|
|
if not endpoints:
|
|
return ()
|
|
try:
|
|
return await self._classify_with_llm(endpoints)
|
|
except Exception:
|
|
logger.warning(
|
|
"LLM classification failed, falling back to heuristic",
|
|
exc_info=True,
|
|
)
|
|
return await self._fallback.classify(endpoints)
|
|
|
|
async def _classify_with_llm(
|
|
self, endpoints: tuple[EndpointInfo, ...]
|
|
) -> tuple[ClassificationResult, ...]:
|
|
"""Attempt LLM-based classification."""
|
|
prompt = _build_classification_prompt(endpoints)
|
|
response = await self._llm.ainvoke(prompt)
|
|
parsed = _parse_llm_response(response.content, endpoints)
|
|
return parsed
|
|
|
|
|
|
def _build_classification_prompt(endpoints: tuple[EndpointInfo, ...]) -> str:
|
|
"""Build a prompt for classifying endpoints."""
|
|
items = []
|
|
for i, ep in enumerate(endpoints):
|
|
items.append(
|
|
f"{i}: {ep.method} {ep.path} - {ep.summary or ep.operation_id}"
|
|
)
|
|
endpoint_list = "\n".join(items)
|
|
return (
|
|
"Classify each API endpoint as 'read' or 'write'. "
|
|
"For each, determine if it needs human interrupt approval, "
|
|
"identify customer-identifying parameters, and assign an agent_group.\n\n"
|
|
f"Endpoints:\n{endpoint_list}\n\n"
|
|
"Respond with a JSON array with one object per endpoint:\n"
|
|
'[{"access_type": "read|write", "agent_group": "...", '
|
|
'"needs_interrupt": true|false, "customer_params": [...], "confidence": 0.0-1.0}]'
|
|
)
|
|
|
|
|
|
def _parse_llm_response(
|
|
content: str, endpoints: tuple[EndpointInfo, ...]
|
|
) -> tuple[ClassificationResult, ...]:
|
|
"""Parse LLM JSON response into ClassificationResult instances.
|
|
|
|
Raises ValueError if the response cannot be parsed or is mismatched.
|
|
"""
|
|
# Extract JSON array from response
|
|
match = re.search(r"\[.*\]", content, re.DOTALL)
|
|
if not match:
|
|
raise ValueError(f"No JSON array found in LLM response: {content!r}")
|
|
|
|
items = json.loads(match.group())
|
|
if not isinstance(items, list) or len(items) != len(endpoints):
|
|
raise ValueError(
|
|
f"LLM returned {len(items)} items for {len(endpoints)} endpoints"
|
|
)
|
|
|
|
results: list[ClassificationResult] = []
|
|
for ep, item in zip(endpoints, items, strict=True):
|
|
raw_access = item.get("access_type", "read")
|
|
access_type = raw_access if raw_access in {"read", "write"} else "read"
|
|
confidence = max(0.0, min(1.0, float(item.get("confidence", 0.8))))
|
|
raw_group = str(item.get("agent_group", "support"))
|
|
agent_group = raw_group if raw_group.strip() else "support"
|
|
results.append(
|
|
ClassificationResult(
|
|
endpoint=ep,
|
|
access_type=access_type,
|
|
customer_params=tuple(item.get("customer_params", [])),
|
|
agent_group=agent_group,
|
|
confidence=confidence,
|
|
needs_interrupt=bool(item.get("needs_interrupt", False)),
|
|
)
|
|
)
|
|
return tuple(results)
|