"""Conversation tracker -- Protocol and implementations for tracking conversation state.""" from __future__ import annotations from typing import TYPE_CHECKING, Protocol, runtime_checkable if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool _ENSURE_SQL = """ INSERT INTO conversations (thread_id, created_at, last_activity) VALUES (%(thread_id)s, NOW(), NOW()) ON CONFLICT (thread_id) DO NOTHING """ _RECORD_TURN_SQL = """ UPDATE conversations SET turn_count = turn_count + 1, agents_used = CASE WHEN %(agent_name)s IS NOT NULL AND NOT (agents_used @> ARRAY[%(agent_name)s]::text[]) THEN agents_used || ARRAY[%(agent_name)s]::text[] ELSE agents_used END, total_tokens = total_tokens + %(tokens)s, total_cost_usd = total_cost_usd + %(cost)s, last_activity = NOW() WHERE thread_id = %(thread_id)s """ _RESOLVE_SQL = """ UPDATE conversations SET resolution_type = %(resolution_type)s, ended_at = NOW() WHERE thread_id = %(thread_id)s """ @runtime_checkable class ConversationTrackerProtocol(Protocol): """Protocol for tracking conversation lifecycle and metrics.""" async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None: """Create conversation row if it does not already exist.""" ... async def record_turn( self, pool: AsyncConnectionPool, thread_id: str, agent_name: str | None, tokens: int, cost: float, ) -> None: """Increment turn count and update aggregated metrics.""" ... async def resolve( self, pool: AsyncConnectionPool, thread_id: str, resolution_type: str, ) -> None: """Mark conversation as resolved with a resolution type.""" ... class NoOpConversationTracker: """No-op implementation -- used in tests or when DB is unavailable.""" async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None: """Do nothing.""" async def record_turn( self, pool: AsyncConnectionPool, thread_id: str, agent_name: str | None, tokens: int, cost: float, ) -> None: """Do nothing.""" async def resolve( self, pool: AsyncConnectionPool, thread_id: str, resolution_type: str, ) -> None: """Do nothing.""" class PostgresConversationTracker: """Postgres-backed conversation tracker.""" async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None: """Insert conversation row; do nothing if already exists (ON CONFLICT DO NOTHING).""" params = {"thread_id": thread_id} async with pool.connection() as conn: await conn.execute(_ENSURE_SQL, params) async def record_turn( self, pool: AsyncConnectionPool, thread_id: str, agent_name: str | None, tokens: int, cost: float, ) -> None: """Increment turn count, append agent if new, update token/cost totals.""" params = { "thread_id": thread_id, "agent_name": agent_name, "tokens": tokens, "cost": cost, } async with pool.connection() as conn: await conn.execute(_RECORD_TURN_SQL, params) async def resolve( self, pool: AsyncConnectionPool, thread_id: str, resolution_type: str, ) -> None: """Set resolution_type and ended_at on the conversation row.""" params = { "thread_id": thread_id, "resolution_type": resolution_type, } async with pool.connection() as conn: await conn.execute(_RESOLVE_SQL, params)