136 lines
3.7 KiB
Python
136 lines
3.7 KiB
Python
"""Conversation tracker -- Protocol and implementations for tracking conversation state."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
|
|
if TYPE_CHECKING:
|
|
from psycopg_pool import AsyncConnectionPool
|
|
|
|
_ENSURE_SQL = """
|
|
INSERT INTO conversations
|
|
(thread_id, created_at, last_activity)
|
|
VALUES
|
|
(%(thread_id)s, NOW(), NOW())
|
|
ON CONFLICT (thread_id) DO NOTHING
|
|
"""
|
|
|
|
_RECORD_TURN_SQL = """
|
|
UPDATE conversations
|
|
SET
|
|
turn_count = turn_count + 1,
|
|
agents_used = CASE
|
|
WHEN %(agent_name)s IS NOT NULL AND NOT (agents_used @> ARRAY[%(agent_name)s]::text[])
|
|
THEN agents_used || ARRAY[%(agent_name)s]::text[]
|
|
ELSE agents_used
|
|
END,
|
|
total_tokens = total_tokens + %(tokens)s,
|
|
total_cost_usd = total_cost_usd + %(cost)s,
|
|
last_activity = NOW()
|
|
WHERE thread_id = %(thread_id)s
|
|
"""
|
|
|
|
_RESOLVE_SQL = """
|
|
UPDATE conversations
|
|
SET
|
|
resolution_type = %(resolution_type)s,
|
|
ended_at = NOW()
|
|
WHERE thread_id = %(thread_id)s
|
|
"""
|
|
|
|
|
|
@runtime_checkable
|
|
class ConversationTrackerProtocol(Protocol):
|
|
"""Protocol for tracking conversation lifecycle and metrics."""
|
|
|
|
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
|
"""Create conversation row if it does not already exist."""
|
|
...
|
|
|
|
async def record_turn(
|
|
self,
|
|
pool: AsyncConnectionPool,
|
|
thread_id: str,
|
|
agent_name: str | None,
|
|
tokens: int,
|
|
cost: float,
|
|
) -> None:
|
|
"""Increment turn count and update aggregated metrics."""
|
|
...
|
|
|
|
async def resolve(
|
|
self,
|
|
pool: AsyncConnectionPool,
|
|
thread_id: str,
|
|
resolution_type: str,
|
|
) -> None:
|
|
"""Mark conversation as resolved with a resolution type."""
|
|
...
|
|
|
|
|
|
class NoOpConversationTracker:
|
|
"""No-op implementation -- used in tests or when DB is unavailable."""
|
|
|
|
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
|
"""Do nothing."""
|
|
|
|
async def record_turn(
|
|
self,
|
|
pool: AsyncConnectionPool,
|
|
thread_id: str,
|
|
agent_name: str | None,
|
|
tokens: int,
|
|
cost: float,
|
|
) -> None:
|
|
"""Do nothing."""
|
|
|
|
async def resolve(
|
|
self,
|
|
pool: AsyncConnectionPool,
|
|
thread_id: str,
|
|
resolution_type: str,
|
|
) -> None:
|
|
"""Do nothing."""
|
|
|
|
|
|
class PostgresConversationTracker:
|
|
"""Postgres-backed conversation tracker."""
|
|
|
|
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
|
"""Insert conversation row; do nothing if already exists (ON CONFLICT DO NOTHING)."""
|
|
params = {"thread_id": thread_id}
|
|
async with pool.connection() as conn:
|
|
await conn.execute(_ENSURE_SQL, params)
|
|
|
|
async def record_turn(
|
|
self,
|
|
pool: AsyncConnectionPool,
|
|
thread_id: str,
|
|
agent_name: str | None,
|
|
tokens: int,
|
|
cost: float,
|
|
) -> None:
|
|
"""Increment turn count, append agent if new, update token/cost totals."""
|
|
params = {
|
|
"thread_id": thread_id,
|
|
"agent_name": agent_name,
|
|
"tokens": tokens,
|
|
"cost": cost,
|
|
}
|
|
async with pool.connection() as conn:
|
|
await conn.execute(_RECORD_TURN_SQL, params)
|
|
|
|
async def resolve(
|
|
self,
|
|
pool: AsyncConnectionPool,
|
|
thread_id: str,
|
|
resolution_type: str,
|
|
) -> None:
|
|
"""Set resolution_type and ended_at on the conversation row."""
|
|
params = {
|
|
"thread_id": thread_id,
|
|
"resolution_type": resolution_type,
|
|
}
|
|
async with pool.connection() as conn:
|
|
await conn.execute(_RESOLVE_SQL, params)
|