"""Error classification and retry logic for tool calls.""" from __future__ import annotations import asyncio from enum import Enum from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from collections.abc import Callable import httpx class ErrorCategory(Enum): """Categories for error classification to guide retry decisions.""" RETRYABLE = "retryable" NON_RETRYABLE = "non_retryable" AUTH_FAILURE = "auth_failure" TIMEOUT = "timeout" NETWORK = "network" def classify_error(exc: Exception) -> ErrorCategory: """Classify an exception into an ErrorCategory. Rules: - httpx.TimeoutException -> TIMEOUT - httpx.ConnectError -> NETWORK - httpx.HTTPStatusError 401/403 -> AUTH_FAILURE - httpx.HTTPStatusError 429/500/502/503 -> RETRYABLE - anything else -> NON_RETRYABLE """ if isinstance(exc, httpx.TimeoutException): return ErrorCategory.TIMEOUT if isinstance(exc, httpx.ConnectError): return ErrorCategory.NETWORK if isinstance(exc, httpx.HTTPStatusError): code = exc.response.status_code if code in (401, 403): return ErrorCategory.AUTH_FAILURE if code in (429, 500, 502, 503): return ErrorCategory.RETRYABLE return ErrorCategory.NON_RETRYABLE return ErrorCategory.NON_RETRYABLE async def with_retry( fn: Callable[..., Any], max_retries: int = 3, base_delay: float = 1.0, ) -> Any: """Execute an async callable with exponential backoff for RETRYABLE errors. Only ErrorCategory.RETRYABLE errors trigger retries. All other error categories raise immediately after the first attempt. """ last_exc: Exception | None = None for attempt in range(1, max_retries + 1): try: return await fn() except Exception as exc: category = classify_error(exc) if category != ErrorCategory.RETRYABLE: raise last_exc = exc if attempt < max_retries: delay = base_delay * (2 ** (attempt - 1)) await asyncio.sleep(delay) raise last_exc # type: ignore[misc]