Files
smart-support/backend/tests/unit/replay/test_api.py
Yaojia Wang f0699436c5 refactor: engineering improvements -- API versioning, structured logging, Alembic, error standardization, test coverage
- API versioning: all REST endpoints prefixed with /api/v1/
- Structured logging: replaced stdlib logging with structlog (console/JSON modes)
- Alembic migrations: versioned DB schema with initial migration
- Error standardization: global exception handlers for consistent envelope format
- Interrupt cleanup: asyncio background task for expired interrupt removal
- Integration tests: +30 tests (analytics, replay, openapi, error, session APIs)
- Frontend tests: +57 tests (all components, pages, useWebSocket hook)
- Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
2026-04-06 23:19:29 +02:00

218 lines
7.0 KiB
Python

"""Unit tests for app.replay.api."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from app.api_utils import envelope
pytestmark = pytest.mark.unit
def _build_app() -> FastAPI:
from app.replay.api import router
app = FastAPI()
app.include_router(router)
@app.exception_handler(HTTPException)
async def _http_exc(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
return app
def _make_mock_pool(
fetchall_result: list[dict],
*,
count: int | None = None,
) -> MagicMock:
"""Build a mock pool that returns the given rows from fetchall.
When *count* is provided, the first execute() call returns a cursor
whose fetchone() yields ``(count,)`` (for the COUNT query) and the
second call returns the rows via fetchall(). When *count* is None
(the default), a single cursor backed by *fetchall_result* is used
for all calls.
"""
if count is not None:
count_cursor = AsyncMock()
count_cursor.fetchone = AsyncMock(return_value=(count,))
rows_cursor = AsyncMock()
rows_cursor.fetchall = AsyncMock(return_value=fetchall_result)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(side_effect=[count_cursor, rows_cursor])
else:
mock_cursor = AsyncMock()
mock_cursor.fetchall = AsyncMock(return_value=fetchall_result)
mock_cursor.fetchone = AsyncMock(return_value=None)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.connection.return_value = mock_ctx
return mock_pool
class TestListConversations:
def test_returns_200_with_empty_list(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
data = body["data"]
assert isinstance(data["conversations"], list)
assert data["total"] == 0
assert data["page"] == 1
assert body["error"] is None
def test_returns_conversations_list(self) -> None:
app = _build_app()
mock_rows = [
{
"thread_id": "t1",
"created_at": "2026-01-01T00:00:00",
"last_activity": "2026-01-01T00:01:00",
"status": "active",
"total_tokens": 100,
"total_cost_usd": 0.01,
}
]
app.state.pool = _make_mock_pool(mock_rows, count=1)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations")
body = resp.json()
assert resp.status_code == 200
data = body["data"]
assert len(data["conversations"]) == 1
assert data["conversations"][0]["thread_id"] == "t1"
assert data["total"] == 1
def test_pagination_defaults(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
def test_pagination_custom_params(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations?page=2&per_page=10")
assert resp.status_code == 200
def test_per_page_max_capped_at_100(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations?per_page=200")
# FastAPI Query(le=100) rejects values > 100
assert resp.status_code == 422
class TestGetReplay:
def test_thread_not_found_returns_404(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404
def test_returns_replay_page_for_existing_thread(self) -> None:
app = _build_app()
mock_rows = [
{
"thread_id": "thread-123",
"checkpoint_id": "cp-001",
"checkpoint": {
"channel_values": {
"messages": [{"type": "human", "content": "Hello"}]
}
},
"metadata": {},
}
]
app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client:
resp = client.get("/api/v1/replay/thread-123")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["thread_id"] == "thread-123"
assert "steps" in body["data"]
assert "total_steps" in body["data"]
assert "page" in body["data"]
assert "per_page" in body["data"]
def test_replay_pagination_params(self) -> None:
app = _build_app()
mock_rows = [
{
"thread_id": "t1",
"checkpoint_id": "cp-001",
"checkpoint": {
"channel_values": {"messages": [{"type": "human", "content": "Hi"}]}
},
"metadata": {},
}
]
app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client:
resp = client.get("/api/v1/replay/t1?page=1&per_page=5")
assert resp.status_code == 200
def test_error_response_has_envelope(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/v1/replay/missing")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] is not None
def test_invalid_thread_id_returns_400(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/v1/replay/id%20with%20spaces")
assert resp.status_code == 400
def test_thread_id_special_chars_returns_400(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/v1/replay/id;DROP TABLE")
assert resp.status_code == 400