"""Transforms PostgresSaver checkpoint rows into ReplayStep list.""" from __future__ import annotations import structlog from app.replay.models import ReplayStep, StepType logger = structlog.get_logger() _EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z" def _extract_messages(row: dict) -> list[dict]: """Safely extract messages list from a checkpoint row.""" checkpoint = row.get("checkpoint") if not checkpoint or not isinstance(checkpoint, dict): return [] channel_values = checkpoint.get("channel_values") if not channel_values or not isinstance(channel_values, dict): return [] messages = channel_values.get("messages") if not messages or not isinstance(messages, list): return [] return messages def _step_from_message(msg: dict, step_number: int) -> ReplayStep | None: """Convert a single message dict to a ReplayStep. Returns None for unknown types.""" msg_type = msg.get("type", "") timestamp = msg.get("created_at") or _EMPTY_TIMESTAMP content = msg.get("content") or "" if isinstance(content, list): # LangChain may encode content as a list of parts content = " ".join( part.get("text", "") if isinstance(part, dict) else str(part) for part in content ) if msg_type == "human": return ReplayStep( step=step_number, type=StepType.user_message, timestamp=timestamp, content=content, ) if msg_type == "ai": tool_calls = msg.get("tool_calls") or [] if tool_calls: first = tool_calls[0] return ReplayStep( step=step_number, type=StepType.tool_call, timestamp=timestamp, content=content, tool=first.get("name"), params=dict(first.get("args") or {}), ) return ReplayStep( step=step_number, type=StepType.agent_response, timestamp=timestamp, content=content, agent=msg.get("name"), ) if msg_type == "tool": raw = content result: dict | None = None try: import json result = json.loads(raw) except (ValueError, TypeError): result = {"raw": raw} return ReplayStep( step=step_number, type=StepType.tool_result, timestamp=timestamp, tool=msg.get("name"), result=result, ) logger.debug("Skipping unknown message type: %s", msg_type) return None def transform_checkpoints(rows: list[dict]) -> list[ReplayStep]: """Transform a list of checkpoint rows into an ordered list of ReplaySteps. Steps are numbered sequentially starting from 1 across all rows. Unknown or malformed messages are silently skipped. """ steps: list[ReplayStep] = [] step_number = 1 for row in rows: try: messages = _extract_messages(row) except Exception: # noqa: BLE001 logger.exception("Error extracting messages from checkpoint row") continue for msg in messages: try: step = _step_from_message(msg, step_number) except Exception: # noqa: BLE001 logger.exception("Error converting message to ReplayStep") step = None if step is not None: steps.append(step) step_number += 1 return steps