"""Token usage tracking callback handler.""" from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, Any from langchain_core.callbacks import BaseCallbackHandler if TYPE_CHECKING: from langchain_core.outputs import LLMResult COST_PER_1K_TOKENS: dict[str, dict[str, float]] = { "claude-sonnet-4-6": {"prompt": 0.003, "completion": 0.015}, "claude-haiku-4-5-20251001": {"prompt": 0.0008, "completion": 0.004}, "gpt-4o": {"prompt": 0.0025, "completion": 0.01}, "gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006}, } DEFAULT_COST = {"prompt": 0.003, "completion": 0.015} @dataclass(frozen=True) class TokenUsage: prompt_tokens: int completion_tokens: int total_tokens: int total_cost_usd: float class TokenUsageCallbackHandler(BaseCallbackHandler): """Accumulates token usage and cost across LLM invocations.""" def __init__(self, model_name: str = "") -> None: self._model_name = model_name self._prompt_tokens = 0 self._completion_tokens = 0 def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: if response.llm_output and "token_usage" in response.llm_output: usage = response.llm_output["token_usage"] self._prompt_tokens += usage.get("prompt_tokens", 0) self._completion_tokens += usage.get("completion_tokens", 0) def get_usage(self) -> TokenUsage: costs = COST_PER_1K_TOKENS.get(self._model_name, DEFAULT_COST) cost = ( self._prompt_tokens * costs["prompt"] / 1000 + self._completion_tokens * costs["completion"] / 1000 ) return TokenUsage( prompt_tokens=self._prompt_tokens, completion_tokens=self._completion_tokens, total_tokens=self._prompt_tokens + self._completion_tokens, total_cost_usd=round(cost, 6), ) def reset(self) -> None: self._prompt_tokens = 0 self._completion_tokens = 0