Files
Yaojia Wang f0699436c5 refactor: engineering improvements -- API versioning, structured logging, Alembic, error standardization, test coverage
- API versioning: all REST endpoints prefixed with /api/v1/
- Structured logging: replaced stdlib logging with structlog (console/JSON modes)
- Alembic migrations: versioned DB schema with initial migration
- Error standardization: global exception handlers for consistent envelope format
- Interrupt cleanup: asyncio background task for expired interrupt removal
- Integration tests: +30 tests (analytics, replay, openapi, error, session APIs)
- Frontend tests: +57 tests (all components, pages, useWebSocket hook)
- Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
2026-04-06 23:19:29 +02:00

126 lines
3.5 KiB
Python

"""Replay API router -- conversation listing and step-by-step replay."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from app.api_utils import envelope
from app.auth import require_admin_api_key
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
router = APIRouter(
prefix="/api/v1",
tags=["replay"],
dependencies=[Depends(require_admin_api_key)],
)
_COUNT_CONVERSATIONS_SQL = """
SELECT COUNT(*) FROM conversations
"""
_LIST_CONVERSATIONS_SQL = """
SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd
FROM conversations
ORDER BY last_activity DESC
LIMIT %(limit)s OFFSET %(offset)s
"""
_GET_CHECKPOINTS_SQL = """
SELECT thread_id, checkpoint_id, checkpoint, metadata
FROM checkpoints
WHERE thread_id = %(thread_id)s
ORDER BY checkpoint_id ASC
"""
async def get_pool(request: Request) -> AsyncConnectionPool:
"""Dependency: extract the shared pool from app state."""
return request.app.state.pool
@router.get("/conversations")
async def list_conversations(
request: Request,
page: Annotated[int, Query(ge=1)] = 1,
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
) -> dict:
"""List conversations with pagination."""
pool = await get_pool(request)
offset = (page - 1) * per_page
async with pool.connection() as conn:
count_cursor = await conn.execute(_COUNT_CONVERSATIONS_SQL)
count_row = await count_cursor.fetchone()
total = count_row[0] if count_row else 0
cursor = await conn.execute(
_LIST_CONVERSATIONS_SQL,
{"limit": per_page, "offset": offset},
)
rows = await cursor.fetchall()
return envelope({
"conversations": [dict(row) for row in rows],
"total": total,
"page": page,
"per_page": per_page,
})
@router.get("/replay/{thread_id}")
async def get_replay(
thread_id: str,
request: Request,
page: Annotated[int, Query(ge=1)] = 1,
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
) -> dict:
"""Return paginated replay steps for a conversation thread."""
from app.replay.transformer import transform_checkpoints
if not _THREAD_ID_PATTERN.match(thread_id):
raise HTTPException(status_code=400, detail="Invalid thread_id format")
pool = await get_pool(request)
async with pool.connection() as conn:
cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id})
rows = await cursor.fetchall()
if not rows:
raise HTTPException(status_code=404, detail="Thread not found")
all_steps = transform_checkpoints([dict(row) for row in rows])
total_steps = len(all_steps)
start = (page - 1) * per_page
end = start + per_page
page_steps = all_steps[start:end]
data = {
"thread_id": thread_id,
"total_steps": total_steps,
"page": page,
"per_page": per_page,
"steps": [
{
"step": s.step,
"type": s.type.value,
"timestamp": s.timestamp,
"content": s.content,
"agent": s.agent,
"tool": s.tool,
"params": s.params,
"result": s.result,
"reasoning": s.reasoning,
"tokens": s.tokens,
"duration_ms": s.duration_ms,
}
for s in page_steps
],
}
return envelope(data)