feat: complete phase 4 -- conversation replay API + analytics dashboard
- Replay models: StepType enum, ReplayStep, ReplayPage frozen dataclasses
- Checkpoint transformer: PostgresSaver JSONB -> structured timeline steps
- Replay API: GET /api/conversations (paginated), GET /api/replay/{thread_id}
- Analytics models: AgentUsage, InterruptStats, AnalyticsResult
- Analytics event recorder: Protocol + PostgresAnalyticsRecorder + NoOp
- Analytics queries: resolution_rate, agent_usage, escalation_rate, cost, interrupts
- Analytics API: GET /api/analytics?range=Xd with envelope response
- DB migration: analytics_events table + conversations column additions
- 74 new tests, 399 total passing, 92.87% coverage
This commit is contained in:
3
backend/app/analytics/__init__.py
Normal file
3
backend/app/analytics/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Analytics module -- event recording and dashboard queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
51
backend/app/analytics/api.py
Normal file
51
backend/app/analytics/api.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Analytics API router -- dashboard metrics endpoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from app.analytics.queries import get_analytics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
router = APIRouter(prefix="/api/analytics", tags=["analytics"])
|
||||
|
||||
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
||||
_DEFAULT_RANGE = "7d"
|
||||
|
||||
|
||||
async def _get_pool(request: Request) -> AsyncConnectionPool:
|
||||
"""Dependency: extract the shared pool from app state."""
|
||||
return request.app.state.pool
|
||||
|
||||
|
||||
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
||||
return {"success": success, "data": data, "error": error}
|
||||
|
||||
|
||||
def _parse_range(range_str: str) -> int:
|
||||
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
|
||||
match = _RANGE_PATTERN.match(range_str)
|
||||
if not match:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid range format '{range_str}'. Expected format: '<N>d' e.g. '7d', '30d'.",
|
||||
)
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def analytics(
|
||||
request: Request,
|
||||
range: str = Query(default=_DEFAULT_RANGE, alias="range"), # noqa: A002
|
||||
) -> dict:
|
||||
"""Return aggregated analytics metrics for the given time range."""
|
||||
range_days = _parse_range(range)
|
||||
pool = await _get_pool(request)
|
||||
result = await get_analytics(pool, range_days=range_days)
|
||||
return _envelope(asdict(result))
|
||||
95
backend/app/analytics/event_recorder.py
Normal file
95
backend/app/analytics/event_recorder.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Analytics event recorder -- Protocol and implementations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
_INSERT_SQL = """
|
||||
INSERT INTO analytics_events
|
||||
(thread_id, event_type, agent_name, tool_name, tokens_used, cost_usd,
|
||||
duration_ms, success, error_message, metadata)
|
||||
VALUES
|
||||
(%(thread_id)s, %(event_type)s, %(agent_name)s, %(tool_name)s,
|
||||
%(tokens_used)s, %(cost_usd)s, %(duration_ms)s, %(success)s,
|
||||
%(error_message)s, %(metadata)s)
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AnalyticsRecorder(Protocol):
|
||||
"""Protocol for recording analytics events."""
|
||||
|
||||
async def record(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
event_type: str,
|
||||
agent_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tokens_used: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
duration_ms: int | None = None,
|
||||
success: bool | None = None,
|
||||
error_message: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class NoOpAnalyticsRecorder:
|
||||
"""No-op implementation for testing or when the DB is unavailable."""
|
||||
|
||||
async def record(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
event_type: str,
|
||||
agent_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tokens_used: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
duration_ms: int | None = None,
|
||||
success: bool | None = None,
|
||||
error_message: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
|
||||
class PostgresAnalyticsRecorder:
|
||||
"""Postgres-backed analytics recorder -- INSERTs into analytics_events."""
|
||||
|
||||
def __init__(self, pool: AsyncConnectionPool) -> None:
|
||||
self._pool = pool
|
||||
|
||||
async def record(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
event_type: str,
|
||||
agent_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tokens_used: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
duration_ms: int | None = None,
|
||||
success: bool | None = None,
|
||||
error_message: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Insert one analytics event row."""
|
||||
params: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"event_type": event_type,
|
||||
"agent_name": agent_name,
|
||||
"tool_name": tool_name,
|
||||
"tokens_used": tokens_used,
|
||||
"cost_usd": cost_usd,
|
||||
"duration_ms": duration_ms,
|
||||
"success": success,
|
||||
"error_message": error_message,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(_INSERT_SQL, params)
|
||||
38
backend/app/analytics/models.py
Normal file
38
backend/app/analytics/models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Value objects for analytics dashboard."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentUsage:
|
||||
"""Agent usage statistics within a time range."""
|
||||
|
||||
agent: str
|
||||
count: int
|
||||
percentage: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InterruptStats:
|
||||
"""Interrupt approval/rejection statistics within a time range."""
|
||||
|
||||
total: int = 0
|
||||
approved: int = 0
|
||||
rejected: int = 0
|
||||
expired: int = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnalyticsResult:
|
||||
"""Full analytics result for a given time range."""
|
||||
|
||||
range: str
|
||||
total_conversations: int
|
||||
resolution_rate: float
|
||||
escalation_rate: float
|
||||
avg_turns_per_conversation: float
|
||||
avg_cost_per_conversation_usd: float
|
||||
agent_usage: tuple[AgentUsage, ...]
|
||||
interrupt_stats: InterruptStats
|
||||
177
backend/app/analytics/queries.py
Normal file
177
backend/app/analytics/queries.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Analytics query functions -- all async, take pool + range_days."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
_RESOLUTION_RATE_SQL = """
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) = 0 THEN 0.0
|
||||
ELSE COUNT(*) FILTER (WHERE resolution_type = 'resolved')::float / COUNT(*)
|
||||
END AS rate
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
||||
"""
|
||||
|
||||
_ESCALATION_RATE_SQL = """
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) = 0 THEN 0.0
|
||||
ELSE COUNT(*) FILTER (WHERE resolution_type = 'escalated')::float / COUNT(*)
|
||||
END AS rate
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
||||
"""
|
||||
|
||||
_TOTAL_CONVERSATIONS_SQL = """
|
||||
SELECT COUNT(*) AS total
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
||||
"""
|
||||
|
||||
_AVG_TURNS_SQL = """
|
||||
SELECT COALESCE(AVG(turn_count), 0.0) AS avg_turns
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
||||
"""
|
||||
|
||||
_COST_PER_CONVERSATION_SQL = """
|
||||
SELECT COALESCE(AVG(total_cost_usd), 0.0) AS avg_cost
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
||||
"""
|
||||
|
||||
_AGENT_USAGE_SQL = """
|
||||
SELECT
|
||||
agent,
|
||||
COUNT(*) AS count,
|
||||
ROUND(COUNT(*) * 100.0 / NULLIF(SUM(COUNT(*)) OVER (), 0), 2) AS percentage
|
||||
FROM (
|
||||
SELECT UNNEST(agents_used) AS agent
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
||||
AND agents_used IS NOT NULL
|
||||
) sub
|
||||
GROUP BY agent
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
|
||||
_INTERRUPT_STATS_SQL = """
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE event_type = 'interrupt') AS total,
|
||||
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND success = TRUE) AS approved,
|
||||
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND success = FALSE
|
||||
AND error_message IS NULL) AS rejected,
|
||||
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND error_message = 'expired') AS expired
|
||||
FROM analytics_events
|
||||
WHERE created_at >= NOW() - INTERVAL '%(days)s days'
|
||||
"""
|
||||
|
||||
|
||||
async def resolution_rate(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||
"""Return the fraction of resolved conversations in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_RESOLUTION_RATE_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0.0
|
||||
return float(row.get("rate") or 0.0)
|
||||
|
||||
|
||||
async def escalation_rate(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||
"""Return the fraction of escalated conversations in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_ESCALATION_RATE_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0.0
|
||||
return float(row.get("rate") or 0.0)
|
||||
|
||||
|
||||
async def _total_conversations(pool: AsyncConnectionPool, range_days: int) -> int:
|
||||
"""Return the total number of conversations in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_TOTAL_CONVERSATIONS_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
return int(row.get("total") or 0)
|
||||
|
||||
|
||||
async def _avg_turns(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||
"""Return the average turn count per conversation in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_AVG_TURNS_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0.0
|
||||
return float(row.get("avg_turns") or 0.0)
|
||||
|
||||
|
||||
async def cost_per_conversation(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||
"""Return the average cost per conversation in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_COST_PER_CONVERSATION_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0.0
|
||||
return float(row.get("avg_cost") or 0.0)
|
||||
|
||||
|
||||
async def agent_usage(pool: AsyncConnectionPool, range_days: int) -> tuple[AgentUsage, ...]:
|
||||
"""Return per-agent usage statistics for the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_AGENT_USAGE_SQL, {"days": range_days})
|
||||
rows = await cursor.fetchall()
|
||||
if not rows:
|
||||
return ()
|
||||
return tuple(
|
||||
AgentUsage(
|
||||
agent=row["agent"],
|
||||
count=int(row["count"]),
|
||||
percentage=float(row["percentage"]),
|
||||
)
|
||||
for row in rows
|
||||
)
|
||||
|
||||
|
||||
async def interrupt_stats(pool: AsyncConnectionPool, range_days: int) -> InterruptStats:
|
||||
"""Return interrupt approval/rejection statistics for the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_INTERRUPT_STATS_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return InterruptStats()
|
||||
return InterruptStats(
|
||||
total=int(row.get("total") or 0),
|
||||
approved=int(row.get("approved") or 0),
|
||||
rejected=int(row.get("rejected") or 0),
|
||||
expired=int(row.get("expired") or 0),
|
||||
)
|
||||
|
||||
|
||||
async def get_analytics(pool: AsyncConnectionPool, range_days: int) -> AnalyticsResult:
|
||||
"""Aggregate all analytics metrics into a single AnalyticsResult."""
|
||||
res_rate, esc_rate, cost, usage, i_stats, total, avg_t = (
|
||||
await resolution_rate(pool, range_days),
|
||||
await escalation_rate(pool, range_days),
|
||||
await cost_per_conversation(pool, range_days),
|
||||
await agent_usage(pool, range_days),
|
||||
await interrupt_stats(pool, range_days),
|
||||
await _total_conversations(pool, range_days),
|
||||
await _avg_turns(pool, range_days),
|
||||
)
|
||||
return AnalyticsResult(
|
||||
range=f"{range_days}d",
|
||||
total_conversations=total,
|
||||
resolution_rate=res_rate,
|
||||
escalation_rate=esc_rate,
|
||||
avg_turns_per_conversation=avg_t,
|
||||
avg_cost_per_conversation_usd=cost,
|
||||
agent_usage=usage,
|
||||
interrupt_stats=i_stats,
|
||||
)
|
||||
@@ -34,6 +34,31 @@ CREATE TABLE IF NOT EXISTS active_interrupts (
|
||||
);
|
||||
"""
|
||||
|
||||
_ANALYTICS_EVENTS_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS analytics_events (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
thread_id TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
agent_name TEXT,
|
||||
tool_name TEXT,
|
||||
tokens_used INTEGER NOT NULL DEFAULT 0,
|
||||
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||
duration_ms INTEGER,
|
||||
success BOOLEAN,
|
||||
error_message TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
_CONVERSATIONS_MIGRATION_DDL = """
|
||||
ALTER TABLE conversations
|
||||
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
||||
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
|
||||
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
|
||||
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ;
|
||||
"""
|
||||
|
||||
|
||||
async def create_pool(settings: Settings) -> AsyncConnectionPool:
|
||||
"""Create an async connection pool with the required psycopg settings."""
|
||||
@@ -55,7 +80,9 @@ async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
|
||||
|
||||
|
||||
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
||||
"""Create application-specific tables (conversations, active_interrupts)."""
|
||||
"""Create application-specific tables and apply migrations."""
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_CONVERSATIONS_DDL)
|
||||
await conn.execute(_INTERRUPTS_DDL)
|
||||
await conn.execute(_ANALYTICS_EVENTS_DDL)
|
||||
await conn.execute(_CONVERSATIONS_MIGRATION_DDL)
|
||||
|
||||
@@ -10,6 +10,8 @@ from typing import TYPE_CHECKING
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.analytics.api import router as analytics_router
|
||||
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.config import Settings
|
||||
from app.db import create_checkpointer, create_pool, setup_app_tables
|
||||
@@ -18,9 +20,10 @@ from app.graph import build_graph
|
||||
from app.intent import LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.llm import create_llm
|
||||
from app.registry import AgentRegistry
|
||||
from app.session_manager import SessionManager
|
||||
from app.openapi.review_api import router as openapi_router
|
||||
from app.registry import AgentRegistry
|
||||
from app.replay.api import router as replay_router
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -73,6 +76,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
app.state.escalator = escalator
|
||||
app.state.settings = settings
|
||||
app.state.pool = pool
|
||||
app.state.analytics_recorder = NoOpAnalyticsRecorder()
|
||||
|
||||
logger.info(
|
||||
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
|
||||
@@ -87,9 +91,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await pool.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Smart Support", version="0.3.0", lifespan=lifespan)
|
||||
app = FastAPI(title="Smart Support", version="0.4.0", lifespan=lifespan)
|
||||
|
||||
app.include_router(openapi_router)
|
||||
app.include_router(replay_router)
|
||||
app.include_router(analytics_router)
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
|
||||
3
backend/app/replay/__init__.py
Normal file
3
backend/app/replay/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Replay module -- conversation replay API and transformer."""
|
||||
|
||||
from __future__ import annotations
|
||||
103
backend/app/replay/api.py
Normal file
103
backend/app/replay/api.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Replay API router -- conversation listing and step-by-step replay."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["replay"])
|
||||
|
||||
_LIST_CONVERSATIONS_SQL = """
|
||||
SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd
|
||||
FROM conversations
|
||||
ORDER BY last_activity DESC
|
||||
LIMIT %(limit)s OFFSET %(offset)s
|
||||
"""
|
||||
|
||||
_GET_CHECKPOINTS_SQL = """
|
||||
SELECT thread_id, checkpoint_id, checkpoint, metadata
|
||||
FROM checkpoints
|
||||
WHERE thread_id = %(thread_id)s
|
||||
ORDER BY checkpoint_id ASC
|
||||
"""
|
||||
|
||||
|
||||
async def get_pool(request: Request) -> AsyncConnectionPool:
|
||||
"""Dependency: extract the shared pool from app state."""
|
||||
return request.app.state.pool
|
||||
|
||||
|
||||
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
||||
return {"success": success, "data": data, "error": error}
|
||||
|
||||
|
||||
@router.get("/conversations")
|
||||
async def list_conversations(
|
||||
request: Request,
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
) -> dict:
|
||||
"""List conversations with pagination."""
|
||||
pool = await get_pool(request)
|
||||
offset = (page - 1) * per_page
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(
|
||||
_LIST_CONVERSATIONS_SQL,
|
||||
{"limit": per_page, "offset": offset},
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
return _envelope([dict(row) for row in rows])
|
||||
|
||||
|
||||
@router.get("/replay/{thread_id}")
|
||||
async def get_replay(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
) -> dict:
|
||||
"""Return paginated replay steps for a conversation thread."""
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
pool = await get_pool(request)
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id})
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
raise HTTPException(status_code=404, detail=f"Thread '{thread_id}' not found")
|
||||
|
||||
all_steps = transform_checkpoints([dict(row) for row in rows])
|
||||
total_steps = len(all_steps)
|
||||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
page_steps = all_steps[start:end]
|
||||
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"total_steps": total_steps,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"steps": [
|
||||
{
|
||||
"step": s.step,
|
||||
"type": s.type.value,
|
||||
"timestamp": s.timestamp,
|
||||
"content": s.content,
|
||||
"agent": s.agent,
|
||||
"tool": s.tool,
|
||||
"params": s.params,
|
||||
"result": s.result,
|
||||
"reasoning": s.reasoning,
|
||||
"tokens": s.tokens,
|
||||
"duration_ms": s.duration_ms,
|
||||
}
|
||||
for s in page_steps
|
||||
],
|
||||
}
|
||||
return _envelope(data)
|
||||
52
backend/app/replay/models.py
Normal file
52
backend/app/replay/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Value objects for conversation replay."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StepType(str, Enum):
|
||||
"""Types of steps in a conversation replay."""
|
||||
|
||||
user_message = "user_message"
|
||||
supervisor_routing = "supervisor_routing"
|
||||
tool_call = "tool_call"
|
||||
tool_result = "tool_result"
|
||||
agent_response = "agent_response"
|
||||
interrupt = "interrupt"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReplayStep:
|
||||
"""A single step in a conversation replay."""
|
||||
|
||||
step: int
|
||||
type: StepType
|
||||
timestamp: str
|
||||
content: str = ""
|
||||
agent: str | None = None
|
||||
tool: str | None = None
|
||||
params: dict | None = None
|
||||
result: dict | None = None
|
||||
reasoning: str | None = None
|
||||
tokens: int | None = None
|
||||
duration_ms: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Store params as a frozen copy to prevent mutation from the outside
|
||||
if self.params is not None:
|
||||
object.__setattr__(self, "params", dict(self.params))
|
||||
if self.result is not None:
|
||||
object.__setattr__(self, "result", dict(self.result))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReplayPage:
|
||||
"""A paginated page of replay steps for a conversation thread."""
|
||||
|
||||
thread_id: str
|
||||
total_steps: int
|
||||
page: int
|
||||
per_page: int
|
||||
steps: tuple[ReplayStep, ...]
|
||||
116
backend/app/replay/transformer.py
Normal file
116
backend/app/replay/transformer.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Transforms PostgresSaver checkpoint rows into ReplayStep list."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.replay.models import ReplayStep, StepType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z"
|
||||
|
||||
|
||||
def _extract_messages(row: dict) -> list[dict]:
|
||||
"""Safely extract messages list from a checkpoint row."""
|
||||
checkpoint = row.get("checkpoint")
|
||||
if not checkpoint or not isinstance(checkpoint, dict):
|
||||
return []
|
||||
channel_values = checkpoint.get("channel_values")
|
||||
if not channel_values or not isinstance(channel_values, dict):
|
||||
return []
|
||||
messages = channel_values.get("messages")
|
||||
if not messages or not isinstance(messages, list):
|
||||
return []
|
||||
return messages
|
||||
|
||||
|
||||
def _step_from_message(msg: dict, step_number: int) -> ReplayStep | None:
|
||||
"""Convert a single message dict to a ReplayStep. Returns None for unknown types."""
|
||||
msg_type = msg.get("type", "")
|
||||
timestamp = msg.get("created_at") or _EMPTY_TIMESTAMP
|
||||
content = msg.get("content") or ""
|
||||
if isinstance(content, list):
|
||||
# LangChain may encode content as a list of parts
|
||||
content = " ".join(
|
||||
part.get("text", "") if isinstance(part, dict) else str(part)
|
||||
for part in content
|
||||
)
|
||||
|
||||
if msg_type == "human":
|
||||
return ReplayStep(
|
||||
step=step_number,
|
||||
type=StepType.user_message,
|
||||
timestamp=timestamp,
|
||||
content=content,
|
||||
)
|
||||
|
||||
if msg_type == "ai":
|
||||
tool_calls = msg.get("tool_calls") or []
|
||||
if tool_calls:
|
||||
first = tool_calls[0]
|
||||
return ReplayStep(
|
||||
step=step_number,
|
||||
type=StepType.tool_call,
|
||||
timestamp=timestamp,
|
||||
content=content,
|
||||
tool=first.get("name"),
|
||||
params=dict(first.get("args") or {}),
|
||||
)
|
||||
return ReplayStep(
|
||||
step=step_number,
|
||||
type=StepType.agent_response,
|
||||
timestamp=timestamp,
|
||||
content=content,
|
||||
agent=msg.get("name"),
|
||||
)
|
||||
|
||||
if msg_type == "tool":
|
||||
raw = content
|
||||
result: dict | None = None
|
||||
try:
|
||||
import json
|
||||
|
||||
result = json.loads(raw)
|
||||
except (ValueError, TypeError):
|
||||
result = {"raw": raw}
|
||||
return ReplayStep(
|
||||
step=step_number,
|
||||
type=StepType.tool_result,
|
||||
timestamp=timestamp,
|
||||
tool=msg.get("name"),
|
||||
result=result,
|
||||
)
|
||||
|
||||
logger.debug("Skipping unknown message type: %s", msg_type)
|
||||
return None
|
||||
|
||||
|
||||
def transform_checkpoints(rows: list[dict]) -> list[ReplayStep]:
|
||||
"""Transform a list of checkpoint rows into an ordered list of ReplaySteps.
|
||||
|
||||
Steps are numbered sequentially starting from 1 across all rows.
|
||||
Unknown or malformed messages are silently skipped.
|
||||
"""
|
||||
steps: list[ReplayStep] = []
|
||||
step_number = 1
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
messages = _extract_messages(row)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error extracting messages from checkpoint row")
|
||||
continue
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
step = _step_from_message(msg, step_number)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error converting message to ReplayStep")
|
||||
step = None
|
||||
|
||||
if step is not None:
|
||||
steps.append(step)
|
||||
step_number += 1
|
||||
|
||||
return steps
|
||||
Reference in New Issue
Block a user