fix: address security findings in Phase 4 analytics and replay
- Fix CRITICAL: use parameterized INTERVAL arithmetic (%(days)s * INTERVAL '1 day') instead of string interpolation inside SQL literal - Use asyncio.gather() for parallel query execution in get_analytics() - Add range upper bound (max 365 days) to prevent DoS via full-table scans - Add thread_id validation (alphanumeric, max 128 chars) in replay API - Sanitize error messages to not reflect user input
This commit is contained in:
@@ -17,6 +17,7 @@ router = APIRouter(prefix="/api/analytics", tags=["analytics"])
|
|||||||
|
|
||||||
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
||||||
_DEFAULT_RANGE = "7d"
|
_DEFAULT_RANGE = "7d"
|
||||||
|
_MAX_RANGE_DAYS = 365
|
||||||
|
|
||||||
|
|
||||||
async def _get_pool(request: Request) -> AsyncConnectionPool:
|
async def _get_pool(request: Request) -> AsyncConnectionPool:
|
||||||
@@ -34,9 +35,15 @@ def _parse_range(range_str: str) -> int:
|
|||||||
if not match:
|
if not match:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Invalid range format '{range_str}'. Expected format: '<N>d' e.g. '7d', '30d'.",
|
detail="Invalid range format. Expected: '<N>d' e.g. '7d', '30d'.",
|
||||||
)
|
)
|
||||||
return int(match.group(1))
|
days = int(match.group(1))
|
||||||
|
if days < 1 or days > _MAX_RANGE_DAYS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Range must be between 1 and {_MAX_RANGE_DAYS} days.",
|
||||||
|
)
|
||||||
|
return days
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
|
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
|
||||||
@@ -15,7 +16,7 @@ SELECT
|
|||||||
ELSE COUNT(*) FILTER (WHERE resolution_type = 'resolved')::float / COUNT(*)
|
ELSE COUNT(*) FILTER (WHERE resolution_type = 'resolved')::float / COUNT(*)
|
||||||
END AS rate
|
END AS rate
|
||||||
FROM conversations
|
FROM conversations
|
||||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_ESCALATION_RATE_SQL = """
|
_ESCALATION_RATE_SQL = """
|
||||||
@@ -24,25 +25,25 @@ SELECT
|
|||||||
ELSE COUNT(*) FILTER (WHERE resolution_type = 'escalated')::float / COUNT(*)
|
ELSE COUNT(*) FILTER (WHERE resolution_type = 'escalated')::float / COUNT(*)
|
||||||
END AS rate
|
END AS rate
|
||||||
FROM conversations
|
FROM conversations
|
||||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_TOTAL_CONVERSATIONS_SQL = """
|
_TOTAL_CONVERSATIONS_SQL = """
|
||||||
SELECT COUNT(*) AS total
|
SELECT COUNT(*) AS total
|
||||||
FROM conversations
|
FROM conversations
|
||||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_AVG_TURNS_SQL = """
|
_AVG_TURNS_SQL = """
|
||||||
SELECT COALESCE(AVG(turn_count), 0.0) AS avg_turns
|
SELECT COALESCE(AVG(turn_count), 0.0) AS avg_turns
|
||||||
FROM conversations
|
FROM conversations
|
||||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_COST_PER_CONVERSATION_SQL = """
|
_COST_PER_CONVERSATION_SQL = """
|
||||||
SELECT COALESCE(AVG(total_cost_usd), 0.0) AS avg_cost
|
SELECT COALESCE(AVG(total_cost_usd), 0.0) AS avg_cost
|
||||||
FROM conversations
|
FROM conversations
|
||||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_AGENT_USAGE_SQL = """
|
_AGENT_USAGE_SQL = """
|
||||||
@@ -53,7 +54,7 @@ SELECT
|
|||||||
FROM (
|
FROM (
|
||||||
SELECT UNNEST(agents_used) AS agent
|
SELECT UNNEST(agents_used) AS agent
|
||||||
FROM conversations
|
FROM conversations
|
||||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
AND agents_used IS NOT NULL
|
AND agents_used IS NOT NULL
|
||||||
) sub
|
) sub
|
||||||
GROUP BY agent
|
GROUP BY agent
|
||||||
@@ -68,7 +69,7 @@ SELECT
|
|||||||
AND error_message IS NULL) AS rejected,
|
AND error_message IS NULL) AS rejected,
|
||||||
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND error_message = 'expired') AS expired
|
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND error_message = 'expired') AS expired
|
||||||
FROM analytics_events
|
FROM analytics_events
|
||||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -122,7 +123,9 @@ async def cost_per_conversation(pool: AsyncConnectionPool, range_days: int) -> f
|
|||||||
return float(row.get("avg_cost") or 0.0)
|
return float(row.get("avg_cost") or 0.0)
|
||||||
|
|
||||||
|
|
||||||
async def agent_usage(pool: AsyncConnectionPool, range_days: int) -> tuple[AgentUsage, ...]:
|
async def agent_usage(
|
||||||
|
pool: AsyncConnectionPool, range_days: int
|
||||||
|
) -> tuple[AgentUsage, ...]:
|
||||||
"""Return per-agent usage statistics for the given range."""
|
"""Return per-agent usage statistics for the given range."""
|
||||||
async with pool.connection() as conn:
|
async with pool.connection() as conn:
|
||||||
cursor = await conn.execute(_AGENT_USAGE_SQL, {"days": range_days})
|
cursor = await conn.execute(_AGENT_USAGE_SQL, {"days": range_days})
|
||||||
@@ -139,7 +142,9 @@ async def agent_usage(pool: AsyncConnectionPool, range_days: int) -> tuple[Agent
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def interrupt_stats(pool: AsyncConnectionPool, range_days: int) -> InterruptStats:
|
async def interrupt_stats(
|
||||||
|
pool: AsyncConnectionPool, range_days: int
|
||||||
|
) -> InterruptStats:
|
||||||
"""Return interrupt approval/rejection statistics for the given range."""
|
"""Return interrupt approval/rejection statistics for the given range."""
|
||||||
async with pool.connection() as conn:
|
async with pool.connection() as conn:
|
||||||
cursor = await conn.execute(_INTERRUPT_STATS_SQL, {"days": range_days})
|
cursor = await conn.execute(_INTERRUPT_STATS_SQL, {"days": range_days})
|
||||||
@@ -154,16 +159,18 @@ async def interrupt_stats(pool: AsyncConnectionPool, range_days: int) -> Interru
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_analytics(pool: AsyncConnectionPool, range_days: int) -> AnalyticsResult:
|
async def get_analytics(
|
||||||
|
pool: AsyncConnectionPool, range_days: int
|
||||||
|
) -> AnalyticsResult:
|
||||||
"""Aggregate all analytics metrics into a single AnalyticsResult."""
|
"""Aggregate all analytics metrics into a single AnalyticsResult."""
|
||||||
res_rate, esc_rate, cost, usage, i_stats, total, avg_t = (
|
res_rate, esc_rate, cost, usage, i_stats, total, avg_t = await asyncio.gather(
|
||||||
await resolution_rate(pool, range_days),
|
resolution_rate(pool, range_days),
|
||||||
await escalation_rate(pool, range_days),
|
escalation_rate(pool, range_days),
|
||||||
await cost_per_conversation(pool, range_days),
|
cost_per_conversation(pool, range_days),
|
||||||
await agent_usage(pool, range_days),
|
agent_usage(pool, range_days),
|
||||||
await interrupt_stats(pool, range_days),
|
interrupt_stats(pool, range_days),
|
||||||
await _total_conversations(pool, range_days),
|
_total_conversations(pool, range_days),
|
||||||
await _avg_turns(pool, range_days),
|
_avg_turns(pool, range_days),
|
||||||
)
|
)
|
||||||
return AnalyticsResult(
|
return AnalyticsResult(
|
||||||
range=f"{range_days}d",
|
range=f"{range_days}d",
|
||||||
|
|||||||
@@ -2,10 +2,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import TYPE_CHECKING, Annotated, Any
|
from typing import TYPE_CHECKING, Annotated, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
from fastapi import APIRouter, HTTPException, Query, Request
|
||||||
|
|
||||||
|
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
@@ -64,13 +67,16 @@ async def get_replay(
|
|||||||
"""Return paginated replay steps for a conversation thread."""
|
"""Return paginated replay steps for a conversation thread."""
|
||||||
from app.replay.transformer import transform_checkpoints
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
if not _THREAD_ID_PATTERN.match(thread_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid thread_id format")
|
||||||
|
|
||||||
pool = await get_pool(request)
|
pool = await get_pool(request)
|
||||||
async with pool.connection() as conn:
|
async with pool.connection() as conn:
|
||||||
cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id})
|
cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id})
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
raise HTTPException(status_code=404, detail=f"Thread '{thread_id}' not found")
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
all_steps = transform_checkpoints([dict(row) for row in rows])
|
all_steps = transform_checkpoints([dict(row) for row in rows])
|
||||||
total_steps = len(all_steps)
|
total_steps = len(all_steps)
|
||||||
|
|||||||
Reference in New Issue
Block a user