"""Replay API router -- conversation listing and step-by-step replay.""" from __future__ import annotations import re from typing import TYPE_CHECKING, Annotated, Any from fastapi import APIRouter, HTTPException, Query, Request _THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool router = APIRouter(prefix="/api", tags=["replay"]) _LIST_CONVERSATIONS_SQL = """ SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd FROM conversations ORDER BY last_activity DESC LIMIT %(limit)s OFFSET %(offset)s """ _GET_CHECKPOINTS_SQL = """ SELECT thread_id, checkpoint_id, checkpoint, metadata FROM checkpoints WHERE thread_id = %(thread_id)s ORDER BY checkpoint_id ASC """ async def get_pool(request: Request) -> AsyncConnectionPool: """Dependency: extract the shared pool from app state.""" return request.app.state.pool def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict: return {"success": success, "data": data, "error": error} @router.get("/conversations") async def list_conversations( request: Request, page: Annotated[int, Query(ge=1)] = 1, per_page: Annotated[int, Query(ge=1, le=100)] = 20, ) -> dict: """List conversations with pagination.""" pool = await get_pool(request) offset = (page - 1) * per_page async with pool.connection() as conn: cursor = await conn.execute( _LIST_CONVERSATIONS_SQL, {"limit": per_page, "offset": offset}, ) rows = await cursor.fetchall() return _envelope([dict(row) for row in rows]) @router.get("/replay/{thread_id}") async def get_replay( thread_id: str, request: Request, page: Annotated[int, Query(ge=1)] = 1, per_page: Annotated[int, Query(ge=1, le=100)] = 20, ) -> dict: """Return paginated replay steps for a conversation thread.""" from app.replay.transformer import transform_checkpoints if not _THREAD_ID_PATTERN.match(thread_id): raise HTTPException(status_code=400, detail="Invalid thread_id format") pool = await get_pool(request) async with pool.connection() as conn: cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id}) rows = await cursor.fetchall() if not rows: raise HTTPException(status_code=404, detail="Thread not found") all_steps = transform_checkpoints([dict(row) for row in rows]) total_steps = len(all_steps) start = (page - 1) * per_page end = start + per_page page_steps = all_steps[start:end] data = { "thread_id": thread_id, "total_steps": total_steps, "page": page, "per_page": per_page, "steps": [ { "step": s.step, "type": s.type.value, "timestamp": s.timestamp, "content": s.content, "agent": s.agent, "tool": s.tool, "params": s.params, "result": s.result, "reasoning": s.reasoning, "tokens": s.tokens, "duration_ms": s.duration_ms, } for s in page_steps ], } return _envelope(data)