"""Tests for app.conversation_tracker module.""" from __future__ import annotations from unittest.mock import AsyncMock, MagicMock import pytest from app.conversation_tracker import ( ConversationTrackerProtocol, NoOpConversationTracker, PostgresConversationTracker, ) pytestmark = pytest.mark.unit def _make_pool() -> AsyncMock: """Create a mock async connection pool.""" pool = AsyncMock() conn = AsyncMock() conn.execute = AsyncMock() pool.connection = MagicMock(return_value=_AsyncContextManager(conn)) return pool, conn class _AsyncContextManager: """Async context manager helper.""" def __init__(self, value: object) -> None: self._value = value async def __aenter__(self) -> object: return self._value async def __aexit__(self, *args: object) -> None: pass class TestConversationTrackerProtocol: def test_noop_satisfies_protocol(self) -> None: tracker = NoOpConversationTracker() assert isinstance(tracker, ConversationTrackerProtocol) def test_postgres_satisfies_protocol(self) -> None: tracker = PostgresConversationTracker() assert isinstance(tracker, ConversationTrackerProtocol) class TestNoOpConversationTracker: @pytest.mark.asyncio async def test_ensure_conversation_does_nothing(self) -> None: tracker = NoOpConversationTracker() pool = AsyncMock() # Should not raise await tracker.ensure_conversation(pool, "thread-1") @pytest.mark.asyncio async def test_record_turn_does_nothing(self) -> None: tracker = NoOpConversationTracker() pool = AsyncMock() await tracker.record_turn(pool, "thread-1", "agent_a", 100, 0.05) @pytest.mark.asyncio async def test_resolve_does_nothing(self) -> None: tracker = NoOpConversationTracker() pool = AsyncMock() await tracker.resolve(pool, "thread-1", "resolved") @pytest.mark.asyncio async def test_accepts_none_agent_name(self) -> None: tracker = NoOpConversationTracker() pool = AsyncMock() await tracker.record_turn(pool, "thread-1", None, 0, 0.0) class TestPostgresConversationTracker: @pytest.mark.asyncio async def test_ensure_conversation_executes_insert(self) -> None: tracker = PostgresConversationTracker() pool, conn = _make_pool() await tracker.ensure_conversation(pool, "thread-abc") conn.execute.assert_awaited_once() sql, params = conn.execute.call_args[0] assert "INSERT" in sql assert "ON CONFLICT" in sql assert params["thread_id"] == "thread-abc" @pytest.mark.asyncio async def test_record_turn_executes_update(self) -> None: tracker = PostgresConversationTracker() pool, conn = _make_pool() await tracker.record_turn(pool, "thread-abc", "order_agent", 250, 0.12) conn.execute.assert_awaited_once() sql, params = conn.execute.call_args[0] assert "UPDATE" in sql assert params["thread_id"] == "thread-abc" assert params["agent_name"] == "order_agent" assert params["tokens"] == 250 assert params["cost"] == 0.12 @pytest.mark.asyncio async def test_record_turn_accepts_none_agent_name(self) -> None: tracker = PostgresConversationTracker() pool, conn = _make_pool() await tracker.record_turn(pool, "thread-abc", None, 0, 0.0) conn.execute.assert_awaited_once() sql, params = conn.execute.call_args[0] assert params["agent_name"] is None @pytest.mark.asyncio async def test_resolve_executes_update(self) -> None: tracker = PostgresConversationTracker() pool, conn = _make_pool() await tracker.resolve(pool, "thread-abc", "resolved") conn.execute.assert_awaited_once() sql, params = conn.execute.call_args[0] assert "UPDATE" in sql assert params["thread_id"] == "thread-abc" assert params["resolution_type"] == "resolved" @pytest.mark.asyncio async def test_resolve_sets_ended_at(self) -> None: tracker = PostgresConversationTracker() pool, conn = _make_pool() await tracker.resolve(pool, "thread-abc", "escalated") sql, params = conn.execute.call_args[0] assert "ended_at" in sql.lower() @pytest.mark.asyncio async def test_ensure_conversation_with_special_thread_id(self) -> None: tracker = PostgresConversationTracker() pool, conn = _make_pool() await tracker.ensure_conversation(pool, "thread-123-abc-XYZ") conn.execute.assert_awaited_once() @pytest.mark.asyncio async def test_record_turn_with_zero_cost(self) -> None: tracker = PostgresConversationTracker() pool, conn = _make_pool() await tracker.record_turn(pool, "t1", "agent", 0, 0.0) conn.execute.assert_awaited_once()