"""Unit tests for app.replay.api.""" from __future__ import annotations from unittest.mock import AsyncMock, MagicMock import pytest from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from fastapi.testclient import TestClient from app.api_utils import envelope pytestmark = pytest.mark.unit def _build_app() -> FastAPI: from app.replay.api import router app = FastAPI() app.include_router(router) @app.exception_handler(HTTPException) async def _http_exc(request, exc): # type: ignore[no-untyped-def] return JSONResponse( status_code=exc.status_code, content=envelope(None, success=False, error=exc.detail), ) return app def _make_mock_pool( fetchall_result: list[dict], *, count: int | None = None, ) -> MagicMock: """Build a mock pool that returns the given rows from fetchall. When *count* is provided, the first execute() call returns a cursor whose fetchone() yields ``(count,)`` (for the COUNT query) and the second call returns the rows via fetchall(). When *count* is None (the default), a single cursor backed by *fetchall_result* is used for all calls. """ if count is not None: count_cursor = AsyncMock() count_cursor.fetchone = AsyncMock(return_value=(count,)) rows_cursor = AsyncMock() rows_cursor.fetchall = AsyncMock(return_value=fetchall_result) mock_conn = AsyncMock() mock_conn.execute = AsyncMock(side_effect=[count_cursor, rows_cursor]) else: mock_cursor = AsyncMock() mock_cursor.fetchall = AsyncMock(return_value=fetchall_result) mock_cursor.fetchone = AsyncMock(return_value=None) mock_conn = AsyncMock() mock_conn.execute = AsyncMock(return_value=mock_cursor) mock_ctx = AsyncMock() mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn) mock_ctx.__aexit__ = AsyncMock(return_value=None) mock_pool = MagicMock() mock_pool.connection.return_value = mock_ctx return mock_pool class TestListConversations: def test_returns_200_with_empty_list(self) -> None: app = _build_app() app.state.pool = _make_mock_pool([], count=0) with TestClient(app) as client: resp = client.get("/api/v1/conversations") assert resp.status_code == 200 body = resp.json() assert body["success"] is True data = body["data"] assert isinstance(data["conversations"], list) assert data["total"] == 0 assert data["page"] == 1 assert body["error"] is None def test_returns_conversations_list(self) -> None: app = _build_app() mock_rows = [ { "thread_id": "t1", "created_at": "2026-01-01T00:00:00", "last_activity": "2026-01-01T00:01:00", "status": "active", "total_tokens": 100, "total_cost_usd": 0.01, } ] app.state.pool = _make_mock_pool(mock_rows, count=1) with TestClient(app) as client: resp = client.get("/api/v1/conversations") body = resp.json() assert resp.status_code == 200 data = body["data"] assert len(data["conversations"]) == 1 assert data["conversations"][0]["thread_id"] == "t1" assert data["total"] == 1 def test_pagination_defaults(self) -> None: app = _build_app() app.state.pool = _make_mock_pool([], count=0) with TestClient(app) as client: resp = client.get("/api/v1/conversations") assert resp.status_code == 200 def test_pagination_custom_params(self) -> None: app = _build_app() app.state.pool = _make_mock_pool([], count=0) with TestClient(app) as client: resp = client.get("/api/v1/conversations?page=2&per_page=10") assert resp.status_code == 200 def test_per_page_max_capped_at_100(self) -> None: app = _build_app() app.state.pool = _make_mock_pool([], count=0) with TestClient(app) as client: resp = client.get("/api/v1/conversations?per_page=200") # FastAPI Query(le=100) rejects values > 100 assert resp.status_code == 422 class TestGetReplay: def test_thread_not_found_returns_404(self) -> None: app = _build_app() app.state.pool = _make_mock_pool([]) with TestClient(app) as client: resp = client.get("/api/v1/replay/nonexistent-thread") assert resp.status_code == 404 def test_returns_replay_page_for_existing_thread(self) -> None: app = _build_app() mock_rows = [ { "thread_id": "thread-123", "checkpoint_id": "cp-001", "checkpoint": { "channel_values": { "messages": [{"type": "human", "content": "Hello"}] } }, "metadata": {}, } ] app.state.pool = _make_mock_pool(mock_rows) with TestClient(app) as client: resp = client.get("/api/v1/replay/thread-123") assert resp.status_code == 200 body = resp.json() assert body["success"] is True assert body["data"]["thread_id"] == "thread-123" assert "steps" in body["data"] assert "total_steps" in body["data"] assert "page" in body["data"] assert "per_page" in body["data"] def test_replay_pagination_params(self) -> None: app = _build_app() mock_rows = [ { "thread_id": "t1", "checkpoint_id": "cp-001", "checkpoint": { "channel_values": {"messages": [{"type": "human", "content": "Hi"}]} }, "metadata": {}, } ] app.state.pool = _make_mock_pool(mock_rows) with TestClient(app) as client: resp = client.get("/api/v1/replay/t1?page=1&per_page=5") assert resp.status_code == 200 def test_error_response_has_envelope(self) -> None: app = _build_app() app.state.pool = _make_mock_pool([]) with TestClient(app) as client: resp = client.get("/api/v1/replay/missing") assert resp.status_code == 404 body = resp.json() assert body["success"] is False assert body["data"] is None assert body["error"] is not None def test_invalid_thread_id_returns_400(self) -> None: app = _build_app() app.state.pool = _make_mock_pool([]) with TestClient(app) as client: resp = client.get("/api/v1/replay/id%20with%20spaces") assert resp.status_code == 400 def test_thread_id_special_chars_returns_400(self) -> None: app = _build_app() app.state.pool = _make_mock_pool([]) with TestClient(app) as client: resp = client.get("/api/v1/replay/id;DROP TABLE") assert resp.status_code == 400