"""Replay API router -- conversation listing and step-by-step replay.""" from __future__ import annotations import re from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends, HTTPException, Query, Request from app.api_utils import envelope from app.auth import require_admin_api_key _THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool router = APIRouter( prefix="/api/v1", tags=["replay"], dependencies=[Depends(require_admin_api_key)], ) _COUNT_CONVERSATIONS_SQL = """ SELECT COUNT(*) FROM conversations """ _LIST_CONVERSATIONS_SQL = """ SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd FROM conversations ORDER BY last_activity DESC LIMIT %(limit)s OFFSET %(offset)s """ _GET_CHECKPOINTS_SQL = """ SELECT thread_id, checkpoint_id, checkpoint, metadata FROM checkpoints WHERE thread_id = %(thread_id)s ORDER BY checkpoint_id ASC """ async def get_pool(request: Request) -> AsyncConnectionPool: """Dependency: extract the shared pool from app state.""" return request.app.state.pool @router.get("/conversations") async def list_conversations( request: Request, page: Annotated[int, Query(ge=1)] = 1, per_page: Annotated[int, Query(ge=1, le=100)] = 20, ) -> dict: """List conversations with pagination.""" pool = await get_pool(request) offset = (page - 1) * per_page async with pool.connection() as conn: count_cursor = await conn.execute(_COUNT_CONVERSATIONS_SQL) count_row = await count_cursor.fetchone() total = count_row[0] if count_row else 0 cursor = await conn.execute( _LIST_CONVERSATIONS_SQL, {"limit": per_page, "offset": offset}, ) rows = await cursor.fetchall() return envelope({ "conversations": [dict(row) for row in rows], "total": total, "page": page, "per_page": per_page, }) @router.get("/replay/{thread_id}") async def get_replay( thread_id: str, request: Request, page: Annotated[int, Query(ge=1)] = 1, per_page: Annotated[int, Query(ge=1, le=100)] = 20, ) -> dict: """Return paginated replay steps for a conversation thread.""" from app.replay.transformer import transform_checkpoints if not _THREAD_ID_PATTERN.match(thread_id): raise HTTPException(status_code=400, detail="Invalid thread_id format") pool = await get_pool(request) async with pool.connection() as conn: cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id}) rows = await cursor.fetchall() if not rows: raise HTTPException(status_code=404, detail="Thread not found") all_steps = transform_checkpoints([dict(row) for row in rows]) total_steps = len(all_steps) start = (page - 1) * per_page end = start + per_page page_steps = all_steps[start:end] data = { "thread_id": thread_id, "total_steps": total_steps, "page": page, "per_page": per_page, "steps": [ { "step": s.step, "type": s.type.value, "timestamp": s.timestamp, "content": s.content, "agent": s.agent, "tool": s.tool, "params": s.params, "result": s.result, "reasoning": s.reasoning, "tokens": s.tokens, "duration_ms": s.duration_ms, } for s in page_steps ], } return envelope(data)