From ef6e5ac2becbb7f8d7cec96f50cab028e1da347f Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Tue, 31 Mar 2026 13:38:09 +0200 Subject: [PATCH] fix: address security findings in Phase 4 analytics and replay - Fix CRITICAL: use parameterized INTERVAL arithmetic (%(days)s * INTERVAL '1 day') instead of string interpolation inside SQL literal - Use asyncio.gather() for parallel query execution in get_analytics() - Add range upper bound (max 365 days) to prevent DoS via full-table scans - Add thread_id validation (alphanumeric, max 128 chars) in replay API - Sanitize error messages to not reflect user input --- backend/app/analytics/api.py | 11 ++++++-- backend/app/analytics/queries.py | 43 +++++++++++++++++++------------- backend/app/replay/api.py | 8 +++++- 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/backend/app/analytics/api.py b/backend/app/analytics/api.py index c86b2eb..d8ed694 100644 --- a/backend/app/analytics/api.py +++ b/backend/app/analytics/api.py @@ -17,6 +17,7 @@ router = APIRouter(prefix="/api/analytics", tags=["analytics"]) _RANGE_PATTERN = re.compile(r"^(\d+)d$") _DEFAULT_RANGE = "7d" +_MAX_RANGE_DAYS = 365 async def _get_pool(request: Request) -> AsyncConnectionPool: @@ -34,9 +35,15 @@ def _parse_range(range_str: str) -> int: if not match: raise HTTPException( status_code=400, - detail=f"Invalid range format '{range_str}'. Expected format: 'd' e.g. '7d', '30d'.", + detail="Invalid range format. Expected: 'd' e.g. '7d', '30d'.", ) - return int(match.group(1)) + days = int(match.group(1)) + if days < 1 or days > _MAX_RANGE_DAYS: + raise HTTPException( + status_code=400, + detail=f"Range must be between 1 and {_MAX_RANGE_DAYS} days.", + ) + return days @router.get("") diff --git a/backend/app/analytics/queries.py b/backend/app/analytics/queries.py index fe28d19..bf80d74 100644 --- a/backend/app/analytics/queries.py +++ b/backend/app/analytics/queries.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats @@ -15,7 +16,7 @@ SELECT ELSE COUNT(*) FILTER (WHERE resolution_type = 'resolved')::float / COUNT(*) END AS rate FROM conversations -WHERE created_at >= NOW() - INTERVAL '%(days)s days' +WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day') """ _ESCALATION_RATE_SQL = """ @@ -24,25 +25,25 @@ SELECT ELSE COUNT(*) FILTER (WHERE resolution_type = 'escalated')::float / COUNT(*) END AS rate FROM conversations -WHERE created_at >= NOW() - INTERVAL '%(days)s days' +WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day') """ _TOTAL_CONVERSATIONS_SQL = """ SELECT COUNT(*) AS total FROM conversations -WHERE created_at >= NOW() - INTERVAL '%(days)s days' +WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day') """ _AVG_TURNS_SQL = """ SELECT COALESCE(AVG(turn_count), 0.0) AS avg_turns FROM conversations -WHERE created_at >= NOW() - INTERVAL '%(days)s days' +WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day') """ _COST_PER_CONVERSATION_SQL = """ SELECT COALESCE(AVG(total_cost_usd), 0.0) AS avg_cost FROM conversations -WHERE created_at >= NOW() - INTERVAL '%(days)s days' +WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day') """ _AGENT_USAGE_SQL = """ @@ -53,7 +54,7 @@ SELECT FROM ( SELECT UNNEST(agents_used) AS agent FROM conversations - WHERE created_at >= NOW() - INTERVAL '%(days)s days' + WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day') AND agents_used IS NOT NULL ) sub GROUP BY agent @@ -68,7 +69,7 @@ SELECT AND error_message IS NULL) AS rejected, COUNT(*) FILTER (WHERE event_type = 'interrupt' AND error_message = 'expired') AS expired FROM analytics_events -WHERE created_at >= NOW() - INTERVAL '%(days)s days' +WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day') """ @@ -122,7 +123,9 @@ async def cost_per_conversation(pool: AsyncConnectionPool, range_days: int) -> f return float(row.get("avg_cost") or 0.0) -async def agent_usage(pool: AsyncConnectionPool, range_days: int) -> tuple[AgentUsage, ...]: +async def agent_usage( + pool: AsyncConnectionPool, range_days: int +) -> tuple[AgentUsage, ...]: """Return per-agent usage statistics for the given range.""" async with pool.connection() as conn: cursor = await conn.execute(_AGENT_USAGE_SQL, {"days": range_days}) @@ -139,7 +142,9 @@ async def agent_usage(pool: AsyncConnectionPool, range_days: int) -> tuple[Agent ) -async def interrupt_stats(pool: AsyncConnectionPool, range_days: int) -> InterruptStats: +async def interrupt_stats( + pool: AsyncConnectionPool, range_days: int +) -> InterruptStats: """Return interrupt approval/rejection statistics for the given range.""" async with pool.connection() as conn: cursor = await conn.execute(_INTERRUPT_STATS_SQL, {"days": range_days}) @@ -154,16 +159,18 @@ async def interrupt_stats(pool: AsyncConnectionPool, range_days: int) -> Interru ) -async def get_analytics(pool: AsyncConnectionPool, range_days: int) -> AnalyticsResult: +async def get_analytics( + pool: AsyncConnectionPool, range_days: int +) -> AnalyticsResult: """Aggregate all analytics metrics into a single AnalyticsResult.""" - res_rate, esc_rate, cost, usage, i_stats, total, avg_t = ( - await resolution_rate(pool, range_days), - await escalation_rate(pool, range_days), - await cost_per_conversation(pool, range_days), - await agent_usage(pool, range_days), - await interrupt_stats(pool, range_days), - await _total_conversations(pool, range_days), - await _avg_turns(pool, range_days), + res_rate, esc_rate, cost, usage, i_stats, total, avg_t = await asyncio.gather( + resolution_rate(pool, range_days), + escalation_rate(pool, range_days), + cost_per_conversation(pool, range_days), + agent_usage(pool, range_days), + interrupt_stats(pool, range_days), + _total_conversations(pool, range_days), + _avg_turns(pool, range_days), ) return AnalyticsResult( range=f"{range_days}d", diff --git a/backend/app/replay/api.py b/backend/app/replay/api.py index f8e38ee..bb1a8d4 100644 --- a/backend/app/replay/api.py +++ b/backend/app/replay/api.py @@ -2,10 +2,13 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, Annotated, Any from fastapi import APIRouter, HTTPException, Query, Request +_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") + if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool @@ -64,13 +67,16 @@ async def get_replay( """Return paginated replay steps for a conversation thread.""" from app.replay.transformer import transform_checkpoints + if not _THREAD_ID_PATTERN.match(thread_id): + raise HTTPException(status_code=400, detail="Invalid thread_id format") + pool = await get_pool(request) async with pool.connection() as conn: cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id}) rows = await cursor.fetchall() if not rows: - raise HTTPException(status_code=404, detail=f"Thread '{thread_id}' not found") + raise HTTPException(status_code=404, detail="Thread not found") all_steps = transform_checkpoints([dict(row) for row in rows]) total_steps = len(all_steps)