Files
smart-support/backend/app/replay/api.py
Yaojia Wang ef6e5ac2be fix: address security findings in Phase 4 analytics and replay
- Fix CRITICAL: use parameterized INTERVAL arithmetic (%(days)s * INTERVAL '1 day')
  instead of string interpolation inside SQL literal
- Use asyncio.gather() for parallel query execution in get_analytics()
- Add range upper bound (max 365 days) to prevent DoS via full-table scans
- Add thread_id validation (alphanumeric, max 128 chars) in replay API
- Sanitize error messages to not reflect user input
2026-03-31 13:38:09 +02:00

110 lines
3.2 KiB
Python

"""Replay API router -- conversation listing and step-by-step replay."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Annotated, Any
from fastapi import APIRouter, HTTPException, Query, Request
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
router = APIRouter(prefix="/api", tags=["replay"])
_LIST_CONVERSATIONS_SQL = """
SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd
FROM conversations
ORDER BY last_activity DESC
LIMIT %(limit)s OFFSET %(offset)s
"""
_GET_CHECKPOINTS_SQL = """
SELECT thread_id, checkpoint_id, checkpoint, metadata
FROM checkpoints
WHERE thread_id = %(thread_id)s
ORDER BY checkpoint_id ASC
"""
async def get_pool(request: Request) -> AsyncConnectionPool:
"""Dependency: extract the shared pool from app state."""
return request.app.state.pool
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
return {"success": success, "data": data, "error": error}
@router.get("/conversations")
async def list_conversations(
request: Request,
page: Annotated[int, Query(ge=1)] = 1,
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
) -> dict:
"""List conversations with pagination."""
pool = await get_pool(request)
offset = (page - 1) * per_page
async with pool.connection() as conn:
cursor = await conn.execute(
_LIST_CONVERSATIONS_SQL,
{"limit": per_page, "offset": offset},
)
rows = await cursor.fetchall()
return _envelope([dict(row) for row in rows])
@router.get("/replay/{thread_id}")
async def get_replay(
thread_id: str,
request: Request,
page: Annotated[int, Query(ge=1)] = 1,
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
) -> dict:
"""Return paginated replay steps for a conversation thread."""
from app.replay.transformer import transform_checkpoints
if not _THREAD_ID_PATTERN.match(thread_id):
raise HTTPException(status_code=400, detail="Invalid thread_id format")
pool = await get_pool(request)
async with pool.connection() as conn:
cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id})
rows = await cursor.fetchall()
if not rows:
raise HTTPException(status_code=404, detail="Thread not found")
all_steps = transform_checkpoints([dict(row) for row in rows])
total_steps = len(all_steps)
start = (page - 1) * per_page
end = start + per_page
page_steps = all_steps[start:end]
data = {
"thread_id": thread_id,
"total_steps": total_steps,
"page": page,
"per_page": per_page,
"steps": [
{
"step": s.step,
"type": s.type.value,
"timestamp": s.timestamp,
"content": s.content,
"agent": s.agent,
"tool": s.tool,
"params": s.params,
"result": s.result,
"reasoning": s.reasoning,
"tokens": s.tokens,
"duration_ms": s.duration_ms,
}
for s in page_steps
],
}
return _envelope(data)