"""Tests for app.db module.""" from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.config import Settings from app.db import _CONVERSATIONS_DDL, _INTERRUPTS_DDL @pytest.mark.unit class TestDbModule: @pytest.mark.asyncio async def test_create_pool_sets_correct_params(self) -> None: settings = Settings( database_url="postgresql://user:pass@localhost:5432/testdb", anthropic_api_key="key", ) with patch("app.db.AsyncConnectionPool") as MockPool: mock_pool = AsyncMock() MockPool.return_value = mock_pool from app.db import create_pool await create_pool(settings) MockPool.assert_called_once() call_kwargs = MockPool.call_args assert "postgresql://user:pass@localhost:5432/testdb" in str(call_kwargs) mock_pool.open.assert_awaited_once() @pytest.mark.asyncio async def test_create_checkpointer_calls_setup(self) -> None: mock_pool = AsyncMock() with patch("app.db.AsyncPostgresSaver") as MockSaver: mock_saver = AsyncMock() MockSaver.return_value = mock_saver from app.db import create_checkpointer await create_checkpointer(mock_pool) MockSaver.assert_called_once_with(conn=mock_pool) mock_saver.setup.assert_awaited_once() @pytest.mark.asyncio async def test_setup_app_tables_executes_ddl(self) -> None: mock_conn = AsyncMock() mock_ctx = AsyncMock() mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn) mock_ctx.__aexit__ = AsyncMock(return_value=None) mock_pool = MagicMock() mock_pool.connection.return_value = mock_ctx from app.db import setup_app_tables await setup_app_tables(mock_pool) assert mock_conn.execute.await_count == 4 def test_ddl_statements_valid(self) -> None: assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL assert "CREATE TABLE IF NOT EXISTS active_interrupts" in _INTERRUPTS_DDL assert "thread_id" in _CONVERSATIONS_DDL assert "interrupt_id" in _INTERRUPTS_DDL