fix: address security findings in Phase 4 analytics and replay

- Fix CRITICAL: use parameterized INTERVAL arithmetic (%(days)s * INTERVAL '1 day')
  instead of string interpolation inside SQL literal
- Use asyncio.gather() for parallel query execution in get_analytics()
- Add range upper bound (max 365 days) to prevent DoS via full-table scans
- Add thread_id validation (alphanumeric, max 128 chars) in replay API
- Sanitize error messages to not reflect user input
This commit is contained in:
Yaojia Wang
2026-03-31 13:38:09 +02:00
parent 33db5aeb10
commit ef6e5ac2be
3 changed files with 41 additions and 21 deletions

View File

@@ -17,6 +17,7 @@ router = APIRouter(prefix="/api/analytics", tags=["analytics"])
_RANGE_PATTERN = re.compile(r"^(\d+)d$") _RANGE_PATTERN = re.compile(r"^(\d+)d$")
_DEFAULT_RANGE = "7d" _DEFAULT_RANGE = "7d"
_MAX_RANGE_DAYS = 365
async def _get_pool(request: Request) -> AsyncConnectionPool: async def _get_pool(request: Request) -> AsyncConnectionPool:
@@ -34,9 +35,15 @@ def _parse_range(range_str: str) -> int:
if not match: if not match:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid range format '{range_str}'. Expected format: '<N>d' e.g. '7d', '30d'.", detail="Invalid range format. Expected: '<N>d' e.g. '7d', '30d'.",
) )
return int(match.group(1)) days = int(match.group(1))
if days < 1 or days > _MAX_RANGE_DAYS:
raise HTTPException(
status_code=400,
detail=f"Range must be between 1 and {_MAX_RANGE_DAYS} days.",
)
return days
@router.get("") @router.get("")

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
@@ -15,7 +16,7 @@ SELECT
ELSE COUNT(*) FILTER (WHERE resolution_type = 'resolved')::float / COUNT(*) ELSE COUNT(*) FILTER (WHERE resolution_type = 'resolved')::float / COUNT(*)
END AS rate END AS rate
FROM conversations FROM conversations
WHERE created_at >= NOW() - INTERVAL '%(days)s days' WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
""" """
_ESCALATION_RATE_SQL = """ _ESCALATION_RATE_SQL = """
@@ -24,25 +25,25 @@ SELECT
ELSE COUNT(*) FILTER (WHERE resolution_type = 'escalated')::float / COUNT(*) ELSE COUNT(*) FILTER (WHERE resolution_type = 'escalated')::float / COUNT(*)
END AS rate END AS rate
FROM conversations FROM conversations
WHERE created_at >= NOW() - INTERVAL '%(days)s days' WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
""" """
_TOTAL_CONVERSATIONS_SQL = """ _TOTAL_CONVERSATIONS_SQL = """
SELECT COUNT(*) AS total SELECT COUNT(*) AS total
FROM conversations FROM conversations
WHERE created_at >= NOW() - INTERVAL '%(days)s days' WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
""" """
_AVG_TURNS_SQL = """ _AVG_TURNS_SQL = """
SELECT COALESCE(AVG(turn_count), 0.0) AS avg_turns SELECT COALESCE(AVG(turn_count), 0.0) AS avg_turns
FROM conversations FROM conversations
WHERE created_at >= NOW() - INTERVAL '%(days)s days' WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
""" """
_COST_PER_CONVERSATION_SQL = """ _COST_PER_CONVERSATION_SQL = """
SELECT COALESCE(AVG(total_cost_usd), 0.0) AS avg_cost SELECT COALESCE(AVG(total_cost_usd), 0.0) AS avg_cost
FROM conversations FROM conversations
WHERE created_at >= NOW() - INTERVAL '%(days)s days' WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
""" """
_AGENT_USAGE_SQL = """ _AGENT_USAGE_SQL = """
@@ -53,7 +54,7 @@ SELECT
FROM ( FROM (
SELECT UNNEST(agents_used) AS agent SELECT UNNEST(agents_used) AS agent
FROM conversations FROM conversations
WHERE created_at >= NOW() - INTERVAL '%(days)s days' WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
AND agents_used IS NOT NULL AND agents_used IS NOT NULL
) sub ) sub
GROUP BY agent GROUP BY agent
@@ -68,7 +69,7 @@ SELECT
AND error_message IS NULL) AS rejected, AND error_message IS NULL) AS rejected,
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND error_message = 'expired') AS expired COUNT(*) FILTER (WHERE event_type = 'interrupt' AND error_message = 'expired') AS expired
FROM analytics_events FROM analytics_events
WHERE created_at >= NOW() - INTERVAL '%(days)s days' WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
""" """
@@ -122,7 +123,9 @@ async def cost_per_conversation(pool: AsyncConnectionPool, range_days: int) -> f
return float(row.get("avg_cost") or 0.0) return float(row.get("avg_cost") or 0.0)
async def agent_usage(pool: AsyncConnectionPool, range_days: int) -> tuple[AgentUsage, ...]: async def agent_usage(
pool: AsyncConnectionPool, range_days: int
) -> tuple[AgentUsage, ...]:
"""Return per-agent usage statistics for the given range.""" """Return per-agent usage statistics for the given range."""
async with pool.connection() as conn: async with pool.connection() as conn:
cursor = await conn.execute(_AGENT_USAGE_SQL, {"days": range_days}) cursor = await conn.execute(_AGENT_USAGE_SQL, {"days": range_days})
@@ -139,7 +142,9 @@ async def agent_usage(pool: AsyncConnectionPool, range_days: int) -> tuple[Agent
) )
async def interrupt_stats(pool: AsyncConnectionPool, range_days: int) -> InterruptStats: async def interrupt_stats(
pool: AsyncConnectionPool, range_days: int
) -> InterruptStats:
"""Return interrupt approval/rejection statistics for the given range.""" """Return interrupt approval/rejection statistics for the given range."""
async with pool.connection() as conn: async with pool.connection() as conn:
cursor = await conn.execute(_INTERRUPT_STATS_SQL, {"days": range_days}) cursor = await conn.execute(_INTERRUPT_STATS_SQL, {"days": range_days})
@@ -154,16 +159,18 @@ async def interrupt_stats(pool: AsyncConnectionPool, range_days: int) -> Interru
) )
async def get_analytics(pool: AsyncConnectionPool, range_days: int) -> AnalyticsResult: async def get_analytics(
pool: AsyncConnectionPool, range_days: int
) -> AnalyticsResult:
"""Aggregate all analytics metrics into a single AnalyticsResult.""" """Aggregate all analytics metrics into a single AnalyticsResult."""
res_rate, esc_rate, cost, usage, i_stats, total, avg_t = ( res_rate, esc_rate, cost, usage, i_stats, total, avg_t = await asyncio.gather(
await resolution_rate(pool, range_days), resolution_rate(pool, range_days),
await escalation_rate(pool, range_days), escalation_rate(pool, range_days),
await cost_per_conversation(pool, range_days), cost_per_conversation(pool, range_days),
await agent_usage(pool, range_days), agent_usage(pool, range_days),
await interrupt_stats(pool, range_days), interrupt_stats(pool, range_days),
await _total_conversations(pool, range_days), _total_conversations(pool, range_days),
await _avg_turns(pool, range_days), _avg_turns(pool, range_days),
) )
return AnalyticsResult( return AnalyticsResult(
range=f"{range_days}d", range=f"{range_days}d",

View File

@@ -2,10 +2,13 @@
from __future__ import annotations from __future__ import annotations
import re
from typing import TYPE_CHECKING, Annotated, Any from typing import TYPE_CHECKING, Annotated, Any
from fastapi import APIRouter, HTTPException, Query, Request from fastapi import APIRouter, HTTPException, Query, Request
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
if TYPE_CHECKING: if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
@@ -64,13 +67,16 @@ async def get_replay(
"""Return paginated replay steps for a conversation thread.""" """Return paginated replay steps for a conversation thread."""
from app.replay.transformer import transform_checkpoints from app.replay.transformer import transform_checkpoints
if not _THREAD_ID_PATTERN.match(thread_id):
raise HTTPException(status_code=400, detail="Invalid thread_id format")
pool = await get_pool(request) pool = await get_pool(request)
async with pool.connection() as conn: async with pool.connection() as conn:
cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id}) cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id})
rows = await cursor.fetchall() rows = await cursor.fetchall()
if not rows: if not rows:
raise HTTPException(status_code=404, detail=f"Thread '{thread_id}' not found") raise HTTPException(status_code=404, detail="Thread not found")
all_steps = transform_checkpoints([dict(row) for row in rows]) all_steps = transform_checkpoints([dict(row) for row in rows])
total_steps = len(all_steps) total_steps = len(all_steps)