Compare commits

..

8 Commits

Author SHA1 Message Date
Yaojia Wang
f0699436c5 refactor: engineering improvements -- API versioning, structured logging, Alembic, error standardization, test coverage
- API versioning: all REST endpoints prefixed with /api/v1/
- Structured logging: replaced stdlib logging with structlog (console/JSON modes)
- Alembic migrations: versioned DB schema with initial migration
- Error standardization: global exception handlers for consistent envelope format
- Interrupt cleanup: asyncio background task for expired interrupt removal
- Integration tests: +30 tests (analytics, replay, openapi, error, session APIs)
- Frontend tests: +57 tests (all components, pages, useWebSocket hook)
- Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
2026-04-06 23:19:29 +02:00
Yaojia Wang
af53111928 refactor: fix architectural issues across frontend and backend
Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
2026-04-06 15:59:14 +02:00
Yaojia Wang
b8654aa31f feat: upgrade LangGraph to 1.x and migrate deprecated APIs
- Bump langgraph from 0.4 to 1.0+, langgraph-supervisor from 0.0.12 to 0.0.30+
- Bump langchain-core, langchain-anthropic, langchain-openai to 1.x
- Add langchain>=1.0 dependency for new create_agent location
- Migrate create_react_agent -> create_agent (prompt -> system_prompt)
- Fix create_supervisor positional arg to named agents= parameter
- Replace AsyncMock checkpointer with InMemorySaver in tests (v1 type validation)
- Update version references in README, ARCHITECTURE, eng-review-plan
2026-04-06 14:51:51 +02:00
Yaojia Wang
be5c84bcff docs: reconcile README and docs with actual codebase
README:
- Remove duplicated agent config, safety, security sections (covered by docs)
- Add ux_design_system.md and safety.py to project structure and doc links
- Convert doc links to descriptive table

agent-config-guide.md:
- Replace fictional agents/tools with real ones from agents.yaml
- Remove nonexistent 'admin' permission level (only read/write)
- Fix template names (e-commerce, saas, fintech)
- List all available built-in tools

openapi-import-guide.md:
- Fix /result -> /classifications endpoint
- Fix POST /approve to show no request body
- Remove nonexistent 'admin' access type
- Update response examples to match actual API

demo-script.md:
- Fix agent names (order_agent -> order_lookup)
- Replace fictional refund scenario with real lookup+cancel flow

ARCHITECTURE.md:
- Fix langgraph-supervisor version (v1.1 -> 0.0.12+)

docker-compose.yml:
- Expose postgres on port 5433 for local dev
2026-04-06 13:55:45 +02:00
Yaojia Wang
19fc9f3289 test: close coverage gaps and add frontend test infrastructure
Backend (516 tests, 94% coverage):
- Add azure_openai endpoint/deployment validation tests (config.py -> 100%)
- Add _total_conversations and _avg_turns direct tests (queries.py -> 100%)
- Add transformer edge cases: list content, string checkpoint, invalid JSON,
  malformed message graceful skip (transformer.py -> 93%)
- Add safety combined status_code+error_message interaction tests
- Fix ambiguous 200/422 assertion to strict 422
- Add E2E pagination shape assertions (total, page, per_page, row count)
- Fix ReplayPool mock to respect LIMIT/OFFSET params

Frontend (23 tests, vitest + happy-dom + @testing-library/react):
- Add vitest infrastructure with happy-dom environment
- Add api.ts tests: success, HTTP error, success=false, URL encoding
- Add DashboardPage tests: loading, data, error, empty states
- Add ReplayListPage tests: loading, empty, data, error, status badge classes
- Add ReplayPage tests: loading, steps, empty, error states
2026-04-06 13:32:10 +02:00
Yaojia Wang
036e12349d refactor: formalize safety rules, extract shared styles, reconcile docs (P2)
- Add backend/app/safety.py with explicit confirmation policy, multi-intent
  semantics, and MCP error taxonomy with retry classification
- Add 26 unit tests for safety module (confirmation rules, error taxonomy)
- Extract repeated inline styles into shared CSS classes in index.css
  (section-card, stat-label, status-badge, data-table, empty/error-state,
  pagination-bar)
- Refactor DashboardPage, ReplayListPage, ReplayPage to use shared classes
- Update README: add missing API endpoints, document safety/confirmation rules
- Use proper HTML entities for arrow/dash characters to fix encoding glitches
2026-04-05 23:10:50 +02:00
Yaojia Wang
e0931daece feat: wire frontend pages to live APIs and standardize response contracts (P1)
- Backend: Add COUNT query and paginated response shape to conversations endpoint
  Returns { conversations: [...], total, page, per_page } instead of flat array
- Frontend: Replace mock data in DashboardPage with fetchAnalytics() API calls
- Frontend: Replace mock data in ReplayListPage with fetchConversations() API calls
- Frontend: Replace mock data in ReplayPage with fetchReplay() API calls
- Add proper loading, empty, and error states to all three pages
- Align ConversationSummary type with actual DB columns (created_at, status)
- Update unit and E2E tests for new paginated conversation response shape
- Add fetchone() to FakeCursor for COUNT query support in E2E tests
2026-04-05 23:06:00 +02:00
Yaojia Wang
e55ec42ae5 fix: restore green builds and align frontend-backend contracts (P0)
- Isolate Settings tests from .env and process env leakage
- Fix analytics metadata test to unwrap psycopg Json wrapper
- Remove unused state variables causing frontend build failures
- Fix ReviewPage to use /classifications endpoint instead of nonexistent /result
- Normalize ReviewPage status enums (failed not error) and access_type values
- Align api.ts types with backend response shapes (ReplayPage, AnalyticsData, AgentUsage)
2026-04-05 23:00:39 +02:00
96 changed files with 6166 additions and 1032 deletions

View File

@@ -26,6 +26,10 @@ WEBHOOK_URL=
SESSION_TTL_MINUTES=30 SESSION_TTL_MINUTES=30
INTERRUPT_TTL_MINUTES=30 INTERRUPT_TTL_MINUTES=30
# Optional: API key for admin endpoints (analytics, replay, openapi, websocket)
# Leave empty to disable authentication (dev mode)
ADMIN_API_KEY=
# Optional: load a named agent template instead of agents.yaml # Optional: load a named agent template instead of agents.yaml
# Available templates: ecommerce, saas, generic # Available templates: ecommerce, saas, generic
TEMPLATE_NAME= TEMPLATE_NAME=

View File

@@ -30,7 +30,7 @@ pytest --cov=app --cov-report=term-missing
# - If any test fails, fix it before starting the new phase # - If any test fails, fix it before starting the new phase
# 3. Create checkpoint to snapshot the starting state # 3. Create checkpoint to snapshot the starting state
/everything-claude-code:checkpoint create "phase-name" /ecc:checkpoint create "phase-name"
# 4. Create the phase branch # 4. Create the phase branch
git checkout main git checkout main
@@ -50,25 +50,32 @@ git checkout -b phase-{N}/{short-description}
3. Identify all tasks, acceptance criteria, and dependencies for this phase 3. Identify all tasks, acceptance criteria, and dependencies for this phase
4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5) 4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5)
### Step 2: Develop Using Orchestrate Skill ### Step 2: Develop Using ECC Skills
Route to the correct orchestration mode based on work type: Route to the correct skill based on work type:
| Work Type | Skill Command | | Work Type | Skill Command | What It Does |
|-----------|---------------| |-----------|---------------|--------------|
| New feature | `/everything-claude-code:orchestrate feature` | | New feature | `/ecc:feature-dev <desc>` | Discovery -> Exploration -> Architecture -> TDD -> Review -> Summary |
| Bug fix | `/everything-claude-code:orchestrate bugfix` | | Bug fix | `/ecc:tdd` then `/ecc:code-review` | RED -> GREEN -> REFACTOR cycle, then review |
| Refactor | `/everything-claude-code:orchestrate refactor` | | Refactor | `/ecc:plan` then `/ecc:tdd` then `/ecc:code-review` | Plan refactor scope, TDD, review |
| Security-sensitive | Add `/ecc:security-review` after code-review | Auth, payments, user input, external APIs |
| Final verification | `/ecc:verify` | Build + tests + lint + coverage + security scan |
ALWAYS use the appropriate orchestrate skill. Never develop without it. A single phase may contain mixed work types. Call the appropriate skill **per sub-task**:
A single phase may contain mixed work types (e.g., Phase 5 has feature + bugfix + refactor). Call the orchestrate skill **per sub-task** with the matching mode. Example:
``` ```
# Within Phase 5: # Within a phase:
/everything-claude-code:orchestrate feature # for demo script /ecc:feature-dev "demo script" # for new features
/everything-claude-code:orchestrate bugfix # for error handling fixes /ecc:tdd # for bug fixes (write failing test, then fix)
/everything-claude-code:orchestrate refactor # for code cleanup /ecc:plan "consolidate error handling" # for refactors (plan first, then TDD)
```
For full multi-phase autonomous execution, use GSD:
```
/gsd:autonomous # execute all remaining phases
/gsd:execute-phase 6 # execute a specific phase
``` ```
### Step 3: Module Independence (CRITICAL) ### Step 3: Module Independence (CRITICAL)
@@ -171,10 +178,10 @@ After all development and testing, run verification in this exact order:
``` ```
# 1. Run the verification skill -- must pass # 1. Run the verification skill -- must pass
/everything-claude-code:verify /ecc:verify
# 2. Verify the checkpoint -- validates all phase deliverables # 2. Verify the checkpoint -- validates all phase deliverables
/everything-claude-code:checkpoint verify "phase-name" /ecc:checkpoint verify "phase-name"
``` ```
The checkpoint verify validates: The checkpoint verify validates:
@@ -222,11 +229,11 @@ git push origin main --tags
All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy. All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy.
A checkpoint includes: A checkpoint includes:
- `/everything-claude-code:checkpoint create` at phase start - `/ecc:checkpoint create` at phase start
- `/everything-claude-code:checkpoint verify` at phase end - `/ecc:checkpoint verify` at phase end
- All tests passing (80%+ coverage) - All tests passing (80%+ coverage)
- Phase dev log written and linked - Phase dev log written and linked
- `/everything-claude-code:verify` passed - `/ecc:verify` passed
- Git tag `checkpoint/phase-{N}` created - Git tag `checkpoint/phase-{N}` created
- Phase marked COMPLETED in four locations - Phase marked COMPLETED in four locations
- Branch merged to main - Branch merged to main
@@ -264,7 +271,7 @@ This project inherits from `~/.claude/rules/`. CLAUDE.md only contains project-s
### Hooks (ECC Plugin -- No Custom Hooks) ### Hooks (ECC Plugin -- No Custom Hooks)
All hooks come from the ECC plugin (`everything-claude-code`). No project-level hooks in `.claude/settings.local.json`. All hooks come from the ECC plugin (`ecc`). No project-level hooks in `.claude/settings.local.json`.
| ECC Hook | Type | What It Does | | ECC Hook | Type | What It Does |
|----------|------|-------------| |----------|------|-------------|
@@ -290,7 +297,7 @@ Controlled by `ECC_HOOK_PROFILE` env var in `~/.claude/settings.json` (currently
- Architecture doc: `docs/ARCHITECTURE.md` - Architecture doc: `docs/ARCHITECTURE.md`
- Phase dev logs: `docs/phases/phase-{N}-dev-log.md` - Phase dev logs: `docs/phases/phase-{N}-dev-log.md`
- Test command: `pytest --cov=app --cov-report=term-missing` - Test command: `pytest --cov=app --cov-report=term-missing`
- **Phase start:** `/everything-claude-code:checkpoint create "phase-name"` - **Phase start:** `/ecc:checkpoint create "phase-name"`
- **Phase end:** `/everything-claude-code:checkpoint verify "phase-name"` - **Phase end:** `/ecc:checkpoint verify "phase-name"`
- Verify command: `/everything-claude-code:verify` - Verify command: `/ecc:verify`
- Orchestrate: `/everything-claude-code:orchestrate {feature|bugfix|refactor}` - Orchestrate: `/ecc:orchestrate {feature|bugfix|refactor}`

123
README.md
View File

@@ -45,10 +45,11 @@ User message -> Chat UI -> FastAPI WebSocket -> LangGraph Supervisor -> Speciali
| Component | Technology | | Component | Technology |
|-----------|-----------| |-----------|-----------|
| Backend | Python 3.11+, FastAPI | | Backend | Python 3.11+, FastAPI |
| Agent orchestration | LangGraph v1.1 | | Agent orchestration | LangGraph 1.x, langgraph-supervisor |
| Session state | PostgreSQL + langgraph-checkpoint-postgres | | Session state | PostgreSQL 16 + langgraph-checkpoint-postgres |
| LLM | Claude Sonnet 4.6 (configurable: OpenAI, Google) | | LLM | Claude Sonnet 4.6 (configurable: OpenAI, Azure OpenAI, Google) |
| Frontend | React 19, TypeScript, Vite | | Frontend | React 19, TypeScript, Vite |
| Testing | pytest (backend), vitest + happy-dom (frontend) |
| Deployment | Docker Compose | | Deployment | Docker Compose |
## Quick Start ## Quick Start
@@ -59,7 +60,11 @@ cd smart-support
# Configure your LLM API key # Configure your LLM API key
cp .env.example .env cp .env.example .env
# Edit .env: set ANTHROPIC_API_KEY (or OPENAI_API_KEY) # Edit .env: set LLM_PROVIDER and the corresponding API key
# anthropic -> ANTHROPIC_API_KEY
# openai -> OPENAI_API_KEY
# azure_openai -> AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT + AZURE_OPENAI_DEPLOYMENT
# google -> GOOGLE_API_KEY
# Start all services # Start all services
docker compose up -d docker compose up -d
@@ -68,6 +73,25 @@ docker compose up -d
open http://localhost open http://localhost
``` ```
### Local Development
```bash
# Start only PostgreSQL via Docker (exposed on port 5433)
docker compose up postgres -d
# Backend (in one terminal)
cd backend
pip install -e ".[dev]"
uvicorn app.main:app --host 0.0.0.0 --port 8001 --reload
# Frontend (in another terminal)
cd frontend
npm install
npm run dev # http://localhost:5173 (proxies /api and /ws to :8001)
```
See [Deployment Guide](docs/deployment.md) for production setup, HTTPS, and scaling.
## Project Structure ## Project Structure
``` ```
@@ -75,17 +99,20 @@ smart-support/
├── backend/ ├── backend/
│ ├── app/ │ ├── app/
│ │ ├── main.py # FastAPI + WebSocket entry point │ │ ├── main.py # FastAPI + WebSocket entry point
│ │ ├── graph.py # LangGraph Supervisor │ │ ├── graph.py # LangGraph Supervisor construction
│ │ ├── graph_context.py # Typed wrapper for graph + classifier + registry
│ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting │ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting
│ │ ├── conversation_tracker.py # Conversation lifecycle tracking │ │ ├── ws_context.py # WebSocket dependency bundle
│ │ ├── auth.py # API key authentication middleware
│ │ ├── api_utils.py # Shared API response helpers
│ │ ├── safety.py # Confirmation rules + MCP error taxonomy
│ │ ├── agents/ # Agent definitions and tools │ │ ├── agents/ # Agent definitions and tools
│ │ ├── registry.py # YAML agent registry loader │ │ ├── registry.py # YAML agent registry loader
│ │ ├── openapi/ # OpenAPI parser and review API │ │ ├── openapi/ # OpenAPI parser, classifier, and review API
│ │ ├── replay/ # Conversation replay API │ │ ├── replay/ # Conversation replay API
│ │ ── analytics/ # Analytics queries and API │ │ ── analytics/ # Analytics queries and API
│ │ └── tools/ # Error handling and retry utilities
│ ├── agents.yaml # Agent registry configuration │ ├── agents.yaml # Agent registry configuration
│ ├── fixtures/ # Demo data and sample OpenAPI spec │ ├── templates/ # Vertical industry templates
│ └── tests/ # Unit, integration, and E2E tests │ └── tests/ # Unit, integration, and E2E tests
├── frontend/ ├── frontend/
│ ├── src/ │ ├── src/
@@ -99,67 +126,49 @@ smart-support/
└── .env.example # Environment variable template └── .env.example # Environment variable template
``` ```
## Agent Configuration
```yaml
# agents.yaml
agents:
- name: order_agent
description: "Handles order status, tracking, and cancellations."
permission: write
tools:
- get_order_status
- cancel_order
personality:
tone: friendly
greeting: "I can help with your order. What is the order number?"
escalation_message: "I'm escalating this to a human agent."
- name: general_agent
description: "Answers general questions and FAQs."
permission: read
tools:
- search_faq
```
## API Endpoints ## API Endpoints
| Method | Path | Description | | Method | Path | Auth | Description |
|--------|------|-------------| |--------|------|------|-------------|
| WS | `/ws` | Main WebSocket chat endpoint | | WS | `/ws` | Token | Main WebSocket chat endpoint (`?token=<key>`) |
| GET | `/api/health` | Health check | | GET | `/api/health` | No | Health check |
| GET | `/api/conversations` | List conversations | | GET | `/api/conversations` | API Key | List conversations (paginated) |
| GET | `/api/replay/{thread_id}` | Replay conversation | | GET | `/api/replay/{thread_id}` | API Key | Replay conversation steps (paginated) |
| GET | `/api/analytics` | Analytics summary | | GET | `/api/analytics` | API Key | Analytics summary (`?range=7d`) |
| POST | `/api/openapi/import` | Import OpenAPI spec | | POST | `/api/openapi/import` | API Key | Start OpenAPI import job |
| GET | `/api/openapi/jobs/{id}` | Check import job status | | GET | `/api/openapi/jobs/{id}` | API Key | Check import job status |
| GET | `/api/openapi/jobs/{id}/classifications` | API Key | Get endpoint classifications |
| PUT | `/api/openapi/jobs/{id}/classifications/{idx}` | API Key | Update a classification |
| POST | `/api/openapi/jobs/{id}/approve` | API Key | Approve and generate tools |
## Security Authentication is controlled by the `ADMIN_API_KEY` environment variable.
API Key endpoints require the `X-API-Key` header. When `ADMIN_API_KEY` is unset, auth is disabled.
- **SSRF protection** -- OpenAPI import blocks private IPs and metadata service URLs
- **Input validation** -- messages validated for size (32 KB), content length (10 KB), thread ID format
- **Rate limiting** -- 10 messages per 10 seconds per session
- **Audit trail** -- every tool call logged with agent, params, result, timestamp
- **Permission isolation** -- each agent only accesses its configured tools
- **Interrupt TTL** -- unanswered approval prompts expire after 30 minutes
## Running Tests ## Running Tests
```bash ```bash
# Backend (516 tests, 94% coverage)
cd backend cd backend
pytest --cov=app --cov-report=term-missing pytest --cov=app --cov-report=term-missing
# Frontend (23 tests, vitest + happy-dom)
cd frontend
npm test
``` ```
Coverage is enforced at 80%+. Backend coverage is enforced at 80%+.
## Documentation ## Documentation
- [Architecture](docs/ARCHITECTURE.md) -- System design and component diagram | Document | Description |
- [Development Plan](docs/DEVELOPMENT-PLAN.md) -- Phase breakdown and status |----------|-------------|
- [Agent Config Guide](docs/agent-config-guide.md) -- How to configure agents | [Architecture](docs/ARCHITECTURE.md) | System design, component diagram, data flow, ADRs |
- [OpenAPI Import Guide](docs/openapi-import-guide.md) -- Auto-discovery workflow | [Development Plan](docs/DEVELOPMENT-PLAN.md) | Phase breakdown, task checklists, and status |
- [Deployment Guide](docs/deployment.md) -- Docker and production deployment | [Agent Config Guide](docs/agent-config-guide.md) | agents.yaml format, fields, templates, routing logic |
- [Demo Script](docs/demo-script.md) -- Step-by-step live demo walkthrough | [OpenAPI Import Guide](docs/openapi-import-guide.md) | Auto-discovery workflow, REST API, SSRF protection |
| [Deployment Guide](docs/deployment.md) | Docker, local dev, production, HTTPS, backups, scaling |
| [Demo Script](docs/demo-script.md) | Step-by-step live demo walkthrough (5 scenes) |
| [UX Design System](docs/ux_design_system.md) | Color palette, typography, component patterns, CSS tokens |
## License ## License

149
backend/alembic.ini Normal file
View File

@@ -0,0 +1,149 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the tzdata library which can be installed by adding
# `alembic[tz]` to the pip requirements.
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url =
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
backend/alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

67
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,67 @@
"""Alembic environment configuration for smart-support."""
from __future__ import annotations
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# No SQLAlchemy ORM models -- we use raw DDL migrations
target_metadata = None
def _get_url() -> str:
"""Read DATABASE_URL from environment, falling back to alembic.ini."""
return os.environ.get("DATABASE_URL", "") or config.get_main_option(
"sqlalchemy.url", ""
)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
Configures the context with just a URL so that an Engine
is not required.
"""
url = _get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode with a live database connection."""
configuration = config.get_section(config.config_ini_section, {})
configuration["sqlalchemy.url"] = _get_url()
connectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,92 @@
"""Initial schema -- all application tables.
Revision ID: a1b2c3d4e5f6
Revises:
Create Date: 2026-04-06
"""
from __future__ import annotations
from alembic import op
revision: str = "a1b2c3d4e5f6"
down_revision: str | None = None
branch_labels: tuple[str, ...] | None = None
depends_on: tuple[str, ...] | None = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE IF NOT EXISTS conversations (
thread_id TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
total_tokens INTEGER NOT NULL DEFAULT 0,
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
status TEXT NOT NULL DEFAULT 'active'
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS active_interrupts (
interrupt_id TEXT PRIMARY KEY,
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
action TEXT NOT NULL,
params JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
resolved_at TIMESTAMPTZ,
resolution TEXT
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS sessions (
thread_id TEXT PRIMARY KEY,
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS analytics_events (
id BIGSERIAL PRIMARY KEY,
thread_id TEXT NOT NULL,
event_type TEXT NOT NULL,
agent_name TEXT,
tool_name TEXT,
tokens_used INTEGER NOT NULL DEFAULT 0,
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
duration_ms INTEGER,
success BOOLEAN,
error_message TEXT,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"""
)
# Migration columns added in Phase 4
op.execute(
"""
ALTER TABLE conversations
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ
"""
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS analytics_events")
op.execute("DROP TABLE IF EXISTS sessions")
op.execute("DROP TABLE IF EXISTS active_interrupts")
op.execute("DROP TABLE IF EXISTS conversations")

View File

@@ -4,16 +4,22 @@ from __future__ import annotations
import re import re
from dataclasses import asdict from dataclasses import asdict
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING
from fastapi import APIRouter, HTTPException, Query, Request from fastapi import APIRouter, Depends, HTTPException, Query, Request
from app.analytics.queries import get_analytics from app.analytics.queries import get_analytics
from app.api_utils import envelope
from app.auth import require_admin_api_key
if TYPE_CHECKING: if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
router = APIRouter(prefix="/api/analytics", tags=["analytics"]) router = APIRouter(
prefix="/api/v1/analytics",
tags=["analytics"],
dependencies=[Depends(require_admin_api_key)],
)
_RANGE_PATTERN = re.compile(r"^(\d+)d$") _RANGE_PATTERN = re.compile(r"^(\d+)d$")
_DEFAULT_RANGE = "7d" _DEFAULT_RANGE = "7d"
@@ -25,10 +31,6 @@ async def _get_pool(request: Request) -> AsyncConnectionPool:
return request.app.state.pool return request.app.state.pool
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
return {"success": success, "data": data, "error": error}
def _parse_range(range_str: str) -> int: def _parse_range(range_str: str) -> int:
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format.""" """Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
match = _RANGE_PATTERN.match(range_str) match = _RANGE_PATTERN.match(range_str)
@@ -55,4 +57,4 @@ async def analytics(
range_days = _parse_range(range) range_days = _parse_range(range)
pool = await _get_pool(request) pool = await _get_pool(request)
result = await get_analytics(pool, range_days=range_days) result = await get_analytics(pool, range_days=range_days)
return _envelope(asdict(result)) return envelope(asdict(result))

10
backend/app/api_utils.py Normal file
View File

@@ -0,0 +1,10 @@
"""Shared API response helpers."""
from __future__ import annotations
from typing import Any
def envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
"""Wrap API response data in a standard envelope format."""
return {"success": success, "data": data, "error": error}

72
backend/app/auth.py Normal file
View File

@@ -0,0 +1,72 @@
"""API key authentication for admin endpoints and WebSocket connections."""
from __future__ import annotations
import secrets
from typing import Annotated
import structlog
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
from fastapi.security import APIKeyHeader
logger = structlog.get_logger()
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
def _get_admin_api_key(request: Request) -> str:
"""Retrieve the configured admin API key from app settings.
Returns empty string if settings are not configured (test/dev mode).
"""
settings = getattr(request.app.state, "settings", None)
if settings is None:
return ""
key = getattr(settings, "admin_api_key", "")
return key if isinstance(key, str) else ""
async def require_admin_api_key(
request: Request,
api_key: Annotated[str | None, Depends(_API_KEY_HEADER)] = None,
) -> None:
"""Dependency that enforces API key authentication on admin endpoints.
Skips validation when no admin_api_key is configured (dev mode).
"""
expected = _get_admin_api_key(request)
if not expected:
return
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing X-API-Key header",
)
if not secrets.compare_digest(api_key, expected):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API key",
)
async def verify_ws_token(
ws: WebSocket,
token: str | None = Query(default=None),
) -> None:
"""Verify WebSocket connection token from query parameter.
Skips validation when no admin_api_key is configured (dev mode).
Usage: ws://host/ws?token=<api_key>
"""
settings = ws.app.state.settings
expected = settings.admin_api_key
if not expected:
return
if token is None or not secrets.compare_digest(token, expected):
await ws.close(code=4001, reason="Unauthorized")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing WebSocket token",
)

View File

@@ -32,6 +32,10 @@ class Settings(BaseSettings):
template_name: str = "" template_name: str = ""
log_format: str = "console" # "console" for dev, "json" for production
admin_api_key: str = ""
anthropic_api_key: str = "" anthropic_api_key: str = ""
openai_api_key: str = "" openai_api_key: str = ""
azure_openai_api_key: str = "" azure_openai_api_key: str = ""

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
@@ -51,6 +52,15 @@ CREATE TABLE IF NOT EXISTS analytics_events (
); );
""" """
_SESSIONS_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
thread_id TEXT PRIMARY KEY,
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_CONVERSATIONS_MIGRATION_DDL = """ _CONVERSATIONS_MIGRATION_DDL = """
ALTER TABLE conversations ALTER TABLE conversations
ADD COLUMN IF NOT EXISTS resolution_type TEXT, ADD COLUMN IF NOT EXISTS resolution_type TEXT,
@@ -79,10 +89,22 @@ async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
return checkpointer return checkpointer
def run_alembic_migrations(database_url: str) -> None:
"""Run Alembic migrations to head."""
from alembic.config import Config
from alembic import command
alembic_cfg = Config(str(Path(__file__).parent.parent / "alembic.ini"))
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
command.upgrade(alembic_cfg, "head")
async def setup_app_tables(pool: AsyncConnectionPool) -> None: async def setup_app_tables(pool: AsyncConnectionPool) -> None:
"""Create application-specific tables and apply migrations.""" """Create application-specific tables and apply migrations."""
async with pool.connection() as conn: async with pool.connection() as conn:
await conn.execute(_CONVERSATIONS_DDL) await conn.execute(_CONVERSATIONS_DDL)
await conn.execute(_INTERRUPTS_DDL) await conn.execute(_INTERRUPTS_DDL)
await conn.execute(_SESSIONS_DDL)
await conn.execute(_ANALYTICS_EVENTS_DDL) await conn.execute(_ANALYTICS_EVENTS_DDL)
await conn.execute(_CONVERSATIONS_MIGRATION_DDL) await conn.execute(_CONVERSATIONS_MIGRATION_DDL)

View File

@@ -3,14 +3,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Protocol from typing import Protocol
import httpx import httpx
import structlog
from pydantic import BaseModel from pydantic import BaseModel
logger = logging.getLogger(__name__) logger = structlog.get_logger()
class EscalationPayload(BaseModel, frozen=True): class EscalationPayload(BaseModel, frozen=True):

View File

@@ -2,23 +2,24 @@
from __future__ import annotations from __future__ import annotations
import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from langgraph.prebuilt import create_react_agent from langchain.agents import create_agent
from langgraph_supervisor import create_supervisor from langgraph_supervisor import create_supervisor
from app.agents import get_tools_by_names from app.agents import get_tools_by_names
from app.graph_context import GraphContext
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph.state import CompiledStateGraph
from app.intent import ClassificationResult, IntentClassifier from app.intent import IntentClassifier
from app.registry import AgentRegistry from app.registry import AgentRegistry
logger = logging.getLogger(__name__) import structlog
logger = structlog.get_logger()
SUPERVISOR_PROMPT = ( SUPERVISOR_PROMPT = (
"You are a customer support supervisor. " "You are a customer support supervisor. "
@@ -59,11 +60,11 @@ def build_agent_nodes(
f"Permission level: {agent_config.permission}." f"Permission level: {agent_config.permission}."
) )
agent_node = create_react_agent( agent_node = create_agent(
model=llm, model=llm,
tools=tools, tools=tools,
name=agent_config.name, name=agent_config.name,
prompt=system_prompt, system_prompt=system_prompt,
) )
agent_nodes.append(agent_node) agent_nodes.append(agent_node)
@@ -75,12 +76,11 @@ def build_graph(
llm: BaseChatModel, llm: BaseChatModel,
checkpointer: AsyncPostgresSaver, checkpointer: AsyncPostgresSaver,
intent_classifier: IntentClassifier | None = None, intent_classifier: IntentClassifier | None = None,
) -> CompiledStateGraph: ) -> GraphContext:
"""Build and compile the LangGraph supervisor graph. """Build and compile the LangGraph supervisor graph.
If an intent_classifier is provided, the supervisor prompt is enhanced Returns a GraphContext that bundles the compiled graph with its
with agent descriptions for better routing. The classifier is stored associated registry and intent classifier.
for use by the routing layer (ws_handler).
""" """
agent_nodes = build_agent_nodes(registry, llm) agent_nodes = build_agent_nodes(registry, llm)
agent_descriptions = _format_agent_descriptions(registry) agent_descriptions = _format_agent_descriptions(registry)
@@ -88,34 +88,16 @@ def build_graph(
prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions) prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions)
workflow = create_supervisor( workflow = create_supervisor(
agent_nodes, agents=agent_nodes,
model=llm, model=llm,
prompt=prompt, prompt=prompt,
output_mode="full_history", output_mode="full_history",
) )
graph = workflow.compile(checkpointer=checkpointer) compiled = workflow.compile(checkpointer=checkpointer)
# Attach classifier and registry to graph for use by ws_handler return GraphContext(
graph.intent_classifier = intent_classifier # type: ignore[attr-defined] graph=compiled,
graph.agent_registry = registry # type: ignore[attr-defined] registry=registry,
intent_classifier=intent_classifier,
return graph )
async def classify_intent(
graph: CompiledStateGraph,
message: str,
) -> ClassificationResult | None:
"""Classify user intent using the graph's attached classifier.
Returns None if no classifier is configured.
"""
classifier = getattr(graph, "intent_classifier", None)
registry = getattr(graph, "agent_registry", None)
if classifier is None or registry is None:
return None
agents = registry.list_agents()
return await classifier.classify(message, agents)

View File

@@ -0,0 +1,36 @@
"""GraphContext -- typed wrapper around the compiled graph and its dependencies."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from app.intent import ClassificationResult, IntentClassifier
from app.registry import AgentRegistry
@dataclass(frozen=True)
class GraphContext:
"""Bundles the compiled LangGraph graph with its associated services.
Replaces the previous pattern of monkey-patching attributes onto the
third-party CompiledStateGraph instance.
"""
graph: CompiledStateGraph
registry: AgentRegistry
intent_classifier: IntentClassifier | None = None
async def classify_intent(self, message: str) -> ClassificationResult | None:
"""Classify user intent using the attached classifier.
Returns None if no classifier is configured.
"""
if self.intent_classifier is None:
return None
agents = self.registry.list_agents()
return await self.intent_classifier.classify(message, agents)

View File

@@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Protocol from typing import TYPE_CHECKING, Protocol
from pydantic import BaseModel from pydantic import BaseModel
@@ -12,7 +11,9 @@ if TYPE_CHECKING:
from app.registry import AgentConfig from app.registry import AgentConfig
logger = logging.getLogger(__name__) import structlog
logger = structlog.get_logger()
CLASSIFICATION_PROMPT = ( CLASSIFICATION_PROMPT = (
"You are an intent classifier for a customer support system.\n" "You are an intent classifier for a customer support system.\n"

View File

@@ -1,10 +1,18 @@
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.""" """Interrupt TTL management -- tracks pending interrupts with auto-expiration.
Provides both in-memory (InterruptManager) and PostgreSQL-backed
(PgInterruptManager) implementations behind a common Protocol.
"""
from __future__ import annotations from __future__ import annotations
import time import time
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -28,8 +36,32 @@ class InterruptStatus:
record: InterruptRecord record: InterruptRecord
class InterruptManagerProtocol(Protocol):
"""Protocol for interrupt TTL management."""
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
def resolve(self, thread_id: str) -> None: ...
def has_pending(self, thread_id: str) -> bool: ...
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action."""
return {
"type": "interrupt_expired",
"thread_id": expired_record.thread_id,
"action": expired_record.action,
"message": (
f"The approval request for '{expired_record.action}' has expired "
f"after {expired_record.ttl_seconds // 60} minutes. "
f"Would you like to try again?"
),
}
class InterruptManager: class InterruptManager:
"""Manages interrupt TTL with auto-expiration. """In-memory interrupt manager for single-worker development.
Complements SessionManager -- this tracks interrupt-specific TTL Complements SessionManager -- this tracks interrupt-specific TTL
while SessionManager handles session-level TTL. while SessionManager handles session-level TTL.
@@ -62,11 +94,9 @@ class InterruptManager:
record = self._interrupts.get(thread_id) record = self._interrupts.get(thread_id)
if record is None: if record is None:
return None return None
elapsed = time.time() - record.created_at elapsed = time.time() - record.created_at
remaining = max(0.0, record.ttl_seconds - elapsed) remaining = max(0.0, record.ttl_seconds - elapsed)
is_expired = elapsed > record.ttl_seconds is_expired = elapsed > record.ttl_seconds
return InterruptStatus( return InterruptStatus(
is_expired=is_expired, is_expired=is_expired,
remaining_seconds=remaining, remaining_seconds=remaining,
@@ -84,28 +114,17 @@ class InterruptManager:
now = time.time() now = time.time()
expired: list[InterruptRecord] = [] expired: list[InterruptRecord] = []
active: dict[str, InterruptRecord] = {} active: dict[str, InterruptRecord] = {}
for thread_id, record in self._interrupts.items(): for thread_id, record in self._interrupts.items():
if now - record.created_at > record.ttl_seconds: if now - record.created_at > record.ttl_seconds:
expired.append(record) expired.append(record)
else: else:
active[thread_id] = record active[thread_id] = record
self._interrupts = active self._interrupts = active
return tuple(expired) return tuple(expired)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action.""" """Generate a WebSocket message prompting the user to retry an expired action."""
return { return _build_retry_prompt(expired_record)
"type": "interrupt_expired",
"thread_id": expired_record.thread_id,
"action": expired_record.action,
"message": (
f"The approval request for '{expired_record.action}' has expired "
f"after {expired_record.ttl_seconds // 60} minutes. "
f"Would you like to try again?"
),
}
def has_pending(self, thread_id: str) -> bool: def has_pending(self, thread_id: str) -> bool:
"""Check if a thread has a pending (non-expired) interrupt.""" """Check if a thread has a pending (non-expired) interrupt."""
@@ -113,3 +132,137 @@ class InterruptManager:
if status is None: if status is None:
return False return False
return not status.is_expired return not status.is_expired
# Alias for explicit naming
InMemoryInterruptManager = InterruptManager
class PgInterruptManager:
"""PostgreSQL-backed interrupt manager for multi-worker production.
Uses the existing active_interrupts table defined in db.py.
"""
def __init__(
self,
pool: AsyncConnectionPool,
ttl_seconds: int = 1800,
) -> None:
self._pool = pool
self._ttl_seconds = ttl_seconds
def register(
self,
thread_id: str,
action: str,
params: dict,
) -> InterruptRecord:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._register(thread_id, action, params)
)
async def _register(
self, thread_id: str, action: str, params: dict
) -> InterruptRecord:
import json
record = InterruptRecord(
interrupt_id=uuid.uuid4().hex,
thread_id=thread_id,
action=action,
params=dict(params),
created_at=time.time(),
ttl_seconds=self._ttl_seconds,
)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
DO UPDATE SET
interrupt_id = %(iid)s,
action = %(action)s,
params = %(params)s,
created_at = NOW(),
resolved_at = NULL
""",
{
"iid": record.interrupt_id,
"tid": thread_id,
"action": action,
"params": json.dumps(params),
},
)
return record
def check_status(self, thread_id: str) -> InterruptStatus | None:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._check_status(thread_id)
)
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
async with self._pool.connection() as conn:
cursor = await conn.execute(
"""
SELECT interrupt_id, action, params, created_at
FROM active_interrupts
WHERE thread_id = %(tid)s AND resolved_at IS NULL
ORDER BY created_at DESC LIMIT 1
""",
{"tid": thread_id},
)
row = await cursor.fetchone()
if row is None:
return None
created_at = row["created_at"].timestamp()
elapsed = time.time() - created_at
remaining = max(0.0, self._ttl_seconds - elapsed)
is_expired = elapsed > self._ttl_seconds
record = InterruptRecord(
interrupt_id=row["interrupt_id"],
thread_id=thread_id,
action=row["action"],
params=row["params"] if isinstance(row["params"], dict) else {},
created_at=created_at,
ttl_seconds=self._ttl_seconds,
)
return InterruptStatus(
is_expired=is_expired,
remaining_seconds=remaining,
record=record,
)
def resolve(self, thread_id: str) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
async def _resolve(self, thread_id: str) -> None:
async with self._pool.connection() as conn:
await conn.execute(
"""
UPDATE active_interrupts
SET resolved_at = NOW(), resolution = 'resolved'
WHERE thread_id = %(tid)s AND resolved_at IS NULL
""",
{"tid": thread_id},
)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
return _build_retry_prompt(expired_record)
def has_pending(self, thread_id: str) -> bool:
status = self.check_status(thread_id)
if status is None:
return False
return not status.is_expired

View File

@@ -0,0 +1,57 @@
"""Structured logging configuration using structlog."""
from __future__ import annotations
import logging
import sys
import structlog
def configure_logging(log_format: str = "console") -> None:
"""Configure structlog with stdlib integration.
Args:
log_format: "console" for human-readable dev output,
"json" for machine-parseable production output.
"""
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
]
if log_format == "json":
renderer: structlog.types.Processor = structlog.processors.JSONRenderer()
else:
renderer = structlog.dev.ConsoleRenderer()
structlog.configure(
processors=[
*shared_processors,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
renderer,
],
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)

View File

@@ -2,47 +2,78 @@
from __future__ import annotations from __future__ import annotations
import logging import asyncio
import contextlib
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from app.analytics.api import router as analytics_router from app.analytics.api import router as analytics_router
from app.analytics.event_recorder import PostgresAnalyticsRecorder from app.analytics.event_recorder import PostgresAnalyticsRecorder
from app.api_utils import envelope
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings from app.config import Settings
from app.conversation_tracker import PostgresConversationTracker from app.conversation_tracker import PostgresConversationTracker
from app.db import create_checkpointer, create_pool, setup_app_tables from app.db import create_checkpointer, create_pool, run_alembic_migrations
from app.escalation import NoOpEscalator, WebhookEscalator from app.escalation import NoOpEscalator, WebhookEscalator
from app.graph import build_graph from app.graph import build_graph
from app.intent import LLMIntentClassifier from app.intent import LLMIntentClassifier
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.llm import create_llm from app.llm import create_llm
from app.logging_config import configure_logging
from app.openapi.review_api import router as openapi_router from app.openapi.review_api import router as openapi_router
from app.registry import AgentRegistry from app.registry import AgentRegistry
from app.replay.api import router as replay_router from app.replay.api import router as replay_router
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
logger = logging.getLogger(__name__) import structlog
logger = structlog.get_logger()
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml" AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist" FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
async def _interrupt_cleanup_loop(
interrupt_manager: InterruptManager,
interval: int = 60,
) -> None:
"""Periodically remove expired interrupts in the background.
Runs until cancelled. Catches all exceptions to prevent the task
from dying unexpectedly.
"""
while True:
await asyncio.sleep(interval)
try:
expired = interrupt_manager.cleanup_expired()
if expired:
logger.info(
"Cleaned up %d expired interrupt(s)",
len(expired),
)
except Exception:
logger.exception("Error during interrupt cleanup")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
settings = Settings() settings = Settings()
configure_logging(settings.log_format)
pool = await create_pool(settings) pool = await create_pool(settings)
checkpointer = await create_checkpointer(pool) checkpointer = await create_checkpointer(pool)
await setup_app_tables(pool) run_alembic_migrations(settings.database_url)
# Load agents from template or default YAML # Load agents from template or default YAML
if settings.template_name: if settings.template_name:
@@ -52,7 +83,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
llm = create_llm(settings) llm = create_llm(settings)
intent_classifier = LLMIntentClassifier(llm) intent_classifier = LLMIntentClassifier(llm)
graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier) graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
session_manager = SessionManager( session_manager = SessionManager(
session_ttl_seconds=settings.session_ttl_minutes * 60, session_ttl_seconds=settings.session_ttl_minutes * 60,
@@ -71,7 +102,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
else: else:
escalator = NoOpEscalator() escalator = NoOpEscalator()
app.state.graph = graph app.state.graph_ctx = graph_ctx
app.state.session_manager = session_manager app.state.session_manager = session_manager
app.state.interrupt_manager = interrupt_manager app.state.interrupt_manager = interrupt_manager
app.state.escalator = escalator app.state.escalator = escalator
@@ -88,12 +119,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
settings.template_name or "(default)", settings.template_name or "(default)",
) )
cleanup_task = asyncio.create_task(
_interrupt_cleanup_loop(interrupt_manager),
)
yield yield
cleanup_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await cleanup_task
await pool.close() await pool.close()
_VERSION = "0.5.0" _VERSION = "0.6.0"
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan) app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
@@ -102,35 +141,72 @@ app.include_router(replay_router)
app.include_router(analytics_router) app.include_router(analytics_router)
@app.get("/api/health") @app.exception_handler(HTTPException)
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap HTTPException in standard envelope format."""
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap validation errors in standard envelope format."""
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Catch-all handler -- never leak stack traces."""
logger.exception("Unhandled exception: %s", exc)
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
@app.get("/api/v1/health")
def health_check() -> dict: def health_check() -> dict:
"""Health check endpoint for load balancers and monitoring.""" """Health check endpoint for load balancers and monitoring."""
return {"status": "ok", "version": _VERSION} return {"status": "ok", "version": _VERSION}
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket) -> None: async def websocket_endpoint(
await ws.accept() ws: WebSocket,
graph = app.state.graph token: str | None = Query(default=None),
session_manager = app.state.session_manager ) -> None:
interrupt_manager = app.state.interrupt_manager
settings = app.state.settings settings = app.state.settings
# Verify WebSocket token when admin_api_key is configured
if settings.admin_api_key:
import secrets as _secrets
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
await ws.close(code=4001, reason="Unauthorized")
return
await ws.accept()
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model) callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
analytics_recorder = app.state.analytics_recorder ws_ctx = WebSocketContext(
conversation_tracker = app.state.conversation_tracker graph_ctx=app.state.graph_ctx,
pool = app.state.pool session_manager=app.state.session_manager,
callback_handler=callback_handler,
interrupt_manager=app.state.interrupt_manager,
analytics_recorder=app.state.analytics_recorder,
conversation_tracker=app.state.conversation_tracker,
pool=app.state.pool,
)
try: try:
while True: while True:
raw_data = await ws.receive_text() raw_data = await ws.receive_text()
await dispatch_message( await dispatch_message(ws, ws_ctx, raw_data)
ws, graph, session_manager, callback_handler, raw_data,
interrupt_manager=interrupt_manager,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
pool=pool,
)
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info("WebSocket client disconnected") logger.info("WebSocket client disconnected")

View File

@@ -8,13 +8,14 @@ classifier and an LLM-backed classifier with heuristic fallback.
from __future__ import annotations from __future__ import annotations
import json import json
import logging
import re import re
from typing import Protocol from typing import Protocol
import structlog
from app.openapi.models import ClassificationResult, EndpointInfo from app.openapi.models import ClassificationResult, EndpointInfo
logger = logging.getLogger(__name__) logger = structlog.get_logger()
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) _WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) _INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})

View File

@@ -6,10 +6,11 @@ Each stage updates the job status and calls the on_progress callback.
from __future__ import annotations from __future__ import annotations
import logging
from collections.abc import Callable from collections.abc import Callable
from dataclasses import replace from dataclasses import replace
import structlog
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
from app.openapi.fetcher import fetch_spec from app.openapi.fetcher import fetch_spec
from app.openapi.models import ImportJob from app.openapi.models import ImportJob
@@ -17,7 +18,7 @@ from app.openapi.parser import parse_endpoints
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
from app.openapi.validator import validate_spec from app.openapi.validator import validate_spec
logger = logging.getLogger(__name__) logger = structlog.get_logger()
ProgressCallback = Callable[[str, ImportJob], None] | None ProgressCallback = Callable[[str, ImportJob], None] | None

View File

@@ -10,20 +10,26 @@ Exposes endpoints for:
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging
import re import re
import uuid import uuid
from typing import Literal from typing import Literal
from fastapi import APIRouter, BackgroundTasks, HTTPException import structlog
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from app.auth import require_admin_api_key
from app.openapi.generator import generate_agent_yaml, generate_tool_code
from app.openapi.importer import ImportOrchestrator from app.openapi.importer import ImportOrchestrator
from app.openapi.models import ClassificationResult, ImportJob from app.openapi.models import ClassificationResult, ImportJob
logger = logging.getLogger(__name__) logger = structlog.get_logger()
router = APIRouter(prefix="/api/openapi", tags=["openapi"]) router = APIRouter(
prefix="/api/v1/openapi",
tags=["openapi"],
dependencies=[Depends(require_admin_api_key)],
)
# In-memory store: job_id -> job dict, guarded by async lock # In-memory store: job_id -> job dict, guarded by async lock
_job_store: dict[str, dict] = {} _job_store: dict[str, dict] = {}
@@ -235,11 +241,42 @@ async def update_classification(
@router.post("/jobs/{job_id}/approve") @router.post("/jobs/{job_id}/approve")
async def approve_job(job_id: str) -> dict: async def approve_job(job_id: str) -> dict:
"""Approve a job's classifications and trigger tool generation.""" """Approve a job's classifications and trigger tool generation.
Generates Python tool code for each classified endpoint and
produces an agent YAML configuration snippet.
"""
job = _job_store.get(job_id) job = _job_store.get(job_id)
if job is None: if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
updated_job = {**job, "status": "approved"} classifications: list[ClassificationResult] = job.get("classifications", [])
if not classifications:
raise HTTPException(
status_code=400,
detail="No classifications to approve. Import must complete first.",
)
base_url = job["spec_url"].rsplit("/", 1)[0]
generated_tools = []
for clf in classifications:
tool = generate_tool_code(clf, base_url)
generated_tools.append({
"function_name": tool.function_name,
"agent_group": clf.agent_group,
"code": tool.code,
})
agent_yaml = generate_agent_yaml(tuple(classifications), base_url)
updated_job = {
**job,
"status": "approved",
"generated_tools": generated_tools,
"agent_yaml": agent_yaml,
}
_job_store[job_id] = updated_job _job_store[job_id] = updated_job
return _job_to_response(updated_job)
response = _job_to_response(updated_job)
response["generated_tools_count"] = len(generated_tools)
return response

View File

@@ -3,16 +3,27 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import TYPE_CHECKING, Annotated, Any from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, HTTPException, Query, Request from fastapi import APIRouter, Depends, HTTPException, Query, Request
from app.api_utils import envelope
from app.auth import require_admin_api_key
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$") _THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
if TYPE_CHECKING: if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
router = APIRouter(prefix="/api", tags=["replay"]) router = APIRouter(
prefix="/api/v1",
tags=["replay"],
dependencies=[Depends(require_admin_api_key)],
)
_COUNT_CONVERSATIONS_SQL = """
SELECT COUNT(*) FROM conversations
"""
_LIST_CONVERSATIONS_SQL = """ _LIST_CONVERSATIONS_SQL = """
SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd
@@ -34,10 +45,6 @@ async def get_pool(request: Request) -> AsyncConnectionPool:
return request.app.state.pool return request.app.state.pool
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
return {"success": success, "data": data, "error": error}
@router.get("/conversations") @router.get("/conversations")
async def list_conversations( async def list_conversations(
request: Request, request: Request,
@@ -48,13 +55,22 @@ async def list_conversations(
pool = await get_pool(request) pool = await get_pool(request)
offset = (page - 1) * per_page offset = (page - 1) * per_page
async with pool.connection() as conn: async with pool.connection() as conn:
count_cursor = await conn.execute(_COUNT_CONVERSATIONS_SQL)
count_row = await count_cursor.fetchone()
total = count_row[0] if count_row else 0
cursor = await conn.execute( cursor = await conn.execute(
_LIST_CONVERSATIONS_SQL, _LIST_CONVERSATIONS_SQL,
{"limit": per_page, "offset": offset}, {"limit": per_page, "offset": offset},
) )
rows = await cursor.fetchall() rows = await cursor.fetchall()
return _envelope([dict(row) for row in rows]) return envelope({
"conversations": [dict(row) for row in rows],
"total": total,
"page": page,
"per_page": per_page,
})
@router.get("/replay/{thread_id}") @router.get("/replay/{thread_id}")
@@ -106,4 +122,4 @@ async def get_replay(
for s in page_steps for s in page_steps
], ],
} }
return _envelope(data) return envelope(data)

View File

@@ -2,11 +2,11 @@
from __future__ import annotations from __future__ import annotations
import logging import structlog
from app.replay.models import ReplayStep, StepType from app.replay.models import ReplayStep, StepType
logger = logging.getLogger(__name__) logger = structlog.get_logger()
_EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z" _EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z"

131
backend/app/safety.py Normal file
View File

@@ -0,0 +1,131 @@
"""Safety policy for destructive-action confirmation rules.
This module makes the confirmation rules explicit and auditable. Every tool
call passes through ``requires_confirmation`` before execution to decide
whether human-in-the-loop approval is needed.
Policy summary
--------------
- ``read`` actions: execute immediately, no confirmation required.
- ``write`` actions: require human approval via interrupt gate.
- OpenAPI-imported endpoints: use ``needs_interrupt`` from classification.
- If both the agent permission AND the endpoint classification agree
the action is read-only, it executes without confirmation.
Multi-intent semantics
----------------------
When a user message contains multiple intents (e.g. "cancel my order and
apply a refund"), the supervisor routes them sequentially. Each action is
evaluated independently:
- If a write action is blocked by an interrupt, subsequent actions in the
same message are paused until the interrupt is resolved.
- Read actions that follow a blocked write are also paused (sequential,
not best-effort) to preserve causal ordering.
- If an interrupt is rejected, the remaining actions are skipped and the
agent informs the user.
MCP error taxonomy
------------------
Tool execution errors are classified into categories for retry decisions:
- ``transient``: network timeouts, rate limits, 5xx -- retryable up to 3 times.
- ``validation``: bad parameters, 4xx -- not retryable, report to user.
- ``auth``: 401/403 -- not retryable, escalate.
- ``unknown``: unclassified -- not retryable, log and escalate.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
@dataclass(frozen=True)
class ConfirmationPolicy:
"""Result of evaluating whether an action needs confirmation."""
requires_confirmation: bool
reason: str
def requires_confirmation(
*,
agent_permission: Literal["read", "write"],
needs_interrupt: bool | None = None,
) -> ConfirmationPolicy:
"""Determine whether an action requires human confirmation.
Parameters
----------
agent_permission:
The permission level of the agent executing the action.
needs_interrupt:
Override from OpenAPI classification. When ``None``, the decision
is based solely on ``agent_permission``.
"""
if needs_interrupt is not None:
if needs_interrupt:
return ConfirmationPolicy(
requires_confirmation=True,
reason="Endpoint classified as requiring human approval",
)
return ConfirmationPolicy(
requires_confirmation=False,
reason="Endpoint classified as safe (no interrupt needed)",
)
if agent_permission == "write":
return ConfirmationPolicy(
requires_confirmation=True,
reason="Write-permission agent actions require confirmation",
)
return ConfirmationPolicy(
requires_confirmation=False,
reason="Read-only agent actions execute immediately",
)
# --- MCP Error Taxonomy ---
MCP_ERROR_CATEGORY = Literal["transient", "validation", "auth", "unknown"]
_TRANSIENT_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504})
_AUTH_STATUS_CODES = frozenset({401, 403})
_MAX_RETRIES = 3
def classify_mcp_error(
*,
status_code: int | None = None,
error_message: str = "",
) -> MCP_ERROR_CATEGORY:
"""Classify an MCP tool error for retry decisions."""
if status_code is not None:
if status_code in _TRANSIENT_STATUS_CODES:
return "transient"
if status_code in _AUTH_STATUS_CODES:
return "auth"
if 400 <= status_code < 500:
return "validation"
lower_msg = error_message.lower()
if any(kw in lower_msg for kw in ("timeout", "timed out", "rate limit")):
return "transient"
if any(kw in lower_msg for kw in ("unauthorized", "forbidden")):
return "auth"
if any(kw in lower_msg for kw in ("invalid", "missing", "bad request")):
return "validation"
return "unknown"
def is_retryable(category: MCP_ERROR_CATEGORY) -> bool:
"""Return whether a given error category is retryable."""
return category == "transient"
def max_retries() -> int:
"""Maximum retry attempts for transient errors."""
return _MAX_RETRIES

View File

@@ -1,9 +1,18 @@
"""Session TTL management with sliding window and interrupt extension.""" """Session TTL management with sliding window and interrupt extension.
Provides both in-memory (SessionManager) and PostgreSQL-backed
(PgSessionManager) implementations behind a common Protocol.
"""
from __future__ import annotations from __future__ import annotations
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -13,8 +22,19 @@ class SessionState:
has_pending_interrupt: bool has_pending_interrupt: bool
class SessionManagerProtocol(Protocol):
"""Protocol for session TTL management."""
def touch(self, thread_id: str) -> SessionState: ...
def is_expired(self, thread_id: str) -> bool: ...
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
def get_state(self, thread_id: str) -> SessionState | None: ...
def remove(self, thread_id: str) -> None: ...
class SessionManager: class SessionManager:
"""Manages session TTL with sliding window and interrupt extensions. """In-memory session manager for single-worker development.
- Each message resets the TTL (sliding window). - Each message resets the TTL (sliding window).
- A pending interrupt suspends expiration until resolved. - A pending interrupt suspends expiration until resolved.
@@ -40,10 +60,8 @@ class SessionManager:
state = self._sessions.get(thread_id) state = self._sessions.get(thread_id)
if state is None: if state is None:
return True return True
if state.has_pending_interrupt: if state.has_pending_interrupt:
return False return False
elapsed = time.time() - state.last_activity elapsed = time.time() - state.last_activity
return elapsed > self._session_ttl return elapsed > self._session_ttl
@@ -52,7 +70,6 @@ class SessionManager:
existing = self._sessions.get(thread_id) existing = self._sessions.get(thread_id)
if existing is None: if existing is None:
return self.touch(thread_id) return self.touch(thread_id)
new_state = SessionState( new_state = SessionState(
thread_id=thread_id, thread_id=thread_id,
last_activity=existing.last_activity, last_activity=existing.last_activity,
@@ -76,3 +93,120 @@ class SessionManager:
def remove(self, thread_id: str) -> None: def remove(self, thread_id: str) -> None:
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id} self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
# Alias for explicit naming
InMemorySessionManager = SessionManager
class PgSessionManager:
"""PostgreSQL-backed session manager for multi-worker production."""
def __init__(
self,
pool: AsyncConnectionPool,
session_ttl_seconds: int = 1800,
) -> None:
self._pool = pool
self._session_ttl = session_ttl_seconds
def touch(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
async def _touch(self, thread_id: str) -> SessionState:
now = datetime.now(timezone.utc)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
VALUES (%(tid)s, %(now)s, FALSE)
ON CONFLICT (thread_id) DO UPDATE
SET last_activity = %(now)s
""",
{"tid": thread_id, "now": now},
)
return SessionState(
thread_id=thread_id,
last_activity=now.timestamp(),
has_pending_interrupt=False,
)
def is_expired(self, thread_id: str) -> bool:
state = self.get_state(thread_id)
if state is None:
return True
if state.has_pending_interrupt:
return False
elapsed = time.time() - state.last_activity
return elapsed > self._session_ttl
def extend_for_interrupt(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._set_interrupt(thread_id, True)
)
def resolve_interrupt(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._set_interrupt(thread_id, False)
)
async def _set_interrupt(
self, thread_id: str, has_interrupt: bool
) -> SessionState:
now = datetime.now(timezone.utc)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
VALUES (%(tid)s, %(now)s, %(interrupt)s)
ON CONFLICT (thread_id) DO UPDATE
SET last_activity = %(now)s,
has_pending_interrupt = %(interrupt)s
""",
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
)
return SessionState(
thread_id=thread_id,
last_activity=now.timestamp(),
has_pending_interrupt=has_interrupt,
)
def get_state(self, thread_id: str) -> SessionState | None:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._get_state(thread_id)
)
async def _get_state(self, thread_id: str) -> SessionState | None:
async with self._pool.connection() as conn:
cursor = await conn.execute(
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
{"tid": thread_id},
)
row = await cursor.fetchone()
if row is None:
return None
return SessionState(
thread_id=thread_id,
last_activity=row["last_activity"].timestamp(),
has_pending_interrupt=row["has_pending_interrupt"],
)
def remove(self, thread_id: str) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
async def _remove(self, thread_id: str) -> None:
async with self._pool.connection() as conn:
await conn.execute(
"DELETE FROM sessions WHERE thread_id = %(tid)s",
{"tid": thread_id},
)

30
backend/app/ws_context.py Normal file
View File

@@ -0,0 +1,30 @@
"""WebSocketContext -- bundles all dependencies needed by dispatch_message."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from app.analytics.event_recorder import AnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler
from app.conversation_tracker import ConversationTrackerProtocol
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
@dataclass(frozen=True)
class WebSocketContext:
"""All dependencies required for WebSocket message processing.
Replaces the previous 9-parameter function signature in dispatch_message.
"""
graph_ctx: GraphContext
session_manager: SessionManager
callback_handler: TokenUsageCallbackHandler
interrupt_manager: InterruptManager | None = None
analytics_recorder: AnalyticsRecorder | None = None
conversation_tracker: ConversationTrackerProtocol | None = None
pool: Any = None

View File

@@ -3,28 +3,26 @@
from __future__ import annotations from __future__ import annotations
import json import json
import logging
import re import re
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.types import Command from langgraph.types import Command
from app.graph import classify_intent
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi import WebSocket from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.analytics.event_recorder import AnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.conversation_tracker import ConversationTrackerProtocol from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
logger = logging.getLogger(__name__) import structlog
logger = structlog.get_logger()
MAX_MESSAGE_SIZE = 32_768 # 32 KB MAX_MESSAGE_SIZE = 32_768 # 32 KB
MAX_CONTENT_LENGTH = 10_000 # characters MAX_CONTENT_LENGTH = 10_000 # characters
@@ -46,7 +44,7 @@ def _evict_stale_threads(cutoff: float) -> None:
async def handle_user_message( async def handle_user_message(
ws: WebSocket, ws: WebSocket,
graph: CompiledStateGraph, ctx: GraphContext,
session_manager: SessionManager, session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler, callback_handler: TokenUsageCallbackHandler,
thread_id: str, thread_id: str,
@@ -54,8 +52,6 @@ async def handle_user_message(
interrupt_manager: InterruptManager | None = None, interrupt_manager: InterruptManager | None = None,
) -> None: ) -> None:
"""Process a user message through the graph and stream results back.""" """Process a user message through the graph and stream results back."""
# Touch first so new sessions are created before expiry check.
# For existing sessions, touch resets the sliding window.
existing = session_manager.get_state(thread_id) existing = session_manager.get_state(thread_id)
if existing is not None and session_manager.is_expired(thread_id): if existing is not None and session_manager.is_expired(thread_id):
msg = "Session expired. Please start a new conversation." msg = "Session expired. Please start a new conversation."
@@ -64,8 +60,7 @@ async def handle_user_message(
session_manager.touch(thread_id) session_manager.touch(thread_id)
# Run intent classification if available (for logging/future multi-intent) classification = await ctx.classify_intent(content)
classification = await classify_intent(graph, content)
if classification is not None: if classification is not None:
logger.info( logger.info(
"Intent classification for thread %s: ambiguous=%s, intents=%s", "Intent classification for thread %s: ambiguous=%s, intents=%s",
@@ -74,7 +69,6 @@ async def handle_user_message(
[i.agent_name for i in classification.intents], [i.agent_name for i in classification.intents],
) )
# If ambiguous, send clarification and return
if classification.is_ambiguous and classification.clarification_question: if classification.is_ambiguous and classification.clarification_question:
await _send_json( await _send_json(
ws, ws,
@@ -89,7 +83,6 @@ async def handle_user_message(
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
# If multi-intent detected, add routing hint to the message
if classification and len(classification.intents) > 1: if classification and len(classification.intents) > 1:
agent_names = [i.agent_name for i in classification.intents] agent_names = [i.agent_name for i in classification.intents]
hint = ( hint = (
@@ -101,7 +94,7 @@ async def handle_user_message(
input_msg = {"messages": [HumanMessage(content=content)]} input_msg = {"messages": [HumanMessage(content=content)]}
try: try:
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"): async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"):
msg_chunk, metadata = chunk msg_chunk, metadata = chunk
node = metadata.get("langgraph_node", "") node = metadata.get("langgraph_node", "")
@@ -126,12 +119,11 @@ async def handle_user_message(
}, },
) )
state = await graph.aget_state(config) state = await ctx.graph.aget_state(config)
if _has_interrupt(state): if _has_interrupt(state):
interrupt_data = _extract_interrupt(state) interrupt_data = _extract_interrupt(state)
session_manager.extend_for_interrupt(thread_id) session_manager.extend_for_interrupt(thread_id)
# Register interrupt with TTL tracking
if interrupt_manager is not None: if interrupt_manager is not None:
interrupt_manager.register( interrupt_manager.register(
thread_id=thread_id, thread_id=thread_id,
@@ -158,7 +150,7 @@ async def handle_user_message(
async def handle_interrupt_response( async def handle_interrupt_response(
ws: WebSocket, ws: WebSocket,
graph: CompiledStateGraph, ctx: GraphContext,
session_manager: SessionManager, session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler, callback_handler: TokenUsageCallbackHandler,
thread_id: str, thread_id: str,
@@ -166,7 +158,6 @@ async def handle_interrupt_response(
interrupt_manager: InterruptManager | None = None, interrupt_manager: InterruptManager | None = None,
) -> None: ) -> None:
"""Resume graph execution after interrupt approval/rejection.""" """Resume graph execution after interrupt approval/rejection."""
# Check interrupt TTL before resuming
if interrupt_manager is not None: if interrupt_manager is not None:
status = interrupt_manager.check_status(thread_id) status = interrupt_manager.check_status(thread_id)
if status is not None and status.is_expired: if status is not None and status.is_expired:
@@ -184,7 +175,7 @@ async def handle_interrupt_response(
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]} config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
try: try:
async for chunk in graph.astream( async for chunk in ctx.graph.astream(
Command(resume=approved), Command(resume=approved),
config=config, config=config,
stream_mode="messages", stream_mode="messages",
@@ -212,14 +203,8 @@ async def handle_interrupt_response(
async def dispatch_message( async def dispatch_message(
ws: WebSocket, ws: WebSocket,
graph: CompiledStateGraph, ctx: WebSocketContext,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
raw_data: str, raw_data: str,
interrupt_manager: InterruptManager | None = None,
analytics_recorder: AnalyticsRecorder | None = None,
conversation_tracker: ConversationTrackerProtocol | None = None,
pool: Any = None,
) -> None: ) -> None:
"""Parse and route an incoming WebSocket message.""" """Parse and route an incoming WebSocket message."""
if len(raw_data) > MAX_MESSAGE_SIZE: if len(raw_data) > MAX_MESSAGE_SIZE:
@@ -268,14 +253,15 @@ async def dispatch_message(
_thread_timestamps[thread_id] = [*recent, now] _thread_timestamps[thread_id] = [*recent, now]
await handle_user_message( await handle_user_message(
ws, graph, session_manager, callback_handler, thread_id, content, ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
interrupt_manager=interrupt_manager, thread_id, content,
interrupt_manager=ctx.interrupt_manager,
) )
await _fire_and_forget_tracking( await _fire_and_forget_tracking(
thread_id=thread_id, thread_id=thread_id,
pool=pool, pool=ctx.pool,
analytics_recorder=analytics_recorder, analytics_recorder=ctx.analytics_recorder,
conversation_tracker=conversation_tracker, conversation_tracker=ctx.conversation_tracker,
agent_name=None, agent_name=None,
tokens=0, tokens=0,
cost=0.0, cost=0.0,
@@ -284,8 +270,9 @@ async def dispatch_message(
elif msg_type == "interrupt_response": elif msg_type == "interrupt_response":
approved = data.get("approved", False) approved = data.get("approved", False)
await handle_interrupt_response( await handle_interrupt_response(
ws, graph, session_manager, callback_handler, thread_id, approved, ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
interrupt_manager=interrupt_manager, thread_id, approved,
interrupt_manager=ctx.interrupt_manager,
) )
else: else:
@@ -294,9 +281,9 @@ async def dispatch_message(
async def _fire_and_forget_tracking( async def _fire_and_forget_tracking(
thread_id: str, thread_id: str,
pool: Any, pool: object,
analytics_recorder: Any | None, analytics_recorder: object | None,
conversation_tracker: Any | None, conversation_tracker: object | None,
agent_name: str | None, agent_name: str | None,
tokens: int, tokens: int,
cost: float, cost: float,

View File

@@ -6,12 +6,13 @@ requires-python = ">=3.11"
dependencies = [ dependencies = [
"fastapi>=0.115,<1.0", "fastapi>=0.115,<1.0",
"uvicorn[standard]>=0.34,<1.0", "uvicorn[standard]>=0.34,<1.0",
"langgraph>=0.4,<1.0", "langgraph>=1.0,<2.0",
"langgraph-supervisor>=0.0.12,<1.0", "langgraph-supervisor>=0.0.30,<1.0",
"langgraph-checkpoint-postgres>=3.0,<4.0", "langgraph-checkpoint-postgres>=3.0,<4.0",
"langchain-core>=0.3,<1.0", "langchain>=1.0,<2.0",
"langchain-anthropic>=0.3,<2.0", "langchain-core>=1.0,<2.0",
"langchain-openai>=0.3,<1.0", "langchain-anthropic>=1.0,<2.0",
"langchain-openai>=1.0,<2.0",
"langchain-google-genai>=2.1,<3.0", "langchain-google-genai>=2.1,<3.0",
"psycopg[binary,pool]>=3.2,<4.0", "psycopg[binary,pool]>=3.2,<4.0",
"pydantic>=2.10,<3.0", "pydantic>=2.10,<3.0",
@@ -20,6 +21,8 @@ dependencies = [
"python-dotenv>=1.0,<2.0", "python-dotenv>=1.0,<2.0",
"httpx>=0.28,<1.0", "httpx>=0.28,<1.0",
"openapi-spec-validator>=0.7,<1.0", "openapi-spec-validator>=0.7,<1.0",
"alembic>=1.13,<2.0",
"structlog>=24.0,<26.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -13,10 +13,12 @@ from httpx import ASGITransport, AsyncClient
from app.analytics.api import router as analytics_router from app.analytics.api import router as analytics_router
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.openapi.review_api import _job_store, router as openapi_router from app.openapi.review_api import _job_store, router as openapi_router
from app.replay.api import router as replay_router from app.replay.api import router as replay_router
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
@@ -74,8 +76,6 @@ def make_graph(
) -> MagicMock: ) -> MagicMock:
"""Build a mock LangGraph CompiledStateGraph.""" """Build a mock LangGraph CompiledStateGraph."""
g = MagicMock() g = MagicMock()
g.intent_classifier = None
g.agent_registry = None
if state is None: if state is None:
state = make_state() state = make_state()
@@ -93,6 +93,14 @@ def make_graph(
return g return g
def make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
"""Build a GraphContext wrapping a mock graph."""
g = graph or make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Fake database pool # Fake database pool
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -107,6 +115,9 @@ class FakeCursor:
async def fetchall(self) -> list[dict]: async def fetchall(self) -> list[dict]:
return self._rows return self._rows
async def fetchone(self) -> tuple | dict | None:
return self._rows[0] if self._rows else None
class FakeConnection: class FakeConnection:
"""Fake async connection that returns a FakeCursor.""" """Fake async connection that returns a FakeCursor."""
@@ -145,6 +156,7 @@ def create_e2e_app(
) -> FastAPI: ) -> FastAPI:
"""Create a FastAPI app wired with mocked dependencies for E2E testing.""" """Create a FastAPI app wired with mocked dependencies for E2E testing."""
g = graph or make_graph() g = graph or make_graph()
graph_ctx = make_graph_ctx(g)
p = pool or FakePool() p = pool or FakePool()
sm = SessionManager(session_ttl_seconds=session_ttl) sm = SessionManager(session_ttl_seconds=session_ttl)
im = InterruptManager(ttl_seconds=interrupt_ttl) im = InterruptManager(ttl_seconds=interrupt_ttl)
@@ -154,7 +166,7 @@ def create_e2e_app(
app.include_router(replay_router) app.include_router(replay_router)
app.include_router(analytics_router) app.include_router(analytics_router)
app.state.graph = g app.state.graph_ctx = graph_ctx
app.state.session_manager = sm app.state.session_manager = sm
app.state.interrupt_manager = im app.state.interrupt_manager = im
app.state.pool = p app.state.pool = p
@@ -162,7 +174,7 @@ def create_e2e_app(
app.state.analytics_recorder = AsyncMock() app.state.analytics_recorder = AsyncMock()
app.state.conversation_tracker = AsyncMock() app.state.conversation_tracker = AsyncMock()
@app.get("/api/health") @app.get("/api/v1/health")
def health_check() -> dict: def health_check() -> dict:
return {"status": "ok", "version": "test"} return {"status": "ok", "version": "test"}
@@ -172,17 +184,16 @@ def create_e2e_app(
try: try:
while True: while True:
raw_data = await ws.receive_text() raw_data = await ws.receive_text()
await dispatch_message( ws_ctx = WebSocketContext(
ws, graph_ctx=app.state.graph_ctx,
app.state.graph, session_manager=app.state.session_manager,
app.state.session_manager, callback_handler=TokenUsageCallbackHandler(model_name="test-model"),
TokenUsageCallbackHandler(model_name="test-model"),
raw_data,
interrupt_manager=app.state.interrupt_manager, interrupt_manager=app.state.interrupt_manager,
analytics_recorder=app.state.analytics_recorder, analytics_recorder=app.state.analytics_recorder,
conversation_tracker=app.state.conversation_tracker, conversation_tracker=app.state.conversation_tracker,
pool=app.state.pool, pool=app.state.pool,
) )
await dispatch_message(ws, ws_ctx, raw_data)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass

View File

@@ -341,7 +341,7 @@ class TestChatEdgeCases:
def test_health_endpoint(self) -> None: def test_health_endpoint(self) -> None:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/health") resp = client.get("/api/v1/health")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["status"] == "ok" assert resp.json()["status"] == "ok"

View File

@@ -62,7 +62,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client: with TestClient(app) as client:
# Step 1: Start import job # Step 1: Start import job
resp = client.post( resp = client.post(
"/api/openapi/import", "/api/v1/openapi/import",
json={"url": "https://api.example.com/openapi.json"}, json={"url": "https://api.example.com/openapi.json"},
) )
assert resp.status_code == 202 assert resp.status_code == 202
@@ -71,7 +71,7 @@ class TestFlow5OpenAPIImport:
job_id = body["job_id"] job_id = body["job_id"]
# Step 2: Check job status (still pending since background task hasn't run) # Step 2: Check job status (still pending since background task hasn't run)
resp = client.get(f"/api/openapi/jobs/{job_id}") resp = client.get(f"/api/v1/openapi/jobs/{job_id}")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["job_id"] == job_id assert resp.json()["job_id"] == job_id
@@ -99,7 +99,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client: with TestClient(app) as client:
# Step 1: Get classifications # Step 1: Get classifications
resp = client.get(f"/api/openapi/jobs/{job_id}/classifications") resp = client.get(f"/api/v1/openapi/jobs/{job_id}/classifications")
assert resp.status_code == 200 assert resp.status_code == 200
classifications = resp.json() classifications = resp.json()
assert len(classifications) == 2 assert len(classifications) == 2
@@ -118,7 +118,7 @@ class TestFlow5OpenAPIImport:
# Step 2: Update a classification # Step 2: Update a classification
resp = client.put( resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/0", f"/api/v1/openapi/jobs/{job_id}/classifications/0",
json={ json={
"access_type": "write", "access_type": "write",
"needs_interrupt": True, "needs_interrupt": True,
@@ -132,7 +132,7 @@ class TestFlow5OpenAPIImport:
assert updated["agent_group"] == "order_actions" assert updated["agent_group"] == "order_actions"
# Step 3: Approve the job # Step 3: Approve the job
resp = client.post(f"/api/openapi/jobs/{job_id}/approve") resp = client.post(f"/api/v1/openapi/jobs/{job_id}/approve")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["status"] == "approved" assert resp.json()["status"] == "approved"
@@ -140,14 +140,14 @@ class TestFlow5OpenAPIImport:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/openapi/jobs/nonexistent") resp = client.get("/api/v1/openapi/jobs/nonexistent")
assert resp.status_code == 404 assert resp.status_code == 404
def test_import_invalid_url_returns_422(self) -> None: def test_import_invalid_url_returns_422(self) -> None:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.post("/api/openapi/import", json={"url": "not-a-url"}) resp = client.post("/api/v1/openapi/import", json={"url": "not-a-url"})
assert resp.status_code == 422 assert resp.status_code == 422
def test_classification_index_out_of_range(self) -> None: def test_classification_index_out_of_range(self) -> None:
@@ -166,7 +166,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client: with TestClient(app) as client:
resp = client.put( resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/99", f"/api/v1/openapi/jobs/{job_id}/classifications/99",
json={ json={
"access_type": "read", "access_type": "read",
"needs_interrupt": False, "needs_interrupt": False,
@@ -191,7 +191,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client: with TestClient(app) as client:
resp = client.put( resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/0", f"/api/v1/openapi/jobs/{job_id}/classifications/0",
json={ json={
"access_type": "read", "access_type": "read",
"needs_interrupt": False, "needs_interrupt": False,

View File

@@ -44,8 +44,16 @@ class ReplayPool(FakePool):
async def execute(self, query: str, params=None): async def execute(self, query: str, params=None):
from tests.e2e.conftest import FakeCursor from tests.e2e.conftest import FakeCursor
if "COUNT" in query and "conversations" in query:
return FakeCursor([(len(self._convos),)])
if "conversations" in query and "SELECT" in query: if "conversations" in query and "SELECT" in query:
return FakeCursor(self._convos) # Respect LIMIT/OFFSET from params if provided
rows = self._convos
if params:
offset = params.get("offset", 0)
limit = params.get("limit", len(rows))
rows = rows[offset : offset + limit]
return FakeCursor(rows)
if "checkpoints" in query: if "checkpoints" in query:
return FakeCursor(self._checkpoints) return FakeCursor(self._checkpoints)
# Analytics queries # Analytics queries
@@ -90,13 +98,15 @@ class TestFlow6ReplayConversation:
app = create_e2e_app(pool=pool) app = create_e2e_app(pool=pool)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
assert len(body["data"]) == 2 data = body["data"]
assert body["data"][0]["thread_id"] == "conv-001" assert len(data["conversations"]) == 2
assert body["data"][1]["thread_id"] == "conv-002" assert data["conversations"][0]["thread_id"] == "conv-001"
assert data["conversations"][1]["thread_id"] == "conv-002"
assert data["total"] == 2
def test_list_conversations_pagination(self) -> None: def test_list_conversations_pagination(self) -> None:
conversations = [ conversations = [
@@ -114,17 +124,22 @@ class TestFlow6ReplayConversation:
app = create_e2e_app(pool=pool) app = create_e2e_app(pool=pool)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations", params={"page": 1, "per_page": 2}) resp = client.get("/api/v1/conversations", params={"page": 1, "per_page": 2})
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
data = body["data"]
assert data["total"] == 5
assert data["page"] == 1
assert data["per_page"] == 2
assert len(data["conversations"]) == 2
def test_replay_thread_not_found(self) -> None: def test_replay_thread_not_found(self) -> None:
pool = ReplayPool(checkpoints=[]) pool = ReplayPool(checkpoints=[])
app = create_e2e_app(pool=pool) app = create_e2e_app(pool=pool)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/nonexistent-thread") resp = client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404 assert resp.status_code == 404
def test_replay_invalid_thread_id_format(self) -> None: def test_replay_invalid_thread_id_format(self) -> None:
@@ -132,7 +147,7 @@ class TestFlow6ReplayConversation:
with TestClient(app) as client: with TestClient(app) as client:
# Thread ID with special chars fails regex validation # Thread ID with special chars fails regex validation
resp = client.get("/api/replay/invalid%20thread%21%40") resp = client.get("/api/v1/replay/invalid%20thread%21%40")
assert resp.status_code == 400 assert resp.status_code == 400
@@ -143,21 +158,21 @@ class TestAnalyticsDashboard:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "invalid"}) resp = client.get("/api/v1/analytics", params={"range": "invalid"})
assert resp.status_code == 400 assert resp.status_code == 400
def test_analytics_range_too_large(self) -> None: def test_analytics_range_too_large(self) -> None:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "999d"}) resp = client.get("/api/v1/analytics", params={"range": "999d"})
assert resp.status_code == 400 assert resp.status_code == 400
def test_analytics_range_zero_rejected(self) -> None: def test_analytics_range_zero_rejected(self) -> None:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "0d"}) resp = client.get("/api/v1/analytics", params={"range": "0d"})
assert resp.status_code == 400 assert resp.status_code == 400
@@ -201,14 +216,15 @@ class TestFullUserJourney:
assert any(m["type"] == "message_complete" for m in messages) assert any(m["type"] == "message_complete" for m in messages)
# Step 2: Check conversations endpoint # Step 2: Check conversations endpoint
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
assert any( assert any(
c["thread_id"] == "e2e-journey-1" for c in body["data"] c["thread_id"] == "e2e-journey-1"
for c in body["data"]["conversations"]
) )
# Step 3: Health check still works # Step 3: Health check still works
resp = client.get("/api/health") resp = client.get("/api/v1/health")
assert resp.status_code == 200 assert resp.status_code == 200

View File

@@ -0,0 +1,183 @@
"""Integration tests for the /api/v1/analytics endpoint.
Tests the full API layer (routing, parameter validation, serialization,
error handling) with a mocked database pool.
"""
from __future__ import annotations
from dataclasses import asdict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from app.analytics.models import AnalyticsResult, InterruptStats
pytestmark = pytest.mark.integration
_SAMPLE_RESULT = AnalyticsResult(
range="7d",
total_conversations=42,
resolution_rate=0.85,
escalation_rate=0.05,
avg_turns_per_conversation=3.2,
avg_cost_per_conversation_usd=0.012,
agent_usage=(),
interrupt_stats=InterruptStats(total=10, approved=7, rejected=2, expired=1),
)
def _build_app():
"""Build a minimal FastAPI app with the analytics router and mocked deps."""
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.analytics.api import router as analytics_router
from app.api_utils import envelope
test_app = FastAPI()
test_app.include_router(analytics_router)
@test_app.exception_handler(Exception)
async def _catch_all(request, exc):
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
from fastapi import HTTPException
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
# No admin_api_key set -> auth is skipped
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = MagicMock()
return test_app
class TestAnalyticsValidRange:
"""Test analytics endpoint with valid range parameters."""
async def test_valid_range_7d_returns_envelope(self) -> None:
"""GET /api/v1/analytics?range=7d returns success envelope with data."""
test_app = _build_app()
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=_SAMPLE_RESULT,
):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "7d"})
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["error"] is None
assert body["data"]["total_conversations"] == 42
assert body["data"]["resolution_rate"] == 0.85
async def test_default_range_returns_success(self) -> None:
"""GET /api/v1/analytics with no range param defaults to 7d."""
test_app = _build_app()
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=_SAMPLE_RESULT,
) as mock_get:
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics")
assert resp.status_code == 200
# Verify default range of 7 days was passed
mock_get.assert_called_once()
call_args = mock_get.call_args
assert call_args[1].get("range_days", call_args[0][1] if len(call_args[0]) > 1 else None) in (7, None) or call_args[0][1] == 7
async def test_large_range_365d_works(self) -> None:
"""GET /api/v1/analytics?range=365d is accepted (max boundary)."""
test_app = _build_app()
result = AnalyticsResult(
range="365d",
total_conversations=1000,
resolution_rate=0.9,
escalation_rate=0.02,
avg_turns_per_conversation=4.0,
avg_cost_per_conversation_usd=0.01,
agent_usage=(),
interrupt_stats=InterruptStats(),
)
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=result,
):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "365d"})
assert resp.status_code == 200
assert resp.json()["success"] is True
class TestAnalyticsInvalidRange:
"""Test analytics endpoint with invalid range parameters."""
async def test_invalid_range_format_returns_400(self) -> None:
"""GET /api/v1/analytics?range=abc returns 400 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "abc"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert "Invalid range format" in body["error"]
async def test_zero_day_range_returns_400(self) -> None:
"""GET /api/v1/analytics?range=0d returns 400 because 0 is below minimum."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "0d"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "between 1 and 365" in body["error"]
async def test_range_exceeding_max_returns_400(self) -> None:
"""GET /api/v1/analytics?range=999d returns 400 because it exceeds 365."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "999d"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "between 1 and 365" in body["error"]

View File

@@ -0,0 +1,128 @@
"""Integration tests for global error handling and envelope format consistency.
Tests that all error responses from the FastAPI app conform to the
standard envelope: {"success": false, "data": null, "error": "..."}.
"""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _build_app():
"""Build the actual FastAPI app with exception handlers but mocked state."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.analytics.api import router as analytics_router
from app.api_utils import envelope
from app.replay.api import router as replay_router
test_app = FastAPI()
test_app.include_router(analytics_router)
test_app.include_router(replay_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@test_app.exception_handler(Exception)
async def _catch_all(request, exc):
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
@test_app.get("/api/v1/health")
def health_check():
return {"status": "ok", "version": "0.6.0"}
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = MagicMock()
return test_app
class TestEnvelopeFormat:
"""Tests that error responses consistently follow envelope format."""
async def test_http_400_produces_envelope(self) -> None:
"""A 400 error returns standard envelope with success=false."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "invalid"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
assert len(body["error"]) > 0
async def test_validation_error_produces_422_envelope(self) -> None:
"""Invalid query param type returns 422 with envelope format."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
# page must be >= 1; passing 0 triggers validation error
resp = await client.get("/api/v1/conversations", params={"page": 0})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
async def test_all_error_fields_present(self) -> None:
"""Error envelope contains exactly success, data, and error keys."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "bad"})
body = resp.json()
assert set(body.keys()) == {"success", "data", "error"}
async def test_health_endpoint_returns_200(self) -> None:
"""Health check returns 200 with status ok."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/health")
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "ok"
assert "version" in body
async def test_unknown_endpoint_returns_404(self) -> None:
"""Requesting a non-existent path returns 404."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/nonexistent-path")
# FastAPI returns 404 for unknown routes; may or may not be wrapped
assert resp.status_code == 404

View File

@@ -0,0 +1,164 @@
"""Integration tests for /api/v1/openapi/ endpoints.
Tests the full API layer for the OpenAPI import review workflow,
including job creation, status retrieval, classification updates,
and approval triggering.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _build_app():
"""Build a minimal FastAPI app with the openapi router and mocked deps."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.api_utils import envelope
from app.openapi.review_api import router as openapi_router
test_app = FastAPI()
test_app.include_router(openapi_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
test_app.state.settings = MagicMock(admin_api_key="")
return test_app
@pytest.fixture(autouse=True)
def _clear_job_store():
"""Clear the in-memory job store between tests."""
from app.openapi.review_api import _job_store
_job_store.clear()
yield
_job_store.clear()
class TestImportEndpoint:
"""Tests for POST /api/v1/openapi/import."""
async def test_import_returns_202_with_job_id(self) -> None:
"""Starting an import returns 202 with a job_id."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/openapi/import",
json={"url": "https://example.com/api/spec.json"},
)
assert resp.status_code == 202
body = resp.json()
assert "job_id" in body
assert body["status"] == "pending"
assert body["spec_url"] == "https://example.com/api/spec.json"
async def test_import_invalid_url_returns_422(self) -> None:
"""POST with invalid URL (no http/https) returns 422."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/openapi/import",
json={"url": "ftp://example.com/spec.json"},
)
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
class TestJobStatusEndpoint:
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
async def test_get_existing_job_returns_status(self) -> None:
"""Retrieving an existing job returns its status."""
from app.openapi.review_api import _job_store
_job_store["test-job-1"] = {
"job_id": "test-job-1",
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 5,
"classified_count": 5,
"error_message": None,
"classifications": [],
}
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/openapi/jobs/test-job-1")
assert resp.status_code == 200
body = resp.json()
assert body["job_id"] == "test-job-1"
assert body["status"] == "done"
assert body["total_endpoints"] == 5
async def test_get_unknown_job_returns_404(self) -> None:
"""Retrieving a non-existent job returns 404 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/openapi/jobs/unknown-id-999")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert "not found" in body["error"].lower()
class TestApproveEndpoint:
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
async def test_approve_with_no_classifications_returns_400(self) -> None:
"""Approving a job with no classifications returns 400."""
from app.openapi.review_api import _job_store
_job_store["empty-job"] = {
"job_id": "empty-job",
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 0,
"classified_count": 0,
"error_message": None,
"classifications": [],
}
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post("/api/v1/openapi/jobs/empty-job/approve")
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "no classifications" in body["error"].lower()

View File

@@ -20,10 +20,12 @@ import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig, AgentRegistry from app.registry import AgentConfig, AgentRegistry
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates" TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
@@ -128,10 +130,8 @@ class TestCheckpoint1OrderQueryRouting:
mock_classifier.classify = AsyncMock(return_value=ClassificationResult( mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
)) ))
graph.intent_classifier = mock_classifier
mock_registry = MagicMock() mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=()) mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
# Graph streams order_lookup response # Graph streams order_lookup response
graph.astream = MagicMock(return_value=AsyncIterHelper([ graph.astream = MagicMock(return_value=AsyncIterHelper([
@@ -140,14 +140,21 @@ class TestCheckpoint1OrderQueryRouting:
])) ]))
graph.aget_state = AsyncMock(return_value=_state()) graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"] tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
assert any(m["tool"] == "get_order_status" for m in tool_msgs) assert any(m["tool"] == "get_order_status" for m in tool_msgs)
@@ -201,25 +208,30 @@ class TestCheckpoint2MultiIntentSequential:
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
), ),
)) ))
graph.intent_classifier = mock_classifier
mock_registry = MagicMock() mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=()) mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state()) graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({ raw = json.dumps({
"type": "message", "type": "message",
"thread_id": "t1", "thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣", "content": "取消订单 1042 并给我一个 10% 折扣",
}) })
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
# Verify the graph was called with the routing hint in the message # Verify the graph was called with the routing hint in the message
call_args = graph.astream.call_args call_args = graph.astream.call_args
@@ -267,21 +279,26 @@ class TestCheckpoint3AmbiguousClarification:
"Could you please provide more details about what you need help with?" "Could you please provide more details about what you need help with?"
), ),
)) ))
graph.intent_classifier = mock_classifier
mock_registry = MagicMock() mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=()) mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state()) graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
clarifications = [m for m in ws.sent if m["type"] == "clarification"] clarifications = [m for m in ws.sent if m["type"] == "clarification"]
assert len(clarifications) == 1 assert len(clarifications) == 1
@@ -303,20 +320,26 @@ class TestCheckpoint4InterruptTTLAutoCancel:
async def test_30min_expired_interrupt_auto_cancels(self) -> None: async def test_30min_expired_interrupt_auto_cancels(self) -> None:
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
graph = MagicMock() graph = MagicMock()
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=st) graph.aget_state = AsyncMock(return_value=st)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager(ttl_seconds=1800) # 30 minutes im = InterruptManager(ttl_seconds=1800) # 30 minutes
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
# Trigger interrupt # Trigger interrupt
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
interrupts = [m for m in ws.sent if m["type"] == "interrupt"] interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1 assert len(interrupts) == 1
@@ -333,7 +356,7 @@ class TestCheckpoint4InterruptTTLAutoCancel:
"thread_id": "t1", "thread_id": "t1",
"approved": True, "approved": True,
}) })
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
# Should get retry prompt, NOT resume the graph # Should get retry prompt, NOT resume the graph
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"] expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]

View File

@@ -0,0 +1,213 @@
"""Integration tests for /api/v1/conversations and /api/v1/replay/{thread_id}.
Tests the full API layer with a mocked database pool, verifying routing,
serialization, pagination, and error handling in envelope format.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _make_fake_cursor(rows, *, fetchone_value=None):
"""Build a fake async cursor returning the given rows on fetchall."""
cursor = AsyncMock()
cursor.fetchall = AsyncMock(return_value=rows)
if fetchone_value is not None:
cursor.fetchone = AsyncMock(return_value=fetchone_value)
return cursor
class _FakeConnection:
"""Fake async connection that returns pre-configured cursors in order."""
def __init__(self, cursors: list) -> None:
self._cursors = list(cursors)
self._idx = 0
async def execute(self, sql, params=None):
cursor = self._cursors[self._idx]
self._idx += 1
return cursor
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class _FakePool:
"""Fake connection pool that yields a fake connection."""
def __init__(self, conn: _FakeConnection) -> None:
self._conn = conn
def connection(self):
return self._conn
def _build_app(pool=None):
"""Build a minimal FastAPI app with the replay router and mocked deps."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.api_utils import envelope
from app.replay.api import router as replay_router
test_app = FastAPI()
test_app.include_router(replay_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = pool or MagicMock()
return test_app
class TestListConversations:
"""Tests for GET /api/v1/conversations endpoint."""
async def test_returns_paginated_envelope(self) -> None:
"""Conversations list returns envelope with pagination metadata."""
count_cursor = _make_fake_cursor([], fetchone_value=(3,))
rows = [
{"thread_id": "t1", "created_at": "2026-01-01", "last_activity": "2026-01-01",
"status": "active", "total_tokens": 100, "total_cost_usd": 0.01},
{"thread_id": "t2", "created_at": "2026-01-02", "last_activity": "2026-01-02",
"status": "resolved", "total_tokens": 200, "total_cost_usd": 0.02},
]
list_cursor = _make_fake_cursor(rows)
conn = _FakeConnection([count_cursor, list_cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["total"] == 3
assert len(body["data"]["conversations"]) == 2
assert body["data"]["page"] == 1
assert body["data"]["per_page"] == 20
async def test_custom_page_and_per_page(self) -> None:
"""Custom page/per_page params are reflected in the response."""
count_cursor = _make_fake_cursor([], fetchone_value=(50,))
list_cursor = _make_fake_cursor([])
conn = _FakeConnection([count_cursor, list_cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations", params={"page": 3, "per_page": 10})
assert resp.status_code == 200
body = resp.json()
assert body["data"]["page"] == 3
assert body["data"]["per_page"] == 10
async def test_invalid_page_returns_422(self) -> None:
"""page=0 violates ge=1 constraint and returns 422 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations", params={"page": 0})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
class TestReplayEndpoint:
"""Tests for GET /api/v1/replay/{thread_id} endpoint."""
async def test_valid_thread_returns_timeline(self) -> None:
"""Replay with valid thread_id returns steps in envelope format."""
checkpoint_rows = [
{
"thread_id": "abc123",
"checkpoint_id": "cp1",
"checkpoint": {
"channel_values": {
"messages": [
{"type": "human", "content": "Hello", "created_at": "2026-01-01T00:00:00Z"},
{"type": "ai", "content": "Hi there!", "created_at": "2026-01-01T00:00:01Z"},
]
}
},
"metadata": {},
}
]
cursor = _make_fake_cursor(checkpoint_rows)
conn = _FakeConnection([cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/abc123")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["thread_id"] == "abc123"
assert body["data"]["total_steps"] == 2
assert len(body["data"]["steps"]) == 2
assert body["data"]["steps"][0]["type"] == "user_message"
assert body["data"]["steps"][1]["type"] == "agent_response"
async def test_invalid_thread_id_format_returns_400(self) -> None:
"""Thread IDs with path traversal characters are rejected with 400."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/../../etc/passwd")
# FastAPI may return 400 from our handler or 404 from routing
assert resp.status_code in (400, 404, 422)
async def test_nonexistent_thread_returns_404(self) -> None:
"""Replay with a thread_id that has no checkpoints returns 404."""
cursor = _make_fake_cursor([])
conn = _FakeConnection([cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert "not found" in body["error"].lower()

View File

@@ -18,10 +18,12 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig from app.registry import AgentConfig
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -103,36 +105,45 @@ def _make_classifier(result: ClassificationResult) -> AsyncMock:
return classifier return classifier
def _make_graph( def _make_graph_and_ctx(
classifier_result: ClassificationResult | None, classifier_result: ClassificationResult | None,
chunks: list, chunks: list,
state=None, state=None,
) -> MagicMock: ) -> tuple[MagicMock, GraphContext]:
"""Build a graph mock with optional intent classifier.""" """Build a graph mock and GraphContext with optional intent classifier."""
graph = MagicMock() graph = MagicMock()
if classifier_result is not None:
graph.intent_classifier = _make_classifier(classifier_result)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=AGENTS)
graph.agent_registry = mock_registry
else:
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks))) graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
graph.aget_state = AsyncMock(return_value=state or _state()) graph.aget_state = AsyncMock(return_value=state or _state())
return graph
if classifier_result is not None:
classifier = _make_classifier(classifier_result)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=AGENTS)
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=classifier,
)
else:
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=None,
)
return graph, graph_ctx
async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]: async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
sm = SessionManager() sm = SessionManager()
sm.touch(thread_id) sm.touch(thread_id)
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content}) raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
return ws.sent return ws.sent
@@ -151,12 +162,12 @@ class TestSingleIntentRouting:
agent_name="order_lookup", confidence=0.95, reasoning="status query", agent_name="order_lookup", confidence=0.95, reasoning="status query",
),), ),),
) )
graph = _make_graph(result, [ graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"), _tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
_chunk("Order 1042 is shipped.", "order_lookup"), _chunk("Order 1042 is shipped.", "order_lookup"),
]) ])
msgs = await _dispatch(graph, "What is the status of order 1042?") msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
tools = [m for m in msgs if m["type"] == "tool_call"] tools = [m for m in msgs if m["type"] == "tool_call"]
assert len(tools) == 1 assert len(tools) == 1
@@ -171,13 +182,13 @@ class TestSingleIntentRouting:
result = ClassificationResult( result = ClassificationResult(
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),), intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
) )
graph = _make_graph( graph, graph_ctx = _make_graph_and_ctx(
result, result,
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")], [_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}), state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
) )
msgs = await _dispatch(graph, "Cancel order 1042") msgs = await _dispatch(graph_ctx, "Cancel order 1042")
tools = [m for m in msgs if m["type"] == "tool_call"] tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "cancel_order" assert tools[0]["tool"] == "cancel_order"
@@ -191,12 +202,12 @@ class TestSingleIntentRouting:
result = ClassificationResult( result = ClassificationResult(
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),), intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
) )
graph = _make_graph(result, [ graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"), _tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"), _chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
]) ])
msgs = await _dispatch(graph, "Give me a 15% coupon") msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
tools = [m for m in msgs if m["type"] == "tool_call"] tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "generate_coupon" assert tools[0]["tool"] == "generate_coupon"
@@ -207,11 +218,11 @@ class TestSingleIntentRouting:
result = ClassificationResult( result = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),), intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
) )
graph = _make_graph(result, [ graph, graph_ctx = _make_graph_and_ctx(result, [
_chunk("I can help with order inquiries.", "fallback"), _chunk("I can help with order inquiries.", "fallback"),
]) ])
msgs = await _dispatch(graph, "What can you do?") msgs = await _dispatch(graph_ctx, "What can you do?")
tokens = [m for m in msgs if m["type"] == "token"] tokens = [m for m in msgs if m["type"] == "token"]
assert tokens[0]["agent"] == "fallback" assert tokens[0]["agent"] == "fallback"
@@ -233,7 +244,7 @@ class TestMultiIntentRouting:
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"), IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
), ),
) )
graph = _make_graph(result, [ graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"), _tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"), _tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
]) ])
@@ -243,13 +254,17 @@ class TestMultiIntentRouting:
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({ raw = json.dumps({
"type": "message", "type": "message",
"thread_id": "t1", "thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣", "content": "取消订单 1042 并给我一个 10% 折扣",
}) })
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
# Verify routing hint was injected # Verify routing hint was injected
call_args = graph.astream.call_args[0][0] call_args = graph.astream.call_args[0][0]
@@ -269,16 +284,20 @@ class TestMultiIntentRouting:
result = ClassificationResult( result = ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),), intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
) )
graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")]) graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
sm = SessionManager() sm = SessionManager()
sm.touch("t1") sm.touch("t1")
im = InterruptManager() im = InterruptManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
msg_content = graph.astream.call_args[0][0]["messages"][0].content msg_content = graph.astream.call_args[0][0]["messages"][0].content
assert "[System:" not in msg_content assert "[System:" not in msg_content
@@ -299,9 +318,9 @@ class TestAmbiguityRouting:
is_ambiguous=True, is_ambiguous=True,
clarification_question="Could you please clarify what you need?", clarification_question="Could you please clarify what you need?",
) )
graph = _make_graph(result, []) graph, graph_ctx = _make_graph_and_ctx(result, [])
msgs = await _dispatch(graph, "嗯...") msgs = await _dispatch(graph_ctx, "嗯...")
clarifications = [m for m in msgs if m["type"] == "clarification"] clarifications = [m for m in msgs if m["type"] == "clarification"]
assert len(clarifications) == 1 assert len(clarifications) == 1
@@ -339,12 +358,12 @@ class TestNoClassifierFallback:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_classifier_routes_via_supervisor(self) -> None: async def test_no_classifier_routes_via_supervisor(self) -> None:
graph = _make_graph( graph, graph_ctx = _make_graph_and_ctx(
classifier_result=None, classifier_result=None,
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")], chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
) )
msgs = await _dispatch(graph, "What is order 1042 status?") msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
tokens = [m for m in msgs if m["type"] == "token"] tokens = [m for m in msgs if m["type"] == "token"]
assert len(tokens) == 1 assert len(tokens) == 1

View File

@@ -0,0 +1,159 @@
"""Integration tests for SessionManager + InterruptManager lifecycle.
These tests exercise the in-memory managers together, verifying the full
lifecycle of sessions and interrupts: creation, TTL sliding, interrupt
registration/resolution, and expired-interrupt cleanup.
No database required -- both managers are in-memory.
"""
from __future__ import annotations
import time
from unittest.mock import patch
import pytest
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
pytestmark = pytest.mark.integration
class TestSessionInterruptLifecycle:
"""Tests for the combined session + interrupt lifecycle."""
def test_create_session_register_interrupt_check_status(self) -> None:
"""Full lifecycle: create session, register interrupt, verify both states."""
sm = SessionManager(session_ttl_seconds=3600)
im = InterruptManager(ttl_seconds=300)
# Create a session
state = sm.touch("thread-1")
assert state.thread_id == "thread-1"
assert not state.has_pending_interrupt
assert not sm.is_expired("thread-1")
# Register an interrupt
record = im.register("thread-1", "cancel_order", {"order_id": "1042"})
sm.extend_for_interrupt("thread-1")
assert im.has_pending("thread-1")
session_state = sm.get_state("thread-1")
assert session_state is not None
assert session_state.has_pending_interrupt
# Session should not expire while interrupt is pending
assert not sm.is_expired("thread-1")
def test_interrupt_expiry_after_ttl(self) -> None:
"""Interrupt expires when TTL elapses, even if session is alive."""
im = InterruptManager(ttl_seconds=5)
record = im.register("thread-2", "refund", {"amount": 50})
assert im.has_pending("thread-2")
# Simulate time passing beyond TTL
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 10
assert not im.has_pending("thread-2")
status = im.check_status("thread-2")
assert status is not None
assert status.is_expired
assert status.remaining_seconds == 0.0
def test_interrupt_resolve_flow(self) -> None:
"""Resolving an interrupt removes it from pending and resets session."""
sm = SessionManager(session_ttl_seconds=3600)
im = InterruptManager(ttl_seconds=300)
sm.touch("thread-3")
im.register("thread-3", "delete_account", {"user_id": "u1"})
sm.extend_for_interrupt("thread-3")
# Verify pending state
assert im.has_pending("thread-3")
assert sm.get_state("thread-3").has_pending_interrupt
# Resolve
im.resolve("thread-3")
sm.resolve_interrupt("thread-3")
assert not im.has_pending("thread-3")
session_state = sm.get_state("thread-3")
assert session_state is not None
assert not session_state.has_pending_interrupt
def test_cleanup_expired_removes_old_interrupts(self) -> None:
"""cleanup_expired removes only expired interrupts, keeping active ones."""
im = InterruptManager(ttl_seconds=10)
# Register two interrupts at different times
old_record = im.register("thread-old", "action_old", {})
new_record = im.register("thread-new", "action_new", {})
# Simulate time where only old one expired
with patch("app.interrupt_manager.time") as mock_time:
# Move old record's creation to the past
im._interrupts["thread-old"] = old_record.__class__(
interrupt_id=old_record.interrupt_id,
thread_id=old_record.thread_id,
action=old_record.action,
params=old_record.params,
created_at=time.time() - 20,
ttl_seconds=old_record.ttl_seconds,
)
mock_time.time.return_value = time.time()
expired = im.cleanup_expired()
assert len(expired) == 1
assert expired[0].thread_id == "thread-old"
# New one should still be pending
assert im.has_pending("thread-new")
assert not im.has_pending("thread-old")
def test_session_ttl_sliding_window(self) -> None:
"""Touching a session resets the sliding window TTL."""
sm = SessionManager(session_ttl_seconds=3600)
state1 = sm.touch("thread-5")
first_activity = state1.last_activity
time.sleep(0.01)
state2 = sm.touch("thread-5")
second_activity = state2.last_activity
assert second_activity > first_activity
assert not sm.is_expired("thread-5")
def test_session_expires_after_ttl_without_activity(self) -> None:
"""Session expires when TTL passes without a touch or interrupt."""
sm = SessionManager(session_ttl_seconds=0)
sm.touch("thread-6")
# TTL is 0 so session is immediately expired
assert sm.is_expired("thread-6")
def test_pending_interrupt_prevents_session_expiry(self) -> None:
"""A session with pending interrupt does not expire even with TTL=0."""
sm = SessionManager(session_ttl_seconds=0)
sm.touch("thread-7")
sm.extend_for_interrupt("thread-7")
# Even with TTL=0, session should not expire because of pending interrupt
assert not sm.is_expired("thread-7")
def test_retry_prompt_for_expired_interrupt(self) -> None:
"""InterruptManager generates a retry prompt for expired interrupts."""
im = InterruptManager(ttl_seconds=300)
record = im.register("thread-8", "cancel_order", {"order_id": "1042"})
prompt = im.generate_retry_prompt(record)
assert prompt["type"] == "interrupt_expired"
assert prompt["thread_id"] == "thread-8"
assert "cancel_order" in prompt["action"]
assert "cancel_order" in prompt["message"]
assert "expired" in prompt["message"].lower()

View File

@@ -15,8 +15,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -81,8 +83,6 @@ def _graph(
resume_chunks: list | None = None, resume_chunks: list | None = None,
) -> MagicMock: ) -> MagicMock:
g = MagicMock() g = MagicMock()
g.intent_classifier = None
g.agent_registry = None
if st is None: if st is None:
st = _state() st = _state()
@@ -100,6 +100,13 @@ def _graph(
return g return g
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
g = graph or _graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
def _setup( def _setup(
graph=None, graph=None,
session_ttl: int = 1800, session_ttl: int = 1800,
@@ -109,23 +116,28 @@ def _setup(
): ):
"""Create test dependencies. Pre-touches session by default.""" """Create test dependencies. Pre-touches session by default."""
g = graph or _graph() g = graph or _graph()
graph_ctx = _make_graph_ctx(g)
sm = SessionManager(session_ttl_seconds=session_ttl) sm = SessionManager(session_ttl_seconds=session_ttl)
im = InterruptManager(ttl_seconds=interrupt_ttl) im = InterruptManager(ttl_seconds=interrupt_ttl)
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
ws = FakeWS() ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
if touch: if touch:
sm.touch(thread_id) sm.touch(thread_id)
return g, sm, im, cb, ws return g, sm, im, cb, ws, ws_ctx
async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"): async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content}) raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True): async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved}) raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -136,10 +148,10 @@ async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
class TestWebSocketHappyPath: class TestWebSocketHappyPath:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_receives_tokens_and_complete(self) -> None: async def test_send_message_receives_tokens_and_complete(self) -> None:
g, sm, im, cb, ws = _setup( g, sm, im, cb, ws, ws_ctx = _setup(
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")]) graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
) )
await _send(ws, g, sm, im, cb, content="What is the status of order 1042?") await _send(ws, ws_ctx, content="What is the status of order 1042?")
tokens = [m for m in ws.sent if m["type"] == "token"] tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 2 assert len(tokens) == 2
@@ -153,13 +165,13 @@ class TestWebSocketHappyPath:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_call_streamed(self) -> None: async def test_tool_call_streamed(self) -> None:
g, sm, im, cb, ws = _setup( g, sm, im, cb, ws, ws_ctx = _setup(
graph=_graph(chunks=[ graph=_graph(chunks=[
_tool_chunk("get_order_status", {"order_id": "1042"}), _tool_chunk("get_order_status", {"order_id": "1042"}),
_chunk("Order shipped."), _chunk("Order shipped."),
]) ])
) )
await _send(ws, g, sm, im, cb, content="Check order 1042") await _send(ws, ws_ctx, content="Check order 1042")
tools = [m for m in ws.sent if m["type"] == "tool_call"] tools = [m for m in ws.sent if m["type"] == "tool_call"]
assert len(tools) == 1 assert len(tools) == 1
@@ -168,9 +180,9 @@ class TestWebSocketHappyPath:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_messages_same_session(self) -> None: async def test_multiple_messages_same_session(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
for i in range(3): for i in range(3):
await _send(ws, g, sm, im, cb, content=f"msg {i}") await _send(ws, ws_ctx, content=f"msg {i}")
completes = [m for m in ws.sent if m["type"] == "message_complete"] completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 3 assert len(completes) == 3
@@ -183,10 +195,10 @@ class TestWebSocketInterruptApproval:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
resume = [_chunk("Order 1042 cancelled.", "order_actions")] resume = [_chunk("Order 1042 cancelled.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume) g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws = _setup(graph=g) g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
# Send message -> triggers interrupt # Send message -> triggers interrupt
await _send(ws, g_, sm, im, cb, content="Cancel order 1042") await _send(ws, ws_ctx, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"] interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1 assert len(interrupts) == 1
@@ -196,7 +208,7 @@ class TestWebSocketInterruptApproval:
# Approve # Approve
ws.sent.clear() ws.sent.clear()
await _respond(ws, g_, sm, im, cb, approved=True) await _respond(ws, ws_ctx, approved=True)
tokens = [m for m in ws.sent if m["type"] == "token"] tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 1 assert len(tokens) == 1
@@ -211,12 +223,12 @@ class TestWebSocketInterruptApproval:
st_int = _state(interrupt=True) st_int = _state(interrupt=True)
resume = [_chunk("Order remains active.", "order_actions")] resume = [_chunk("Order remains active.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume) g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws = _setup(graph=g) g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
await _send(ws, g_, sm, im, cb, content="Cancel order 1042") await _send(ws, ws_ctx, content="Cancel order 1042")
ws.sent.clear() ws.sent.clear()
await _respond(ws, g_, sm, im, cb, approved=False) await _respond(ws, ws_ctx, approved=False)
tokens = [m for m in ws.sent if m["type"] == "token"] tokens = [m for m in ws.sent if m["type"] == "token"]
assert "remains active" in tokens[0]["content"] assert "remains active" in tokens[0]["content"]
@@ -226,28 +238,28 @@ class TestWebSocketInterruptApproval:
class TestWebSocketSessionTTL: class TestWebSocketSessionTTL:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_expired_session_returns_error(self) -> None: async def test_expired_session_returns_error(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=0) g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
# Session was touched in _setup, but TTL is 0 so it's already expired # Session was touched in _setup, but TTL is 0 so it's already expired
await _send(ws, g, sm, im, cb, content="hello") await _send(ws, ws_ctx, content="hello")
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "expired" in ws.sent[0]["message"].lower() assert "expired" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_session_not_expired(self) -> None: async def test_new_session_not_expired(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=3600) g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
await _send(ws, g, sm, im, cb, content="hello") await _send(ws, ws_ctx, content="hello")
completes = [m for m in ws.sent if m["type"] == "message_complete"] completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 1 assert len(completes) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sliding_window_resets_on_message(self) -> None: async def test_sliding_window_resets_on_message(self) -> None:
g, sm, im, cb, ws = _setup(session_ttl=3600) g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
await _send(ws, g, sm, im, cb, content="hello") await _send(ws, ws_ctx, content="hello")
first_activity = sm.get_state("t1").last_activity first_activity = sm.get_state("t1").last_activity
time.sleep(0.01) time.sleep(0.01)
await _send(ws, g, sm, im, cb, content="hello again") await _send(ws, ws_ctx, content="hello again")
second_activity = sm.get_state("t1").last_activity second_activity = sm.get_state("t1").last_activity
assert second_activity > first_activity assert second_activity > first_activity
@@ -256,9 +268,9 @@ class TestWebSocketSessionTTL:
async def test_interrupt_extends_session_ttl(self) -> None: async def test_interrupt_extends_session_ttl(self) -> None:
st_int = _state(interrupt=True) st_int = _state(interrupt=True)
g = _graph(chunks=[], st=st_int) g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600) g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
await _send(ws, g_, sm, im, cb, content="cancel order") await _send(ws, ws_ctx, content="cancel order")
state = sm.get_state("t1") state = sm.get_state("t1")
assert state is not None assert state is not None
@@ -270,53 +282,53 @@ class TestWebSocketSessionTTL:
class TestWebSocketValidation: class TestWebSocketValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_json(self) -> None: async def test_invalid_json(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im) await dispatch_message(ws, ws_ctx, "not json")
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "Invalid JSON" in ws.sent[0]["message"] assert "Invalid JSON" in ws.sent[0]["message"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_thread_id(self) -> None: async def test_missing_thread_id(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "content": "hi"}) raw = json.dumps({"type": "message", "content": "hi"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "thread_id" in ws.sent[0]["message"] assert "thread_id" in ws.sent[0]["message"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_thread_id_format(self) -> None: async def test_invalid_thread_id_format(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"}) raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_content(self) -> None: async def test_missing_content(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1"}) raw = json.dumps({"type": "message", "thread_id": "t1"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unknown_message_type(self) -> None: async def test_unknown_message_type(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "foobar", "thread_id": "t1"}) raw = json.dumps({"type": "foobar", "thread_id": "t1"})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "Unknown" in ws.sent[0]["message"] assert "Unknown" in ws.sent[0]["message"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_too_large(self) -> None: async def test_message_too_large(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im) await dispatch_message(ws, ws_ctx, "x" * 40_000)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "too large" in ws.sent[0]["message"].lower() assert "too large" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_content_too_long(self) -> None: async def test_content_too_long(self) -> None:
g, sm, im, cb, ws = _setup() g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001}) raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im) await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error" assert ws.sent[0]["type"] == "error"
assert "too long" in ws.sent[0]["message"].lower() assert "too long" in ws.sent[0]["message"].lower()
@@ -327,10 +339,10 @@ class TestWebSocketInterruptTTL:
async def test_expired_interrupt_sends_retry_prompt(self) -> None: async def test_expired_interrupt_sends_retry_prompt(self) -> None:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
g = _graph(chunks=[], st=st_int) g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5) g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
# Trigger interrupt # Trigger interrupt
await _send(ws, g_, sm, im, cb, content="Cancel order 1042") await _send(ws, ws_ctx, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"] interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1 assert len(interrupts) == 1
@@ -341,7 +353,7 @@ class TestWebSocketInterruptTTL:
with patch("app.interrupt_manager.time") as mock_time: with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 10 mock_time.time.return_value = record.created_at + 10
await _respond(ws, g_, sm, im, cb, approved=True) await _respond(ws, ws_ctx, approved=True)
assert ws.sent[0]["type"] == "interrupt_expired" assert ws.sent[0]["type"] == "interrupt_expired"
assert "cancel_order" in ws.sent[0]["message"] assert "cancel_order" in ws.sent[0]["message"]

View File

@@ -44,7 +44,7 @@ def _make_analytics_result() -> object:
) )
def _get_analytics(app: FastAPI, path: str = "/api/analytics", **patch_kwargs: object) -> object: def _get_analytics(app: FastAPI, path: str = "/api/v1/analytics", **patch_kwargs: object) -> object:
"""Helper: patch get_analytics, make request, return (response, mock).""" """Helper: patch get_analytics, make request, return (response, mock)."""
analytics_result = _make_analytics_result() analytics_result = _make_analytics_result()
with ( with (
@@ -84,7 +84,7 @@ class TestAnalyticsEndpoint:
def test_custom_range_7d(self) -> None: def test_custom_range_7d(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool() app.state.pool = _make_mock_pool()
resp, mock_ga = _get_analytics(app, "/api/analytics?range=7d") resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=7d")
assert resp.status_code == 200 assert resp.status_code == 200
mock_ga.assert_called_once() mock_ga.assert_called_once()
@@ -94,7 +94,7 @@ class TestAnalyticsEndpoint:
def test_custom_range_30d(self) -> None: def test_custom_range_30d(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool() app.state.pool = _make_mock_pool()
resp, mock_ga = _get_analytics(app, "/api/analytics?range=30d") resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=30d")
assert resp.status_code == 200 assert resp.status_code == 200
call_kwargs = mock_ga.call_args call_kwargs = mock_ga.call_args
@@ -107,7 +107,7 @@ class TestAnalyticsEndpoint:
app.state.pool = _make_mock_pool() app.state.pool = _make_mock_pool()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics?range=invalid") resp = client.get("/api/v1/analytics?range=invalid")
assert resp.status_code == 400 assert resp.status_code == 400
@@ -116,7 +116,7 @@ class TestAnalyticsEndpoint:
app.state.pool = _make_mock_pool() app.state.pool = _make_mock_pool()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics?range=7") resp = client.get("/api/v1/analytics?range=7")
assert resp.status_code == 400 assert resp.status_code == 400

View File

@@ -145,4 +145,11 @@ class TestPostgresAnalyticsRecorder:
) )
call_args = mock_conn.execute.call_args call_args = mock_conn.execute.call_args
params = call_args[0][1] params = call_args[0][1]
assert params["metadata"] == {"key": "val"} # PostgresAnalyticsRecorder wraps metadata with psycopg Json() adapter.
# Unwrap to compare the inner dict.
from psycopg.types.json import Json
meta = params["metadata"]
if isinstance(meta, Json):
meta = meta.obj
assert meta == {"key": "val"}

View File

@@ -158,6 +158,42 @@ class TestInterruptStatsQuery:
assert result.expired == 0 assert result.expired == 0
class TestTotalConversations:
@pytest.mark.asyncio
async def test_returns_count(self) -> None:
from app.analytics.queries import _total_conversations
pool = _make_pool_with_fetchone({"total": 42})
result = await _total_conversations(pool, range_days=7)
assert result == 42
@pytest.mark.asyncio
async def test_zero_state_returns_zero(self) -> None:
from app.analytics.queries import _total_conversations
pool = _make_pool_with_fetchone(None)
result = await _total_conversations(pool, range_days=7)
assert result == 0
class TestAvgTurns:
@pytest.mark.asyncio
async def test_returns_float(self) -> None:
from app.analytics.queries import _avg_turns
pool = _make_pool_with_fetchone({"avg_turns": 3.5})
result = await _avg_turns(pool, range_days=7)
assert result == 3.5
@pytest.mark.asyncio
async def test_zero_state_returns_zero(self) -> None:
from app.analytics.queries import _avg_turns
pool = _make_pool_with_fetchone(None)
result = await _avg_turns(pool, range_days=7)
assert result == 0.0
class TestGetAnalytics: class TestGetAnalytics:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_returns_analytics_result(self) -> None: async def test_returns_analytics_result(self) -> None:

View File

@@ -28,7 +28,7 @@ def client():
@pytest.fixture @pytest.fixture
def job_id(client): def job_id(client):
"""Create a job and return its ID.""" """Create a job and return its ID."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
assert response.status_code == 202 assert response.status_code == 202
return response.json()["job_id"] return response.json()["job_id"]
@@ -61,11 +61,11 @@ def job_with_classifications(client, job_id):
class TestImportEndpoint: class TestImportEndpoint:
"""Tests for POST /api/openapi/import.""" """Tests for POST /api/v1/openapi/import."""
def test_post_import_returns_job_id(self, client) -> None: def test_post_import_returns_job_id(self, client) -> None:
"""POST /import returns 202 with a job_id.""" """POST /import returns 202 with a job_id."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
assert response.status_code == 202 assert response.status_code == 202
data = response.json() data = response.json()
assert "job_id" in data assert "job_id" in data
@@ -73,38 +73,38 @@ class TestImportEndpoint:
def test_post_import_empty_url_returns_422(self, client) -> None: def test_post_import_empty_url_returns_422(self, client) -> None:
"""POST /import with empty URL returns 422 validation error.""" """POST /import with empty URL returns 422 validation error."""
response = client.post("/api/openapi/import", json={"url": ""}) response = client.post("/api/v1/openapi/import", json={"url": ""})
assert response.status_code == 422 assert response.status_code == 422
def test_post_import_missing_url_returns_422(self, client) -> None: def test_post_import_missing_url_returns_422(self, client) -> None:
"""POST /import with missing URL field returns 422.""" """POST /import with missing URL field returns 422."""
response = client.post("/api/openapi/import", json={}) response = client.post("/api/v1/openapi/import", json={})
assert response.status_code == 422 assert response.status_code == 422
def test_post_import_invalid_scheme_returns_422(self, client) -> None: def test_post_import_invalid_scheme_returns_422(self, client) -> None:
"""POST /import with non-http URL returns 422.""" """POST /import with non-http URL returns 422."""
response = client.post("/api/openapi/import", json={"url": "ftp://evil.com/spec"}) response = client.post("/api/v1/openapi/import", json={"url": "ftp://evil.com/spec"})
assert response.status_code == 422 assert response.status_code == 422
def test_post_import_returns_pending_status(self, client) -> None: def test_post_import_returns_pending_status(self, client) -> None:
"""Newly created job has pending status.""" """Newly created job has pending status."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
data = response.json() data = response.json()
assert data["status"] == "pending" assert data["status"] == "pending"
def test_post_import_returns_spec_url(self, client) -> None: def test_post_import_returns_spec_url(self, client) -> None:
"""Response includes the original spec URL.""" """Response includes the original spec URL."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
data = response.json() data = response.json()
assert data["spec_url"] == _SAMPLE_URL assert data["spec_url"] == _SAMPLE_URL
class TestGetJobEndpoint: class TestGetJobEndpoint:
"""Tests for GET /api/openapi/jobs/{job_id}.""" """Tests for GET /api/v1/openapi/jobs/{job_id}."""
def test_get_job_returns_status(self, client, job_id) -> None: def test_get_job_returns_status(self, client, job_id) -> None:
"""GET /jobs/{id} returns job status.""" """GET /jobs/{id} returns job status."""
response = client.get(f"/api/openapi/jobs/{job_id}") response = client.get(f"/api/v1/openapi/jobs/{job_id}")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "status" in data assert "status" in data
@@ -112,23 +112,23 @@ class TestGetJobEndpoint:
def test_get_unknown_job_returns_404(self, client) -> None: def test_get_unknown_job_returns_404(self, client) -> None:
"""GET /jobs/nonexistent returns 404.""" """GET /jobs/nonexistent returns 404."""
response = client.get("/api/openapi/jobs/nonexistent-id") response = client.get("/api/v1/openapi/jobs/nonexistent-id")
assert response.status_code == 404 assert response.status_code == 404
def test_get_job_includes_spec_url(self, client, job_id) -> None: def test_get_job_includes_spec_url(self, client, job_id) -> None:
"""Job response includes the spec URL.""" """Job response includes the spec URL."""
response = client.get(f"/api/openapi/jobs/{job_id}") response = client.get(f"/api/v1/openapi/jobs/{job_id}")
data = response.json() data = response.json()
assert data["spec_url"] == _SAMPLE_URL assert data["spec_url"] == _SAMPLE_URL
class TestGetClassificationsEndpoint: class TestGetClassificationsEndpoint:
"""Tests for GET /api/openapi/jobs/{job_id}/classifications.""" """Tests for GET /api/v1/openapi/jobs/{job_id}/classifications."""
def test_get_classifications_returns_list(self, client, job_with_classifications) -> None: def test_get_classifications_returns_list(self, client, job_with_classifications) -> None:
"""GET /classifications returns a list.""" """GET /classifications returns a list."""
response = client.get( response = client.get(
f"/api/openapi/jobs/{job_with_classifications}/classifications" f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -137,13 +137,13 @@ class TestGetClassificationsEndpoint:
def test_get_classifications_unknown_job_returns_404(self, client) -> None: def test_get_classifications_unknown_job_returns_404(self, client) -> None:
"""GET /classifications for unknown job returns 404.""" """GET /classifications for unknown job returns 404."""
response = client.get("/api/openapi/jobs/unknown/classifications") response = client.get("/api/v1/openapi/jobs/unknown/classifications")
assert response.status_code == 404 assert response.status_code == 404
def test_classification_has_expected_fields(self, client, job_with_classifications) -> None: def test_classification_has_expected_fields(self, client, job_with_classifications) -> None:
"""Each classification item has access_type and endpoint fields.""" """Each classification item has access_type and endpoint fields."""
response = client.get( response = client.get(
f"/api/openapi/jobs/{job_with_classifications}/classifications" f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
) )
item = response.json()[0] item = response.json()[0]
assert "access_type" in item assert "access_type" in item
@@ -152,12 +152,12 @@ class TestGetClassificationsEndpoint:
class TestUpdateClassificationEndpoint: class TestUpdateClassificationEndpoint:
"""Tests for PUT /api/openapi/jobs/{job_id}/classifications/{idx}.""" """Tests for PUT /api/v1/openapi/jobs/{job_id}/classifications/{idx}."""
def test_update_classification_succeeds(self, client, job_with_classifications) -> None: def test_update_classification_succeeds(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 updates the classification.""" """PUT /classifications/0 updates the classification."""
response = client.put( response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0", f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"}, json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -165,7 +165,7 @@ class TestUpdateClassificationEndpoint:
def test_update_unknown_job_returns_404(self, client) -> None: def test_update_unknown_job_returns_404(self, client) -> None:
"""PUT /classifications/0 for unknown job returns 404.""" """PUT /classifications/0 for unknown job returns 404."""
response = client.put( response = client.put(
"/api/openapi/jobs/unknown/classifications/0", "/api/v1/openapi/jobs/unknown/classifications/0",
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"}, json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -173,7 +173,7 @@ class TestUpdateClassificationEndpoint:
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None: def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid access_type returns 422.""" """PUT /classifications/0 with invalid access_type returns 422."""
response = client.put( response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0", f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"}, json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -181,7 +181,7 @@ class TestUpdateClassificationEndpoint:
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None: def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid agent_group returns 422.""" """PUT /classifications/0 with invalid agent_group returns 422."""
response = client.put( response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0", f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"}, json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -189,31 +189,31 @@ class TestUpdateClassificationEndpoint:
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None: def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
"""PUT /classifications/999 returns 404 for out-of-range index.""" """PUT /classifications/999 returns 404 for out-of-range index."""
response = client.put( response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/999", f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/999",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"}, json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"},
) )
assert response.status_code == 404 assert response.status_code == 404
class TestApproveEndpoint: class TestApproveEndpoint:
"""Tests for POST /api/openapi/jobs/{job_id}/approve.""" """Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
def test_approve_job_succeeds(self, client, job_with_classifications) -> None: def test_approve_job_succeeds(self, client, job_with_classifications) -> None:
"""POST /approve transitions job to approved status.""" """POST /approve transitions job to approved status."""
response = client.post( response = client.post(
f"/api/openapi/jobs/{job_with_classifications}/approve" f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
) )
assert response.status_code == 200 assert response.status_code == 200
def test_approve_unknown_job_returns_404(self, client) -> None: def test_approve_unknown_job_returns_404(self, client) -> None:
"""POST /approve for unknown job returns 404.""" """POST /approve for unknown job returns 404."""
response = client.post("/api/openapi/jobs/unknown/approve") response = client.post("/api/v1/openapi/jobs/unknown/approve")
assert response.status_code == 404 assert response.status_code == 404
def test_approve_returns_job_status(self, client, job_with_classifications) -> None: def test_approve_returns_job_status(self, client, job_with_classifications) -> None:
"""POST /approve returns updated job status.""" """POST /approve returns updated job status."""
response = client.post( response = client.post(
f"/api/openapi/jobs/{job_with_classifications}/approve" f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
) )
data = response.json() data = response.json()
assert "status" in data assert "status" in data

View File

@@ -5,9 +5,12 @@ from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.api_utils import envelope
pytestmark = pytest.mark.unit pytestmark = pytest.mark.unit
@@ -16,16 +19,46 @@ def _build_app() -> FastAPI:
app = FastAPI() app = FastAPI()
app.include_router(router) app.include_router(router)
@app.exception_handler(HTTPException)
async def _http_exc(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
return app return app
def _make_mock_pool(fetchall_result: list[dict]) -> MagicMock: def _make_mock_pool(
"""Build a mock pool that returns the given rows from fetchall.""" fetchall_result: list[dict],
mock_cursor = AsyncMock() *,
mock_cursor.fetchall = AsyncMock(return_value=fetchall_result) count: int | None = None,
) -> MagicMock:
"""Build a mock pool that returns the given rows from fetchall.
mock_conn = AsyncMock() When *count* is provided, the first execute() call returns a cursor
mock_conn.execute = AsyncMock(return_value=mock_cursor) whose fetchone() yields ``(count,)`` (for the COUNT query) and the
second call returns the rows via fetchall(). When *count* is None
(the default), a single cursor backed by *fetchall_result* is used
for all calls.
"""
if count is not None:
count_cursor = AsyncMock()
count_cursor.fetchone = AsyncMock(return_value=(count,))
rows_cursor = AsyncMock()
rows_cursor.fetchall = AsyncMock(return_value=fetchall_result)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(side_effect=[count_cursor, rows_cursor])
else:
mock_cursor = AsyncMock()
mock_cursor.fetchall = AsyncMock(return_value=fetchall_result)
mock_cursor.fetchone = AsyncMock(return_value=None)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_ctx = AsyncMock() mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn) mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
@@ -39,14 +72,17 @@ def _make_mock_pool(fetchall_result: list[dict]) -> MagicMock:
class TestListConversations: class TestListConversations:
def test_returns_200_with_empty_list(self) -> None: def test_returns_200_with_empty_list(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
assert isinstance(body["data"], list) data = body["data"]
assert isinstance(data["conversations"], list)
assert data["total"] == 0
assert data["page"] == 1
assert body["error"] is None assert body["error"] is None
def test_returns_conversations_list(self) -> None: def test_returns_conversations_list(self) -> None:
@@ -61,39 +97,41 @@ class TestListConversations:
"total_cost_usd": 0.01, "total_cost_usd": 0.01,
} }
] ]
app.state.pool = _make_mock_pool(mock_rows) app.state.pool = _make_mock_pool(mock_rows, count=1)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
body = resp.json() body = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert len(body["data"]) == 1 data = body["data"]
assert body["data"][0]["thread_id"] == "t1" assert len(data["conversations"]) == 1
assert data["conversations"][0]["thread_id"] == "t1"
assert data["total"] == 1
def test_pagination_defaults(self) -> None: def test_pagination_defaults(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
assert resp.status_code == 200 assert resp.status_code == 200
def test_pagination_custom_params(self) -> None: def test_pagination_custom_params(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations?page=2&per_page=10") resp = client.get("/api/v1/conversations?page=2&per_page=10")
assert resp.status_code == 200 assert resp.status_code == 200
def test_per_page_max_capped_at_100(self) -> None: def test_per_page_max_capped_at_100(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations?per_page=200") resp = client.get("/api/v1/conversations?per_page=200")
# FastAPI validation rejects values > 100 # FastAPI Query(le=100) rejects values > 100
assert resp.status_code in (200, 422) assert resp.status_code == 422
class TestGetReplay: class TestGetReplay:
@@ -102,7 +140,7 @@ class TestGetReplay:
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([])
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/nonexistent-thread") resp = client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404 assert resp.status_code == 404
def test_returns_replay_page_for_existing_thread(self) -> None: def test_returns_replay_page_for_existing_thread(self) -> None:
@@ -122,7 +160,7 @@ class TestGetReplay:
app.state.pool = _make_mock_pool(mock_rows) app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/thread-123") resp = client.get("/api/v1/replay/thread-123")
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
@@ -147,7 +185,7 @@ class TestGetReplay:
app.state.pool = _make_mock_pool(mock_rows) app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/t1?page=1&per_page=5") resp = client.get("/api/v1/replay/t1?page=1&per_page=5")
assert resp.status_code == 200 assert resp.status_code == 200
def test_error_response_has_envelope(self) -> None: def test_error_response_has_envelope(self) -> None:
@@ -155,16 +193,19 @@ class TestGetReplay:
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([])
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/missing") resp = client.get("/api/v1/replay/missing")
assert resp.status_code == 404 assert resp.status_code == 404
assert "detail" in resp.json() body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] is not None
def test_invalid_thread_id_returns_400(self) -> None: def test_invalid_thread_id_returns_400(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([])
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/id%20with%20spaces") resp = client.get("/api/v1/replay/id%20with%20spaces")
assert resp.status_code == 400 assert resp.status_code == 400
def test_thread_id_special_chars_returns_400(self) -> None: def test_thread_id_special_chars_returns_400(self) -> None:
@@ -172,5 +213,5 @@ class TestGetReplay:
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([])
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/id;DROP TABLE") resp = client.get("/api/v1/replay/id;DROP TABLE")
assert resp.status_code == 400 assert resp.status_code == 400

View File

@@ -153,3 +153,105 @@ class TestTransformCheckpoints:
rows = [_make_row([{"type": "human", "content": "Hi"}])] rows = [_make_row([{"type": "human", "content": "Hi"}])]
steps = transform_checkpoints(rows) steps = transform_checkpoints(rows)
assert isinstance(steps[0].timestamp, str) assert isinstance(steps[0].timestamp, str)
def test_list_content_joined_to_string(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{
"type": "human",
"content": [
{"text": "Hello"},
{"text": " world"},
],
}
]
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].content == "Hello world"
def test_checkpoint_as_string_skipped(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
{
"thread_id": "t1",
"checkpoint_id": "cp1",
"checkpoint": "not-a-dict",
"metadata": {},
}
]
steps = transform_checkpoints(rows)
assert steps == []
def test_channel_values_not_dict_skipped(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
{
"thread_id": "t1",
"checkpoint_id": "cp1",
"checkpoint": {"channel_values": "bad"},
"metadata": {},
}
]
steps = transform_checkpoints(rows)
assert steps == []
def test_tool_result_valid_json_parsed(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{
"type": "tool",
"content": '{"order_id": "123", "status": "shipped"}',
"name": "get_order_status",
}
]
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].result == {"order_id": "123", "status": "shipped"}
def test_tool_result_invalid_json_wrapped(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{
"type": "tool",
"content": "not valid json",
"name": "some_tool",
}
]
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].result == {"raw": "not valid json"}
def test_malformed_message_skipped_gracefully(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{"type": "human", "content": "Good message"},
42, # not a dict -- will raise in _step_from_message
{"type": "ai", "content": "Response", "tool_calls": []},
]
)
]
steps = transform_checkpoints(rows)
# The malformed message is skipped; the other two produce steps.
assert len(steps) == 2
assert steps[0].step == 1
assert steps[1].step == 2

View File

@@ -7,10 +7,41 @@ import pytest
from app.config import Settings from app.config import Settings
def _isolated_settings(**kwargs: object) -> Settings:
"""Create a Settings instance that ignores .env files and process env vars.
pydantic-settings reads from env_file and environment by default, which
causes test results to depend on the machine they run on. We override
model_config at the class level temporarily so that every test gets
deterministic results.
"""
# Build a throwaway subclass that disables env-file and env-var loading.
class _IsolatedSettings(Settings):
model_config = Settings.model_config.copy()
model_config["env_file"] = None # type: ignore[assignment]
model_config["env_ignore_empty"] = True
# _env_parse_none_str makes pydantic-settings treat missing env vars as
# absent rather than empty-string, so required fields will raise.
import os
env_backup = os.environ.copy()
# Strip all env vars that Settings knows about so they can't leak in.
settings_fields = set(Settings.model_fields)
for key in list(os.environ):
if key.lower() in settings_fields:
del os.environ[key]
try:
return _IsolatedSettings(**kwargs) # type: ignore[return-value]
finally:
os.environ.clear()
os.environ.update(env_backup)
@pytest.mark.unit @pytest.mark.unit
class TestSettings: class TestSettings:
def test_default_values(self) -> None: def test_default_values(self) -> None:
settings = Settings( settings = _isolated_settings(
database_url="postgresql://x:x@localhost/db", database_url="postgresql://x:x@localhost/db",
anthropic_api_key="key", anthropic_api_key="key",
) )
@@ -20,7 +51,7 @@ class TestSettings:
assert settings.interrupt_ttl_minutes == 30 assert settings.interrupt_ttl_minutes == 30
def test_custom_values(self) -> None: def test_custom_values(self) -> None:
settings = Settings( settings = _isolated_settings(
database_url="postgresql://x:x@localhost/db", database_url="postgresql://x:x@localhost/db",
llm_provider="openai", llm_provider="openai",
llm_model="gpt-4o", llm_model="gpt-4o",
@@ -33,18 +64,18 @@ class TestSettings:
def test_invalid_provider_rejected(self) -> None: def test_invalid_provider_rejected(self) -> None:
with pytest.raises(Exception): with pytest.raises(Exception):
Settings( _isolated_settings(
database_url="postgresql://x:x@localhost/db", database_url="postgresql://x:x@localhost/db",
llm_provider="invalid", llm_provider="invalid",
) )
def test_missing_database_url_rejected(self) -> None: def test_missing_database_url_rejected(self) -> None:
with pytest.raises(Exception): with pytest.raises(Exception):
Settings(anthropic_api_key="key") _isolated_settings(anthropic_api_key="key")
def test_empty_api_key_for_provider_rejected(self) -> None: def test_empty_api_key_for_provider_rejected(self) -> None:
with pytest.raises(ValueError, match="API key"): with pytest.raises(ValueError, match="API key"):
Settings( _isolated_settings(
database_url="postgresql://x:x@localhost/db", database_url="postgresql://x:x@localhost/db",
llm_provider="anthropic", llm_provider="anthropic",
anthropic_api_key="", anthropic_api_key="",
@@ -52,9 +83,27 @@ class TestSettings:
def test_wrong_provider_key_rejected(self) -> None: def test_wrong_provider_key_rejected(self) -> None:
with pytest.raises(ValueError, match="API key"): with pytest.raises(ValueError, match="API key"):
Settings( _isolated_settings(
database_url="postgresql://x:x@localhost/db", database_url="postgresql://x:x@localhost/db",
llm_provider="openai", llm_provider="openai",
anthropic_api_key="key", anthropic_api_key="key",
openai_api_key="", openai_api_key="",
) )
def test_azure_openai_missing_endpoint_rejected(self) -> None:
with pytest.raises(ValueError, match="AZURE_OPENAI_ENDPOINT"):
_isolated_settings(
database_url="postgresql://x:x@localhost/db",
llm_provider="azure_openai",
azure_openai_api_key="key",
azure_openai_deployment="my-deploy",
)
def test_azure_openai_missing_deployment_rejected(self) -> None:
with pytest.raises(ValueError, match="AZURE_OPENAI_DEPLOYMENT"):
_isolated_settings(
database_url="postgresql://x:x@localhost/db",
llm_provider="azure_openai",
azure_openai_api_key="key",
azure_openai_endpoint="https://example.openai.azure.com",
)

View File

@@ -55,7 +55,7 @@ class TestDbModule:
from app.db import setup_app_tables from app.db import setup_app_tables
await setup_app_tables(mock_pool) await setup_app_tables(mock_pool)
assert mock_conn.execute.await_count == 4 assert mock_conn.execute.await_count == 5
def test_ddl_statements_valid(self) -> None: def test_ddl_statements_valid(self) -> None:
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL

View File

@@ -51,5 +51,5 @@ class TestAnalyticsEventsDDL:
from app.db import setup_app_tables from app.db import setup_app_tables
await setup_app_tables(mock_pool) await setup_app_tables(mock_pool)
# Now expects 4 statements: conversations, interrupts, analytics_events, migrations # Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations
assert mock_conn.execute.await_count == 4 assert mock_conn.execute.await_count == 5

View File

@@ -8,7 +8,9 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message from app.ws_handler import dispatch_message
pytestmark = pytest.mark.unit pytestmark = pytest.mark.unit
@@ -20,7 +22,7 @@ def _make_ws() -> AsyncMock:
return ws return ws
def _make_graph() -> AsyncMock: def _make_graph() -> MagicMock:
graph = AsyncMock() graph = AsyncMock()
class AsyncIterHelper: class AsyncIterHelper:
@@ -34,23 +36,32 @@ def _make_graph() -> AsyncMock:
state = MagicMock() state = MagicMock()
state.tasks = () state.tasks = ()
graph.aget_state = AsyncMock(return_value=state) graph.aget_state = AsyncMock(return_value=state)
graph.intent_classifier = None
graph.agent_registry = None
return graph return graph
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
graph = _make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
return WebSocketContext(
graph_ctx=graph_ctx,
session_manager=sm or SessionManager(),
callback_handler=TokenUsageCallbackHandler(),
)
@pytest.mark.unit @pytest.mark.unit
class TestEmptyMessageHandling: class TestEmptyMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_message_content_returns_error(self) -> None: async def test_empty_message_content_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -60,13 +71,12 @@ class TestEmptyMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_whitespace_only_message_treated_as_empty(self) -> None: async def test_whitespace_only_message_treated_as_empty(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -77,14 +87,13 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_content_over_10000_chars_returns_error(self) -> None: async def test_content_over_10000_chars_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
content = "x" * 10001 content = "x" * 10001
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -93,14 +102,13 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_content_exactly_10000_chars_is_accepted(self) -> None: async def test_content_exactly_10000_chars_is_accepted(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
content = "x" * 10000 content = "x" * 10000
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
# Should be processed, not an error about length # Should be processed, not an error about length
@@ -110,12 +118,10 @@ class TestOversizedMessageHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_raw_message_over_32kb_returns_error(self) -> None: async def test_raw_message_over_32kb_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
large_msg = "x" * 40_000 large_msg = "x" * 40_000
await dispatch_message(ws, graph, sm, cb, large_msg) await dispatch_message(ws, ws_ctx, large_msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -127,11 +133,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_json_returns_error(self) -> None: async def test_invalid_json_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, "not valid json {{") await dispatch_message(ws, ws_ctx, "not valid json {{")
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -140,11 +144,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_string_returns_json_error(self) -> None: async def test_empty_string_returns_json_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, "") await dispatch_message(ws, ws_ctx, "")
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -152,11 +154,9 @@ class TestInvalidJsonHandling:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_json_array_not_object_returns_error(self) -> None: async def test_json_array_not_object_returns_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]') await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -167,17 +167,15 @@ class TestRateLimiting:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rapid_fire_messages_rate_limited(self) -> None: async def test_rapid_fire_messages_rate_limited(self) -> None:
ws = _make_ws() ws = _make_ws()
_make_graph() # ensure graph factory works, not needed directly
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit) # Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
rate_limit_triggered = False rate_limit_triggered = False
for i in range(11): for i in range(11):
graph2 = _make_graph() # fresh graph each time ws_ctx = _make_ws_ctx(sm=sm)
await dispatch_message(ws, graph2, sm, cb, json.dumps({ await dispatch_message(ws, ws_ctx, json.dumps({
"type": "message", "type": "message",
"thread_id": "t1", "thread_id": "t1",
"content": f"message {i}", "content": f"message {i}",
@@ -193,19 +191,18 @@ class TestRateLimiting:
async def test_different_threads_have_separate_rate_limits(self) -> None: async def test_different_threads_have_separate_rate_limits(self) -> None:
ws = _make_ws() ws = _make_ws()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
sm.touch("t2") sm.touch("t2")
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited # Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
for i in range(5): for i in range(5):
graph1 = _make_graph() ws_ctx1 = _make_ws_ctx(sm=sm)
graph2 = _make_graph() ws_ctx2 = _make_ws_ctx(sm=sm)
await dispatch_message(ws, graph1, sm, cb, json.dumps({ await dispatch_message(ws, ws_ctx1, json.dumps({
"type": "message", "thread_id": "t1", "content": f"msg {i}", "type": "message", "thread_id": "t1", "content": f"msg {i}",
})) }))
await dispatch_message(ws, graph2, sm, cb, json.dumps({ await dispatch_message(ws, ws_ctx2, json.dumps({
"type": "message", "thread_id": "t2", "content": f"msg {i}", "type": "message", "thread_id": "t2", "content": f"msg {i}",
})) }))

View File

@@ -0,0 +1,142 @@
"""Tests for standardized error response envelope format."""
from __future__ import annotations
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field
from app.api_utils import envelope
pytestmark = pytest.mark.unit
def _build_test_app() -> FastAPI:
"""Build a minimal FastAPI app with the standard exception handlers."""
app = FastAPI()
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
class ItemRequest(BaseModel):
name: str = Field(..., min_length=1)
count: int = Field(..., gt=0)
@app.get("/items/{item_id}")
def get_item(item_id: int) -> dict:
if item_id == 0:
raise HTTPException(status_code=400, detail="Invalid item ID")
if item_id == 999:
raise HTTPException(status_code=404, detail="Item not found")
if item_id == 401:
raise HTTPException(status_code=401, detail="Not authenticated")
return envelope({"id": item_id, "name": "test"})
@app.post("/items")
def create_item(item: ItemRequest) -> dict:
return envelope({"id": 1, "name": item.name})
@app.get("/crash")
def crash() -> dict:
msg = "unexpected failure"
raise RuntimeError(msg)
return app
class TestHttpExceptionEnvelope:
"""HTTPException responses use the standard envelope format."""
def test_400_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/0")
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Invalid item ID"
def test_404_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/999")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Item not found"
def test_401_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/401")
assert resp.status_code == 401
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Not authenticated"
class TestValidationErrorEnvelope:
"""Validation errors return 422 with envelope format."""
def test_validation_error_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post("/items", json={"name": "", "count": -1})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
assert len(body["error"]) > 0
class TestGeneralExceptionEnvelope:
"""Unhandled exceptions return 500 with safe envelope."""
def test_unhandled_exception_returns_500_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/crash")
assert resp.status_code == 500
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Internal server error"
class TestSuccessResponseUnchanged:
"""Success responses still work normally."""
def test_success_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app) as client:
resp = client.get("/items/42")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["id"] == 42
assert body["error"] is None

View File

@@ -6,8 +6,10 @@ from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from langgraph.checkpoint.memory import InMemorySaver
from app.graph import build_agent_nodes, build_graph, classify_intent from app.graph import build_agent_nodes, build_graph
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget from app.intent import ClassificationResult, IntentTarget
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -34,41 +36,43 @@ class TestBuildGraph:
mock_llm = MagicMock() mock_llm = MagicMock()
mock_llm.bind_tools = MagicMock(return_value=mock_llm) mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_llm.with_structured_output = MagicMock(return_value=mock_llm) mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
mock_checkpointer = AsyncMock() checkpointer = InMemorySaver()
graph = build_graph(sample_registry, mock_llm, mock_checkpointer) graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
assert graph is not None assert graph_ctx is not None
assert graph_ctx.graph is not None
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None: def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock() mock_llm = MagicMock()
mock_llm.bind_tools = MagicMock(return_value=mock_llm) mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_llm.with_structured_output = MagicMock(return_value=mock_llm) mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
mock_checkpointer = AsyncMock() checkpointer = InMemorySaver()
mock_classifier = MagicMock() mock_classifier = MagicMock()
graph = build_graph( graph_ctx = build_graph(
sample_registry, mock_llm, mock_checkpointer, intent_classifier=mock_classifier sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
) )
assert graph.intent_classifier is mock_classifier assert graph_ctx.intent_classifier is mock_classifier
assert graph.agent_registry is sample_registry assert graph_ctx.registry is sample_registry
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None: def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock() mock_llm = MagicMock()
mock_llm.bind_tools = MagicMock(return_value=mock_llm) mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_llm.with_structured_output = MagicMock(return_value=mock_llm) mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
mock_checkpointer = AsyncMock() checkpointer = InMemorySaver()
graph = build_graph(sample_registry, mock_llm, mock_checkpointer) graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
assert graph.intent_classifier is None assert graph_ctx.intent_classifier is None
@pytest.mark.unit @pytest.mark.unit
class TestClassifyIntent: class TestClassifyIntent:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_returns_none_without_classifier(self) -> None: async def test_returns_none_without_classifier(self) -> None:
graph = MagicMock() mock_registry = MagicMock()
graph.intent_classifier = None mock_registry.list_agents = MagicMock(return_value=())
result = await classify_intent(graph, "hello") graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None)
result = await graph_ctx.classify_intent("hello")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -79,11 +83,12 @@ class TestClassifyIntent:
mock_classifier = AsyncMock() mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=expected) mock_classifier.classify = AsyncMock(return_value=expected)
graph = MagicMock() mock_registry = MagicMock()
graph.intent_classifier = mock_classifier mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = MagicMock() graph_ctx = GraphContext(
graph.agent_registry.list_agents = MagicMock(return_value=()) graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier,
)
result = await classify_intent(graph, "check order") result = await graph_ctx.classify_intent("check order")
assert result is not None assert result is not None
assert result.intents[0].agent_name == "order_lookup" assert result.intents[0].agent_name == "order_lookup"

View File

@@ -0,0 +1,86 @@
"""Tests for the interrupt cleanup background loop in main.py."""
from __future__ import annotations
import asyncio
import logging
from unittest.mock import MagicMock, patch
import pytest
from app.main import _interrupt_cleanup_loop
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cleanup_loop_calls_cleanup_expired() -> None:
"""The loop should call cleanup_expired after each sleep interval."""
manager = MagicMock()
manager.cleanup_expired.return_value = ()
call_count = 0
original_sleep = asyncio.sleep
async def _fake_sleep(seconds: float) -> None:
nonlocal call_count
call_count += 1
if call_count >= 2:
raise asyncio.CancelledError
await original_sleep(0)
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
with pytest.raises(asyncio.CancelledError):
await _interrupt_cleanup_loop(manager, interval=60)
assert manager.cleanup_expired.call_count >= 1
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cleanup_loop_survives_exceptions() -> None:
"""The loop should not die when cleanup_expired raises an exception."""
manager = MagicMock()
manager.cleanup_expired.side_effect = [RuntimeError("db gone"), ()]
call_count = 0
original_sleep = asyncio.sleep
async def _fake_sleep(seconds: float) -> None:
nonlocal call_count
call_count += 1
if call_count >= 3:
raise asyncio.CancelledError
await original_sleep(0)
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
with pytest.raises(asyncio.CancelledError):
await _interrupt_cleanup_loop(manager, interval=60)
# Should have been called twice: once raising, once returning ()
assert manager.cleanup_expired.call_count == 2
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cleanup_loop_logs_expired_count(capsys: pytest.CaptureFixture[str]) -> None:
"""The loop should log when expired interrupts are found."""
fake_record = MagicMock()
manager = MagicMock()
manager.cleanup_expired.return_value = (fake_record, fake_record)
call_count = 0
original_sleep = asyncio.sleep
async def _fake_sleep(seconds: float) -> None:
nonlocal call_count
call_count += 1
if call_count >= 2:
raise asyncio.CancelledError
await original_sleep(0)
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
with pytest.raises(asyncio.CancelledError):
await _interrupt_cleanup_loop(manager, interval=60)
captured = capsys.readouterr()
assert "2 expired interrupt" in captured.out

View File

@@ -0,0 +1,20 @@
"""Tests for structured logging configuration."""
from __future__ import annotations
import pytest
from app.logging_config import configure_logging
pytestmark = pytest.mark.unit
def test_configure_logging_console_mode() -> None:
"""Console mode configures without error."""
configure_logging("console")
def test_configure_logging_json_mode() -> None:
"""JSON mode configures without error."""
configure_logging("json")

View File

@@ -13,7 +13,7 @@ class TestMainModule:
assert app.title == "Smart Support" assert app.title == "Smart Support"
def test_app_version(self) -> None: def test_app_version(self) -> None:
assert app.version == "0.5.0" assert app.version == "0.6.0"
def test_agents_yaml_path_exists(self) -> None: def test_agents_yaml_path_exists(self) -> None:
assert AGENTS_YAML.name == "agents.yaml" assert AGENTS_YAML.name == "agents.yaml"
@@ -36,7 +36,7 @@ class TestMainModule:
def test_health_route_registered(self) -> None: def test_health_route_registered(self) -> None:
routes = [r.path for r in app.routes if hasattr(r, "path")] routes = [r.path for r in app.routes if hasattr(r, "path")]
assert "/api/health" in routes assert "/api/v1/health" in routes
def test_app_version_is_0_5_0(self) -> None: def test_app_version_is_0_5_0(self) -> None:
assert app.version == "0.5.0" assert app.version == "0.6.0"

View File

@@ -0,0 +1,96 @@
"""Tests for app.safety module -- confirmation rules and MCP error taxonomy."""
from __future__ import annotations
import pytest
from app.safety import (
classify_mcp_error,
is_retryable,
max_retries,
requires_confirmation,
)
pytestmark = pytest.mark.unit
class TestRequiresConfirmation:
def test_read_agent_no_override(self) -> None:
result = requires_confirmation(agent_permission="read")
assert result.requires_confirmation is False
def test_write_agent_no_override(self) -> None:
result = requires_confirmation(agent_permission="write")
assert result.requires_confirmation is True
def test_interrupt_override_true(self) -> None:
result = requires_confirmation(
agent_permission="read", needs_interrupt=True,
)
assert result.requires_confirmation is True
def test_interrupt_override_false(self) -> None:
result = requires_confirmation(
agent_permission="write", needs_interrupt=False,
)
assert result.requires_confirmation is False
class TestClassifyMcpError:
@pytest.mark.parametrize("code", [408, 429, 500, 502, 503, 504])
def test_transient_status_codes(self, code: int) -> None:
assert classify_mcp_error(status_code=code) == "transient"
@pytest.mark.parametrize("code", [401, 403])
def test_auth_status_codes(self, code: int) -> None:
assert classify_mcp_error(status_code=code) == "auth"
@pytest.mark.parametrize("code", [400, 404, 422])
def test_validation_status_codes(self, code: int) -> None:
assert classify_mcp_error(status_code=code) == "validation"
def test_unknown_status_code(self) -> None:
assert classify_mcp_error(status_code=200) == "unknown"
def test_timeout_message(self) -> None:
assert classify_mcp_error(error_message="Connection timed out") == "transient"
def test_rate_limit_message(self) -> None:
assert classify_mcp_error(error_message="Rate limit exceeded") == "transient"
def test_unauthorized_message(self) -> None:
assert classify_mcp_error(error_message="Unauthorized access") == "auth"
def test_invalid_message(self) -> None:
assert classify_mcp_error(error_message="Invalid parameter") == "validation"
def test_unknown_message(self) -> None:
assert classify_mcp_error(error_message="Something happened") == "unknown"
def test_status_code_takes_precedence_over_message(self) -> None:
# 429 is transient by code; message would classify as validation
assert classify_mcp_error(status_code=429, error_message="invalid param") == "transient"
def test_non_classified_status_falls_through_to_message(self) -> None:
# 200 is not in any status set, so message classification takes over
assert classify_mcp_error(status_code=200, error_message="timed out") == "transient"
def test_no_args_returns_unknown(self) -> None:
assert classify_mcp_error() == "unknown"
class TestRetryPolicy:
def test_transient_is_retryable(self) -> None:
assert is_retryable("transient") is True
def test_validation_not_retryable(self) -> None:
assert is_retryable("validation") is False
def test_auth_not_retryable(self) -> None:
assert is_retryable("auth") is False
def test_unknown_not_retryable(self) -> None:
assert is_retryable("unknown") is False
def test_max_retries_value(self) -> None:
assert max_retries() == 3

View File

@@ -8,8 +8,10 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import ( from app.ws_handler import (
_extract_interrupt, _extract_interrupt,
_has_interrupt, _has_interrupt,
@@ -25,18 +27,42 @@ def _make_ws() -> AsyncMock:
return ws return ws
def _make_graph() -> AsyncMock: def _make_graph() -> MagicMock:
graph = AsyncMock() graph = AsyncMock()
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
state = MagicMock() state = MagicMock()
state.tasks = () state.tasks = ()
graph.aget_state = AsyncMock(return_value=state) graph.aget_state = AsyncMock(return_value=state)
# Phase 2: graph needs intent_classifier and agent_registry attrs
graph.intent_classifier = None
graph.agent_registry = None
return graph return graph
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
g = graph or _make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
def _make_ws_ctx(
graph_ctx: GraphContext | None = None,
sm: SessionManager | None = None,
cb: TokenUsageCallbackHandler | None = None,
interrupt_manager: InterruptManager | None = None,
analytics_recorder=None,
conversation_tracker=None,
pool=None,
) -> WebSocketContext:
return WebSocketContext(
graph_ctx=graph_ctx or _make_graph_ctx(),
session_manager=sm or SessionManager(),
callback_handler=cb or TokenUsageCallbackHandler(),
interrupt_manager=interrupt_manager,
analytics_recorder=analytics_recorder,
conversation_tracker=conversation_tracker,
pool=pool,
)
class AsyncIterHelper: class AsyncIterHelper:
"""Helper to make a list behave as an async iterator.""" """Helper to make a list behave as an async iterator."""
@@ -57,11 +83,9 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_json(self) -> None: async def test_invalid_json(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
await dispatch_message(ws, graph, sm, cb, "not json") await dispatch_message(ws, ws_ctx, "not json")
ws.send_json.assert_awaited_once() ws.send_json.assert_awaited_once()
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -70,12 +94,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_thread_id(self) -> None: async def test_missing_thread_id(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "message", "content": "hello"}) msg = json.dumps({"type": "message", "content": "hello"})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "thread_id" in call_data["message"] assert "thread_id" in call_data["message"]
@@ -83,24 +105,20 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_content(self) -> None: async def test_missing_content(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "message", "thread_id": "t1"}) msg = json.dumps({"type": "message", "thread_id": "t1"})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unknown_message_type(self) -> None: async def test_unknown_message_type(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "unknown", "thread_id": "t1"}) msg = json.dumps({"type": "unknown", "thread_id": "t1"})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "Unknown" in call_data["message"] assert "Unknown" in call_data["message"]
@@ -108,12 +126,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_too_large(self) -> None: async def test_message_too_large(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
large_msg = "x" * 40_000 large_msg = "x" * 40_000
await dispatch_message(ws, graph, sm, cb, large_msg) await dispatch_message(ws, ws_ctx, large_msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "too large" in call_data["message"].lower() assert "too large" in call_data["message"].lower()
@@ -121,12 +137,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_thread_id_format(self) -> None: async def test_invalid_thread_id_format(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"}) msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "thread_id" in call_data["message"].lower() assert "thread_id" in call_data["message"].lower()
@@ -134,12 +148,10 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_content_too_long(self) -> None: async def test_content_too_long(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() ws_ctx = _make_ws_ctx()
sm = SessionManager()
cb = TokenUsageCallbackHandler()
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "too long" in call_data["message"].lower() assert "too long" in call_data["message"].lower()
@@ -147,14 +159,13 @@ class TestDispatchMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dispatch_with_interrupt_manager(self) -> None: async def test_dispatch_with_interrupt_manager(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
im = InterruptManager() im = InterruptManager()
ws_ctx = _make_ws_ctx(sm=sm, interrupt_manager=im)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im) await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"
@@ -164,14 +175,14 @@ class TestHandleUserMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_expired_session(self) -> None: async def test_expired_session(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager(session_ttl_seconds=0) sm = SessionManager(session_ttl_seconds=0)
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
# First call creates the session (TTL=0) # First call creates the session (TTL=0)
await handle_user_message(ws, graph, sm, cb, "t1", "hello") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
# Second call finds it expired # Second call finds it expired
await handle_user_message(ws, graph, sm, cb, "t1", "hello again") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello again")
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
assert "expired" in call_data["message"].lower() assert "expired" in call_data["message"].lower()
@@ -179,12 +190,12 @@ class TestHandleUserMessage:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_successful_message(self) -> None: async def test_successful_message(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hello") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"
@@ -193,13 +204,12 @@ class TestHandleUserMessage:
ws = _make_ws() ws = _make_ws()
graph = AsyncMock() graph = AsyncMock()
graph.astream = MagicMock(side_effect=RuntimeError("boom")) graph.astream = MagicMock(side_effect=RuntimeError("boom"))
graph.intent_classifier = None graph_ctx = _make_graph_ctx(graph=graph)
graph.agent_registry = None
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hello") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error" assert call_data["type"] == "error"
@@ -207,8 +217,6 @@ class TestHandleUserMessage:
async def test_interrupt_registered_with_manager(self) -> None: async def test_interrupt_registered_with_manager(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = AsyncMock() graph = AsyncMock()
graph.intent_classifier = None
graph.agent_registry = None
graph.astream = MagicMock(return_value=AsyncIterHelper([])) graph.astream = MagicMock(return_value=AsyncIterHelper([]))
# Simulate interrupt in state # Simulate interrupt in state
@@ -220,13 +228,14 @@ class TestHandleUserMessage:
state.tasks = (task,) state.tasks = (task,)
graph.aget_state = AsyncMock(return_value=state) graph.aget_state = AsyncMock(return_value=state)
graph_ctx = _make_graph_ctx(graph=graph)
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
im = InterruptManager() im = InterruptManager()
sm.touch("t1") sm.touch("t1")
await handle_user_message( await handle_user_message(
ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im, ws, graph_ctx, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
) )
# Interrupt should be registered # Interrupt should be registered
@@ -257,16 +266,17 @@ class TestHandleUserMessage:
clarification_question="What do you mean?", clarification_question="What do you mean?",
) )
) )
graph.intent_classifier = mock_classifier
mock_registry = MagicMock() mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=()) mock_registry.list_agents = MagicMock(return_value=())
graph.agent_registry = mock_registry graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
await handle_user_message(ws, graph, sm, cb, "t1", "hmm") await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hmm")
calls = [c[0][0] for c in ws.send_json.call_args_list] calls = [c[0][0] for c in ws.send_json.call_args_list]
clarification_msgs = [c for c in calls if c.get("type") == "clarification"] clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
@@ -279,13 +289,13 @@ class TestHandleInterruptResponse:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_approved_interrupt(self) -> None: async def test_approved_interrupt(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
sm.touch("t1") sm.touch("t1")
sm.extend_for_interrupt("t1") sm.extend_for_interrupt("t1")
await handle_interrupt_response(ws, graph, sm, cb, "t1", True) await handle_interrupt_response(ws, graph_ctx, sm, cb, "t1", True)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"
@@ -294,7 +304,7 @@ class TestHandleInterruptResponse:
from unittest.mock import patch from unittest.mock import patch
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=5) im = InterruptManager(ttl_seconds=5)
@@ -307,7 +317,7 @@ class TestHandleInterruptResponse:
with patch("app.interrupt_manager.time") as mock_time: with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = im._interrupts["t1"].created_at + 10 mock_time.time.return_value = im._interrupts["t1"].created_at + 10
await handle_interrupt_response( await handle_interrupt_response(
ws, graph, sm, cb, "t1", True, interrupt_manager=im ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
) )
call_data = ws.send_json.call_args[0][0] call_data = ws.send_json.call_args[0][0]
@@ -317,7 +327,7 @@ class TestHandleInterruptResponse:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_interrupt_resolves(self) -> None: async def test_valid_interrupt_resolves(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph() graph_ctx = _make_graph_ctx()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() cb = TokenUsageCallbackHandler()
im = InterruptManager(ttl_seconds=1800) im = InterruptManager(ttl_seconds=1800)
@@ -327,7 +337,7 @@ class TestHandleInterruptResponse:
im.register("t1", "cancel_order", {}) im.register("t1", "cancel_order", {})
await handle_interrupt_response( await handle_interrupt_response(
ws, graph, sm, cb, "t1", True, interrupt_manager=im ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
) )
# Interrupt should be resolved # Interrupt should be resolved
@@ -374,19 +384,14 @@ class TestDispatchMessageWithTracking:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_conversation_tracker_called_on_message(self) -> None: async def test_conversation_tracker_called_on_message(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
tracker = AsyncMock() tracker = AsyncMock()
pool = MagicMock() pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message( await dispatch_message(ws, ws_ctx, msg)
ws, graph, sm, cb, msg,
conversation_tracker=tracker,
pool=pool,
)
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1") tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
tracker.record_turn.assert_awaited_once() tracker.record_turn.assert_awaited_once()
@@ -394,53 +399,42 @@ class TestDispatchMessageWithTracking:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_analytics_recorder_called_on_message(self) -> None: async def test_analytics_recorder_called_on_message(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
recorder = AsyncMock() recorder = AsyncMock()
pool = MagicMock() pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, analytics_recorder=recorder, pool=pool)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
await dispatch_message( await dispatch_message(ws, ws_ctx, msg)
ws, graph, sm, cb, msg,
analytics_recorder=recorder,
pool=pool,
)
recorder.record.assert_awaited_once() recorder.record.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tracker_failure_does_not_break_chat(self) -> None: async def test_tracker_failure_does_not_break_chat(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler()
tracker = AsyncMock() tracker = AsyncMock()
tracker.ensure_conversation.side_effect = RuntimeError("DB down") tracker.ensure_conversation.side_effect = RuntimeError("DB down")
pool = MagicMock() pool = MagicMock()
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# Should not raise despite tracker failure # Should not raise despite tracker failure
await dispatch_message( await dispatch_message(ws, ws_ctx, msg)
ws, graph, sm, cb, msg,
conversation_tracker=tracker,
pool=pool,
)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_tracker_no_error(self) -> None: async def test_no_tracker_no_error(self) -> None:
ws = _make_ws() ws = _make_ws()
graph = _make_graph()
sm = SessionManager() sm = SessionManager()
cb = TokenUsageCallbackHandler() ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1") sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"}) msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
# No tracker or recorder passed -- should work fine # No tracker or recorder passed -- should work fine
await dispatch_message(ws, graph, sm, cb, msg) await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0] last_call = ws.send_json.call_args[0][0]
assert last_call["type"] == "message_complete" assert last_call["type"] == "message_complete"

View File

@@ -5,7 +5,7 @@ services:
POSTGRES_DB: smart_support POSTGRES_DB: smart_support
POSTGRES_USER: smart_support POSTGRES_USER: smart_support
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-dev_password} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-dev_password}
# ports: ["5432:5432"] # Uncomment for local dev DB access only ports: ["5433:5432"] # Local dev: expose on 5433 to match backend/.env
volumes: volumes:
- pgdata:/var/lib/postgresql/data - pgdata:/var/lib/postgresql/data
healthcheck: healthcheck:
@@ -41,7 +41,7 @@ services:
postgres: postgres:
condition: service_healthy condition: service_healthy
healthcheck: healthcheck:
test: ["CMD-SHELL", "curl -f http://localhost:8000/api/health || exit 1"] test: ["CMD-SHELL", "curl -f http://localhost:8000/api/v1/health || exit 1"]
interval: 10s interval: 10s
timeout: 5s timeout: 5s
retries: 5 retries: 5

View File

@@ -63,7 +63,7 @@ Smart Support 通过 MCP 协议连接内部系统,将自动化率提升到 60%+
v v
+--------+--------------------+ +--------+--------------------+
| LangGraph Supervisor | (Agent 编排 + 意图路由) | LangGraph Supervisor | (Agent 编排 + 意图路由)
| langgraph-supervisor v1.1 | | langgraph-supervisor 0.0.30+|
+--------+--------------------+ +--------+--------------------+
| |
+----+----+----+----+ +----+----+----+----+
@@ -99,7 +99,12 @@ smart-support/
├── backend/ ├── backend/
│ ├── app/ │ ├── app/
│ │ ├── main.py # FastAPI + WebSocket 入口 │ │ ├── main.py # FastAPI + WebSocket 入口
│ │ ├── graph.py # LangGraph Supervisor 配置 │ │ ├── graph.py # LangGraph Supervisor 构建
│ │ ├── graph_context.py # GraphContext: 图 + 分类器 + 注册表的类型化封装
│ │ ├── ws_handler.py # WebSocket 消息分发 + 速率限制
│ │ ├── ws_context.py # WebSocketContext: WS 依赖包
│ │ ├── auth.py # API Key 认证中间件
│ │ ├── api_utils.py # 共享 API 响应工具 (envelope)
│ │ ├── agents/ # Agent 定义 + 工具绑定 │ │ ├── agents/ # Agent 定义 + 工具绑定
│ │ ├── registry.py # YAML Agent 注册表加载器 │ │ ├── registry.py # YAML Agent 注册表加载器
│ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成 │ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成
@@ -139,7 +144,11 @@ smart-support/
| 模块 | 职责 | | 模块 | 职责 |
|------|------| |------|------|
| main.py | 应用入口, WebSocket 端点, 静态文件服务 | | main.py | 应用入口, WebSocket 端点, 静态文件服务 |
| WebSocket Handler | 双向通信: 接收用户消息, 流式返回 token, 处理 interrupt 响应 | | auth.py | API Key 认证: 管理端点通过 `X-API-Key` header, WebSocket 通过 `?token=` query param |
| ws_handler.py | 双向通信: 接收用户消息, 流式返回 token, 处理 interrupt 响应 |
| graph_context.py | 类型化封装: 将编译后的图与分类器、注册表绑定, 替代猴子补丁 |
| ws_context.py | 依赖包: 将 WebSocket 处理所需的 9 个依赖打包为单一不可变对象 |
| api_utils.py | 共享响应格式: 统一的 `envelope()` 函数 |
### 2.3 Agent 编排层 (LangGraph) ### 2.3 Agent 编排层 (LangGraph)
@@ -315,7 +324,7 @@ Agent 调用写操作工具
|------|------|------| |------|------|------|
| 语言 | Python 3.11+ | LangGraph/LangChain 生态首选语言, Agent 开发最成熟 | | 语言 | Python 3.11+ | LangGraph/LangChain 生态首选语言, Agent 开发最成熟 |
| Web 框架 | FastAPI | 原生 async, WebSocket 支持, 性能优秀 | | Web 框架 | FastAPI | 原生 async, WebSocket 支持, 性能优秀 |
| Agent 编排 | LangGraph v1.1 + langgraph-supervisor | 内置 supervisor 模式, 中间件支持, 不重复造轮子 | | Agent 编排 | LangGraph 1.x + langgraph-supervisor | 内置 supervisor 模式, 中间件支持, 不重复造轮子 |
| MCP 集成 | langchain-mcp-adapters | MultiServerMCPClient 管理多 MCP 连接 | | MCP 集成 | langchain-mcp-adapters | MultiServerMCPClient 管理多 MCP 连接 |
| 本地工具 | LangChain @tool 装饰器 | 简单 Python 函数即工具, 无 MCP 开销 | | 本地工具 | LangChain @tool 装饰器 | 简单 Python 函数即工具, 无 MCP 开销 |
| 状态持久化 | PostgresSaver (langgraph-checkpoint-postgres v3.0.5) | 从第一天起用 PostgreSQL, 支持回放/分析查询 | | 状态持久化 | PostgresSaver (langgraph-checkpoint-postgres v3.0.5) | 从第一天起用 PostgreSQL, 支持回放/分析查询 |
@@ -427,6 +436,19 @@ CREATE INDEX idx_interrupts_ttl ON interrupts(ttl_expires_at)
WHERE status = 'pending'; WHERE status = 'pending';
``` ```
#### sessions (自定义 - 会话状态持久化)
```sql
-- 用于多 worker 部署的 PostgreSQL 会话状态管理
-- PgSessionManager 使用此表替代内存中的 dict
CREATE TABLE sessions (
thread_id TEXT PRIMARY KEY,
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
```
#### analytics_events (自定义 - 分析事件流) #### analytics_events (自定义 - 分析事件流)
```sql ```sql

View File

@@ -9,33 +9,38 @@ specialist with a specific role, permission level, and set of tools it can call.
```yaml ```yaml
agents: agents:
- name: order_agent - name: order_lookup
description: "Handles order status, tracking, and cancellations." description: "Looks up order status and tracking information."
permission: write permission: read
tools: tools:
- get_order_status - get_order_status
- cancel_order - get_tracking_info
personality: personality:
tone: friendly tone: "friendly and informative"
greeting: "I can help with your order. What is your order number?" greeting: "I can help you check your order status!"
escalation_message: "I'm escalating this to a human agent now." escalation_message: "Let me connect you with our support team."
- name: refund_agent - name: order_actions
description: "Processes refund requests." description: "Performs order modifications like cancellations."
permission: write permission: write
tools: tools:
- process_refund - cancel_order
- check_refund_eligibility personality:
personality: tone: "careful and reassuring"
tone: empathetic greeting: "I can help you with order changes."
greeting: "I'm the refund specialist. How can I help?" escalation_message: "I'll connect you with a specialist."
escalation_message: "I need to escalate this refund request."
- name: discount
- name: general_agent description: "Applies discounts and generates coupon codes."
description: "Answers general questions and FAQs." permission: write
tools:
- apply_discount
- generate_coupon
- name: fallback
description: "Handles general questions and unclear requests."
permission: read permission: read
tools: tools:
- search_faq
- fallback_respond - fallback_respond
``` ```
@@ -52,31 +57,38 @@ user messages to the right agent. Be specific.
Controls the interrupt threshold: Controls the interrupt threshold:
- `read` -- no interrupt required. Agent can act immediately. - `read` -- no interrupt required. Agent can act immediately.
- `write` -- requires human approval via interrupt before executing tools. - `write` -- requires human approval via interrupt before executing tools.
- `admin` -- requires human approval and is logged for audit.
### `tools` (required) ### `tools` (required)
List of tool names this agent can use. Tools are registered in the agent factory. List of tool names this agent can use. Tools are registered in `backend/app/agents/`.
Each tool name must match a registered LangChain tool. Each tool name must match a registered LangChain tool.
Available built-in tools:
- `get_order_status` -- look up order details
- `get_tracking_info` -- get shipping/tracking info
- `cancel_order` -- cancel an order (write)
- `apply_discount` -- apply a discount code (write)
- `generate_coupon` -- generate a new coupon (write)
- `fallback_respond` -- generic fallback response
### `personality` (optional) ### `personality` (optional)
Customizes agent behavior: Customizes agent behavior:
- `tone` -- `friendly`, `formal`, `empathetic`, `technical` - `tone` -- free-text description of the agent's communication style
- `greeting` -- Opening message injected at session start. - `greeting` -- opening message injected at session start
- `escalation_message` -- Message sent when the agent escalates. - `escalation_message` -- message sent when the agent escalates
## Built-in Templates ## Built-in Templates
Use `TEMPLATE_NAME` environment variable to load a pre-built agent configuration: Use `TEMPLATE_NAME` environment variable to load a pre-built agent configuration:
| Template | Description | | Template | Filename | Description |
|----------|-------------| |----------|----------|-------------|
| `ecommerce` | Orders, refunds, shipping, product questions | | `e-commerce` | `templates/e-commerce.yaml` | Orders, shipping, discounts |
| `saas` | Account management, billing, technical support | | `saas` | `templates/saas.yaml` | Account management, billing, support |
| `generic` | General-purpose FAQ and escalation | | `fintech` | `templates/fintech.yaml` | Financial services support |
Example: Example:
```bash ```bash
TEMPLATE_NAME=ecommerce uvicorn app.main:app TEMPLATE_NAME=e-commerce uvicorn app.main:app
``` ```
## Adding New Agents ## Adding New Agents
@@ -98,7 +110,7 @@ matches the agent's description.
## Escalation ## Escalation
Any agent can trigger escalation by calling the `escalate` tool. This: Any agent can trigger escalation. This:
1. Sends a webhook notification (if `WEBHOOK_URL` is configured). 1. Sends a webhook notification (if `WEBHOOK_URL` is configured).
2. Marks the conversation with `resolution_type = escalated`. 2. Marks the conversation with `resolution_type = escalated`.
3. Sends the agent's `escalation_message` to the user. 3. Sends the agent's `escalation_message` to the user.

View File

@@ -46,7 +46,7 @@ Navigate to http://localhost in your browser.
1. Open the Chat tab (default). 1. Open the Chat tab (default).
2. Send: **"What is the status of order 12345?"** 2. Send: **"What is the status of order 12345?"**
- Observe the `tool_call` indicator appear in the sidebar (order_agent calling `get_order_status`). - Observe the `tool_call` indicator appear in the sidebar (`order_lookup` calling `get_order_status`).
- The agent responds with order status. - The agent responds with order status.
3. Send: **"Can you cancel that order?"** 3. Send: **"Can you cancel that order?"**
- The system detects a write operation and shows an **Interrupt Prompt**. - The system detects a write operation and shows an **Interrupt Prompt**.
@@ -61,15 +61,15 @@ Key points to highlight:
### Scene 2: Multi-Agent Routing (2 minutes) ### Scene 2: Multi-Agent Routing (2 minutes)
1. Start a new browser tab (new session) or clear session storage. 1. Start a new browser tab (new session) or clear session storage.
2. Send: **"I need to track my order AND request a refund for a previous order"** 2. Send: **"I need to check order 12345 AND cancel order 67890"**
- The supervisor detects two intents: `order_agent` and `refund_agent`. - The supervisor detects two intents: `order_lookup` (read) and `order_actions` (write).
- Both agents run in sequence. - Both agents run in sequence.
- Two interrupt prompts may appear if both operations are write-level. - The cancellation triggers an interrupt prompt for human approval.
Key points to highlight: Key points to highlight:
- Intent classification detecting multiple actions - Intent classification detecting multiple actions
- Automatic routing to appropriate specialist agents - Automatic routing to appropriate specialist agents
- Sequential execution with confirmation gates - Sequential execution with confirmation gates for write operations
### Scene 3: Conversation Replay (2 minutes) ### Scene 3: Conversation Replay (2 minutes)

View File

@@ -54,11 +54,19 @@ Set these in production (never commit secrets):
| `ANTHROPIC_API_KEY` | Yes* | LLM provider API key | | `ANTHROPIC_API_KEY` | Yes* | LLM provider API key |
| `LLM_PROVIDER` | Yes | `anthropic`, `openai`, or `google` | | `LLM_PROVIDER` | Yes | `anthropic`, `openai`, or `google` |
| `LLM_MODEL` | Yes | Model name for your provider | | `LLM_MODEL` | Yes | Model name for your provider |
| `ADMIN_API_KEY` | Recommended | API key for admin endpoints (analytics, replay, openapi, WS). Leave empty to disable auth (dev mode only) |
| `WEBHOOK_URL` | No | Escalation notification endpoint | | `WEBHOOK_URL` | No | Escalation notification endpoint |
| `SESSION_TTL_MINUTES` | No | Session timeout (default: 30) | | `SESSION_TTL_MINUTES` | No | Session timeout (default: 30) |
*Or `OPENAI_API_KEY` / `GOOGLE_API_KEY` depending on `LLM_PROVIDER`. *Or `OPENAI_API_KEY` / `GOOGLE_API_KEY` depending on `LLM_PROVIDER`.
### Authentication
When `ADMIN_API_KEY` is set, all admin REST endpoints require the `X-API-Key` header,
and WebSocket connections require a `?token=<key>` query parameter.
When unset or empty, authentication is disabled (suitable for local development only).
### HTTPS ### HTTPS
For production, place a reverse proxy (nginx, Caddy, or a load balancer) in For production, place a reverse proxy (nginx, Caddy, or a load balancer) in
@@ -87,10 +95,12 @@ cat backup.sql | docker compose exec -T postgres psql -U smart_support smart_sup
### Scaling ### Scaling
The backend is stateless (session state is in PostgreSQL via LangGraph's The backend supports multi-worker deployments. LangGraph session state is
PostgresSaver). You can run multiple backend replicas behind a load balancer. persisted in PostgreSQL via PostgresSaver. For full horizontal scaling, use
`PgSessionManager` and `PgInterruptManager` (instead of the default in-memory
managers) to share session and interrupt state across workers.
The WebSocket connections are session-specific. Use sticky sessions or a shared WebSocket connections are session-specific. Use sticky sessions or a shared
session backend if load balancing WebSockets across multiple instances. session backend if load balancing WebSockets across multiple instances.
## Manual / Development Setup ## Manual / Development Setup
@@ -139,7 +149,7 @@ GET /api/health
Response: Response:
```json ```json
{"status": "ok", "version": "0.5.0"} {"status": "ok", "version": "0.6.0"}
``` ```
### WebSocket health ### WebSocket health

View File

@@ -10,7 +10,8 @@ Import a URL, review the AI-classified endpoints, approve, and your agents are l
1. **Import** -- Provide a URL to an OpenAPI 3.0 spec (JSON or YAML). 1. **Import** -- Provide a URL to an OpenAPI 3.0 spec (JSON or YAML).
2. **Parse** -- The system downloads and parses the spec. 2. **Parse** -- The system downloads and parses the spec.
3. **Classify** -- An LLM classifies each endpoint's: 3. **Classify** -- An LLM classifies each endpoint's:
- `access_type`: `read`, `write`, or `admin` - `access_type`: `read` or `write`
- `needs_interrupt`: whether human approval is required
- `agent_group`: which specialist agent should handle this endpoint - `agent_group`: which specialist agent should handle this endpoint
4. **Review** -- You inspect and edit the classifications in the UI. 4. **Review** -- You inspect and edit the classifications in the UI.
5. **Approve** -- Approved endpoints are registered as tools on the appropriate agents. 5. **Approve** -- Approved endpoints are registered as tools on the appropriate agents.
@@ -19,12 +20,12 @@ Import a URL, review the AI-classified endpoints, approve, and your agents are l
1. Navigate to the **API Review** tab. 1. Navigate to the **API Review** tab.
2. Paste your OpenAPI spec URL into the import form. 2. Paste your OpenAPI spec URL into the import form.
3. Click **Import**. 3. Click **Scan Tools**.
4. Wait for the job to complete (status: `pending` -> `processing` -> `done`). 4. Wait for the job to complete (status: `pending` -> `processing` -> `done`).
5. Review the endpoint table: 5. Review the endpoint cards grouped by agent:
- Edit `access_type` if the AI misclassified sensitivity. - Edit `access_type` (Read Only / Write) if the AI misclassified sensitivity.
- Edit `agent_group` to reassign an endpoint to a different agent. - Edit the agent assignment to reassign an endpoint to a different agent.
6. Click **Approve & Save** when satisfied. 6. Click **Save Configuration** when satisfied.
## Using the REST API ## Using the REST API
@@ -39,11 +40,15 @@ Content-Type: application/json
} }
``` ```
Response: Response (202):
```json ```json
{ {
"success": true, "job_id": "abc123",
"data": { "job_id": "abc123", "status": "pending" } "status": "pending",
"spec_url": "https://api.example.com/openapi.yaml",
"total_endpoints": 0,
"classified_count": 0,
"error_message": null
} }
``` ```
@@ -53,27 +58,47 @@ Response:
GET /api/openapi/jobs/{job_id} GET /api/openapi/jobs/{job_id}
``` ```
### Get job results ### Get classifications
```http ```http
GET /api/openapi/jobs/{job_id}/result GET /api/openapi/jobs/{job_id}/classifications
```
Response: array of classification objects with `index`, `access_type`,
`needs_interrupt`, `agent_group`, `confidence`, `customer_params`, and `endpoint`.
### Update a classification
```http
PUT /api/openapi/jobs/{job_id}/classifications/{index}
Content-Type: application/json
{
"access_type": "write",
"needs_interrupt": true,
"agent_group": "order_actions"
}
``` ```
### Approve job ### Approve job
```http ```http
POST /api/openapi/jobs/{job_id}/approve POST /api/openapi/jobs/{job_id}/approve
Content-Type: application/json ```
No request body. Generates tool code for each classified endpoint and produces
an agent YAML configuration. Response includes `generated_tools_count`.
Response:
```json
{ {
"endpoints": [ "job_id": "abc123",
{ "status": "approved",
"path": "/orders/{order_id}", "spec_url": "https://api.example.com/openapi.yaml",
"method": "get", "total_endpoints": 5,
"access_type": "read", "classified_count": 5,
"agent_group": "order_agent" "error_message": null,
} "generated_tools_count": 5
]
} }
``` ```
@@ -82,18 +107,17 @@ Content-Type: application/json
| Access Type | Description | Interrupt Required | | Access Type | Description | Interrupt Required |
|-------------|-------------|-------------------| |-------------|-------------|-------------------|
| `read` | GET operations, no side effects | No | | `read` | GET operations, no side effects | No |
| `write` | POST/PUT/PATCH that modify data | Yes | | `write` | POST/PUT/PATCH/DELETE that modify data | Yes (by default) |
| `admin` | DELETE, bulk operations, sensitive writes | Yes |
The `needs_interrupt` flag can be overridden per-endpoint during review.
## SSRF Protection ## SSRF Protection
All import requests are validated against an allowlist: All import requests are validated:
- Private IP ranges are blocked (10.x, 172.16.x, 192.168.x, 127.x) - Private IP ranges are blocked (10.x, 172.16.x, 192.168.x, 127.x)
- Localhost and metadata service URLs are blocked - Localhost and cloud metadata service URLs are blocked
- Only `http://` and `https://` schemes are permitted - Only `http://` and `https://` schemes are permitted
To allow internal URLs (e.g., in development), set `SSRF_ALLOWLIST_HOSTS` in your environment.
## Supported Spec Formats ## Supported Spec Formats
- OpenAPI 3.0.x (JSON or YAML) - OpenAPI 3.0.x (JSON or YAML)

View File

@@ -0,0 +1,76 @@
# Engineering Improvements -- Development Log
> Status: COMPLETED
> Branch: `eng/engineering-improvements`
> Date started: 2026-04-06
> Date completed: 2026-04-06
## What Was Built
### Phase 1: Quick Wins (no new deps)
1. **Interrupt Cleanup Background Task** -- Added asyncio background task in lifespan that calls `interrupt_manager.cleanup_expired()` every 60 seconds. Prevents unbounded memory growth from expired interrupts.
2. **API Versioning** -- All REST endpoints prefixed with `/api/v1/` (was `/api/`). Updated 4 router prefixes, Docker healthcheck, all frontend fetch URLs, and all test assertions. WebSocket `/ws` endpoint unchanged.
3. **Error Response Standardization** -- Added global exception handlers for `HTTPException`, `RequestValidationError`, and `Exception`. All error responses now use the same envelope format as success responses: `{"success": false, "data": null, "error": "..."}`.
### Phase 2: Medium Items (new deps)
4. **Alembic Database Migrations** -- Replaced inline DDL in `setup_app_tables()` with versioned Alembic migrations. Initial migration `001_initial_schema.py` captures all 4 tables + ALTER TABLE migration. `setup_app_tables()` preserved for tests. Production uses `run_alembic_migrations()`.
5. **Structured Logging** -- Replaced stdlib `logging.getLogger()` with `structlog.get_logger()` across 10 files. Added `logging_config.py` with console (dev) and JSON (production) modes. Configurable via `LOG_FORMAT` env var.
### Phase 3: Test Coverage
7. **Integration Tests (+30)** -- Created 5 new test files: analytics API, replay API, OpenAPI API, error responses, session/interrupt lifecycle. Uses httpx.AsyncClient with ASGITransport for full API layer testing.
8. **Frontend Tests (+57)** -- Created 12 new test files covering all components (ChatInput, ChatMessages, InterruptPrompt, ErrorBanner, NavBar, MetricCard, ReplayTimeline, AgentAction, Layout), pages (ChatPage, ReviewPage), and hooks (useWebSocket).
## Code Structure
### New files created
- `backend/app/logging_config.py` -- structlog configuration
- `backend/alembic.ini` -- Alembic config
- `backend/alembic/env.py` -- Migration environment
- `backend/alembic/versions/001_initial_schema.py` -- Initial migration
- `backend/tests/unit/test_interrupt_cleanup.py` (3 tests)
- `backend/tests/unit/test_error_responses.py` (6 tests)
- `backend/tests/unit/test_logging_config.py` (2 tests)
- `backend/tests/integration/test_analytics_api.py` (6 tests)
- `backend/tests/integration/test_replay_api.py` (6 tests)
- `backend/tests/integration/test_openapi_api.py` (5 tests)
- `backend/tests/integration/test_error_responses.py` (5 tests)
- `backend/tests/integration/test_session_interrupt_lifecycle.py` (8 tests)
- 12 frontend test files (57 tests total)
### Modified files
- `backend/app/main.py` -- cleanup task, exception handlers, alembic, structlog
- `backend/app/db.py` -- added run_alembic_migrations()
- `backend/app/config.py` -- added log_format setting
- `backend/pyproject.toml` -- added alembic, structlog deps
- 4 router files -- `/api/v1/` prefix
- 10 files -- structlog migration
- `docker-compose.yml` -- healthcheck URL
- `frontend/src/api.ts` -- `/api/v1/` URLs
- All existing test files -- API path updates + error envelope assertions
## Test Coverage
- Backend: 557 tests (was 516), 89.75% coverage
- Unit: ~490 tests
- Integration: ~60 tests
- E2E: ~7 tests
- Frontend: 80 tests (was 23), 16 test files (was 4)
## Deviations from Plan
- Redis rate limiting deferred (single-worker sufficient for now)
- ConversationTracker verified correct by design (pool per-method), skipped
- Coverage dropped slightly from 90.26% to 89.75% due to new alembic/logging modules with partial test coverage (still well above 80% threshold)
## Known Issues / Tech Debt
- Rate limiting remains process-global (needs Redis for multi-worker)
- Alembic migrations not tested against real PostgreSQL in CI (would need running DB)
- Frontend test coverage could be deeper (e.g., WebSocket reconnect edge cases)

View File

@@ -104,8 +104,8 @@ smart-support/
## Tech Stack ## Tech Stack
- Python 3.11+, FastAPI, LangGraph v1.1.0 - Python 3.11+, FastAPI, LangGraph 1.x (currently 1.1.6)
- langgraph-supervisor, langchain-mcp-adapters, langgraph-checkpoint-postgres v3.0.5 - langgraph-supervisor 0.0.31, langchain-mcp-adapters, langgraph-checkpoint-postgres v3.0.5
- React (frontend), PostgreSQL 16 (via Docker Compose) - React (frontend), PostgreSQL 16 (via Docker Compose)
- Claude Sonnet 4.6 via `ChatAnthropic` (configurable via env) - Claude Sonnet 4.6 via `ChatAnthropic` (configurable via env)
- pytest + FastAPI TestClient for backend tests - pytest + FastAPI TestClient for backend tests

View File

@@ -14,13 +14,24 @@
"react-router-dom": "^7.13.2" "react-router-dom": "^7.13.2"
}, },
"devDependencies": { "devDependencies": {
"@testing-library/jest-dom": "^6.9.1",
"@testing-library/react": "^16.3.2",
"@types/react": "^19.0.0", "@types/react": "^19.0.0",
"@types/react-dom": "^19.0.0", "@types/react-dom": "^19.0.0",
"@vitejs/plugin-react": "^4.3.0", "@vitejs/plugin-react": "^4.3.0",
"happy-dom": "^20.8.9",
"typescript": "~5.7.0", "typescript": "~5.7.0",
"vite": "^6.2.0" "vite": "^6.2.0",
"vitest": "^4.1.2"
} }
}, },
"node_modules/@adobe/css-tools": {
"version": "4.4.4",
"resolved": "https://registry.npmjs.org/@adobe/css-tools/-/css-tools-4.4.4.tgz",
"integrity": "sha512-Elp+iwUx5rN5+Y8xLt5/GRoG20WGoDCQ/1Fb+1LiGtvwbDavuSk0jhD/eZdckHAuzcDzccnkv+rEjyWfRx18gg==",
"dev": true,
"license": "MIT"
},
"node_modules/@babel/code-frame": { "node_modules/@babel/code-frame": {
"version": "7.29.0", "version": "7.29.0",
"resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.29.0.tgz", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.29.0.tgz",
@@ -255,6 +266,16 @@
"@babel/core": "^7.0.0-0" "@babel/core": "^7.0.0-0"
} }
}, },
"node_modules/@babel/runtime": {
"version": "7.29.2",
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.29.2.tgz",
"integrity": "sha512-JiDShH45zKHWyGe4ZNVRrCjBz8Nh9TMmZG1kh4QTK8hCBTWBi8Da+i7s1fJw7/lYpM4ccepSNfqzZ/QvABBi5g==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=6.9.0"
}
},
"node_modules/@babel/template": { "node_modules/@babel/template": {
"version": "7.28.6", "version": "7.28.6",
"resolved": "https://registry.npmjs.org/@babel/template/-/template-7.28.6.tgz", "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.28.6.tgz",
@@ -1152,6 +1173,97 @@
"win32" "win32"
] ]
}, },
"node_modules/@standard-schema/spec": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz",
"integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==",
"dev": true,
"license": "MIT"
},
"node_modules/@testing-library/dom": {
"version": "10.4.1",
"resolved": "https://registry.npmjs.org/@testing-library/dom/-/dom-10.4.1.tgz",
"integrity": "sha512-o4PXJQidqJl82ckFaXUeoAW+XysPLauYI43Abki5hABd853iMhitooc6znOnczgbTYmEP6U6/y1ZyKAIsvMKGg==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@babel/code-frame": "^7.10.4",
"@babel/runtime": "^7.12.5",
"@types/aria-query": "^5.0.1",
"aria-query": "5.3.0",
"dom-accessibility-api": "^0.5.9",
"lz-string": "^1.5.0",
"picocolors": "1.1.1",
"pretty-format": "^27.0.2"
},
"engines": {
"node": ">=18"
}
},
"node_modules/@testing-library/jest-dom": {
"version": "6.9.1",
"resolved": "https://registry.npmjs.org/@testing-library/jest-dom/-/jest-dom-6.9.1.tgz",
"integrity": "sha512-zIcONa+hVtVSSep9UT3jZ5rizo2BsxgyDYU7WFD5eICBE7no3881HGeb/QkGfsJs6JTkY1aQhT7rIPC7e+0nnA==",
"dev": true,
"license": "MIT",
"dependencies": {
"@adobe/css-tools": "^4.4.0",
"aria-query": "^5.0.0",
"css.escape": "^1.5.1",
"dom-accessibility-api": "^0.6.3",
"picocolors": "^1.1.1",
"redent": "^3.0.0"
},
"engines": {
"node": ">=14",
"npm": ">=6",
"yarn": ">=1"
}
},
"node_modules/@testing-library/jest-dom/node_modules/dom-accessibility-api": {
"version": "0.6.3",
"resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.6.3.tgz",
"integrity": "sha512-7ZgogeTnjuHbo+ct10G9Ffp0mif17idi0IyWNVA/wcwcm7NPOD/WEHVP3n7n3MhXqxoIYm8d6MuZohYWIZ4T3w==",
"dev": true,
"license": "MIT"
},
"node_modules/@testing-library/react": {
"version": "16.3.2",
"resolved": "https://registry.npmjs.org/@testing-library/react/-/react-16.3.2.tgz",
"integrity": "sha512-XU5/SytQM+ykqMnAnvB2umaJNIOsLF3PVv//1Ew4CTcpz0/BRyy/af40qqrt7SjKpDdT1saBMc42CUok5gaw+g==",
"dev": true,
"license": "MIT",
"dependencies": {
"@babel/runtime": "^7.12.5"
},
"engines": {
"node": ">=18"
},
"peerDependencies": {
"@testing-library/dom": "^10.0.0",
"@types/react": "^18.0.0 || ^19.0.0",
"@types/react-dom": "^18.0.0 || ^19.0.0",
"react": "^18.0.0 || ^19.0.0",
"react-dom": "^18.0.0 || ^19.0.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@types/aria-query": {
"version": "5.0.4",
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz",
"integrity": "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==",
"dev": true,
"license": "MIT",
"peer": true
},
"node_modules/@types/babel__core": { "node_modules/@types/babel__core": {
"version": "7.20.5", "version": "7.20.5",
"resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz",
@@ -1197,6 +1309,17 @@
"@babel/types": "^7.28.2" "@babel/types": "^7.28.2"
} }
}, },
"node_modules/@types/chai": {
"version": "5.2.3",
"resolved": "https://registry.npmjs.org/@types/chai/-/chai-5.2.3.tgz",
"integrity": "sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==",
"dev": true,
"license": "MIT",
"dependencies": {
"@types/deep-eql": "*",
"assertion-error": "^2.0.1"
}
},
"node_modules/@types/debug": { "node_modules/@types/debug": {
"version": "4.1.13", "version": "4.1.13",
"resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.13.tgz", "resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.13.tgz",
@@ -1206,6 +1329,13 @@
"@types/ms": "*" "@types/ms": "*"
} }
}, },
"node_modules/@types/deep-eql": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/@types/deep-eql/-/deep-eql-4.0.2.tgz",
"integrity": "sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==",
"dev": true,
"license": "MIT"
},
"node_modules/@types/estree": { "node_modules/@types/estree": {
"version": "1.0.8", "version": "1.0.8",
"resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz",
@@ -1245,6 +1375,16 @@
"integrity": "sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA==", "integrity": "sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA==",
"license": "MIT" "license": "MIT"
}, },
"node_modules/@types/node": {
"version": "25.5.2",
"resolved": "https://registry.npmjs.org/@types/node/-/node-25.5.2.tgz",
"integrity": "sha512-tO4ZIRKNC+MDWV4qKVZe3Ql/woTnmHDr5JD8UI5hn2pwBrHEwOEMZK7WlNb5RKB6EoJ02gwmQS9OrjuFnZYdpg==",
"dev": true,
"license": "MIT",
"dependencies": {
"undici-types": "~7.18.0"
}
},
"node_modules/@types/react": { "node_modules/@types/react": {
"version": "19.2.14", "version": "19.2.14",
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.14.tgz", "resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.14.tgz",
@@ -1270,6 +1410,23 @@
"integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==", "integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==",
"license": "MIT" "license": "MIT"
}, },
"node_modules/@types/whatwg-mimetype": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/@types/whatwg-mimetype/-/whatwg-mimetype-3.0.2.tgz",
"integrity": "sha512-c2AKvDT8ToxLIOUlN51gTiHXflsfIFisS4pO7pDPoKouJCESkhZnEy623gwP9laCy5lnLDAw1vAzu2vM2YLOrA==",
"dev": true,
"license": "MIT"
},
"node_modules/@types/ws": {
"version": "8.18.1",
"resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz",
"integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==",
"dev": true,
"license": "MIT",
"dependencies": {
"@types/node": "*"
}
},
"node_modules/@ungap/structured-clone": { "node_modules/@ungap/structured-clone": {
"version": "1.3.0", "version": "1.3.0",
"resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.3.0.tgz", "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.3.0.tgz",
@@ -1297,6 +1454,164 @@
"vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" "vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0"
} }
}, },
"node_modules/@vitest/expect": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-4.1.2.tgz",
"integrity": "sha512-gbu+7B0YgUJ2nkdsRJrFFW6X7NTP44WlhiclHniUhxADQJH5Szt9mZ9hWnJPJ8YwOK5zUOSSlSvyzRf0u1DSBQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"@standard-schema/spec": "^1.1.0",
"@types/chai": "^5.2.2",
"@vitest/spy": "4.1.2",
"@vitest/utils": "4.1.2",
"chai": "^6.2.2",
"tinyrainbow": "^3.1.0"
},
"funding": {
"url": "https://opencollective.com/vitest"
}
},
"node_modules/@vitest/mocker": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-4.1.2.tgz",
"integrity": "sha512-Ize4iQtEALHDttPRCmN+FKqOl2vxTiNUhzobQFFt/BM1lRUTG7zRCLOykG/6Vo4E4hnUdfVLo5/eqKPukcWW7Q==",
"dev": true,
"license": "MIT",
"dependencies": {
"@vitest/spy": "4.1.2",
"estree-walker": "^3.0.3",
"magic-string": "^0.30.21"
},
"funding": {
"url": "https://opencollective.com/vitest"
},
"peerDependencies": {
"msw": "^2.4.9",
"vite": "^6.0.0 || ^7.0.0 || ^8.0.0"
},
"peerDependenciesMeta": {
"msw": {
"optional": true
},
"vite": {
"optional": true
}
}
},
"node_modules/@vitest/pretty-format": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-4.1.2.tgz",
"integrity": "sha512-dwQga8aejqeuB+TvXCMzSQemvV9hNEtDDpgUKDzOmNQayl2OG241PSWeJwKRH3CiC+sESrmoFd49rfnq7T4RnA==",
"dev": true,
"license": "MIT",
"dependencies": {
"tinyrainbow": "^3.1.0"
},
"funding": {
"url": "https://opencollective.com/vitest"
}
},
"node_modules/@vitest/runner": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-4.1.2.tgz",
"integrity": "sha512-Gr+FQan34CdiYAwpGJmQG8PgkyFVmARK8/xSijia3eTFgVfpcpztWLuP6FttGNfPLJhaZVP/euvujeNYar36OQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"@vitest/utils": "4.1.2",
"pathe": "^2.0.3"
},
"funding": {
"url": "https://opencollective.com/vitest"
}
},
"node_modules/@vitest/snapshot": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-4.1.2.tgz",
"integrity": "sha512-g7yfUmxYS4mNxk31qbOYsSt2F4m1E02LFqO53Xpzg3zKMhLAPZAjjfyl9e6z7HrW6LvUdTwAQR3HHfLjpko16A==",
"dev": true,
"license": "MIT",
"dependencies": {
"@vitest/pretty-format": "4.1.2",
"@vitest/utils": "4.1.2",
"magic-string": "^0.30.21",
"pathe": "^2.0.3"
},
"funding": {
"url": "https://opencollective.com/vitest"
}
},
"node_modules/@vitest/spy": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-4.1.2.tgz",
"integrity": "sha512-DU4fBnbVCJGNBwVA6xSToNXrkZNSiw59H8tcuUspVMsBDBST4nfvsPsEHDHGtWRRnqBERBQu7TrTKskmjqTXKA==",
"dev": true,
"license": "MIT",
"funding": {
"url": "https://opencollective.com/vitest"
}
},
"node_modules/@vitest/utils": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-4.1.2.tgz",
"integrity": "sha512-xw2/TiX82lQHA06cgbqRKFb5lCAy3axQ4H4SoUFhUsg+wztiet+co86IAMDtF6Vm1hc7J6j09oh/rgDn+JdKIQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"@vitest/pretty-format": "4.1.2",
"convert-source-map": "^2.0.0",
"tinyrainbow": "^3.1.0"
},
"funding": {
"url": "https://opencollective.com/vitest"
}
},
"node_modules/ansi-regex": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz",
"integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=8"
}
},
"node_modules/ansi-styles": {
"version": "5.2.0",
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz",
"integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=10"
},
"funding": {
"url": "https://github.com/chalk/ansi-styles?sponsor=1"
}
},
"node_modules/aria-query": {
"version": "5.3.0",
"resolved": "https://registry.npmjs.org/aria-query/-/aria-query-5.3.0.tgz",
"integrity": "sha512-b0P0sZPKtyu8HkeRAfCq0IfURZK+SuwMjY1UXGBU27wpAiTwQAIlq56IbIO+ytk/JjS1fMR14ee5WBBfKi5J6A==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"dequal": "^2.0.3"
}
},
"node_modules/assertion-error": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-2.0.1.tgz",
"integrity": "sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=12"
}
},
"node_modules/bail": { "node_modules/bail": {
"version": "2.0.2", "version": "2.0.2",
"resolved": "https://registry.npmjs.org/bail/-/bail-2.0.2.tgz", "resolved": "https://registry.npmjs.org/bail/-/bail-2.0.2.tgz",
@@ -1385,6 +1700,16 @@
"url": "https://github.com/sponsors/wooorm" "url": "https://github.com/sponsors/wooorm"
} }
}, },
"node_modules/chai": {
"version": "6.2.2",
"resolved": "https://registry.npmjs.org/chai/-/chai-6.2.2.tgz",
"integrity": "sha512-NUPRluOfOiTKBKvWPtSD4PhFvWCqOi0BGStNWs57X9js7XGTprSmFoz5F0tWhR4WPjNeR9jXqdC7/UpSJTnlRg==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=18"
}
},
"node_modules/character-entities": { "node_modules/character-entities": {
"version": "2.0.2", "version": "2.0.2",
"resolved": "https://registry.npmjs.org/character-entities/-/character-entities-2.0.2.tgz", "resolved": "https://registry.npmjs.org/character-entities/-/character-entities-2.0.2.tgz",
@@ -1455,6 +1780,13 @@
"url": "https://opencollective.com/express" "url": "https://opencollective.com/express"
} }
}, },
"node_modules/css.escape": {
"version": "1.5.1",
"resolved": "https://registry.npmjs.org/css.escape/-/css.escape-1.5.1.tgz",
"integrity": "sha512-YUifsXXuknHlUsmlgyY0PKzgPOr7/FjCePfHNt0jxm83wHZi44VDMQ7/fGNkjY3/jV1MC+1CmZbaHzugyeRtpg==",
"dev": true,
"license": "MIT"
},
"node_modules/csstype": { "node_modules/csstype": {
"version": "3.2.3", "version": "3.2.3",
"resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz",
@@ -1513,6 +1845,14 @@
"url": "https://github.com/sponsors/wooorm" "url": "https://github.com/sponsors/wooorm"
} }
}, },
"node_modules/dom-accessibility-api": {
"version": "0.5.16",
"resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.5.16.tgz",
"integrity": "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==",
"dev": true,
"license": "MIT",
"peer": true
},
"node_modules/electron-to-chromium": { "node_modules/electron-to-chromium": {
"version": "1.5.328", "version": "1.5.328",
"resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.328.tgz", "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.328.tgz",
@@ -1520,6 +1860,26 @@
"dev": true, "dev": true,
"license": "ISC" "license": "ISC"
}, },
"node_modules/entities": {
"version": "7.0.1",
"resolved": "https://registry.npmjs.org/entities/-/entities-7.0.1.tgz",
"integrity": "sha512-TWrgLOFUQTH994YUyl1yT4uyavY5nNB5muff+RtWaqNVCAK408b5ZnnbNAUEWLTCpum9w6arT70i1XdQ4UeOPA==",
"dev": true,
"license": "BSD-2-Clause",
"engines": {
"node": ">=0.12"
},
"funding": {
"url": "https://github.com/fb55/entities?sponsor=1"
}
},
"node_modules/es-module-lexer": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-2.0.0.tgz",
"integrity": "sha512-5POEcUuZybH7IdmGsD8wlf0AI55wMecM9rVBTI/qEAy2c1kTOm3DjFYjrBdI2K3BaJjJYfYFeRtM0t9ssnRuxw==",
"dev": true,
"license": "MIT"
},
"node_modules/esbuild": { "node_modules/esbuild": {
"version": "0.25.12", "version": "0.25.12",
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.12.tgz", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.12.tgz",
@@ -1582,6 +1942,26 @@
"url": "https://opencollective.com/unified" "url": "https://opencollective.com/unified"
} }
}, },
"node_modules/estree-walker": {
"version": "3.0.3",
"resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz",
"integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==",
"dev": true,
"license": "MIT",
"dependencies": {
"@types/estree": "^1.0.0"
}
},
"node_modules/expect-type": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.3.0.tgz",
"integrity": "sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==",
"dev": true,
"license": "Apache-2.0",
"engines": {
"node": ">=12.0.0"
}
},
"node_modules/extend": { "node_modules/extend": {
"version": "3.0.2", "version": "3.0.2",
"resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz",
@@ -1631,6 +2011,24 @@
"node": ">=6.9.0" "node": ">=6.9.0"
} }
}, },
"node_modules/happy-dom": {
"version": "20.8.9",
"resolved": "https://registry.npmjs.org/happy-dom/-/happy-dom-20.8.9.tgz",
"integrity": "sha512-Tz23LR9T9jOGVZm2x1EPdXqwA37G/owYMxRwU0E4miurAtFsPMQ1d2Jc2okUaSjZqAFz2oEn3FLXC5a0a+siyA==",
"dev": true,
"license": "MIT",
"dependencies": {
"@types/node": ">=20.0.0",
"@types/whatwg-mimetype": "^3.0.2",
"@types/ws": "^8.18.1",
"entities": "^7.0.1",
"whatwg-mimetype": "^3.0.0",
"ws": "^8.18.3"
},
"engines": {
"node": ">=20.0.0"
}
},
"node_modules/hast-util-to-jsx-runtime": { "node_modules/hast-util-to-jsx-runtime": {
"version": "2.3.6", "version": "2.3.6",
"resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.6.tgz", "resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.6.tgz",
@@ -1681,6 +2079,16 @@
"url": "https://opencollective.com/unified" "url": "https://opencollective.com/unified"
} }
}, },
"node_modules/indent-string": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/indent-string/-/indent-string-4.0.0.tgz",
"integrity": "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=8"
}
},
"node_modules/inline-style-parser": { "node_modules/inline-style-parser": {
"version": "0.2.7", "version": "0.2.7",
"resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.7.tgz", "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.7.tgz",
@@ -1796,6 +2204,27 @@
"yallist": "^3.0.2" "yallist": "^3.0.2"
} }
}, },
"node_modules/lz-string": {
"version": "1.5.0",
"resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz",
"integrity": "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"lz-string": "bin/bin.js"
}
},
"node_modules/magic-string": {
"version": "0.30.21",
"resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.21.tgz",
"integrity": "sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"@jridgewell/sourcemap-codec": "^1.5.5"
}
},
"node_modules/mdast-util-from-markdown": { "node_modules/mdast-util-from-markdown": {
"version": "2.0.3", "version": "2.0.3",
"resolved": "https://registry.npmjs.org/mdast-util-from-markdown/-/mdast-util-from-markdown-2.0.3.tgz", "resolved": "https://registry.npmjs.org/mdast-util-from-markdown/-/mdast-util-from-markdown-2.0.3.tgz",
@@ -2391,6 +2820,16 @@
], ],
"license": "MIT" "license": "MIT"
}, },
"node_modules/min-indent": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz",
"integrity": "sha512-I9jwMn07Sy/IwOj3zVkVik2JTvgpaykDZEigL6Rx6N9LbMywwUSMtxET+7lVoDLLd3O3IXwJwvuuns8UB/HeAg==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=4"
}
},
"node_modules/ms": { "node_modules/ms": {
"version": "2.1.3", "version": "2.1.3",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz",
@@ -2423,6 +2862,17 @@
"dev": true, "dev": true,
"license": "MIT" "license": "MIT"
}, },
"node_modules/obug": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/obug/-/obug-2.1.1.tgz",
"integrity": "sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==",
"dev": true,
"funding": [
"https://github.com/sponsors/sxzz",
"https://opencollective.com/debug"
],
"license": "MIT"
},
"node_modules/parse-entities": { "node_modules/parse-entities": {
"version": "4.0.2", "version": "4.0.2",
"resolved": "https://registry.npmjs.org/parse-entities/-/parse-entities-4.0.2.tgz", "resolved": "https://registry.npmjs.org/parse-entities/-/parse-entities-4.0.2.tgz",
@@ -2448,6 +2898,13 @@
"integrity": "sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA==", "integrity": "sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA==",
"license": "MIT" "license": "MIT"
}, },
"node_modules/pathe": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz",
"integrity": "sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==",
"dev": true,
"license": "MIT"
},
"node_modules/picocolors": { "node_modules/picocolors": {
"version": "1.1.1", "version": "1.1.1",
"resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz",
@@ -2497,6 +2954,22 @@
"node": "^10 || ^12 || >=14" "node": "^10 || ^12 || >=14"
} }
}, },
"node_modules/pretty-format": {
"version": "27.5.1",
"resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-27.5.1.tgz",
"integrity": "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"ansi-regex": "^5.0.1",
"ansi-styles": "^5.0.0",
"react-is": "^17.0.1"
},
"engines": {
"node": "^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0"
}
},
"node_modules/property-information": { "node_modules/property-information": {
"version": "7.1.0", "version": "7.1.0",
"resolved": "https://registry.npmjs.org/property-information/-/property-information-7.1.0.tgz", "resolved": "https://registry.npmjs.org/property-information/-/property-information-7.1.0.tgz",
@@ -2528,6 +3001,14 @@
"react": "^19.2.4" "react": "^19.2.4"
} }
}, },
"node_modules/react-is": {
"version": "17.0.2",
"resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz",
"integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==",
"dev": true,
"license": "MIT",
"peer": true
},
"node_modules/react-markdown": { "node_modules/react-markdown": {
"version": "10.1.0", "version": "10.1.0",
"resolved": "https://registry.npmjs.org/react-markdown/-/react-markdown-10.1.0.tgz", "resolved": "https://registry.npmjs.org/react-markdown/-/react-markdown-10.1.0.tgz",
@@ -2603,6 +3084,20 @@
"react-dom": ">=18" "react-dom": ">=18"
} }
}, },
"node_modules/redent": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/redent/-/redent-3.0.0.tgz",
"integrity": "sha512-6tDA8g98We0zd0GvVeMT9arEOnTw9qM03L9cJXaCjrip1OO764RDBLBfrB4cwzNGDj5OA5ioymC9GkizgWJDUg==",
"dev": true,
"license": "MIT",
"dependencies": {
"indent-string": "^4.0.0",
"strip-indent": "^3.0.0"
},
"engines": {
"node": ">=8"
}
},
"node_modules/remark-parse": { "node_modules/remark-parse": {
"version": "11.0.0", "version": "11.0.0",
"resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-11.0.0.tgz", "resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-11.0.0.tgz",
@@ -2703,6 +3198,13 @@
"integrity": "sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==", "integrity": "sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==",
"license": "MIT" "license": "MIT"
}, },
"node_modules/siginfo": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz",
"integrity": "sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==",
"dev": true,
"license": "ISC"
},
"node_modules/source-map-js": { "node_modules/source-map-js": {
"version": "1.2.1", "version": "1.2.1",
"resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",
@@ -2723,6 +3225,20 @@
"url": "https://github.com/sponsors/wooorm" "url": "https://github.com/sponsors/wooorm"
} }
}, },
"node_modules/stackback": {
"version": "0.0.2",
"resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz",
"integrity": "sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==",
"dev": true,
"license": "MIT"
},
"node_modules/std-env": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/std-env/-/std-env-4.0.0.tgz",
"integrity": "sha512-zUMPtQ/HBY3/50VbpkupYHbRroTRZJPRLvreamgErJVys0ceuzMkD44J/QjqhHjOzK42GQ3QZIeFG1OYfOtKqQ==",
"dev": true,
"license": "MIT"
},
"node_modules/stringify-entities": { "node_modules/stringify-entities": {
"version": "4.0.4", "version": "4.0.4",
"resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.4.tgz", "resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.4.tgz",
@@ -2737,6 +3253,19 @@
"url": "https://github.com/sponsors/wooorm" "url": "https://github.com/sponsors/wooorm"
} }
}, },
"node_modules/strip-indent": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/strip-indent/-/strip-indent-3.0.0.tgz",
"integrity": "sha512-laJTa3Jb+VQpaC6DseHhF7dXVqHTfJPCRDaEbid/drOhgitgYku/letMUqOXFoWV0zIIUbjpdH2t+tYj4bQMRQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"min-indent": "^1.0.0"
},
"engines": {
"node": ">=8"
}
},
"node_modules/style-to-js": { "node_modules/style-to-js": {
"version": "1.1.21", "version": "1.1.21",
"resolved": "https://registry.npmjs.org/style-to-js/-/style-to-js-1.1.21.tgz", "resolved": "https://registry.npmjs.org/style-to-js/-/style-to-js-1.1.21.tgz",
@@ -2755,6 +3284,23 @@
"inline-style-parser": "0.2.7" "inline-style-parser": "0.2.7"
} }
}, },
"node_modules/tinybench": {
"version": "2.9.0",
"resolved": "https://registry.npmjs.org/tinybench/-/tinybench-2.9.0.tgz",
"integrity": "sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==",
"dev": true,
"license": "MIT"
},
"node_modules/tinyexec": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/tinyexec/-/tinyexec-1.0.4.tgz",
"integrity": "sha512-u9r3uZC0bdpGOXtlxUIdwf9pkmvhqJdrVCH9fapQtgy/OeTTMZ1nqH7agtvEfmGui6e1XxjcdrlxvxJvc3sMqw==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=18"
}
},
"node_modules/tinyglobby": { "node_modules/tinyglobby": {
"version": "0.2.15", "version": "0.2.15",
"resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz",
@@ -2772,6 +3318,16 @@
"url": "https://github.com/sponsors/SuperchupuDev" "url": "https://github.com/sponsors/SuperchupuDev"
} }
}, },
"node_modules/tinyrainbow": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/tinyrainbow/-/tinyrainbow-3.1.0.tgz",
"integrity": "sha512-Bf+ILmBgretUrdJxzXM0SgXLZ3XfiaUuOj/IKQHuTXip+05Xn+uyEYdVg0kYDipTBcLrCVyUzAPz7QmArb0mmw==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/trim-lines": { "node_modules/trim-lines": {
"version": "3.0.1", "version": "3.0.1",
"resolved": "https://registry.npmjs.org/trim-lines/-/trim-lines-3.0.1.tgz", "resolved": "https://registry.npmjs.org/trim-lines/-/trim-lines-3.0.1.tgz",
@@ -2806,6 +3362,13 @@
"node": ">=14.17" "node": ">=14.17"
} }
}, },
"node_modules/undici-types": {
"version": "7.18.2",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.18.2.tgz",
"integrity": "sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==",
"dev": true,
"license": "MIT"
},
"node_modules/unified": { "node_modules/unified": {
"version": "11.0.5", "version": "11.0.5",
"resolved": "https://registry.npmjs.org/unified/-/unified-11.0.5.tgz", "resolved": "https://registry.npmjs.org/unified/-/unified-11.0.5.tgz",
@@ -3027,6 +3590,137 @@
} }
} }
}, },
"node_modules/vitest": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/vitest/-/vitest-4.1.2.tgz",
"integrity": "sha512-xjR1dMTVHlFLh98JE3i/f/WePqJsah4A0FK9cc8Ehp9Udk0AZk6ccpIZhh1qJ/yxVWRZ+Q54ocnD8TXmkhspGg==",
"dev": true,
"license": "MIT",
"dependencies": {
"@vitest/expect": "4.1.2",
"@vitest/mocker": "4.1.2",
"@vitest/pretty-format": "4.1.2",
"@vitest/runner": "4.1.2",
"@vitest/snapshot": "4.1.2",
"@vitest/spy": "4.1.2",
"@vitest/utils": "4.1.2",
"es-module-lexer": "^2.0.0",
"expect-type": "^1.3.0",
"magic-string": "^0.30.21",
"obug": "^2.1.1",
"pathe": "^2.0.3",
"picomatch": "^4.0.3",
"std-env": "^4.0.0-rc.1",
"tinybench": "^2.9.0",
"tinyexec": "^1.0.2",
"tinyglobby": "^0.2.15",
"tinyrainbow": "^3.1.0",
"vite": "^6.0.0 || ^7.0.0 || ^8.0.0",
"why-is-node-running": "^2.3.0"
},
"bin": {
"vitest": "vitest.mjs"
},
"engines": {
"node": "^20.0.0 || ^22.0.0 || >=24.0.0"
},
"funding": {
"url": "https://opencollective.com/vitest"
},
"peerDependencies": {
"@edge-runtime/vm": "*",
"@opentelemetry/api": "^1.9.0",
"@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0",
"@vitest/browser-playwright": "4.1.2",
"@vitest/browser-preview": "4.1.2",
"@vitest/browser-webdriverio": "4.1.2",
"@vitest/ui": "4.1.2",
"happy-dom": "*",
"jsdom": "*",
"vite": "^6.0.0 || ^7.0.0 || ^8.0.0"
},
"peerDependenciesMeta": {
"@edge-runtime/vm": {
"optional": true
},
"@opentelemetry/api": {
"optional": true
},
"@types/node": {
"optional": true
},
"@vitest/browser-playwright": {
"optional": true
},
"@vitest/browser-preview": {
"optional": true
},
"@vitest/browser-webdriverio": {
"optional": true
},
"@vitest/ui": {
"optional": true
},
"happy-dom": {
"optional": true
},
"jsdom": {
"optional": true
},
"vite": {
"optional": false
}
}
},
"node_modules/whatwg-mimetype": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-3.0.0.tgz",
"integrity": "sha512-nt+N2dzIutVRxARx1nghPKGv1xHikU7HKdfafKkLNLindmPU/ch3U31NOCGGA/dmPcmb1VlofO0vnKAcsm0o/Q==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=12"
}
},
"node_modules/why-is-node-running": {
"version": "2.3.0",
"resolved": "https://registry.npmjs.org/why-is-node-running/-/why-is-node-running-2.3.0.tgz",
"integrity": "sha512-hUrmaWBdVDcxvYqnyh09zunKzROWjbZTiNy8dBEjkS7ehEDQibXJ7XvlmtbwuTclUiIyN+CyXQD4Vmko8fNm8w==",
"dev": true,
"license": "MIT",
"dependencies": {
"siginfo": "^2.0.0",
"stackback": "0.0.2"
},
"bin": {
"why-is-node-running": "cli.js"
},
"engines": {
"node": ">=8"
}
},
"node_modules/ws": {
"version": "8.20.0",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.20.0.tgz",
"integrity": "sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": ">=5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/yallist": { "node_modules/yallist": {
"version": "3.1.1", "version": "3.1.1",
"resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz",

View File

@@ -6,7 +6,9 @@
"scripts": { "scripts": {
"dev": "vite", "dev": "vite",
"build": "tsc -b && vite build", "build": "tsc -b && vite build",
"preview": "vite preview" "preview": "vite preview",
"test": "vitest run",
"test:watch": "vitest"
}, },
"dependencies": { "dependencies": {
"react": "^19.0.0", "react": "^19.0.0",
@@ -15,10 +17,14 @@
"react-router-dom": "^7.13.2" "react-router-dom": "^7.13.2"
}, },
"devDependencies": { "devDependencies": {
"@testing-library/jest-dom": "^6.9.1",
"@testing-library/react": "^16.3.2",
"@types/react": "^19.0.0", "@types/react": "^19.0.0",
"@types/react-dom": "^19.0.0", "@types/react-dom": "^19.0.0",
"@vitejs/plugin-react": "^4.3.0", "@vitejs/plugin-react": "^4.3.0",
"happy-dom": "^20.8.9",
"typescript": "~5.7.0", "typescript": "~5.7.0",
"vite": "^6.2.0" "vite": "^6.2.0",
"vitest": "^4.1.2"
} }
} }

117
frontend/src/api.test.ts Normal file
View File

@@ -0,0 +1,117 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { fetchConversations, fetchReplay, fetchAnalytics } from "./api";
// Mock global fetch
const mockFetch = vi.fn();
vi.stubGlobal("fetch", mockFetch);
beforeEach(() => {
mockFetch.mockReset();
});
function jsonResponse(body: unknown, status = 200): Response {
return {
ok: status >= 200 && status < 300,
status,
statusText: status === 200 ? "OK" : "Error",
json: () => Promise.resolve(body),
} as Response;
}
describe("fetchConversations", () => {
it("returns conversations page on success", async () => {
const data = {
conversations: [{ thread_id: "t1", created_at: "", last_activity: "", status: "active", total_tokens: 0, total_cost_usd: 0 }],
total: 1,
page: 1,
per_page: 20,
};
mockFetch.mockResolvedValue(jsonResponse({ success: true, data, error: null }));
const result = await fetchConversations();
expect(result.conversations).toHaveLength(1);
expect(result.total).toBe(1);
expect(mockFetch).toHaveBeenCalledWith("/api/v1/conversations?page=1&per_page=20");
});
it("passes custom page and perPage", async () => {
mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { conversations: [], total: 0, page: 2, per_page: 10 }, error: null }));
await fetchConversations(2, 10);
expect(mockFetch).toHaveBeenCalledWith("/api/v1/conversations?page=2&per_page=10");
});
it("throws on HTTP error", async () => {
mockFetch.mockResolvedValue(jsonResponse({}, 500));
await expect(fetchConversations()).rejects.toThrow("API error 500");
});
it("throws on success=false with error message", async () => {
mockFetch.mockResolvedValue(jsonResponse({ success: false, data: null, error: "Database unavailable" }));
await expect(fetchConversations()).rejects.toThrow("Database unavailable");
});
it("throws unknown error when success=false with no message", async () => {
mockFetch.mockResolvedValue(jsonResponse({ success: false, data: null, error: null }));
await expect(fetchConversations()).rejects.toThrow("Unknown API error");
});
});
describe("fetchReplay", () => {
it("returns replay page on success", async () => {
const data = {
thread_id: "t1",
total_steps: 3,
page: 1,
per_page: 20,
steps: [{ step: 1, type: "message", content: "Hello", agent: null, tool: null, params: null, result: null, timestamp: "" }],
};
mockFetch.mockResolvedValue(jsonResponse({ success: true, data, error: null }));
const result = await fetchReplay("t1");
expect(result.total_steps).toBe(3);
expect(result.steps).toHaveLength(1);
});
it("encodes thread_id in URL", async () => {
mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { thread_id: "a/b", total_steps: 0, page: 1, per_page: 20, steps: [] }, error: null }));
await fetchReplay("a/b");
expect(mockFetch).toHaveBeenCalledWith("/api/v1/replay/a%2Fb?page=1&per_page=20");
});
it("throws on HTTP error", async () => {
mockFetch.mockResolvedValue(jsonResponse({}, 404));
await expect(fetchReplay("missing")).rejects.toThrow("API error 404");
});
});
describe("fetchAnalytics", () => {
it("returns analytics data on success", async () => {
const data = {
range: "7d",
total_conversations: 100,
resolution_rate: 0.75,
escalation_rate: 0.25,
avg_turns_per_conversation: 3.5,
avg_cost_per_conversation_usd: 0.03,
agent_usage: [],
interrupt_stats: { total: 0, approved: 0, rejected: 0, expired: 0 },
};
mockFetch.mockResolvedValue(jsonResponse({ success: true, data, error: null }));
const result = await fetchAnalytics("7d");
expect(result.total_conversations).toBe(100);
expect(result.range).toBe("7d");
});
it("uses default range", async () => {
mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { range: "7d" }, error: null }));
await fetchAnalytics();
expect(mockFetch).toHaveBeenCalledWith("/api/v1/analytics?range=7d");
});
});

View File

@@ -10,13 +10,11 @@ export interface ApiResponse<T> {
export interface ConversationSummary { export interface ConversationSummary {
thread_id: string; thread_id: string;
started_at: string; created_at: string;
last_activity: string; last_activity: string;
turn_count: number; status: string | null;
agents_used: string[];
total_tokens: number; total_tokens: number;
total_cost_usd: number; total_cost_usd: number;
resolution_type: string | null;
} }
export interface ConversationsPage { export interface ConversationsPage {
@@ -39,17 +37,16 @@ export interface ReplayStep {
export interface ReplayPage { export interface ReplayPage {
thread_id: string; thread_id: string;
steps: ReplayStep[]; total_steps: number;
total: number;
page: number; page: number;
per_page: number; per_page: number;
steps: ReplayStep[];
} }
export interface AgentUsage { export interface AgentUsage {
agent_name: string; agent: string;
message_count: number; count: number;
total_tokens: number; percentage: number;
total_cost_usd: number;
} }
export interface InterruptStats { export interface InterruptStats {
@@ -60,14 +57,12 @@ export interface InterruptStats {
} }
export interface AnalyticsData { export interface AnalyticsData {
range: string;
total_conversations: number; total_conversations: number;
resolved_conversations: number;
escalated_conversations: number;
resolution_rate: number; resolution_rate: number;
escalation_rate: number; escalation_rate: number;
total_tokens: number;
total_cost_usd: number;
avg_turns_per_conversation: number; avg_turns_per_conversation: number;
avg_cost_per_conversation_usd: number;
agent_usage: AgentUsage[]; agent_usage: AgentUsage[];
interrupt_stats: InterruptStats; interrupt_stats: InterruptStats;
} }
@@ -89,7 +84,7 @@ export async function fetchConversations(
perPage = 20 perPage = 20
): Promise<ConversationsPage> { ): Promise<ConversationsPage> {
return apiFetch<ConversationsPage>( return apiFetch<ConversationsPage>(
`/api/conversations?page=${page}&per_page=${perPage}` `/api/v1/conversations?page=${page}&per_page=${perPage}`
); );
} }
@@ -99,10 +94,81 @@ export async function fetchReplay(
perPage = 20 perPage = 20
): Promise<ReplayPage> { ): Promise<ReplayPage> {
return apiFetch<ReplayPage>( return apiFetch<ReplayPage>(
`/api/replay/${encodeURIComponent(threadId)}?page=${page}&per_page=${perPage}` `/api/v1/replay/${encodeURIComponent(threadId)}?page=${page}&per_page=${perPage}`
); );
} }
export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> { export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> {
return apiFetch<AnalyticsData>(`/api/analytics?range=${range}`); return apiFetch<AnalyticsData>(`/api/v1/analytics?range=${range}`);
}
// -- OpenAPI import --
export interface ImportJobResponse {
job_id: string;
status: string;
spec_url: string;
total_endpoints: number;
classified_count: number;
error_message: string | null;
generated_tools_count?: number;
}
export interface EndpointClassification {
index: number;
access_type: string;
needs_interrupt: boolean;
agent_group: string;
confidence: number;
customer_params: string[];
endpoint: {
path: string;
method: string;
operation_id: string;
summary: string;
description: string;
};
}
async function apiPost<T>(path: string, body: unknown): Promise<T> {
const res = await fetch(`${API_BASE}${path}`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(body),
});
if (!res.ok) {
throw new Error(`API error ${res.status}: ${res.statusText}`);
}
return res.json();
}
export async function startImport(url: string): Promise<ImportJobResponse> {
return apiPost<ImportJobResponse>("/api/v1/openapi/import", { url });
}
export async function fetchImportJob(jobId: string): Promise<ImportJobResponse> {
const res = await fetch(`${API_BASE}/api/v1/openapi/jobs/${encodeURIComponent(jobId)}`);
if (!res.ok) {
throw new Error(`API error ${res.status}: ${res.statusText}`);
}
return res.json();
}
export async function fetchClassifications(
jobId: string
): Promise<EndpointClassification[]> {
const res = await fetch(
`${API_BASE}/api/v1/openapi/jobs/${encodeURIComponent(jobId)}/classifications`
);
if (!res.ok) {
throw new Error(`API error ${res.status}: ${res.statusText}`);
}
return res.json();
}
export async function approveJob(jobId: string): Promise<ImportJobResponse> {
return apiPost<ImportJobResponse>(
`/api/v1/openapi/jobs/${encodeURIComponent(jobId)}/approve`,
{}
);
} }

View File

@@ -0,0 +1,47 @@
import { describe, it, expect } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { AgentAction } from "./AgentAction";
import type { ToolAction } from "../types";
function makeAction(overrides: Partial<ToolAction> = {}): ToolAction {
return {
id: "action-1",
agent: "OrderAgent",
tool: "get_order",
args: { order_id: "ORD-100" },
timestamp: Date.now(),
...overrides,
};
}
describe("AgentAction", () => {
it("renders agent name and tool name", () => {
render(<AgentAction action={makeAction()} />);
expect(screen.getByText("OrderAgent")).toBeInTheDocument();
expect(screen.getByText("get_order")).toBeInTheDocument();
});
it("shows args and result when expanded", () => {
const action = makeAction({ result: { status: "shipped" } });
render(<AgentAction action={action} />);
// Click header to expand
fireEvent.click(screen.getByText("OrderAgent"));
expect(screen.getByText("Args:")).toBeInTheDocument();
expect(screen.getByText("Result:")).toBeInTheDocument();
expect(screen.getByText(/"order_id": "ORD-100"/)).toBeInTheDocument();
expect(screen.getByText(/"status": "shipped"/)).toBeInTheDocument();
});
it("does not show result section when result is undefined", () => {
render(<AgentAction action={makeAction()} />);
// Expand
fireEvent.click(screen.getByText("OrderAgent"));
expect(screen.getByText("Args:")).toBeInTheDocument();
expect(screen.queryByText("Result:")).not.toBeInTheDocument();
});
});

View File

@@ -0,0 +1,53 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { ChatInput } from "./ChatInput";
describe("ChatInput", () => {
it("renders input field and send button", () => {
render(<ChatInput onSend={vi.fn()} disabled={false} />);
expect(screen.getByPlaceholderText("Message Smart Support...")).toBeInTheDocument();
expect(screen.getByRole("button", { name: "Send Message" })).toBeInTheDocument();
});
it("calls onSend with trimmed content when form is submitted via Enter", () => {
const onSend = vi.fn();
render(<ChatInput onSend={onSend} disabled={false} />);
const input = screen.getByPlaceholderText("Message Smart Support...");
fireEvent.change(input, { target: { value: " Hello world " } });
fireEvent.keyDown(input, { key: "Enter" });
expect(onSend).toHaveBeenCalledWith("Hello world");
});
it("clears input after successful send", () => {
const onSend = vi.fn();
render(<ChatInput onSend={onSend} disabled={false} />);
const input = screen.getByPlaceholderText("Message Smart Support...") as HTMLInputElement;
fireEvent.change(input, { target: { value: "Test message" } });
fireEvent.keyDown(input, { key: "Enter" });
expect(input.value).toBe("");
});
it("shows disabled placeholder and disables input when disabled", () => {
render(<ChatInput onSend={vi.fn()} disabled={true} />);
const input = screen.getByPlaceholderText("Agent is working...") as HTMLInputElement;
expect(input.disabled).toBe(true);
expect(screen.getByRole("button", { name: "Send Message" })).toBeDisabled();
});
it("does not call onSend when input is empty or whitespace", () => {
const onSend = vi.fn();
render(<ChatInput onSend={onSend} disabled={false} />);
const input = screen.getByPlaceholderText("Message Smart Support...");
fireEvent.change(input, { target: { value: " " } });
fireEvent.keyDown(input, { key: "Enter" });
expect(onSend).not.toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,59 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen } from "@testing-library/react";
import { ChatMessages } from "./ChatMessages";
import type { ChatMessage } from "../types";
// Mock react-markdown to avoid complex rendering
vi.mock("react-markdown", () => ({
default: ({ children }: { children: string }) => <span>{children}</span>,
}));
describe("ChatMessages", () => {
it("renders welcome message when messages array is empty", () => {
render(<ChatMessages messages={[]} />);
expect(screen.getByText("Hello! How can I help you today?")).toBeInTheDocument();
expect(screen.getByText("Smart Support")).toBeInTheDocument();
});
it("renders user messages with correct sender label", () => {
const messages: ChatMessage[] = [
{ id: "1", sender: "user", content: "I need help", timestamp: Date.now() },
];
render(<ChatMessages messages={messages} />);
expect(screen.getByText("You")).toBeInTheDocument();
expect(screen.getByText("I need help")).toBeInTheDocument();
expect(screen.getByText("Me")).toBeInTheDocument();
});
it("renders agent messages with agent name", () => {
const messages: ChatMessage[] = [
{ id: "2", sender: "agent", agent: "OrderBot", content: "Sure, let me check.", timestamp: Date.now() },
];
render(<ChatMessages messages={messages} />);
expect(screen.getByText("OrderBot")).toBeInTheDocument();
expect(screen.getByText("Sure, let me check.")).toBeInTheDocument();
expect(screen.getByText("AI")).toBeInTheDocument();
});
it("shows streaming cursor for messages being streamed", () => {
const messages: ChatMessage[] = [
{ id: "3", sender: "agent", agent: "Bot", content: "Processing", timestamp: Date.now(), isStreaming: true },
];
render(<ChatMessages messages={messages} />);
expect(screen.getByText("|")).toBeInTheDocument();
expect(document.querySelector(".cursor-blink")).toBeTruthy();
});
it("shows fallback agent label when agent field is missing", () => {
const messages: ChatMessage[] = [
{ id: "4", sender: "agent", content: "Generic response", timestamp: Date.now() },
];
render(<ChatMessages messages={messages} />);
expect(screen.getByText("Agent")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,33 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { ErrorBanner } from "./ErrorBanner";
describe("ErrorBanner", () => {
it("returns null when status is connected", () => {
const { container } = render(<ErrorBanner status="connected" />);
expect(container.innerHTML).toBe("");
});
it("shows disconnection message when status is disconnected", () => {
render(<ErrorBanner status="disconnected" onReconnect={vi.fn()} />);
expect(screen.getByText("Disconnected from server. Retrying...")).toBeInTheDocument();
expect(screen.getByRole("alert")).toBeInTheDocument();
});
it("shows connecting message when status is connecting", () => {
render(<ErrorBanner status="connecting" />);
expect(screen.getByText("Connecting to server...")).toBeInTheDocument();
// No reconnect button while connecting
expect(screen.queryByText("Reconnect")).not.toBeInTheDocument();
});
it("calls onReconnect when reconnect button is clicked", () => {
const onReconnect = vi.fn();
render(<ErrorBanner status="disconnected" onReconnect={onReconnect} />);
fireEvent.click(screen.getByText("Reconnect"));
expect(onReconnect).toHaveBeenCalledTimes(1);
});
});

View File

@@ -0,0 +1,58 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { InterruptPrompt } from "./InterruptPrompt";
import type { InterruptMessage } from "../types";
describe("InterruptPrompt", () => {
const baseInterrupt: InterruptMessage = {
type: "interrupt",
thread_id: "t1",
action: "cancel_order",
params: {},
};
it("renders action name and approval title", () => {
render(<InterruptPrompt interrupt={baseInterrupt} onRespond={vi.fn()} />);
expect(screen.getByText("Action Requires Approval")).toBeInTheDocument();
expect(screen.getByText("cancel_order")).toBeInTheDocument();
});
it("calls onRespond with true when Approve button is clicked", () => {
const onRespond = vi.fn();
render(<InterruptPrompt interrupt={baseInterrupt} onRespond={onRespond} />);
fireEvent.click(screen.getByText("Approve Action"));
expect(onRespond).toHaveBeenCalledWith(true);
});
it("calls onRespond with false when Reject button is clicked", () => {
const onRespond = vi.fn();
render(<InterruptPrompt interrupt={baseInterrupt} onRespond={onRespond} />);
fireEvent.click(screen.getByText("Reject & Escalate"));
expect(onRespond).toHaveBeenCalledWith(false);
});
it("displays order_id parameter when present", () => {
const interrupt: InterruptMessage = {
...baseInterrupt,
params: { order_id: "ORD-12345" },
};
render(<InterruptPrompt interrupt={interrupt} onRespond={vi.fn()} />);
expect(screen.getByText("Target Order ID")).toBeInTheDocument();
expect(screen.getByText("ORD-12345")).toBeInTheDocument();
});
it("displays message parameter when present", () => {
const interrupt: InterruptMessage = {
...baseInterrupt,
params: { message: "This will refund $50" },
};
render(<InterruptPrompt interrupt={interrupt} onRespond={vi.fn()} />);
expect(screen.getByText("Detail Message")).toBeInTheDocument();
expect(screen.getByText("This will refund $50")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,39 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen } from "@testing-library/react";
import { MemoryRouter, Routes, Route } from "react-router-dom";
import { Layout } from "./Layout";
// Mock NavBar to simplify layout tests
vi.mock("./NavBar", () => ({
NavBar: () => <nav data-testid="navbar">NavBar</nav>,
}));
function renderLayout(path = "/") {
return render(
<MemoryRouter initialEntries={[path]}>
<Routes>
<Route element={<Layout />}>
<Route path="/" element={<div>Home Content</div>} />
<Route path="/dashboard" element={<div>Dashboard Content</div>} />
</Route>
</Routes>
</MemoryRouter>
);
}
describe("Layout", () => {
it("renders NavBar component", () => {
renderLayout();
expect(screen.getByTestId("navbar")).toBeInTheDocument();
});
it("renders child route content via Outlet", () => {
renderLayout("/");
expect(screen.getByText("Home Content")).toBeInTheDocument();
});
it("renders correct content for different routes", () => {
renderLayout("/dashboard");
expect(screen.getByText("Dashboard Content")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,28 @@
import { describe, it, expect } from "vitest";
import { render, screen } from "@testing-library/react";
import { MetricCard } from "./MetricCard";
describe("MetricCard", () => {
it("renders label and value", () => {
render(<MetricCard label="Total Users" value={42} />);
expect(screen.getByText("Total Users")).toBeInTheDocument();
expect(screen.getByText("42")).toBeInTheDocument();
});
it("renders with unit prefix and suffix", () => {
render(<MetricCard label="Cost" value="3.50" unit="$" suffix="/mo" />);
expect(screen.getByText("Cost")).toBeInTheDocument();
expect(screen.getByText("$")).toBeInTheDocument();
expect(screen.getByText("3.50")).toBeInTheDocument();
expect(screen.getByText("/mo")).toBeInTheDocument();
});
it("handles zero value correctly", () => {
render(<MetricCard label="Errors" value={0} />);
expect(screen.getByText("Errors")).toBeInTheDocument();
expect(screen.getByText("0")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,54 @@
import { describe, it, expect } from "vitest";
import { render, screen } from "@testing-library/react";
import { MemoryRouter } from "react-router-dom";
import { NavBar } from "./NavBar";
function renderNavBar(initialPath = "/") {
return render(
<MemoryRouter initialEntries={[initialPath]}>
<NavBar />
</MemoryRouter>
);
}
describe("NavBar", () => {
it("renders all navigation links", () => {
renderNavBar();
expect(screen.getByText("Dashboard")).toBeInTheDocument();
expect(screen.getByText("Inbox")).toBeInTheDocument();
expect(screen.getByText("Conversation Replay")).toBeInTheDocument();
expect(screen.getByText("Agents & Tools")).toBeInTheDocument();
});
it("navigation links point to correct routes", () => {
renderNavBar();
const dashboardLink = screen.getByText("Dashboard").closest("a");
expect(dashboardLink).toHaveAttribute("href", "/dashboard");
const inboxLink = screen.getByText("Inbox").closest("a");
expect(inboxLink).toHaveAttribute("href", "/");
const replayLink = screen.getByText("Conversation Replay").closest("a");
expect(replayLink).toHaveAttribute("href", "/replay");
const reviewLink = screen.getByText("Agents & Tools").closest("a");
expect(reviewLink).toHaveAttribute("href", "/review");
});
it("active link has active class when on matching route", () => {
renderNavBar("/dashboard");
const dashboardLink = screen.getByText("Dashboard").closest("a");
expect(dashboardLink?.className).toContain("active");
const inboxLink = screen.getByText("Inbox").closest("a");
expect(inboxLink?.className).not.toContain("active");
});
it("renders brand name", () => {
renderNavBar();
expect(screen.getByText("Nexus AI")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,69 @@
import { describe, it, expect } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { ReplayTimeline } from "./ReplayTimeline";
import type { ReplayStep } from "../api";
function makeStep(overrides: Partial<ReplayStep> = {}): ReplayStep {
return {
step: 1,
type: "message",
content: "Hello",
agent: null,
tool: null,
params: null,
result: null,
timestamp: "2026-04-01T12:00:00Z",
...overrides,
};
}
describe("ReplayTimeline", () => {
it("returns null when steps array is empty", () => {
const { container } = render(<ReplayTimeline steps={[]} />);
expect(container.innerHTML).toBe("");
});
it("renders a list of steps with type badges", () => {
const steps = [
makeStep({ step: 1, type: "message", content: "User said hi" }),
makeStep({ step: 2, type: "tool_call", content: "Calling API", agent: "OrderBot", tool: "get_order" }),
];
render(<ReplayTimeline steps={steps} />);
expect(screen.getByText("message")).toBeInTheDocument();
expect(screen.getByText("tool call")).toBeInTheDocument();
expect(screen.getByText("User said hi")).toBeInTheDocument();
expect(screen.getByText("OrderBot")).toBeInTheDocument();
expect(screen.getByText("get_order()")).toBeInTheDocument();
});
it("expands step details when View JSON Payload button is clicked", () => {
const steps = [
makeStep({
step: 1,
type: "tool_call",
params: { order_id: "123" },
result: { status: "ok" },
}),
];
render(<ReplayTimeline steps={steps} />);
const expandButton = screen.getByText("View JSON Payload", { exact: false });
expect(expandButton).toBeInTheDocument();
fireEvent.click(expandButton);
// After expanding, the JSON payload should be visible
expect(screen.getByText("Hide JSON Payload", { exact: false })).toBeInTheDocument();
expect(screen.getByText(/"order_id": "123"/)).toBeInTheDocument();
});
it("does not show expand button when step has no params or result", () => {
const steps = [
makeStep({ step: 1, type: "message", params: null, result: null }),
];
render(<ReplayTimeline steps={steps} />);
expect(screen.queryByText("View JSON Payload", { exact: false })).not.toBeInTheDocument();
});
});

View File

@@ -0,0 +1,221 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { renderHook, act } from "@testing-library/react";
import { useWebSocket } from "./useWebSocket";
// Mock sessionStorage
const mockSessionStorage: Record<string, string> = {};
vi.stubGlobal("sessionStorage", {
getItem: (key: string) => mockSessionStorage[key] ?? null,
setItem: (key: string, value: string) => {
mockSessionStorage[key] = value;
},
});
// Mock crypto.randomUUID
vi.stubGlobal("crypto", { randomUUID: () => "test-uuid-1234" });
// Mock WebSocket
class MockWebSocket {
static OPEN = 1;
static CLOSED = 3;
static instances: MockWebSocket[] = [];
url: string;
readyState = 0;
onopen: (() => void) | null = null;
onclose: (() => void) | null = null;
onmessage: ((event: { data: string }) => void) | null = null;
onerror: (() => void) | null = null;
send = vi.fn();
close = vi.fn().mockImplementation(() => {
this.readyState = MockWebSocket.CLOSED;
// Trigger onclose asynchronously like real WebSocket
setTimeout(() => this.onclose?.(), 0);
});
constructor(url: string) {
this.url = url;
MockWebSocket.instances.push(this);
}
simulateOpen() {
this.readyState = MockWebSocket.OPEN;
this.onopen?.();
}
simulateMessage(data: unknown) {
this.onmessage?.({ data: JSON.stringify(data) });
}
simulateClose() {
this.readyState = MockWebSocket.CLOSED;
this.onclose?.();
}
simulateError() {
this.onerror?.();
}
}
vi.stubGlobal("WebSocket", MockWebSocket);
beforeEach(() => {
MockWebSocket.instances = [];
delete mockSessionStorage["smart_support_thread_id"];
vi.useFakeTimers();
});
afterEach(() => {
vi.useRealTimers();
});
describe("useWebSocket", () => {
it("establishes connection with correct URL on mount", () => {
const onMessage = vi.fn();
renderHook(() => useWebSocket(onMessage));
expect(MockWebSocket.instances).toHaveLength(1);
expect(MockWebSocket.instances[0].url).toContain("/ws");
});
it("sets status to connected when WebSocket opens", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
expect(result.current.status).toBe("connecting");
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
expect(result.current.status).toBe("connected");
});
it("parses incoming JSON messages and dispatches to callback", () => {
const onMessage = vi.fn();
renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
const serverMsg = { type: "token", agent: "bot", content: "Hello" };
act(() => {
MockWebSocket.instances[0].simulateMessage(serverMsg);
});
expect(onMessage).toHaveBeenCalledWith(serverMsg);
});
it("sends JSON through WebSocket via sendMessage", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
act(() => {
result.current.sendMessage("Hi there");
});
expect(MockWebSocket.instances[0].send).toHaveBeenCalledTimes(1);
const sent = JSON.parse(MockWebSocket.instances[0].send.mock.calls[0][0]);
expect(sent.type).toBe("message");
expect(sent.content).toBe("Hi there");
expect(sent.thread_id).toBeDefined();
});
it("calls onDisconnect when WebSocket closes", () => {
const onMessage = vi.fn();
const onDisconnect = vi.fn();
renderHook(() => useWebSocket(onMessage, { onDisconnect }));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
act(() => {
MockWebSocket.instances[0].simulateClose();
});
expect(onDisconnect).toHaveBeenCalledTimes(1);
});
it("sets status to disconnected on close and attempts reconnect", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
act(() => {
MockWebSocket.instances[0].simulateClose();
});
expect(result.current.status).toBe("disconnected");
// After timeout, a new WebSocket should be created (reconnect attempt)
act(() => {
vi.advanceTimersByTime(1500);
});
expect(MockWebSocket.instances.length).toBeGreaterThanOrEqual(2);
});
it("closes WebSocket on error event", () => {
const onMessage = vi.fn();
renderHook(() => useWebSocket(onMessage));
const ws = MockWebSocket.instances[0];
act(() => {
ws.simulateError();
});
expect(ws.close).toHaveBeenCalledTimes(1);
});
it("reconnect resets retries and creates a new connection", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
const wsBeforeReconnect = MockWebSocket.instances[0];
act(() => {
result.current.reconnect();
});
// The old socket should have been closed
expect(wsBeforeReconnect.close).toHaveBeenCalled();
// Let the close callback fire and reconnect timer run
act(() => {
vi.advanceTimersByTime(100);
});
// A new WebSocket should have been created
expect(MockWebSocket.instances.length).toBeGreaterThan(1);
});
it("sends interrupt response with approved flag", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
act(() => {
result.current.sendInterruptResponse(true);
});
const sent = JSON.parse(MockWebSocket.instances[0].send.mock.calls[0][0]);
expect(sent.type).toBe("interrupt_response");
expect(sent.approved).toBe(true);
});
});

View File

@@ -658,6 +658,140 @@ body {
border-color: var(--text-primary); border-color: var(--text-primary);
} }
/* --- Shared Data Display Components --- */
.section-card {
background-color: var(--bg-surface);
border-radius: var(--radius-xl);
padding: 1.5rem;
border: 1px solid var(--border-light);
}
.stat-label {
font-size: 0.75rem;
text-transform: uppercase;
color: var(--text-secondary);
font-weight: 600;
letter-spacing: 0.05em;
}
.stat-value {
font-weight: 600;
font-size: 0.9375rem;
color: var(--text-primary);
}
.status-badge {
display: inline-block;
font-size: 0.75rem;
padding: 4px 10px;
border-radius: 6px;
font-weight: 600;
}
.status-badge--resolved {
background-color: #DEF7EC;
color: #03543F;
}
.status-badge--escalated {
background-color: #FDE8E8;
color: #9B1C1C;
}
.status-badge--active {
background-color: var(--bg-hover);
color: var(--text-secondary);
}
.data-table {
width: 100%;
border-collapse: collapse;
text-align: left;
}
.data-table th {
padding: 0.75rem 1.5rem;
font-size: 0.75rem;
text-transform: uppercase;
color: var(--text-secondary);
font-weight: 600;
}
.data-table td {
padding: 1.25rem 1.5rem;
font-size: 0.9375rem;
}
.data-table thead tr {
border-bottom: 2px solid var(--border-light);
}
.data-table tbody tr {
border-bottom: 1px solid var(--border-light);
transition: background-color 0.2s;
}
.data-table tbody tr:last-child {
border-bottom: none;
}
.data-table tbody tr:hover {
background-color: var(--bg-hover);
}
.empty-state {
padding: 3rem;
text-align: center;
color: var(--text-secondary);
}
.empty-state__title {
font-size: 1.125rem;
font-weight: 600;
margin: 0;
}
.empty-state__description {
margin-top: 0.5rem;
}
.error-state {
padding: 3rem;
text-align: center;
color: var(--text-secondary);
}
.error-state__title {
font-size: 1.125rem;
font-weight: 600;
color: var(--brand-accent);
margin: 0;
}
.error-state__description {
margin-top: 0.5rem;
}
.pagination-bar {
padding: 1.25rem 1.5rem;
border-top: 1px solid var(--border-light);
display: flex;
justify-content: space-between;
align-items: center;
background-color: var(--bg-surface-inner);
}
.pagination-bar__info {
font-size: 0.875rem;
color: var(--text-secondary);
}
.pagination-bar__controls {
display: flex;
gap: 0.5rem;
}
/* --- Skeleton Loading Animation --- */ /* --- Skeleton Loading Animation --- */
@keyframes pulse-skeleton { @keyframes pulse-skeleton {
0% { opacity: 0.5; background-color: var(--bg-hover); } 0% { opacity: 0.5; background-color: var(--bg-hover); }

View File

@@ -0,0 +1,106 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, fireEvent, waitFor, act } from "@testing-library/react";
import { ChatPage } from "./ChatPage";
// Mock react-markdown
vi.mock("react-markdown", () => ({
default: ({ children }: { children: string }) => <span>{children}</span>,
}));
// Mock crypto.randomUUID for stable IDs
vi.stubGlobal("crypto", { randomUUID: () => `uuid-${Date.now()}-${Math.random()}` });
// Capture the onMessage callback from the hook
let capturedOnMessage: ((msg: unknown) => void) | null = null;
const mockSendMessage = vi.fn();
const mockSendInterruptResponse = vi.fn();
const mockReconnect = vi.fn();
let mockStatus = "connected";
vi.mock("../hooks/useWebSocket", () => ({
useWebSocket: (onMessage: (msg: unknown) => void) => {
capturedOnMessage = onMessage;
return {
status: mockStatus,
threadId: "test-thread",
sendMessage: mockSendMessage,
sendInterruptResponse: mockSendInterruptResponse,
reconnect: mockReconnect,
};
},
}));
beforeEach(() => {
capturedOnMessage = null;
mockSendMessage.mockReset();
mockSendInterruptResponse.mockReset();
mockReconnect.mockReset();
mockStatus = "connected";
});
describe("ChatPage", () => {
it("renders chat interface with input field and header", () => {
render(<ChatPage />);
expect(screen.getByText("Inbox")).toBeInTheDocument();
expect(screen.getByPlaceholderText("Message Smart Support...")).toBeInTheDocument();
expect(screen.getByRole("button", { name: "Send Message" })).toBeInTheDocument();
});
it("user can type and submit a message", () => {
render(<ChatPage />);
const input = screen.getByPlaceholderText("Message Smart Support...");
fireEvent.change(input, { target: { value: "Hello bot" } });
fireEvent.keyDown(input, { key: "Enter" });
expect(mockSendMessage).toHaveBeenCalledWith("Hello bot");
expect(screen.getByText("Hello bot")).toBeInTheDocument();
expect(screen.getByText("You")).toBeInTheDocument();
});
it("displays streaming tokens as they arrive", () => {
render(<ChatPage />);
act(() => {
capturedOnMessage?.({ type: "token", agent: "Bot", content: "Hello " });
});
act(() => {
capturedOnMessage?.({ type: "token", agent: "Bot", content: "world" });
});
expect(screen.getByText("Hello world")).toBeInTheDocument();
});
it("shows interrupt prompt when interrupt message received", () => {
render(<ChatPage />);
act(() => {
capturedOnMessage?.({
type: "interrupt",
thread_id: "t1",
action: "cancel_order",
params: { order_id: "ORD-999" },
});
});
expect(screen.getByText("Action Requires Approval")).toBeInTheDocument();
expect(screen.getByText("cancel_order")).toBeInTheDocument();
});
it("shows error message when server sends error", () => {
render(<ChatPage />);
act(() => {
capturedOnMessage?.({ type: "error", message: "Something went wrong" });
});
expect(screen.getByText("Error: Something went wrong")).toBeInTheDocument();
});
it("renders welcome message in empty state", () => {
render(<ChatPage />);
expect(screen.getByText("Hello! How can I help you today?")).toBeInTheDocument();
});
});

View File

@@ -13,10 +13,8 @@ import type {
ToolAction, ToolAction,
} from "../types"; } from "../types";
let msgCounter = 0;
function nextId(): string { function nextId(): string {
msgCounter += 1; return crypto.randomUUID();
return `msg-${msgCounter}`;
} }
export function ChatPage() { export function ChatPage() {
@@ -68,6 +66,48 @@ export function ChatPage() {
setIsWaiting(false); setIsWaiting(false);
break; break;
} }
case "clarification": {
setMessages((prev) => [
...prev,
{
id: nextId(),
sender: "agent",
agent: "System",
content: msg.message,
timestamp: Date.now(),
},
]);
setIsWaiting(false);
break;
}
case "interrupt_expired": {
setCurrentInterrupt(null);
setMessages((prev) => [
...prev,
{
id: nextId(),
sender: "agent",
agent: "System",
content: msg.message,
timestamp: Date.now(),
},
]);
setIsWaiting(false);
break;
}
case "tool_result": {
setToolActions((prev) => {
const last = prev[prev.length - 1];
if (last && last.tool === msg.tool && last.agent === msg.agent) {
return [
...prev.slice(0, -1),
{ ...last, result: msg.result },
];
}
return prev;
});
break;
}
case "message_complete": { case "message_complete": {
setMessages((prev) => { setMessages((prev) => {
const last = prev[prev.length - 1]; const last = prev[prev.length - 1];

View File

@@ -0,0 +1,70 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, waitFor } from "@testing-library/react";
import { DashboardPage } from "./DashboardPage";
vi.mock("../api", () => ({
fetchAnalytics: vi.fn(),
}));
import { fetchAnalytics } from "../api";
const mockFetchAnalytics = vi.mocked(fetchAnalytics);
beforeEach(() => {
mockFetchAnalytics.mockReset();
});
const MOCK_DATA = {
range: "30d",
total_conversations: 100,
resolution_rate: 0.75,
escalation_rate: 0.25,
avg_turns_per_conversation: 3.5,
avg_cost_per_conversation_usd: 0.03,
agent_usage: [{ agent: "order_agent", count: 50, percentage: 0.5 }],
interrupt_stats: { total: 10, approved: 8, rejected: 2, expired: 0 },
};
describe("DashboardPage", () => {
it("renders loading state initially", () => {
mockFetchAnalytics.mockReturnValue(new Promise(() => {})); // never resolves
render(<DashboardPage />);
expect(document.querySelector(".skeleton-box")).toBeTruthy();
});
it("renders data after successful fetch", async () => {
mockFetchAnalytics.mockResolvedValue(MOCK_DATA);
render(<DashboardPage />);
await waitFor(() => {
expect(screen.getByText("100")).toBeInTheDocument();
});
expect(screen.getByText("75.0%")).toBeInTheDocument();
expect(screen.getByText("$0.03")).toBeInTheDocument();
});
it("renders error state on fetch failure", async () => {
mockFetchAnalytics.mockRejectedValue(new Error("Network error"));
render(<DashboardPage />);
await waitFor(() => {
expect(screen.getByText("Failed to load analytics")).toBeInTheDocument();
});
expect(screen.getByText("Network error")).toBeInTheDocument();
});
it("renders empty state when data has zero conversations", async () => {
mockFetchAnalytics.mockResolvedValue({
...MOCK_DATA,
total_conversations: 0,
agent_usage: [],
interrupt_stats: { total: 0, approved: 0, rejected: 0, expired: 0 },
});
render(<DashboardPage />);
await waitFor(() => {
expect(screen.getByText("0")).toBeInTheDocument();
});
expect(screen.getByText("No agent activity recorded yet.")).toBeInTheDocument();
expect(screen.getByText("No interrupt events recorded yet.")).toBeInTheDocument();
});
});

View File

@@ -1,4 +1,5 @@
import { useState, useEffect } from "react"; import { useState, useEffect } from "react";
import { fetchAnalytics, AnalyticsData } from "../api";
const RANGE_OPTIONS = [ const RANGE_OPTIONS = [
{ value: "7d", label: "7 days" }, { value: "7d", label: "7 days" },
@@ -6,36 +7,19 @@ const RANGE_OPTIONS = [
{ value: "30d", label: "30 days" }, { value: "30d", label: "30 days" },
]; ];
// Mock Data
const MOCK_DATA = {
total_conversations: 4208,
resolution_rate: 0.724,
escalation_rate: 0.276,
avg_turns_per_conversation: 3.4,
total_tokens: 1450200,
total_cost_usd: 12.45,
agent_usage: [
{ agent_name: "Order Specialist", message_count: 8540, total_tokens: 854000, total_cost_usd: 7.20 },
{ agent_name: "Billing Assistant", message_count: 3120, total_tokens: 412000, total_cost_usd: 3.50 },
{ agent_name: "Router & Orchestrator", message_count: 4208, total_tokens: 184200, total_cost_usd: 1.75 },
],
interrupt_stats: {
total: 412,
approved: 380,
rejected: 28,
expired: 4,
}
};
export function DashboardPage() { export function DashboardPage() {
const [range, setRange] = useState("30d"); const [range, setRange] = useState("30d");
const [isLoading, setIsLoading] = useState(true); const [isLoading, setIsLoading] = useState(true);
const data = MOCK_DATA; const [data, setData] = useState<AnalyticsData | null>(null);
const [error, setError] = useState<string | null>(null);
useEffect(() => { useEffect(() => {
setIsLoading(true); setIsLoading(true);
const timer = setTimeout(() => setIsLoading(false), 1200); setError(null);
return () => clearTimeout(timer); fetchAnalytics(range)
.then((result) => setData(result))
.catch((err: Error) => setError(err.message))
.finally(() => setIsLoading(false));
}, [range]); }, [range]);
function pct(value: number): string { function pct(value: number): string {
@@ -80,8 +64,8 @@ export function DashboardPage() {
{isLoading ? ( {isLoading ? (
<> <>
<div style={{ display: "grid", gridTemplateColumns: "repeat(auto-fit, minmax(200px, 1fr))", gap: "1.5rem", marginBottom: "2.5rem" }}> <div style={{ display: "grid", gridTemplateColumns: "repeat(auto-fit, minmax(200px, 1fr))", gap: "1.5rem", marginBottom: "2.5rem" }}>
{[1, 2, 3, 4, 5].map(i => ( {[1, 2, 3, 4].map(i => (
<div key={i} className="skeleton-box" style={{ height: "120px", padding: "1.5rem", borderRadius: "var(--radius-xl)", border: "1px solid var(--border-light)", background: "var(--bg-surface)" }}> <div key={i} className="skeleton-box section-card" style={{ height: "120px" }}>
<div className="skeleton-text" style={{ width: "60%", height: "12px", marginBottom: "1.5rem" }}></div> <div className="skeleton-text" style={{ width: "60%", height: "12px", marginBottom: "1.5rem" }}></div>
<div className="skeleton-text" style={{ width: "40%", height: "30px", marginBottom: "1rem" }}></div> <div className="skeleton-text" style={{ width: "40%", height: "30px", marginBottom: "1rem" }}></div>
<div className="skeleton-text" style={{ width: "80%", height: "12px" }}></div> <div className="skeleton-text" style={{ width: "80%", height: "12px" }}></div>
@@ -93,42 +77,56 @@ export function DashboardPage() {
<div className="skeleton-box" style={{ height: "300px", borderRadius: "var(--radius-xl)", background: "var(--bg-surface)" }}></div> <div className="skeleton-box" style={{ height: "300px", borderRadius: "var(--radius-xl)", background: "var(--bg-surface)" }}></div>
</div> </div>
</> </>
) : error ? (
<div className="error-state">
<p className="error-state__title">Failed to load analytics</p>
<p className="error-state__description">{error}</p>
<button onClick={() => setRange(range)} className="btn btn-secondary" style={{ marginTop: "1rem" }}>Retry</button>
</div>
) : !data ? (
<div className="empty-state">
<p className="empty-state__title">No analytics data available</p>
<p className="empty-state__description">Start some conversations to see metrics here.</p>
</div>
) : ( ) : (
<> <>
<div style={{ display: "grid", gridTemplateColumns: "repeat(auto-fit, minmax(200px, 1fr))", gap: "1.5rem", marginBottom: "2.5rem" }}> <div style={{ display: "grid", gridTemplateColumns: "repeat(auto-fit, minmax(200px, 1fr))", gap: "1.5rem", marginBottom: "2.5rem" }}>
<MetricBox label="Tickets Processed" value={data.total_conversations.toLocaleString()} trend="+12% vs last month" /> <MetricBox label="Tickets Processed" value={data.total_conversations.toLocaleString()} trend={`Range: ${data.range}`} />
<MetricBox label="Auto-Resolution Rate" value={pct(data.resolution_rate)} trend="Target: 70%" positive /> <MetricBox label="Auto-Resolution Rate" value={pct(data.resolution_rate)} trend="Target: 70%" positive={data.resolution_rate >= 0.7} />
<MetricBox label="Human Escalations" value={pct(data.escalation_rate)} trend="Avg 28%" /> <MetricBox label="Human Escalations" value={pct(data.escalation_rate)} trend="Lower is better" />
<MetricBox label="Human-in-the-Loop Prompts" value={data.interrupt_stats.total.toLocaleString()} trend="High Risk Actions Intercepted" /> <MetricBox label="Avg Cost / Conversation" value={formatCost(data.avg_cost_per_conversation_usd)} trend={`${data.avg_turns_per_conversation.toFixed(1)} avg turns`} />
<MetricBox label="LLM Intelligence Cost" value={formatCost(data.total_cost_usd)} trend={`${(data.total_tokens / 1000).toLocaleString()}k Tokens`} />
</div> </div>
<div style={{ display: "grid", gridTemplateColumns: "2fr 1fr", gap: "1.5rem" }}> <div style={{ display: "grid", gridTemplateColumns: "2fr 1fr", gap: "1.5rem" }}>
{/* Agent Workload Table */} {/* Agent Workload Table */}
<div style={{ backgroundColor: "var(--bg-surface)", borderRadius: "var(--radius-xl)", padding: "1.5rem", border: "1px solid var(--border-light)" }}> <div className="section-card">
<h3 style={{ fontSize: "1.125rem", color: "var(--text-primary)", fontWeight: 700, margin: "0 0 1rem 0" }}>Agent Workload Distribution</h3> <h3 style={{ fontSize: "1.125rem", color: "var(--text-primary)", fontWeight: 700, margin: "0 0 1rem 0" }}>Agent Workload Distribution</h3>
<table style={{ width: "100%", borderCollapse: "collapse", textAlign: "left" }}> {data.agent_usage.length === 0 ? (
<thead> <p style={{ color: "var(--text-secondary)", fontSize: "0.875rem" }}>No agent activity recorded yet.</p>
<tr style={{ borderBottom: "2px solid var(--border-light)" }}> ) : (
<th style={{ padding: "0.75rem 0", fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)" }}>Agent Name</th> <table className="data-table">
<th style={{ padding: "0.75rem 0", fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)" }}>Actions Handled</th> <thead>
<th style={{ padding: "0.75rem 0", fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)" }}>Cost Footprint</th> <tr>
</tr> <th style={{ paddingLeft: 0 }}>Agent Name</th>
</thead> <th>Message Count</th>
<tbody> <th>Share</th>
{data.agent_usage.map((a) => (
<tr key={a.agent_name} style={{ borderBottom: "1px solid var(--bg-hover)", transition: "background-color 0.2s" }} onMouseEnter={(e) => e.currentTarget.style.backgroundColor = 'var(--bg-hover)'} onMouseLeave={(e) => e.currentTarget.style.backgroundColor = 'transparent'}>
<td style={{ padding: "1rem 0 1rem 1rem", fontWeight: 600, fontSize: "0.9375rem" }}>{a.agent_name}</td>
<td style={{ padding: "1rem 0", fontSize: "0.9375rem" }}>{a.message_count.toLocaleString()}</td>
<td style={{ padding: "1rem 1rem 1rem 0", fontSize: "0.9375rem" }}>{formatCost(a.total_cost_usd)}</td>
</tr> </tr>
))} </thead>
</tbody> <tbody>
</table> {data.agent_usage.map((a) => (
<tr key={a.agent}>
<td style={{ paddingLeft: 0, fontWeight: 600 }}>{a.agent}</td>
<td>{a.count.toLocaleString()}</td>
<td>{pct(a.percentage)}</td>
</tr>
))}
</tbody>
</table>
)}
</div> </div>
{/* Human in the loop card */} {/* Human in the loop card */}
<div style={{ backgroundColor: "var(--bg-surface)", borderRadius: "var(--radius-xl)", padding: "1.5rem", border: "1px solid var(--border-light)" }}> <div className="section-card">
<div style={{ display: "flex", alignItems: "center", gap: "0.5rem", marginBottom: "1rem" }}> <div style={{ display: "flex", alignItems: "center", gap: "0.5rem", marginBottom: "1rem" }}>
<h3 style={{ fontSize: "1.125rem", color: "var(--text-primary)", fontWeight: 700, margin: 0 }}>Security Approvals</h3> <h3 style={{ fontSize: "1.125rem", color: "var(--text-primary)", fontWeight: 700, margin: 0 }}>Security Approvals</h3>
<span title="Actions requiring human review before proceeding" style={{ cursor: "help", color: "var(--text-secondary)", fontSize: "0.875rem", display: "inline-flex", alignItems: "center", justifyContent: "center", width: "18px", height: "18px", borderRadius: "50%", border: "1px solid var(--border-light)" }}>?</span> <span title="Actions requiring human review before proceeding" style={{ cursor: "help", color: "var(--text-secondary)", fontSize: "0.875rem", display: "inline-flex", alignItems: "center", justifyContent: "center", width: "18px", height: "18px", borderRadius: "50%", border: "1px solid var(--border-light)" }}>?</span>
@@ -136,24 +134,28 @@ export function DashboardPage() {
<p style={{ fontSize: "0.875rem", color: "var(--text-secondary)", marginBottom: "1.5rem", lineHeight: 1.5 }}> <p style={{ fontSize: "0.875rem", color: "var(--text-secondary)", marginBottom: "1.5rem", lineHeight: 1.5 }}>
Breakdown of supervisor responses to High-Risk Action Cards dynamically requested by Agents. Breakdown of supervisor responses to High-Risk Action Cards dynamically requested by Agents.
</p> </p>
<div style={{ display: "flex", flexDirection: "column", gap: "1rem" }}>
<div style={{ display: "flex", justifyContent: "space-between", alignItems: "center" }}>
<span style={{ fontWeight: 600, fontSize: "0.875rem" }}>Action Approved</span>
<span style={{ color: "#059669", fontWeight: 700 }}>{data.interrupt_stats.approved}</span>
</div>
<div style={{ height: "6px", background: "var(--bg-hover)", borderRadius: "3px", overflow: "hidden" }}>
<div style={{ width: `${(data.interrupt_stats.approved / data.interrupt_stats.total) * 100}%`, height: "100%", background: "#059669" }} />
</div>
<div style={{ display: "flex", justifyContent: "space-between", alignItems: "center", marginTop: "0.5rem" }}> {data.interrupt_stats.total === 0 ? (
<span style={{ fontWeight: 600, fontSize: "0.875rem" }}>Action Rejected (Escalated)</span> <p style={{ color: "var(--text-secondary)", fontSize: "0.875rem" }}>No interrupt events recorded yet.</p>
<span style={{ color: "#DC2626", fontWeight: 700 }}>{data.interrupt_stats.rejected}</span> ) : (
<div style={{ display: "flex", flexDirection: "column", gap: "1rem" }}>
<div style={{ display: "flex", justifyContent: "space-between", alignItems: "center" }}>
<span style={{ fontWeight: 600, fontSize: "0.875rem" }}>Action Approved</span>
<span style={{ color: "#059669", fontWeight: 700 }}>{data.interrupt_stats.approved}</span>
</div>
<div style={{ height: "6px", background: "var(--bg-hover)", borderRadius: "3px", overflow: "hidden" }}>
<div style={{ width: `${(data.interrupt_stats.approved / data.interrupt_stats.total) * 100}%`, height: "100%", background: "#059669" }} />
</div>
<div style={{ display: "flex", justifyContent: "space-between", alignItems: "center", marginTop: "0.5rem" }}>
<span style={{ fontWeight: 600, fontSize: "0.875rem" }}>Action Rejected (Escalated)</span>
<span style={{ color: "#DC2626", fontWeight: 700 }}>{data.interrupt_stats.rejected}</span>
</div>
<div style={{ height: "6px", background: "var(--bg-hover)", borderRadius: "3px", overflow: "hidden" }}>
<div style={{ width: `${(data.interrupt_stats.rejected / data.interrupt_stats.total) * 100}%`, height: "100%", background: "#DC2626" }} />
</div>
</div> </div>
<div style={{ height: "6px", background: "var(--bg-hover)", borderRadius: "3px", overflow: "hidden" }}> )}
<div style={{ width: `${(data.interrupt_stats.rejected / data.interrupt_stats.total) * 100}%`, height: "100%", background: "#DC2626" }} />
</div>
</div>
</div> </div>
</div> </div>
</> </>
@@ -165,24 +167,10 @@ export function DashboardPage() {
function MetricBox({ label, value, trend, positive }: { label: string, value: string | number, trend: string, positive?: boolean }) { function MetricBox({ label, value, trend, positive }: { label: string, value: string | number, trend: string, positive?: boolean }) {
return ( return (
<div style={{ <div className="section-card" style={{ display: "flex", flexDirection: "column", gap: "0.5rem" }}>
backgroundColor: "var(--bg-surface)", <div className="stat-label">{label}</div>
padding: "1.5rem", <div style={{ fontSize: "2rem", fontWeight: 700, color: "var(--text-primary)" }}>{value}</div>
borderRadius: "var(--radius-xl)", <div style={{ fontSize: "0.8125rem", color: positive ? "#059669" : "var(--text-secondary)", fontWeight: positive ? 600 : 400 }}>{trend}</div>
border: "1px solid var(--border-light)",
display: "flex",
flexDirection: "column",
gap: "0.5rem"
}}>
<div style={{ fontSize: "0.8125rem", color: "var(--text-secondary)", textTransform: "uppercase", letterSpacing: "0.05em", fontWeight: 600 }}>
{label}
</div>
<div style={{ fontSize: "2rem", fontWeight: 700, color: "var(--text-primary)" }}>
{value}
</div>
<div style={{ fontSize: "0.8125rem", color: positive ? "#059669" : "var(--text-secondary)", fontWeight: positive ? 600 : 400 }}>
{trend}
</div>
</div> </div>
); );
} }

View File

@@ -0,0 +1,106 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, waitFor } from "@testing-library/react";
import { MemoryRouter } from "react-router-dom";
import { ReplayListPage } from "./ReplayListPage";
vi.mock("../api", () => ({
fetchConversations: vi.fn(),
}));
import { fetchConversations } from "../api";
const mockFetchConversations = vi.mocked(fetchConversations);
beforeEach(() => {
mockFetchConversations.mockReset();
});
function renderWithRouter() {
return render(
<MemoryRouter>
<ReplayListPage />
</MemoryRouter>
);
}
describe("ReplayListPage", () => {
it("renders loading state initially", () => {
mockFetchConversations.mockReturnValue(new Promise(() => {}));
renderWithRouter();
expect(document.querySelector(".skeleton-box")).toBeTruthy();
});
it("renders empty state when no conversations", async () => {
mockFetchConversations.mockResolvedValue({
conversations: [],
total: 0,
page: 1,
per_page: 20,
});
renderWithRouter();
await waitFor(() => {
expect(screen.getByText("No conversations yet")).toBeInTheDocument();
});
});
it("renders conversation list on success", async () => {
mockFetchConversations.mockResolvedValue({
conversations: [
{
thread_id: "t1",
created_at: "2026-04-01T00:00:00Z",
last_activity: "2026-04-01T00:01:00Z",
status: "resolved",
total_tokens: 100,
total_cost_usd: 0.01,
},
],
total: 1,
page: 1,
per_page: 20,
});
renderWithRouter();
await waitFor(() => {
expect(screen.getByText("t1")).toBeInTheDocument();
});
expect(screen.getByText("resolved")).toBeInTheDocument();
});
it("renders error state on fetch failure", async () => {
mockFetchConversations.mockRejectedValue(new Error("Server down"));
renderWithRouter();
await waitFor(() => {
expect(screen.getByText("Failed to load conversations")).toBeInTheDocument();
});
expect(screen.getByText("Server down")).toBeInTheDocument();
});
it("applies correct status badge classes", async () => {
mockFetchConversations.mockResolvedValue({
conversations: [
{ thread_id: "t1", created_at: "", last_activity: "", status: "resolved", total_tokens: 0, total_cost_usd: 0 },
{ thread_id: "t2", created_at: "", last_activity: "", status: "escalated", total_tokens: 0, total_cost_usd: 0 },
{ thread_id: "t3", created_at: "", last_activity: "", status: null, total_tokens: 0, total_cost_usd: 0 },
],
total: 3,
page: 1,
per_page: 20,
});
renderWithRouter();
await waitFor(() => {
expect(screen.getByText("resolved")).toBeInTheDocument();
});
const resolvedBadge = screen.getByText("resolved");
expect(resolvedBadge.className).toContain("status-badge--resolved");
const escalatedBadge = screen.getByText("escalated");
expect(escalatedBadge.className).toContain("status-badge--escalated");
const activeBadge = screen.getByText("active");
expect(activeBadge.className).toContain("status-badge--active");
});
});

View File

@@ -1,130 +1,132 @@
import { useState } from "react"; import { useState, useEffect } from "react";
import { useNavigate } from "react-router-dom"; import { useNavigate } from "react-router-dom";
import { fetchConversations, ConversationSummary } from "../api";
// Mock Data
const MOCK_CONVERSATIONS = [
{ thread_id: "th_9281ja8s9", user: "Maria G.", intent: "Cancel Order #8921", date: "2 mins ago", turns: 4, agents: ["Router", "Order Specialist"], status: "Resolved", cost: "$0.02" },
{ thread_id: "th_1092jf8u1", user: "David C.", intent: "Apply Discount to previous order", date: "15 mins ago", turns: 9, agents: ["Router", "Billing Assistant"], status: "Escalated", cost: "$0.08", hitl: true },
{ thread_id: "th_0099ab7x2", user: "Sarah L.", intent: "Where is my package?", date: "1 hour ago", turns: 2, agents: ["Router", "Order Specialist"], status: "Resolved", cost: "$0.01" },
{ thread_id: "th_5518kc3p0", user: "John M.", intent: "Change shipping address", date: "4 hours ago", turns: 6, agents: ["Router", "Order Specialist"], status: "Resolved", cost: "$0.04" },
{ thread_id: "th_1102po9m4", user: "Elena P.", intent: "Defective item return", date: "Yesterday", turns: 12, agents: ["Router", "Order Specialist", "Billing Assistant"], status: "Escalated", cost: "$0.15", hitl: true },
];
export function ReplayListPage() { export function ReplayListPage() {
const navigate = useNavigate(); const navigate = useNavigate();
const [page, setPage] = useState(1); const [page, setPage] = useState(1);
const totalPages = 24; const [perPage] = useState(20);
const [total, setTotal] = useState(0);
const [conversations, setConversations] = useState<ConversationSummary[]>([]);
const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
useEffect(() => {
setIsLoading(true);
setError(null);
fetchConversations(page, perPage)
.then((result) => {
setConversations(result.conversations);
setTotal(result.total);
})
.catch((err: Error) => setError(err.message))
.finally(() => setIsLoading(false));
}, [page, perPage]);
const totalPages = Math.max(1, Math.ceil(total / perPage));
function formatDate(iso: string): string {
try {
return new Date(iso).toLocaleString();
} catch {
return iso;
}
}
function formatCost(usd: number): string {
return `$${usd.toFixed(2)}`;
}
function statusClass(status: string | null): string {
if (status === "resolved") return "status-badge status-badge--resolved";
if (status === "escalated") return "status-badge status-badge--escalated";
return "status-badge status-badge--active";
}
return ( return (
<div className="page-container"> <div className="page-container">
<div className="page-header" style={{ display: "flex", justifyContent: "space-between", alignItems: "flex-end", marginBottom: "2rem" }}> <div className="page-header">
<div> <h2>Conversation Replay</h2>
<h2>Conversation Replay</h2> <p>Review autonomous agent sessions and audit MCP action execution trails.</p>
<p>Review autonomous agent sessions and audit MCP action execution trails.</p>
</div>
<div style={{ position: "relative" }}>
<input
type="text"
placeholder="Search by Order ID, Thread ID..."
style={{
padding: "0.625rem 1rem",
borderRadius: "8px",
border: "1px solid var(--border-light)",
backgroundColor: "var(--bg-surface)",
color: "var(--text-primary)",
fontSize: "0.875rem",
width: "280px"
}}
/>
</div>
</div> </div>
<div style={{ backgroundColor: "var(--bg-surface)", borderRadius: "var(--radius-xl)", overflow: "hidden", border: "1px solid var(--border-light)" }}> {error ? (
<table style={{ width: "100%", borderCollapse: "collapse", textAlign: "left" }}> <div className="error-state">
<thead> <p className="error-state__title">Failed to load conversations</p>
<tr style={{ backgroundColor: "var(--bg-surface-inner)", borderBottom: "1px solid var(--border-light)" }}> <p className="error-state__description">{error}</p>
<th style={{ padding: "1rem 1.5rem", fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)", fontWeight: 600 }}>Thread</th> <button onClick={() => setPage(1)} className="btn btn-secondary" style={{ marginTop: "1rem" }}>Retry</button>
<th style={{ padding: "1rem 1.5rem", fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)", fontWeight: 600 }}>Detected Intent</th> </div>
<th style={{ padding: "1rem 1.5rem", fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)", fontWeight: 600 }}>Agents Invoked</th> ) : isLoading ? (
<th style={{ padding: "1rem 1.5rem", fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)", fontWeight: 600 }}>Outcome</th> <div className="section-card" style={{ padding: "2rem" }}>
<th style={{ padding: "1rem 1.5rem", fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)", fontWeight: 600 }}>Performance</th> {[1, 2, 3, 4, 5].map(i => (
</tr> <div key={i} className="skeleton-box" style={{ height: "60px", marginBottom: "1rem", borderRadius: "8px" }}>
</thead> <div className="skeleton-text" style={{ width: "30%", height: "14px", margin: "12px 16px" }}></div>
<tbody> </div>
{MOCK_CONVERSATIONS.map((c, i) => ( ))}
<tr </div>
key={c.thread_id} ) : conversations.length === 0 ? (
onClick={() => navigate(`/replay/${c.thread_id}`)} <div className="empty-state">
style={{ <p className="empty-state__title">No conversations yet</p>
borderBottom: i === MOCK_CONVERSATIONS.length - 1 ? "none" : "1px solid var(--border-light)", <p className="empty-state__description">Start a chat session to see conversations here.</p>
cursor: "pointer", </div>
transition: "background-color 0.2s" ) : (
}} <div className="section-card" style={{ padding: 0, overflow: "hidden" }}>
className="replay-row-hover" <table className="data-table">
> <thead>
<td style={{ padding: "1.25rem 1.5rem" }}> <tr style={{ backgroundColor: "var(--bg-surface-inner)" }}>
<div style={{ fontWeight: 600, color: "var(--text-primary)", fontSize: "0.9375rem" }}>{c.user}</div> <th>Thread</th>
<div style={{ fontSize: "0.75rem", color: "var(--text-secondary)", fontFamily: "monospace", marginTop: "4px" }}>{c.thread_id}</div> <th>Created</th>
</td> <th>Last Activity</th>
<td style={{ padding: "1.25rem 1.5rem" }}> <th>Status</th>
<div style={{ fontWeight: 500, color: "var(--text-primary)", fontSize: "0.9375rem" }}>{c.intent}</div> <th>Cost</th>
<div style={{ fontSize: "0.75rem", color: "var(--text-secondary)", marginTop: "4px" }}>{c.date}</div>
</td>
<td style={{ padding: "1.25rem 1.5rem" }}>
<div style={{ display: "flex", flexWrap: "wrap", gap: "6px" }}>
{c.agents.map(a => (
<span key={a} style={{ fontSize: "0.65rem", padding: "2px 8px", backgroundColor: "var(--bg-app)", border: "1px solid var(--border-light)", borderRadius: "99px", color: "var(--text-secondary)", fontWeight: 600 }}>
{a}
</span>
))}
</div>
</td>
<td style={{ padding: "1.25rem 1.5rem" }}>
<span style={{
fontSize: "0.75rem",
padding: "4px 10px",
borderRadius: "6px",
fontWeight: 600,
backgroundColor: c.status === "Resolved" ? "#DEF7EC" : "#FDE8E8",
color: c.status === "Resolved" ? "#03543F" : "#9B1C1C",
}}>
{c.status}
</span>
{c.hitl && <span style={{ marginLeft: "8px", fontSize: "1.25rem" }} title="Human in the loop invoked">🔒</span>}
</td>
<td style={{ padding: "1.25rem 1.5rem", fontSize: "0.875rem", color: "var(--text-secondary)" }}>
{c.turns} turns {c.cost}
</td>
</tr> </tr>
))} </thead>
</tbody> <tbody>
</table> {conversations.map((c) => (
<tr
<div style={{ padding: "1.25rem 1.5rem", borderTop: "1px solid var(--border-light)", display: "flex", justifyContent: "space-between", alignItems: "center", backgroundColor: "var(--bg-surface-inner)" }}> key={c.thread_id}
<span style={{ fontSize: "0.875rem", color: "var(--text-secondary)" }}>Showing 1-5 of 120 sessions</span> onClick={() => navigate(`/replay/${c.thread_id}`)}
<div style={{ display: "flex", gap: "0.5rem" }}> style={{ cursor: "pointer" }}
<button >
onClick={(e) => { e.stopPropagation(); setPage(p => Math.max(1, p - 1)) }} <td>
disabled={page === 1} <span style={{ fontWeight: 600, fontFamily: "monospace" }}>{c.thread_id}</span>
className="btn btn-secondary" </td>
> <td style={{ color: "var(--text-secondary)" }}>{formatDate(c.created_at)}</td>
Previous <td style={{ color: "var(--text-secondary)" }}>{formatDate(c.last_activity)}</td>
</button> <td>
<button <span className={statusClass(c.status)}>{c.status ?? "active"}</span>
onClick={(e) => { e.stopPropagation(); setPage(p => Math.min(totalPages, p + 1)) }} </td>
disabled={page >= totalPages} <td style={{ color: "var(--text-secondary)" }}>
className="btn btn-secondary" {c.total_tokens.toLocaleString()} tokens / {formatCost(c.total_cost_usd)}
> </td>
Next </tr>
</button> ))}
</tbody>
</table>
<div className="pagination-bar">
<span className="pagination-bar__info">
Showing {(page - 1) * perPage + 1}-{Math.min(page * perPage, total)} of {total} sessions
</span>
<div className="pagination-bar__controls">
<button
onClick={(e) => { e.stopPropagation(); setPage(p => Math.max(1, p - 1)) }}
disabled={page === 1}
className="btn btn-secondary"
>
Previous
</button>
<button
onClick={(e) => { e.stopPropagation(); setPage(p => Math.min(totalPages, p + 1)) }}
disabled={page >= totalPages}
className="btn btn-secondary"
>
Next
</button>
</div>
</div> </div>
</div> </div>
</div> )}
<style>{`
.replay-row-hover:hover {
background-color: var(--bg-hover) !important;
}
`}</style>
</div> </div>
); );
} }

View File

@@ -0,0 +1,85 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, waitFor } from "@testing-library/react";
import { MemoryRouter, Route, Routes } from "react-router-dom";
import { ReplayPage } from "./ReplayPage";
vi.mock("../api", () => ({
fetchReplay: vi.fn(),
}));
vi.mock("../components/ReplayTimeline", () => ({
ReplayTimeline: ({ steps }: { steps: unknown[] }) => (
<div data-testid="replay-timeline">{steps.length} steps</div>
),
}));
import { fetchReplay } from "../api";
const mockFetchReplay = vi.mocked(fetchReplay);
beforeEach(() => {
mockFetchReplay.mockReset();
});
function renderWithRoute(threadId: string) {
return render(
<MemoryRouter initialEntries={[`/replay/${threadId}`]}>
<Routes>
<Route path="/replay/:threadId" element={<ReplayPage />} />
</Routes>
</MemoryRouter>
);
}
describe("ReplayPage", () => {
it("renders loading state initially", () => {
mockFetchReplay.mockReturnValue(new Promise(() => {}));
renderWithRoute("t1");
expect(document.querySelector(".skeleton-box")).toBeTruthy();
});
it("renders replay steps on success", async () => {
mockFetchReplay.mockResolvedValue({
thread_id: "t1",
total_steps: 2,
page: 1,
per_page: 100,
steps: [
{ step: 1, type: "message", content: "Hello", agent: null, tool: null, params: null, result: null, timestamp: "2026-04-01T00:00:00Z" },
{ step: 2, type: "response", content: "Hi!", agent: "bot", tool: null, params: null, result: null, timestamp: "2026-04-01T00:00:01Z" },
],
});
renderWithRoute("t1");
await waitFor(() => {
expect(screen.getByTestId("replay-timeline")).toBeInTheDocument();
});
expect(screen.getByText("2 steps")).toBeInTheDocument();
// Thread ID appears in multiple places (header + sidebar)
expect(screen.getAllByText("t1").length).toBeGreaterThan(0);
});
it("renders empty state when no steps", async () => {
mockFetchReplay.mockResolvedValue({
thread_id: "t1",
total_steps: 0,
page: 1,
per_page: 100,
steps: [],
});
renderWithRoute("t1");
await waitFor(() => {
expect(screen.getByText("No replay steps found")).toBeInTheDocument();
});
});
it("renders error state on fetch failure", async () => {
mockFetchReplay.mockRejectedValue(new Error("Not found"));
renderWithRoute("t1");
await waitFor(() => {
expect(screen.getByText("Failed to load replay")).toBeInTheDocument();
});
expect(screen.getByText("Not found")).toBeInTheDocument();
});
});

View File

@@ -1,70 +1,90 @@
import { useState } from "react"; import { useState, useEffect } from "react";
import { useParams, useNavigate } from "react-router-dom"; import { useParams, useNavigate } from "react-router-dom";
import { ReplayTimeline } from "../components/ReplayTimeline"; import { ReplayTimeline } from "../components/ReplayTimeline";
import { fetchReplay, ReplayStep } from "../api";
const MOCK_STEPS = [
{ step: 1, type: "message", timestamp: "2026-04-05T10:00:00Z", agent: "Customer", content: "My laptop arrived with a shattered screen. I need a replacement immediately! Order #8921." },
{ step: 2, type: "token", timestamp: "2026-04-05T10:00:02Z", agent: "Router", content: "Intent detected: 'return_request'. Routing to Order Specialist." },
{ step: 3, type: "tool_call", timestamp: "2026-04-05T10:00:03Z", agent: "Order Specialist", tool: "get_order_details", params: { order_id: "8921" } },
{ step: 4, type: "tool_result", timestamp: "2026-04-05T10:00:04Z", tool: "get_order_details", result: { status: "Delivered", items: ["MacBook Pro 16", "USB-C Hub"], total_value: 2499.00 } },
{ step: 5, type: "tool_call", timestamp: "2026-04-05T10:00:06Z", agent: "Order Specialist", tool: "initiate_return", params: { order_id: "8921", reason: "Damaged in transit", replacement: true } },
{ step: 6, type: "interrupt", timestamp: "2026-04-05T10:00:06Z", agent: "System", content: "SECURITY POLICY TRIGGERED: High-Value Return (>$1000). Human approval required before initiating RMS workflow." },
{ step: 7, type: "interrupt_response", timestamp: "2026-04-05T10:15:22Z", agent: "Alex Thompson (Supervisor)", content: "REJECTED. Standard policy for shattered screens requires photo evidence before dispatching replacement unit." },
{ step: 8, type: "message", timestamp: "2026-04-05T10:15:25Z", agent: "Order Specialist", content: "I'm so sorry to hear your laptop screen was shattered! Because this is a high-value item, our policy requires a photo of the damage before we can dispatch your replacement unit. Could you please take a quick picture and upload it here?" }
];
export function ReplayPage() { export function ReplayPage() {
const { threadId } = useParams<{ threadId: string }>(); const { threadId } = useParams<{ threadId: string }>();
const navigate = useNavigate(); const navigate = useNavigate();
const [page, setPage] = useState(1); const [steps, setSteps] = useState<ReplayStep[]>([]);
const [totalSteps, setTotalSteps] = useState(0);
const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
useEffect(() => {
if (!threadId) return;
setIsLoading(true);
setError(null);
fetchReplay(threadId, 1, 100)
.then((result) => {
setSteps(result.steps);
setTotalSteps(result.total_steps);
})
.catch((err: Error) => setError(err.message))
.finally(() => setIsLoading(false));
}, [threadId]);
if (!threadId) return null; if (!threadId) return null;
return ( return (
<div className="page-container"> <div className="page-container">
<div className="page-header" style={{ display: "flex", justifyContent: "space-between", alignItems: "flex-end", marginBottom: "2rem" }}> <div className="page-header" style={{ marginBottom: "2rem" }}>
<div> <button
<button onClick={() => navigate("/replay")}
onClick={() => navigate("/replay")} style={{ background: "none", border: "none", color: "var(--text-secondary)", fontSize: "0.875rem", cursor: "pointer", padding: "0 0 0.5rem 0", display: "flex", alignItems: "center", gap: "0.25rem" }}
style={{ background: "none", border: "none", color: "var(--text-secondary)", fontSize: "0.875rem", cursor: "pointer", padding: "0 0 0.5rem 0", display: "flex", alignItems: "center", gap: "0.25rem" }} >
> &larr; Back to All Replays
Back to All Replays </button>
</button> <h2>Audit Trail: <span style={{ fontFamily: "monospace", color: "var(--brand-primary)" }}>{threadId}</span></h2>
<h2>Audit Trail: <span style={{ fontFamily: "monospace", color: "var(--brand-primary)" }}>{threadId}</span></h2> <p>Detailed temporal log of agent reflections, MCP tool calls, and human overrides.</p>
<p>Detailed temporal log of agent reflections, MCP tool calls, and human overrides.</p>
</div>
</div> </div>
<div style={{ display: "grid", gridTemplateColumns: "1fr 3fr", gap: "2rem" }}> {error ? (
{/* Sidebar Summary Info */} <div className="error-state">
<div style={{ backgroundColor: "var(--bg-surface)", padding: "1.5rem", borderRadius: "var(--radius-xl)", border: "1px solid var(--border-light)", alignSelf: "start" }}> <p className="error-state__title">Failed to load replay</p>
<h3 style={{ fontSize: "1rem", marginBottom: "1.25rem", color: "var(--text-primary)" }}>Session Context</h3> <p className="error-state__description">{error}</p>
</div>
<div style={{ display: "flex", flexDirection: "column", gap: "1rem" }}> ) : isLoading ? (
<div> <div style={{ display: "grid", gridTemplateColumns: "1fr 3fr", gap: "2rem" }}>
<div style={{ fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)", fontWeight: 600 }}>Customer</div> <div className="skeleton-box" style={{ height: "250px", borderRadius: "var(--radius-xl)", background: "var(--bg-surface)" }}></div>
<div style={{ fontWeight: 600, fontSize: "0.9375rem" }}>Maria G.</div> <div className="skeleton-box" style={{ height: "400px", borderRadius: "var(--radius-xl)", background: "var(--bg-surface)" }}></div>
</div> </div>
<div> ) : steps.length === 0 ? (
<div style={{ fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)", fontWeight: 600 }}>Final Outcome</div> <div className="empty-state">
<div style={{ display: "inline-block", backgroundColor: "#FDE8E8", color: "#9B1C1C", padding: "4px 8px", borderRadius: "6px", fontSize: "0.75rem", fontWeight: 700, marginTop: "4px" }}>ESCALATED 🔒</div> <p className="empty-state__title">No replay steps found</p>
</div> <p className="empty-state__description">This conversation has no recorded checkpoints.</p>
<div> </div>
<div style={{ fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)", fontWeight: 600 }}>Time Elapsed</div> ) : (
<div style={{ fontSize: "0.9375rem" }}>15m 25s</div> <div style={{ display: "grid", gridTemplateColumns: "1fr 3fr", gap: "2rem" }}>
</div> {/* Sidebar Summary Info */}
<div> <div className="section-card" style={{ alignSelf: "start" }}>
<div style={{ fontSize: "0.75rem", textTransform: "uppercase", color: "var(--text-secondary)", fontWeight: 600 }}>Total Tokens</div> <h3 style={{ fontSize: "1rem", marginBottom: "1.25rem", color: "var(--text-primary)" }}>Session Context</h3>
<div style={{ fontSize: "0.9375rem" }}>3,402 ($0.15)</div> <div style={{ display: "flex", flexDirection: "column", gap: "1rem" }}>
<div>
<div className="stat-label">Thread ID</div>
<div className="stat-value" style={{ fontSize: "0.8125rem", fontFamily: "monospace", wordBreak: "break-all" }}>{threadId}</div>
</div>
<div>
<div className="stat-label">Total Steps</div>
<div className="stat-value">{totalSteps}</div>
</div>
<div>
<div className="stat-label">Time Range</div>
<div style={{ fontSize: "0.8125rem" }}>
{steps[0]?.timestamp ? new Date(steps[0].timestamp).toLocaleString() : "N/A"}
{" \u2013 "}
{steps[steps.length - 1]?.timestamp ? new Date(steps[steps.length - 1].timestamp).toLocaleString() : "N/A"}
</div>
</div>
</div> </div>
</div> </div>
</div>
{/* Timeline */} {/* Timeline */}
<div style={{ backgroundColor: "var(--bg-surface)", padding: "2rem", borderRadius: "var(--radius-xl)", border: "1px solid var(--border-light)" }}> <div className="section-card" style={{ padding: "2rem" }}>
<ReplayTimeline steps={MOCK_STEPS as any} /> <ReplayTimeline steps={steps} />
</div>
</div> </div>
</div> )}
</div> </div>
); );
} }

View File

@@ -0,0 +1,200 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
import { ReviewPage } from "./ReviewPage";
vi.mock("../api", () => ({
startImport: vi.fn(),
fetchImportJob: vi.fn(),
fetchClassifications: vi.fn(),
approveJob: vi.fn(),
}));
import { startImport, fetchImportJob, fetchClassifications, approveJob } from "../api";
const mockStartImport = vi.mocked(startImport);
const mockFetchImportJob = vi.mocked(fetchImportJob);
const mockFetchClassifications = vi.mocked(fetchClassifications);
const mockApproveJob = vi.mocked(approveJob);
beforeEach(() => {
mockStartImport.mockReset();
mockFetchImportJob.mockReset();
mockFetchClassifications.mockReset();
mockApproveJob.mockReset();
});
describe("ReviewPage", () => {
it("renders the OpenAPI URL input form", () => {
render(<ReviewPage />);
expect(screen.getByText("Agents & Tools Registry")).toBeInTheDocument();
expect(screen.getByPlaceholderText("https://example.com/openapi.yaml")).toBeInTheDocument();
expect(screen.getByText("Scan Tools")).toBeInTheDocument();
});
it("submit form triggers API call with entered URL", async () => {
mockStartImport.mockResolvedValue({
job_id: "job-1",
status: "processing",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 5,
classified_count: 0,
error_message: null,
});
mockFetchImportJob.mockResolvedValue({
job_id: "job-1",
status: "processing",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 5,
classified_count: 0,
error_message: null,
});
render(<ReviewPage />);
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(mockStartImport).toHaveBeenCalledWith("https://example.com/openapi.yaml");
});
});
it("shows loading state during import", async () => {
mockStartImport.mockReturnValue(new Promise(() => {})); // never resolves
render(<ReviewPage />);
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://api.test.com/spec.json" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(screen.getByText("Importing...")).toBeInTheDocument();
});
});
it("displays classification results after job completes", async () => {
mockStartImport.mockResolvedValue({
job_id: "job-1",
status: "done",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 2,
classified_count: 2,
error_message: null,
});
mockFetchImportJob.mockResolvedValue({
job_id: "job-1",
status: "done",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 2,
classified_count: 2,
error_message: null,
});
mockFetchClassifications.mockResolvedValue([
{
index: 0,
access_type: "read",
needs_interrupt: false,
agent_group: "OrderAgent",
confidence: 0.95,
customer_params: [],
endpoint: { path: "/orders", method: "get", operation_id: "getOrders", summary: "List orders", description: "" },
},
{
index: 1,
access_type: "write",
needs_interrupt: true,
agent_group: "OrderAgent",
confidence: 0.9,
customer_params: ["order_id"],
endpoint: { path: "/orders/{id}/cancel", method: "post", operation_id: "cancelOrder", summary: "Cancel an order", description: "" },
},
]);
render(<ReviewPage />);
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(screen.getByText("Assigned Capabilities (2)")).toBeInTheDocument();
});
expect(screen.getByText("OrderAgent")).toBeInTheDocument();
expect(screen.getByText("/orders")).toBeInTheDocument();
expect(screen.getByText("List orders")).toBeInTheDocument();
});
it("shows error on API failure", async () => {
mockStartImport.mockRejectedValue(new Error("Network timeout"));
render(<ReviewPage />);
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(screen.getByText("Error: Network timeout")).toBeInTheDocument();
});
});
it("shows success message after approval", async () => {
// Set up initial state with classifications
mockStartImport.mockResolvedValue({
job_id: "job-1",
status: "done",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 1,
classified_count: 1,
error_message: null,
});
mockFetchImportJob.mockResolvedValue({
job_id: "job-1",
status: "done",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 1,
classified_count: 1,
error_message: null,
});
mockFetchClassifications.mockResolvedValue([
{
index: 0,
access_type: "read",
needs_interrupt: false,
agent_group: "TestAgent",
confidence: 0.9,
customer_params: [],
endpoint: { path: "/test", method: "get", operation_id: "test", summary: "Test endpoint", description: "" },
},
]);
mockApproveJob.mockResolvedValue({
job_id: "job-1",
status: "approved",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 1,
classified_count: 1,
error_message: null,
generated_tools_count: 3,
});
render(<ReviewPage />);
// Import first
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(screen.getByText("Save Configuration")).toBeInTheDocument();
});
// Approve
fireEvent.click(screen.getByText("Save Configuration"));
await waitFor(() => {
expect(screen.getByText("Configuration saved. 3 tools generated.")).toBeInTheDocument();
});
});
});

View File

@@ -1,12 +1,14 @@
import { useEffect, useRef, useState } from "react"; import { useEffect, useRef, useState } from "react";
import {
approveJob,
fetchClassifications,
fetchImportJob,
startImport,
type EndpointClassification,
type ImportJobResponse,
} from "../api";
interface ImportJob { interface FlatClassification {
job_id: string;
status: "pending" | "processing" | "done" | "error";
error?: string;
}
interface EndpointClassification {
path: string; path: string;
method: string; method: string;
summary: string; summary: string;
@@ -14,55 +16,23 @@ interface EndpointClassification {
agent_group: string; agent_group: string;
} }
interface JobResult { function flattenClassification(c: EndpointClassification): FlatClassification {
job_id: string; return {
status: string; path: c.endpoint?.path ?? "",
endpoints: EndpointClassification[]; method: c.endpoint?.method ?? "",
summary: c.endpoint?.summary ?? "",
access_type: c.access_type ?? "read",
agent_group: c.agent_group ?? "Unassigned",
};
} }
export function ReviewPage() { export function ReviewPage() {
const [url, setUrl] = useState(""); const [url, setUrl] = useState("");
const [job, setJob] = useState<ImportJob | null>(null); const [job, setJob] = useState<ImportJobResponse | null>(null);
const [result, setResult] = useState<JobResult | null>(null);
const [submitting, setSubmitting] = useState(false); const [submitting, setSubmitting] = useState(false);
const [submitError, setSubmitError] = useState<string | null>(null); const [submitError, setSubmitError] = useState<string | null>(null);
const [classifications, setClassifications] = useState<EndpointClassification[]>([ const [approveStatus, setApproveStatus] = useState<string | null>(null);
{ const [classifications, setClassifications] = useState<FlatClassification[]>([]);
path: "/api/v1/orders/{order_id}/cancel",
method: "post",
summary: "Cancel an active Shopify order",
access_type: "write",
agent_group: "Order Specialist",
},
{
path: "/api/v1/orders/{order_id}",
method: "get",
summary: "Retrieve detailed information about an order",
access_type: "read",
agent_group: "Order Specialist",
},
{
path: "/api/v1/payments/{charge_id}/refund",
method: "post",
summary: "Issue a full or partial refund for a charge",
access_type: "admin",
agent_group: "Billing Assistant",
},
{
path: "/api/v1/customers/{email}/discounts",
method: "post",
summary: "Apply a loyalty discount to a customer account",
access_type: "write",
agent_group: "Billing Assistant",
},
{
path: "/api/v1/inventory/check",
method: "get",
summary: "Query realtime stock levels across warehouses",
access_type: "read",
agent_group: "Unassigned",
}
]);
const pollRef = useRef<ReturnType<typeof setTimeout> | null>(null); const pollRef = useRef<ReturnType<typeof setTimeout> | null>(null);
useEffect(() => { useEffect(() => {
@@ -72,20 +42,14 @@ export function ReviewPage() {
}, []); }, []);
function pollJob(jobId: string) { function pollJob(jobId: string) {
fetch(`/api/openapi/jobs/${encodeURIComponent(jobId)}`) fetchImportJob(jobId)
.then((r) => r.json()) .then((j) => {
.then((data) => {
const j: ImportJob = data.data ?? data;
setJob(j); setJob(j);
if (j.status === "done") { if (j.status === "done") {
return fetch(`/api/openapi/jobs/${encodeURIComponent(jobId)}/result`) return fetchClassifications(jobId).then((clfs) => {
.then((r) => r.json()) setClassifications(clfs.map(flattenClassification));
.then((rdata) => { });
const res: JobResult = rdata.data ?? rdata; } else if (j.status === "failed") {
setResult(res);
setClassifications(res.endpoints ?? []);
});
} else if (j.status === "error") {
return; return;
} else { } else {
pollRef.current = setTimeout(() => pollJob(jobId), 2000); pollRef.current = setTimeout(() => pollJob(jobId), 2000);
@@ -101,18 +65,12 @@ export function ReviewPage() {
if (!url.trim()) return; if (!url.trim()) return;
setSubmitting(true); setSubmitting(true);
setSubmitError(null); setSubmitError(null);
setApproveStatus(null);
setJob(null); setJob(null);
setResult(null);
setClassifications([]); setClassifications([]);
fetch("/api/openapi/import", { startImport(url)
method: "POST", .then((j) => {
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ url }),
})
.then((r) => r.json())
.then((data) => {
const j: ImportJob = data.data ?? data;
setJob(j); setJob(j);
if (j.job_id) pollJob(j.job_id); if (j.job_id) pollJob(j.job_id);
}) })
@@ -122,7 +80,7 @@ export function ReviewPage() {
function handleFieldChange( function handleFieldChange(
idx: number, idx: number,
field: keyof EndpointClassification, field: keyof FlatClassification,
value: string value: string
) { ) {
setClassifications((prev) => setClassifications((prev) =>
@@ -132,21 +90,26 @@ export function ReviewPage() {
function handleApprove() { function handleApprove() {
if (!job?.job_id) return; if (!job?.job_id) return;
fetch(`/api/openapi/jobs/${encodeURIComponent(job.job_id)}/approve`, { setApproveStatus(null);
method: "POST", approveJob(job.job_id)
headers: { "Content-Type": "application/json" }, .then((result) => {
body: JSON.stringify({ endpoints: classifications }), setJob(result);
}).then(() => { setApproveStatus(
alert("Approved and saved."); `Configuration saved. ${result.generated_tools_count ?? 0} tools generated.`
}); );
})
.catch((err: Error) => setApproveStatus(`Error: ${err.message}`));
} }
const groupedByAgent = classifications.reduce((acc, c, idx) => { const groupedByAgent = classifications.reduce(
const group = c.agent_group || "Unassigned"; (acc, c, idx) => {
if (!acc[group]) acc[group] = []; const group = c.agent_group || "Unassigned";
acc[group].push({ ...c, originalIdx: idx }); if (!acc[group]) acc[group] = [];
return acc; acc[group].push({ ...c, originalIdx: idx });
}, {} as Record<string, (EndpointClassification & { originalIdx: number })[]>); return acc;
},
{} as Record<string, (FlatClassification & { originalIdx: number })[]>
);
return ( return (
<div className="page-container"> <div className="page-container">
@@ -169,35 +132,105 @@ export function ReviewPage() {
</button> </button>
</form> </form>
{submitError && <div style={{ color: "var(--brand-accent)", marginBottom: "1rem" }}>Error: {submitError}</div>} {submitError && (
<div style={{ color: "var(--brand-accent)", marginBottom: "1rem" }}>
Error: {submitError}
</div>
)}
{job && ( {job && (
<div style={{ padding: "1rem", background: "var(--bg-surface)", border: "1px solid var(--border-light)", borderRadius: "var(--radius-md)", marginBottom: "1.5rem" }}> <div
style={{
padding: "1rem",
background: "var(--bg-surface)",
border: "1px solid var(--border-light)",
borderRadius: "var(--radius-md)",
marginBottom: "1.5rem",
}}
>
<strong>Job:</strong> {job.job_id} &mdash; Status:{" "} <strong>Job:</strong> {job.job_id} &mdash; Status:{" "}
<span style={{ fontWeight: 600, color: job.status === "done" ? "#10b981" : job.status === "error" ? "var(--brand-accent)" : "#f59e0b" }}> <span
style={{
fontWeight: 600,
color:
job.status === "done" || job.status === "approved"
? "#10b981"
: job.status === "failed"
? "var(--brand-accent)"
: "#f59e0b",
}}
>
{job.status} {job.status}
</span> </span>
{job.error && <div style={{ marginTop: "4px", color: "var(--brand-accent)" }}>{job.error}</div>} {job.error_message && (
<div style={{ marginTop: "4px", color: "var(--brand-accent)" }}>
{job.error_message}
</div>
)}
</div>
)}
{approveStatus && (
<div
style={{
padding: "0.75rem 1rem",
background: approveStatus.startsWith("Error")
? "#fef2f2"
: "#f0fdf4",
border: `1px solid ${approveStatus.startsWith("Error") ? "#fecaca" : "#bbf7d0"}`,
borderRadius: "var(--radius-md)",
marginBottom: "1rem",
fontSize: "0.875rem",
}}
>
{approveStatus}
</div> </div>
)} )}
{classifications.length > 0 && ( {classifications.length > 0 && (
<> <>
<div style={{ display: "flex", justifyContent: "space-between", alignItems: "center", marginBottom: "1rem" }}> <div
<div> style={{
<h3 style={{ margin: 0, fontSize: "1.25rem", color: "var(--text-primary)" }}>Assigned Capabilities ({classifications.length})</h3> display: "flex",
<p style={{ margin: "0.25rem 0 0 0", fontSize: "0.875rem", color: "var(--text-secondary)" }}>Grouped by target Agent.</p> justifyContent: "space-between",
</div> alignItems: "center",
<button onClick={handleApprove} className="btn btn-primary"> marginBottom: "1rem",
Save Configuration }}
</button> >
<div>
<h3
style={{
margin: 0,
fontSize: "1.25rem",
color: "var(--text-primary)",
}}
>
Assigned Capabilities ({classifications.length})
</h3>
<p
style={{
margin: "0.25rem 0 0 0",
fontSize: "0.875rem",
color: "var(--text-secondary)",
}}
>
Grouped by target Agent.
</p>
</div>
<button onClick={handleApprove} className="btn btn-primary">
Save Configuration
</button>
</div> </div>
<div className="agent-grid"> <div className="agent-grid">
{Object.entries(groupedByAgent).map(([groupName, tools]) => ( {Object.entries(groupedByAgent).map(([groupName, tools]) => (
<div key={groupName} className="agent-grid-card"> <div key={groupName} className="agent-grid-card">
<div className="agent-card-header-bg"> <div className="agent-card-header-bg">
<div className="agent-avatar-lg">{groupName === "Unassigned" ? "?" : groupName.charAt(0).toUpperCase()}</div> <div className="agent-avatar-lg">
{groupName === "Unassigned"
? "?"
: groupName.charAt(0).toUpperCase()}
</div>
<div className="agent-card-meta"> <div className="agent-card-meta">
<h3>{groupName}</h3> <h3>{groupName}</h3>
<span>{tools.length} Attached Tools</span> <span>{tools.length} Attached Tools</span>
@@ -207,26 +240,51 @@ export function ReviewPage() {
{tools.map((t) => ( {tools.map((t) => (
<div key={t.originalIdx} className="tool-pill-item"> <div key={t.originalIdx} className="tool-pill-item">
<div className="tool-pill-header"> <div className="tool-pill-header">
<span className="tool-method-badge" style={{ background: t.method === "get" ? "#3b82f6" : t.method === "post" ? "#10b981" : t.method === "delete" ? "#ef4444" : "#f59e0b" }}> <span
className="tool-method-badge"
style={{
background:
t.method === "get"
? "#3b82f6"
: t.method === "post"
? "#10b981"
: t.method === "delete"
? "#ef4444"
: "#f59e0b",
}}
>
{t.method} {t.method}
</span> </span>
<span className="tool-path-text" title={t.path}>{t.path}</span> <span className="tool-path-text" title={t.path}>
{t.path}
</span>
</div> </div>
<div className="tool-summary-text">{t.summary}</div> <div className="tool-summary-text">{t.summary}</div>
<div className="tool-pill-controls"> <div className="tool-pill-controls">
<select <select
value={t.access_type} value={t.access_type}
onChange={(e) => handleFieldChange(t.originalIdx, "access_type", e.target.value)} onChange={(e) =>
handleFieldChange(
t.originalIdx,
"access_type",
e.target.value
)
}
className="tool-select" className="tool-select"
> >
<option value="read">Read Only</option> <option value="read">Read Only</option>
<option value="write">Write (Confirm)</option> <option value="write">Write (Confirm)</option>
<option value="admin">Admin</option>
</select> </select>
<input <input
type="text" type="text"
value={t.agent_group} value={t.agent_group}
onChange={(e) => handleFieldChange(t.originalIdx, "agent_group", e.target.value)} onChange={(e) =>
handleFieldChange(
t.originalIdx,
"agent_group",
e.target.value
)
}
className="tool-input" className="tool-input"
placeholder="Agent Name" placeholder="Agent Name"
/> />

View File

@@ -0,0 +1 @@
import "@testing-library/jest-dom/vitest";

View File

@@ -39,13 +39,28 @@ export interface ErrorMessage {
message: string; message: string;
} }
export interface ClarificationMessage {
type: "clarification";
thread_id: string;
message: string;
}
export interface InterruptExpiredMessage {
type: "interrupt_expired";
thread_id: string;
action: string;
message: string;
}
export type ServerMessage = export type ServerMessage =
| TokenMessage | TokenMessage
| InterruptMessage | InterruptMessage
| ToolCallMessage | ToolCallMessage
| ToolResultMessage | ToolResultMessage
| MessageCompleteMessage | MessageCompleteMessage
| ErrorMessage; | ErrorMessage
| ClarificationMessage
| InterruptExpiredMessage;
// -- Client -> Server messages -- // -- Client -> Server messages --

View File

@@ -1 +1 @@
{"root":["./src/app.tsx","./src/api.ts","./src/main.tsx","./src/types.ts","./src/components/agentaction.tsx","./src/components/chatinput.tsx","./src/components/chatmessages.tsx","./src/components/errorbanner.tsx","./src/components/interruptprompt.tsx","./src/components/layout.tsx","./src/components/metriccard.tsx","./src/components/navbar.tsx","./src/components/replaytimeline.tsx","./src/hooks/usewebsocket.ts","./src/pages/chatpage.tsx","./src/pages/dashboardpage.tsx","./src/pages/replaylistpage.tsx","./src/pages/replaypage.tsx","./src/pages/reviewpage.tsx"],"version":"5.7.3"} {"root":["./src/app.tsx","./src/api.test.ts","./src/api.ts","./src/main.tsx","./src/test-setup.ts","./src/types.ts","./src/vite-env.d.ts","./src/components/agentaction.tsx","./src/components/chatinput.tsx","./src/components/chatmessages.tsx","./src/components/errorbanner.tsx","./src/components/interruptprompt.tsx","./src/components/layout.tsx","./src/components/metriccard.tsx","./src/components/navbar.tsx","./src/components/replaytimeline.tsx","./src/hooks/usewebsocket.ts","./src/pages/chatpage.tsx","./src/pages/dashboardpage.test.tsx","./src/pages/dashboardpage.tsx","./src/pages/replaylistpage.test.tsx","./src/pages/replaylistpage.tsx","./src/pages/replaypage.test.tsx","./src/pages/replaypage.tsx","./src/pages/reviewpage.tsx"],"version":"5.7.3"}

View File

@@ -1,8 +1,14 @@
/// <reference types="vitest" />
import react from "@vitejs/plugin-react"; import react from "@vitejs/plugin-react";
import { defineConfig } from "vite"; import { defineConfig } from "vite";
export default defineConfig({ export default defineConfig({
plugins: [react()], plugins: [react()],
test: {
environment: "happy-dom",
globals: true,
setupFiles: ["./src/test-setup.ts"],
},
server: { server: {
port: 5173, port: 5173,
proxy: { proxy: {