Files
smart-support/backend/tests/unit/test_callbacks.py
Yaojia Wang 33488fd634 feat: complete phase 1 -- core framework with chat loop, agents, and React UI
Backend:
- FastAPI WebSocket /ws endpoint with streaming via LangGraph astream
- LangGraph Supervisor connecting 3 mock agents (order_lookup, order_actions, fallback)
- YAML Agent Registry with Pydantic validation and immutable configs
- PostgresSaver checkpoint persistence via langgraph-checkpoint-postgres
- Session TTL with 30-min sliding window and interrupt extension
- LLM provider abstraction (Anthropic/OpenAI/Google)
- Token usage + cost tracking callback handler
- Input validation: message size cap, thread_id format, content length
- Security: no hardcoded defaults, startup API key validation, no input reflection

Frontend:
- React 19 + TypeScript + Vite chat UI
- WebSocket hook with reconnect + exponential backoff
- Streaming token display with agent attribution
- Interrupt approval/reject UI for write operations
- Collapsible tool call viewer

Testing:
- 87 unit tests, 87% coverage (exceeds 80% requirement)
- Ruff lint + format clean

Infrastructure:
- Docker Compose (PostgreSQL 16 + backend)
- pyproject.toml with full dependency management
2026-03-30 00:54:21 +02:00

103 lines
3.4 KiB
Python

"""Tests for app.callbacks module."""
from __future__ import annotations
import pytest
from app.callbacks import TokenUsageCallbackHandler
@pytest.mark.unit
class TestTokenUsageCallbackHandler:
def test_initial_state(self) -> None:
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
usage = handler.get_usage()
assert usage.prompt_tokens == 0
assert usage.completion_tokens == 0
assert usage.total_tokens == 0
assert usage.total_cost_usd == 0.0
def test_accumulates_tokens(self) -> None:
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
class FakeResult:
llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}}
handler.on_llm_end(FakeResult())
usage = handler.get_usage()
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
def test_accumulates_across_calls(self) -> None:
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
class FakeResult:
llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}}
handler.on_llm_end(FakeResult())
handler.on_llm_end(FakeResult())
usage = handler.get_usage()
assert usage.prompt_tokens == 200
assert usage.completion_tokens == 100
assert usage.total_tokens == 300
def test_cost_calculation(self) -> None:
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
class FakeResult:
llm_output = {"token_usage": {"prompt_tokens": 1000, "completion_tokens": 1000}}
handler.on_llm_end(FakeResult())
usage = handler.get_usage()
# claude-sonnet-4-6: prompt $0.003/1K, completion $0.015/1K
expected_cost = 1000 * 0.003 / 1000 + 1000 * 0.015 / 1000
assert usage.total_cost_usd == pytest.approx(expected_cost)
def test_reset(self) -> None:
handler = TokenUsageCallbackHandler(model_name="claude-sonnet-4-6")
class FakeResult:
llm_output = {"token_usage": {"prompt_tokens": 100, "completion_tokens": 50}}
handler.on_llm_end(FakeResult())
handler.reset()
usage = handler.get_usage()
assert usage.total_tokens == 0
def test_usage_is_immutable(self) -> None:
handler = TokenUsageCallbackHandler()
usage = handler.get_usage()
with pytest.raises(Exception):
usage.prompt_tokens = 999
def test_unknown_model_uses_default_cost(self) -> None:
handler = TokenUsageCallbackHandler(model_name="unknown-model")
class FakeResult:
llm_output = {"token_usage": {"prompt_tokens": 1000, "completion_tokens": 1000}}
handler.on_llm_end(FakeResult())
usage = handler.get_usage()
assert usage.total_cost_usd > 0
def test_handles_missing_token_usage(self) -> None:
handler = TokenUsageCallbackHandler()
class FakeResult:
llm_output = {}
handler.on_llm_end(FakeResult())
usage = handler.get_usage()
assert usage.total_tokens == 0
def test_handles_none_llm_output(self) -> None:
handler = TokenUsageCallbackHandler()
class FakeResult:
llm_output = None
handler.on_llm_end(FakeResult())
usage = handler.get_usage()
assert usage.total_tokens == 0