refactor: engineering improvements -- API versioning, structured logging, Alembic, error standardization, test coverage

- API versioning: all REST endpoints prefixed with /api/v1/
- Structured logging: replaced stdlib logging with structlog (console/JSON modes)
- Alembic migrations: versioned DB schema with initial migration
- Error standardization: global exception handlers for consistent envelope format
- Interrupt cleanup: asyncio background task for expired interrupt removal
- Integration tests: +30 tests (analytics, replay, openapi, error, session APIs)
- Frontend tests: +57 tests (all components, pages, useWebSocket hook)
- Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
This commit is contained in:
Yaojia Wang
2026-04-06 23:19:29 +02:00
parent af53111928
commit f0699436c5
59 changed files with 2846 additions and 149 deletions

View File

@@ -26,6 +26,10 @@ WEBHOOK_URL=
SESSION_TTL_MINUTES=30 SESSION_TTL_MINUTES=30
INTERRUPT_TTL_MINUTES=30 INTERRUPT_TTL_MINUTES=30
# Optional: API key for admin endpoints (analytics, replay, openapi, websocket)
# Leave empty to disable authentication (dev mode)
ADMIN_API_KEY=
# Optional: load a named agent template instead of agents.yaml # Optional: load a named agent template instead of agents.yaml
# Available templates: ecommerce, saas, generic # Available templates: ecommerce, saas, generic
TEMPLATE_NAME= TEMPLATE_NAME=

View File

@@ -30,7 +30,7 @@ pytest --cov=app --cov-report=term-missing
# - If any test fails, fix it before starting the new phase # - If any test fails, fix it before starting the new phase
# 3. Create checkpoint to snapshot the starting state # 3. Create checkpoint to snapshot the starting state
/everything-claude-code:checkpoint create "phase-name" /ecc:checkpoint create "phase-name"
# 4. Create the phase branch # 4. Create the phase branch
git checkout main git checkout main
@@ -50,25 +50,32 @@ git checkout -b phase-{N}/{short-description}
3. Identify all tasks, acceptance criteria, and dependencies for this phase 3. Identify all tasks, acceptance criteria, and dependencies for this phase
4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5) 4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5)
### Step 2: Develop Using Orchestrate Skill ### Step 2: Develop Using ECC Skills
Route to the correct orchestration mode based on work type: Route to the correct skill based on work type:
| Work Type | Skill Command | | Work Type | Skill Command | What It Does |
|-----------|---------------| |-----------|---------------|--------------|
| New feature | `/everything-claude-code:orchestrate feature` | | New feature | `/ecc:feature-dev <desc>` | Discovery -> Exploration -> Architecture -> TDD -> Review -> Summary |
| Bug fix | `/everything-claude-code:orchestrate bugfix` | | Bug fix | `/ecc:tdd` then `/ecc:code-review` | RED -> GREEN -> REFACTOR cycle, then review |
| Refactor | `/everything-claude-code:orchestrate refactor` | | Refactor | `/ecc:plan` then `/ecc:tdd` then `/ecc:code-review` | Plan refactor scope, TDD, review |
| Security-sensitive | Add `/ecc:security-review` after code-review | Auth, payments, user input, external APIs |
| Final verification | `/ecc:verify` | Build + tests + lint + coverage + security scan |
ALWAYS use the appropriate orchestrate skill. Never develop without it. A single phase may contain mixed work types. Call the appropriate skill **per sub-task**:
A single phase may contain mixed work types (e.g., Phase 5 has feature + bugfix + refactor). Call the orchestrate skill **per sub-task** with the matching mode. Example:
``` ```
# Within Phase 5: # Within a phase:
/everything-claude-code:orchestrate feature # for demo script /ecc:feature-dev "demo script" # for new features
/everything-claude-code:orchestrate bugfix # for error handling fixes /ecc:tdd # for bug fixes (write failing test, then fix)
/everything-claude-code:orchestrate refactor # for code cleanup /ecc:plan "consolidate error handling" # for refactors (plan first, then TDD)
```
For full multi-phase autonomous execution, use GSD:
```
/gsd:autonomous # execute all remaining phases
/gsd:execute-phase 6 # execute a specific phase
``` ```
### Step 3: Module Independence (CRITICAL) ### Step 3: Module Independence (CRITICAL)
@@ -171,10 +178,10 @@ After all development and testing, run verification in this exact order:
``` ```
# 1. Run the verification skill -- must pass # 1. Run the verification skill -- must pass
/everything-claude-code:verify /ecc:verify
# 2. Verify the checkpoint -- validates all phase deliverables # 2. Verify the checkpoint -- validates all phase deliverables
/everything-claude-code:checkpoint verify "phase-name" /ecc:checkpoint verify "phase-name"
``` ```
The checkpoint verify validates: The checkpoint verify validates:
@@ -222,11 +229,11 @@ git push origin main --tags
All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy. All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy.
A checkpoint includes: A checkpoint includes:
- `/everything-claude-code:checkpoint create` at phase start - `/ecc:checkpoint create` at phase start
- `/everything-claude-code:checkpoint verify` at phase end - `/ecc:checkpoint verify` at phase end
- All tests passing (80%+ coverage) - All tests passing (80%+ coverage)
- Phase dev log written and linked - Phase dev log written and linked
- `/everything-claude-code:verify` passed - `/ecc:verify` passed
- Git tag `checkpoint/phase-{N}` created - Git tag `checkpoint/phase-{N}` created
- Phase marked COMPLETED in four locations - Phase marked COMPLETED in four locations
- Branch merged to main - Branch merged to main
@@ -264,7 +271,7 @@ This project inherits from `~/.claude/rules/`. CLAUDE.md only contains project-s
### Hooks (ECC Plugin -- No Custom Hooks) ### Hooks (ECC Plugin -- No Custom Hooks)
All hooks come from the ECC plugin (`everything-claude-code`). No project-level hooks in `.claude/settings.local.json`. All hooks come from the ECC plugin (`ecc`). No project-level hooks in `.claude/settings.local.json`.
| ECC Hook | Type | What It Does | | ECC Hook | Type | What It Does |
|----------|------|-------------| |----------|------|-------------|
@@ -290,7 +297,7 @@ Controlled by `ECC_HOOK_PROFILE` env var in `~/.claude/settings.json` (currently
- Architecture doc: `docs/ARCHITECTURE.md` - Architecture doc: `docs/ARCHITECTURE.md`
- Phase dev logs: `docs/phases/phase-{N}-dev-log.md` - Phase dev logs: `docs/phases/phase-{N}-dev-log.md`
- Test command: `pytest --cov=app --cov-report=term-missing` - Test command: `pytest --cov=app --cov-report=term-missing`
- **Phase start:** `/everything-claude-code:checkpoint create "phase-name"` - **Phase start:** `/ecc:checkpoint create "phase-name"`
- **Phase end:** `/everything-claude-code:checkpoint verify "phase-name"` - **Phase end:** `/ecc:checkpoint verify "phase-name"`
- Verify command: `/everything-claude-code:verify` - Verify command: `/ecc:verify`
- Orchestrate: `/everything-claude-code:orchestrate {feature|bugfix|refactor}` - Orchestrate: `/ecc:orchestrate {feature|bugfix|refactor}`

View File

@@ -99,8 +99,12 @@ smart-support/
├── backend/ ├── backend/
│ ├── app/ │ ├── app/
│ │ ├── main.py # FastAPI + WebSocket entry point │ │ ├── main.py # FastAPI + WebSocket entry point
│ │ ├── graph.py # LangGraph Supervisor │ │ ├── graph.py # LangGraph Supervisor construction
│ │ ├── graph_context.py # Typed wrapper for graph + classifier + registry
│ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting │ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting
│ │ ├── ws_context.py # WebSocket dependency bundle
│ │ ├── auth.py # API key authentication middleware
│ │ ├── api_utils.py # Shared API response helpers
│ │ ├── safety.py # Confirmation rules + MCP error taxonomy │ │ ├── safety.py # Confirmation rules + MCP error taxonomy
│ │ ├── agents/ # Agent definitions and tools │ │ ├── agents/ # Agent definitions and tools
│ │ ├── registry.py # YAML agent registry loader │ │ ├── registry.py # YAML agent registry loader
@@ -124,18 +128,21 @@ smart-support/
## API Endpoints ## API Endpoints
| Method | Path | Description | | Method | Path | Auth | Description |
|--------|------|-------------| |--------|------|------|-------------|
| WS | `/ws` | Main WebSocket chat endpoint | | WS | `/ws` | Token | Main WebSocket chat endpoint (`?token=<key>`) |
| GET | `/api/health` | Health check | | GET | `/api/health` | No | Health check |
| GET | `/api/conversations` | List conversations (paginated) | | GET | `/api/conversations` | API Key | List conversations (paginated) |
| GET | `/api/replay/{thread_id}` | Replay conversation steps (paginated) | | GET | `/api/replay/{thread_id}` | API Key | Replay conversation steps (paginated) |
| GET | `/api/analytics` | Analytics summary (`?range=7d`) | | GET | `/api/analytics` | API Key | Analytics summary (`?range=7d`) |
| POST | `/api/openapi/import` | Start OpenAPI import job | | POST | `/api/openapi/import` | API Key | Start OpenAPI import job |
| GET | `/api/openapi/jobs/{id}` | Check import job status | | GET | `/api/openapi/jobs/{id}` | API Key | Check import job status |
| GET | `/api/openapi/jobs/{id}/classifications` | Get endpoint classifications | | GET | `/api/openapi/jobs/{id}/classifications` | API Key | Get endpoint classifications |
| PUT | `/api/openapi/jobs/{id}/classifications/{idx}` | Update a classification | | PUT | `/api/openapi/jobs/{id}/classifications/{idx}` | API Key | Update a classification |
| POST | `/api/openapi/jobs/{id}/approve` | Approve and generate tools | | POST | `/api/openapi/jobs/{id}/approve` | API Key | Approve and generate tools |
Authentication is controlled by the `ADMIN_API_KEY` environment variable.
API Key endpoints require the `X-API-Key` header. When `ADMIN_API_KEY` is unset, auth is disabled.
## Running Tests ## Running Tests

149
backend/alembic.ini Normal file
View File

@@ -0,0 +1,149 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the tzdata library which can be installed by adding
# `alembic[tz]` to the pip requirements.
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url =
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
backend/alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

67
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,67 @@
"""Alembic environment configuration for smart-support."""
from __future__ import annotations
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# No SQLAlchemy ORM models -- we use raw DDL migrations
target_metadata = None
def _get_url() -> str:
"""Read DATABASE_URL from environment, falling back to alembic.ini."""
return os.environ.get("DATABASE_URL", "") or config.get_main_option(
"sqlalchemy.url", ""
)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
Configures the context with just a URL so that an Engine
is not required.
"""
url = _get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode with a live database connection."""
configuration = config.get_section(config.config_ini_section, {})
configuration["sqlalchemy.url"] = _get_url()
connectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,92 @@
"""Initial schema -- all application tables.
Revision ID: a1b2c3d4e5f6
Revises:
Create Date: 2026-04-06
"""
from __future__ import annotations
from alembic import op
revision: str = "a1b2c3d4e5f6"
down_revision: str | None = None
branch_labels: tuple[str, ...] | None = None
depends_on: tuple[str, ...] | None = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE IF NOT EXISTS conversations (
thread_id TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
total_tokens INTEGER NOT NULL DEFAULT 0,
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
status TEXT NOT NULL DEFAULT 'active'
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS active_interrupts (
interrupt_id TEXT PRIMARY KEY,
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
action TEXT NOT NULL,
params JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
resolved_at TIMESTAMPTZ,
resolution TEXT
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS sessions (
thread_id TEXT PRIMARY KEY,
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS analytics_events (
id BIGSERIAL PRIMARY KEY,
thread_id TEXT NOT NULL,
event_type TEXT NOT NULL,
agent_name TEXT,
tool_name TEXT,
tokens_used INTEGER NOT NULL DEFAULT 0,
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
duration_ms INTEGER,
success BOOLEAN,
error_message TEXT,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"""
)
# Migration columns added in Phase 4
op.execute(
"""
ALTER TABLE conversations
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ
"""
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS analytics_events")
op.execute("DROP TABLE IF EXISTS sessions")
op.execute("DROP TABLE IF EXISTS active_interrupts")
op.execute("DROP TABLE IF EXISTS conversations")

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
router = APIRouter( router = APIRouter(
prefix="/api/analytics", prefix="/api/v1/analytics",
tags=["analytics"], tags=["analytics"],
dependencies=[Depends(require_admin_api_key)], dependencies=[Depends(require_admin_api_key)],
) )

View File

@@ -2,14 +2,14 @@
from __future__ import annotations from __future__ import annotations
import logging
import secrets import secrets
from typing import Annotated from typing import Annotated
import structlog
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
logger = logging.getLogger(__name__) logger = structlog.get_logger()
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False) _API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)

View File

@@ -32,6 +32,8 @@ class Settings(BaseSettings):
template_name: str = "" template_name: str = ""
log_format: str = "console" # "console" for dev, "json" for production
admin_api_key: str = "" admin_api_key: str = ""
anthropic_api_key: str = "" anthropic_api_key: str = ""

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
@@ -88,6 +89,17 @@ async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
return checkpointer return checkpointer
def run_alembic_migrations(database_url: str) -> None:
"""Run Alembic migrations to head."""
from alembic.config import Config
from alembic import command
alembic_cfg = Config(str(Path(__file__).parent.parent / "alembic.ini"))
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
command.upgrade(alembic_cfg, "head")
async def setup_app_tables(pool: AsyncConnectionPool) -> None: async def setup_app_tables(pool: AsyncConnectionPool) -> None:
"""Create application-specific tables and apply migrations.""" """Create application-specific tables and apply migrations."""
async with pool.connection() as conn: async with pool.connection() as conn:

View File

@@ -3,14 +3,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Protocol from typing import Protocol
import httpx import httpx
import structlog
from pydantic import BaseModel from pydantic import BaseModel
logger = logging.getLogger(__name__) logger = structlog.get_logger()
class EscalationPayload(BaseModel, frozen=True): class EscalationPayload(BaseModel, frozen=True):

View File

@@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from langchain.agents import create_agent from langchain.agents import create_agent
@@ -18,7 +17,9 @@ if TYPE_CHECKING:
from app.intent import IntentClassifier from app.intent import IntentClassifier
from app.registry import AgentRegistry from app.registry import AgentRegistry
logger = logging.getLogger(__name__) import structlog
logger = structlog.get_logger()
SUPERVISOR_PROMPT = ( SUPERVISOR_PROMPT = (
"You are a customer support supervisor. " "You are a customer support supervisor. "

View File

@@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Protocol from typing import TYPE_CHECKING, Protocol
from pydantic import BaseModel from pydantic import BaseModel
@@ -12,7 +11,9 @@ if TYPE_CHECKING:
from app.registry import AgentConfig from app.registry import AgentConfig
logger = logging.getLogger(__name__) import structlog
logger = structlog.get_logger()
CLASSIFICATION_PROMPT = ( CLASSIFICATION_PROMPT = (
"You are an intent classifier for a customer support system.\n" "You are an intent classifier for a customer support system.\n"

View File

@@ -0,0 +1,57 @@
"""Structured logging configuration using structlog."""
from __future__ import annotations
import logging
import sys
import structlog
def configure_logging(log_format: str = "console") -> None:
"""Configure structlog with stdlib integration.
Args:
log_format: "console" for human-readable dev output,
"json" for machine-parseable production output.
"""
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
]
if log_format == "json":
renderer: structlog.types.Processor = structlog.processors.JSONRenderer()
else:
renderer = structlog.dev.ConsoleRenderer()
structlog.configure(
processors=[
*shared_processors,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
renderer,
],
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)

View File

@@ -2,25 +2,30 @@
from __future__ import annotations from __future__ import annotations
import logging import asyncio
import contextlib
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from fastapi import Depends, FastAPI, Query, WebSocket, WebSocketDisconnect from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from app.analytics.api import router as analytics_router from app.analytics.api import router as analytics_router
from app.analytics.event_recorder import PostgresAnalyticsRecorder from app.analytics.event_recorder import PostgresAnalyticsRecorder
from app.api_utils import envelope
from app.callbacks import TokenUsageCallbackHandler from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings from app.config import Settings
from app.conversation_tracker import PostgresConversationTracker from app.conversation_tracker import PostgresConversationTracker
from app.db import create_checkpointer, create_pool, setup_app_tables from app.db import create_checkpointer, create_pool, run_alembic_migrations
from app.escalation import NoOpEscalator, WebhookEscalator from app.escalation import NoOpEscalator, WebhookEscalator
from app.graph import build_graph from app.graph import build_graph
from app.intent import LLMIntentClassifier from app.intent import LLMIntentClassifier
from app.interrupt_manager import InterruptManager from app.interrupt_manager import InterruptManager
from app.llm import create_llm from app.llm import create_llm
from app.logging_config import configure_logging
from app.openapi.review_api import router as openapi_router from app.openapi.review_api import router as openapi_router
from app.registry import AgentRegistry from app.registry import AgentRegistry
from app.replay.api import router as replay_router from app.replay.api import router as replay_router
@@ -31,19 +36,44 @@ from app.ws_handler import dispatch_message
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
logger = logging.getLogger(__name__) import structlog
logger = structlog.get_logger()
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml" AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist" FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
async def _interrupt_cleanup_loop(
interrupt_manager: InterruptManager,
interval: int = 60,
) -> None:
"""Periodically remove expired interrupts in the background.
Runs until cancelled. Catches all exceptions to prevent the task
from dying unexpectedly.
"""
while True:
await asyncio.sleep(interval)
try:
expired = interrupt_manager.cleanup_expired()
if expired:
logger.info(
"Cleaned up %d expired interrupt(s)",
len(expired),
)
except Exception:
logger.exception("Error during interrupt cleanup")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
settings = Settings() settings = Settings()
configure_logging(settings.log_format)
pool = await create_pool(settings) pool = await create_pool(settings)
checkpointer = await create_checkpointer(pool) checkpointer = await create_checkpointer(pool)
await setup_app_tables(pool) run_alembic_migrations(settings.database_url)
# Load agents from template or default YAML # Load agents from template or default YAML
if settings.template_name: if settings.template_name:
@@ -89,8 +119,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
settings.template_name or "(default)", settings.template_name or "(default)",
) )
cleanup_task = asyncio.create_task(
_interrupt_cleanup_loop(interrupt_manager),
)
yield yield
cleanup_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await cleanup_task
await pool.close() await pool.close()
@@ -103,7 +141,35 @@ app.include_router(replay_router)
app.include_router(analytics_router) app.include_router(analytics_router)
@app.get("/api/health") @app.exception_handler(HTTPException)
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap HTTPException in standard envelope format."""
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap validation errors in standard envelope format."""
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Catch-all handler -- never leak stack traces."""
logger.exception("Unhandled exception: %s", exc)
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
@app.get("/api/v1/health")
def health_check() -> dict: def health_check() -> dict:
"""Health check endpoint for load balancers and monitoring.""" """Health check endpoint for load balancers and monitoring."""
return {"status": "ok", "version": _VERSION} return {"status": "ok", "version": _VERSION}

View File

@@ -8,13 +8,14 @@ classifier and an LLM-backed classifier with heuristic fallback.
from __future__ import annotations from __future__ import annotations
import json import json
import logging
import re import re
from typing import Protocol from typing import Protocol
import structlog
from app.openapi.models import ClassificationResult, EndpointInfo from app.openapi.models import ClassificationResult, EndpointInfo
logger = logging.getLogger(__name__) logger = structlog.get_logger()
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) _WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) _INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})

View File

@@ -6,10 +6,11 @@ Each stage updates the job status and calls the on_progress callback.
from __future__ import annotations from __future__ import annotations
import logging
from collections.abc import Callable from collections.abc import Callable
from dataclasses import replace from dataclasses import replace
import structlog
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
from app.openapi.fetcher import fetch_spec from app.openapi.fetcher import fetch_spec
from app.openapi.models import ImportJob from app.openapi.models import ImportJob
@@ -17,7 +18,7 @@ from app.openapi.parser import parse_endpoints
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
from app.openapi.validator import validate_spec from app.openapi.validator import validate_spec
logger = logging.getLogger(__name__) logger = structlog.get_logger()
ProgressCallback = Callable[[str, ImportJob], None] | None ProgressCallback = Callable[[str, ImportJob], None] | None

View File

@@ -10,11 +10,11 @@ Exposes endpoints for:
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging
import re import re
import uuid import uuid
from typing import Literal from typing import Literal
import structlog
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
@@ -23,10 +23,10 @@ from app.openapi.generator import generate_agent_yaml, generate_tool_code
from app.openapi.importer import ImportOrchestrator from app.openapi.importer import ImportOrchestrator
from app.openapi.models import ClassificationResult, ImportJob from app.openapi.models import ClassificationResult, ImportJob
logger = logging.getLogger(__name__) logger = structlog.get_logger()
router = APIRouter( router = APIRouter(
prefix="/api/openapi", prefix="/api/v1/openapi",
tags=["openapi"], tags=["openapi"],
dependencies=[Depends(require_admin_api_key)], dependencies=[Depends(require_admin_api_key)],
) )

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
router = APIRouter( router = APIRouter(
prefix="/api", prefix="/api/v1",
tags=["replay"], tags=["replay"],
dependencies=[Depends(require_admin_api_key)], dependencies=[Depends(require_admin_api_key)],
) )

View File

@@ -2,11 +2,11 @@
from __future__ import annotations from __future__ import annotations
import logging import structlog
from app.replay.models import ReplayStep, StepType from app.replay.models import ReplayStep, StepType
logger = logging.getLogger(__name__) logger = structlog.get_logger()
_EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z" _EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z"

View File

@@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import json import json
import logging
import re import re
import time import time
from collections import defaultdict from collections import defaultdict
@@ -21,7 +20,9 @@ if TYPE_CHECKING:
from app.session_manager import SessionManager from app.session_manager import SessionManager
from app.ws_context import WebSocketContext from app.ws_context import WebSocketContext
logger = logging.getLogger(__name__) import structlog
logger = structlog.get_logger()
MAX_MESSAGE_SIZE = 32_768 # 32 KB MAX_MESSAGE_SIZE = 32_768 # 32 KB
MAX_CONTENT_LENGTH = 10_000 # characters MAX_CONTENT_LENGTH = 10_000 # characters

View File

@@ -21,6 +21,8 @@ dependencies = [
"python-dotenv>=1.0,<2.0", "python-dotenv>=1.0,<2.0",
"httpx>=0.28,<1.0", "httpx>=0.28,<1.0",
"openapi-spec-validator>=0.7,<1.0", "openapi-spec-validator>=0.7,<1.0",
"alembic>=1.13,<2.0",
"structlog>=24.0,<26.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -174,7 +174,7 @@ def create_e2e_app(
app.state.analytics_recorder = AsyncMock() app.state.analytics_recorder = AsyncMock()
app.state.conversation_tracker = AsyncMock() app.state.conversation_tracker = AsyncMock()
@app.get("/api/health") @app.get("/api/v1/health")
def health_check() -> dict: def health_check() -> dict:
return {"status": "ok", "version": "test"} return {"status": "ok", "version": "test"}

View File

@@ -341,7 +341,7 @@ class TestChatEdgeCases:
def test_health_endpoint(self) -> None: def test_health_endpoint(self) -> None:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/health") resp = client.get("/api/v1/health")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["status"] == "ok" assert resp.json()["status"] == "ok"

View File

@@ -62,7 +62,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client: with TestClient(app) as client:
# Step 1: Start import job # Step 1: Start import job
resp = client.post( resp = client.post(
"/api/openapi/import", "/api/v1/openapi/import",
json={"url": "https://api.example.com/openapi.json"}, json={"url": "https://api.example.com/openapi.json"},
) )
assert resp.status_code == 202 assert resp.status_code == 202
@@ -71,7 +71,7 @@ class TestFlow5OpenAPIImport:
job_id = body["job_id"] job_id = body["job_id"]
# Step 2: Check job status (still pending since background task hasn't run) # Step 2: Check job status (still pending since background task hasn't run)
resp = client.get(f"/api/openapi/jobs/{job_id}") resp = client.get(f"/api/v1/openapi/jobs/{job_id}")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["job_id"] == job_id assert resp.json()["job_id"] == job_id
@@ -99,7 +99,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client: with TestClient(app) as client:
# Step 1: Get classifications # Step 1: Get classifications
resp = client.get(f"/api/openapi/jobs/{job_id}/classifications") resp = client.get(f"/api/v1/openapi/jobs/{job_id}/classifications")
assert resp.status_code == 200 assert resp.status_code == 200
classifications = resp.json() classifications = resp.json()
assert len(classifications) == 2 assert len(classifications) == 2
@@ -118,7 +118,7 @@ class TestFlow5OpenAPIImport:
# Step 2: Update a classification # Step 2: Update a classification
resp = client.put( resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/0", f"/api/v1/openapi/jobs/{job_id}/classifications/0",
json={ json={
"access_type": "write", "access_type": "write",
"needs_interrupt": True, "needs_interrupt": True,
@@ -132,7 +132,7 @@ class TestFlow5OpenAPIImport:
assert updated["agent_group"] == "order_actions" assert updated["agent_group"] == "order_actions"
# Step 3: Approve the job # Step 3: Approve the job
resp = client.post(f"/api/openapi/jobs/{job_id}/approve") resp = client.post(f"/api/v1/openapi/jobs/{job_id}/approve")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["status"] == "approved" assert resp.json()["status"] == "approved"
@@ -140,14 +140,14 @@ class TestFlow5OpenAPIImport:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/openapi/jobs/nonexistent") resp = client.get("/api/v1/openapi/jobs/nonexistent")
assert resp.status_code == 404 assert resp.status_code == 404
def test_import_invalid_url_returns_422(self) -> None: def test_import_invalid_url_returns_422(self) -> None:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.post("/api/openapi/import", json={"url": "not-a-url"}) resp = client.post("/api/v1/openapi/import", json={"url": "not-a-url"})
assert resp.status_code == 422 assert resp.status_code == 422
def test_classification_index_out_of_range(self) -> None: def test_classification_index_out_of_range(self) -> None:
@@ -166,7 +166,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client: with TestClient(app) as client:
resp = client.put( resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/99", f"/api/v1/openapi/jobs/{job_id}/classifications/99",
json={ json={
"access_type": "read", "access_type": "read",
"needs_interrupt": False, "needs_interrupt": False,
@@ -191,7 +191,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client: with TestClient(app) as client:
resp = client.put( resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/0", f"/api/v1/openapi/jobs/{job_id}/classifications/0",
json={ json={
"access_type": "read", "access_type": "read",
"needs_interrupt": False, "needs_interrupt": False,

View File

@@ -98,7 +98,7 @@ class TestFlow6ReplayConversation:
app = create_e2e_app(pool=pool) app = create_e2e_app(pool=pool)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
@@ -124,7 +124,7 @@ class TestFlow6ReplayConversation:
app = create_e2e_app(pool=pool) app = create_e2e_app(pool=pool)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations", params={"page": 1, "per_page": 2}) resp = client.get("/api/v1/conversations", params={"page": 1, "per_page": 2})
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
@@ -139,7 +139,7 @@ class TestFlow6ReplayConversation:
app = create_e2e_app(pool=pool) app = create_e2e_app(pool=pool)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/nonexistent-thread") resp = client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404 assert resp.status_code == 404
def test_replay_invalid_thread_id_format(self) -> None: def test_replay_invalid_thread_id_format(self) -> None:
@@ -147,7 +147,7 @@ class TestFlow6ReplayConversation:
with TestClient(app) as client: with TestClient(app) as client:
# Thread ID with special chars fails regex validation # Thread ID with special chars fails regex validation
resp = client.get("/api/replay/invalid%20thread%21%40") resp = client.get("/api/v1/replay/invalid%20thread%21%40")
assert resp.status_code == 400 assert resp.status_code == 400
@@ -158,21 +158,21 @@ class TestAnalyticsDashboard:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "invalid"}) resp = client.get("/api/v1/analytics", params={"range": "invalid"})
assert resp.status_code == 400 assert resp.status_code == 400
def test_analytics_range_too_large(self) -> None: def test_analytics_range_too_large(self) -> None:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "999d"}) resp = client.get("/api/v1/analytics", params={"range": "999d"})
assert resp.status_code == 400 assert resp.status_code == 400
def test_analytics_range_zero_rejected(self) -> None: def test_analytics_range_zero_rejected(self) -> None:
app = create_e2e_app() app = create_e2e_app()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "0d"}) resp = client.get("/api/v1/analytics", params={"range": "0d"})
assert resp.status_code == 400 assert resp.status_code == 400
@@ -216,7 +216,7 @@ class TestFullUserJourney:
assert any(m["type"] == "message_complete" for m in messages) assert any(m["type"] == "message_complete" for m in messages)
# Step 2: Check conversations endpoint # Step 2: Check conversations endpoint
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
@@ -226,5 +226,5 @@ class TestFullUserJourney:
) )
# Step 3: Health check still works # Step 3: Health check still works
resp = client.get("/api/health") resp = client.get("/api/v1/health")
assert resp.status_code == 200 assert resp.status_code == 200

View File

@@ -0,0 +1,183 @@
"""Integration tests for the /api/v1/analytics endpoint.
Tests the full API layer (routing, parameter validation, serialization,
error handling) with a mocked database pool.
"""
from __future__ import annotations
from dataclasses import asdict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from app.analytics.models import AnalyticsResult, InterruptStats
pytestmark = pytest.mark.integration
_SAMPLE_RESULT = AnalyticsResult(
range="7d",
total_conversations=42,
resolution_rate=0.85,
escalation_rate=0.05,
avg_turns_per_conversation=3.2,
avg_cost_per_conversation_usd=0.012,
agent_usage=(),
interrupt_stats=InterruptStats(total=10, approved=7, rejected=2, expired=1),
)
def _build_app():
"""Build a minimal FastAPI app with the analytics router and mocked deps."""
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.analytics.api import router as analytics_router
from app.api_utils import envelope
test_app = FastAPI()
test_app.include_router(analytics_router)
@test_app.exception_handler(Exception)
async def _catch_all(request, exc):
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
from fastapi import HTTPException
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
# No admin_api_key set -> auth is skipped
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = MagicMock()
return test_app
class TestAnalyticsValidRange:
"""Test analytics endpoint with valid range parameters."""
async def test_valid_range_7d_returns_envelope(self) -> None:
"""GET /api/v1/analytics?range=7d returns success envelope with data."""
test_app = _build_app()
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=_SAMPLE_RESULT,
):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "7d"})
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["error"] is None
assert body["data"]["total_conversations"] == 42
assert body["data"]["resolution_rate"] == 0.85
async def test_default_range_returns_success(self) -> None:
"""GET /api/v1/analytics with no range param defaults to 7d."""
test_app = _build_app()
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=_SAMPLE_RESULT,
) as mock_get:
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics")
assert resp.status_code == 200
# Verify default range of 7 days was passed
mock_get.assert_called_once()
call_args = mock_get.call_args
assert call_args[1].get("range_days", call_args[0][1] if len(call_args[0]) > 1 else None) in (7, None) or call_args[0][1] == 7
async def test_large_range_365d_works(self) -> None:
"""GET /api/v1/analytics?range=365d is accepted (max boundary)."""
test_app = _build_app()
result = AnalyticsResult(
range="365d",
total_conversations=1000,
resolution_rate=0.9,
escalation_rate=0.02,
avg_turns_per_conversation=4.0,
avg_cost_per_conversation_usd=0.01,
agent_usage=(),
interrupt_stats=InterruptStats(),
)
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=result,
):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "365d"})
assert resp.status_code == 200
assert resp.json()["success"] is True
class TestAnalyticsInvalidRange:
"""Test analytics endpoint with invalid range parameters."""
async def test_invalid_range_format_returns_400(self) -> None:
"""GET /api/v1/analytics?range=abc returns 400 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "abc"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert "Invalid range format" in body["error"]
async def test_zero_day_range_returns_400(self) -> None:
"""GET /api/v1/analytics?range=0d returns 400 because 0 is below minimum."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "0d"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "between 1 and 365" in body["error"]
async def test_range_exceeding_max_returns_400(self) -> None:
"""GET /api/v1/analytics?range=999d returns 400 because it exceeds 365."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "999d"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "between 1 and 365" in body["error"]

View File

@@ -0,0 +1,128 @@
"""Integration tests for global error handling and envelope format consistency.
Tests that all error responses from the FastAPI app conform to the
standard envelope: {"success": false, "data": null, "error": "..."}.
"""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _build_app():
"""Build the actual FastAPI app with exception handlers but mocked state."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.analytics.api import router as analytics_router
from app.api_utils import envelope
from app.replay.api import router as replay_router
test_app = FastAPI()
test_app.include_router(analytics_router)
test_app.include_router(replay_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@test_app.exception_handler(Exception)
async def _catch_all(request, exc):
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
@test_app.get("/api/v1/health")
def health_check():
return {"status": "ok", "version": "0.6.0"}
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = MagicMock()
return test_app
class TestEnvelopeFormat:
"""Tests that error responses consistently follow envelope format."""
async def test_http_400_produces_envelope(self) -> None:
"""A 400 error returns standard envelope with success=false."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "invalid"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
assert len(body["error"]) > 0
async def test_validation_error_produces_422_envelope(self) -> None:
"""Invalid query param type returns 422 with envelope format."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
# page must be >= 1; passing 0 triggers validation error
resp = await client.get("/api/v1/conversations", params={"page": 0})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
async def test_all_error_fields_present(self) -> None:
"""Error envelope contains exactly success, data, and error keys."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "bad"})
body = resp.json()
assert set(body.keys()) == {"success", "data", "error"}
async def test_health_endpoint_returns_200(self) -> None:
"""Health check returns 200 with status ok."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/health")
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "ok"
assert "version" in body
async def test_unknown_endpoint_returns_404(self) -> None:
"""Requesting a non-existent path returns 404."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/nonexistent-path")
# FastAPI returns 404 for unknown routes; may or may not be wrapped
assert resp.status_code == 404

View File

@@ -0,0 +1,164 @@
"""Integration tests for /api/v1/openapi/ endpoints.
Tests the full API layer for the OpenAPI import review workflow,
including job creation, status retrieval, classification updates,
and approval triggering.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _build_app():
"""Build a minimal FastAPI app with the openapi router and mocked deps."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.api_utils import envelope
from app.openapi.review_api import router as openapi_router
test_app = FastAPI()
test_app.include_router(openapi_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
test_app.state.settings = MagicMock(admin_api_key="")
return test_app
@pytest.fixture(autouse=True)
def _clear_job_store():
"""Clear the in-memory job store between tests."""
from app.openapi.review_api import _job_store
_job_store.clear()
yield
_job_store.clear()
class TestImportEndpoint:
"""Tests for POST /api/v1/openapi/import."""
async def test_import_returns_202_with_job_id(self) -> None:
"""Starting an import returns 202 with a job_id."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/openapi/import",
json={"url": "https://example.com/api/spec.json"},
)
assert resp.status_code == 202
body = resp.json()
assert "job_id" in body
assert body["status"] == "pending"
assert body["spec_url"] == "https://example.com/api/spec.json"
async def test_import_invalid_url_returns_422(self) -> None:
"""POST with invalid URL (no http/https) returns 422."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/openapi/import",
json={"url": "ftp://example.com/spec.json"},
)
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
class TestJobStatusEndpoint:
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
async def test_get_existing_job_returns_status(self) -> None:
"""Retrieving an existing job returns its status."""
from app.openapi.review_api import _job_store
_job_store["test-job-1"] = {
"job_id": "test-job-1",
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 5,
"classified_count": 5,
"error_message": None,
"classifications": [],
}
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/openapi/jobs/test-job-1")
assert resp.status_code == 200
body = resp.json()
assert body["job_id"] == "test-job-1"
assert body["status"] == "done"
assert body["total_endpoints"] == 5
async def test_get_unknown_job_returns_404(self) -> None:
"""Retrieving a non-existent job returns 404 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/openapi/jobs/unknown-id-999")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert "not found" in body["error"].lower()
class TestApproveEndpoint:
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
async def test_approve_with_no_classifications_returns_400(self) -> None:
"""Approving a job with no classifications returns 400."""
from app.openapi.review_api import _job_store
_job_store["empty-job"] = {
"job_id": "empty-job",
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 0,
"classified_count": 0,
"error_message": None,
"classifications": [],
}
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post("/api/v1/openapi/jobs/empty-job/approve")
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "no classifications" in body["error"].lower()

View File

@@ -0,0 +1,213 @@
"""Integration tests for /api/v1/conversations and /api/v1/replay/{thread_id}.
Tests the full API layer with a mocked database pool, verifying routing,
serialization, pagination, and error handling in envelope format.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _make_fake_cursor(rows, *, fetchone_value=None):
"""Build a fake async cursor returning the given rows on fetchall."""
cursor = AsyncMock()
cursor.fetchall = AsyncMock(return_value=rows)
if fetchone_value is not None:
cursor.fetchone = AsyncMock(return_value=fetchone_value)
return cursor
class _FakeConnection:
"""Fake async connection that returns pre-configured cursors in order."""
def __init__(self, cursors: list) -> None:
self._cursors = list(cursors)
self._idx = 0
async def execute(self, sql, params=None):
cursor = self._cursors[self._idx]
self._idx += 1
return cursor
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class _FakePool:
"""Fake connection pool that yields a fake connection."""
def __init__(self, conn: _FakeConnection) -> None:
self._conn = conn
def connection(self):
return self._conn
def _build_app(pool=None):
"""Build a minimal FastAPI app with the replay router and mocked deps."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.api_utils import envelope
from app.replay.api import router as replay_router
test_app = FastAPI()
test_app.include_router(replay_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = pool or MagicMock()
return test_app
class TestListConversations:
"""Tests for GET /api/v1/conversations endpoint."""
async def test_returns_paginated_envelope(self) -> None:
"""Conversations list returns envelope with pagination metadata."""
count_cursor = _make_fake_cursor([], fetchone_value=(3,))
rows = [
{"thread_id": "t1", "created_at": "2026-01-01", "last_activity": "2026-01-01",
"status": "active", "total_tokens": 100, "total_cost_usd": 0.01},
{"thread_id": "t2", "created_at": "2026-01-02", "last_activity": "2026-01-02",
"status": "resolved", "total_tokens": 200, "total_cost_usd": 0.02},
]
list_cursor = _make_fake_cursor(rows)
conn = _FakeConnection([count_cursor, list_cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["total"] == 3
assert len(body["data"]["conversations"]) == 2
assert body["data"]["page"] == 1
assert body["data"]["per_page"] == 20
async def test_custom_page_and_per_page(self) -> None:
"""Custom page/per_page params are reflected in the response."""
count_cursor = _make_fake_cursor([], fetchone_value=(50,))
list_cursor = _make_fake_cursor([])
conn = _FakeConnection([count_cursor, list_cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations", params={"page": 3, "per_page": 10})
assert resp.status_code == 200
body = resp.json()
assert body["data"]["page"] == 3
assert body["data"]["per_page"] == 10
async def test_invalid_page_returns_422(self) -> None:
"""page=0 violates ge=1 constraint and returns 422 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations", params={"page": 0})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
class TestReplayEndpoint:
"""Tests for GET /api/v1/replay/{thread_id} endpoint."""
async def test_valid_thread_returns_timeline(self) -> None:
"""Replay with valid thread_id returns steps in envelope format."""
checkpoint_rows = [
{
"thread_id": "abc123",
"checkpoint_id": "cp1",
"checkpoint": {
"channel_values": {
"messages": [
{"type": "human", "content": "Hello", "created_at": "2026-01-01T00:00:00Z"},
{"type": "ai", "content": "Hi there!", "created_at": "2026-01-01T00:00:01Z"},
]
}
},
"metadata": {},
}
]
cursor = _make_fake_cursor(checkpoint_rows)
conn = _FakeConnection([cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/abc123")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["thread_id"] == "abc123"
assert body["data"]["total_steps"] == 2
assert len(body["data"]["steps"]) == 2
assert body["data"]["steps"][0]["type"] == "user_message"
assert body["data"]["steps"][1]["type"] == "agent_response"
async def test_invalid_thread_id_format_returns_400(self) -> None:
"""Thread IDs with path traversal characters are rejected with 400."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/../../etc/passwd")
# FastAPI may return 400 from our handler or 404 from routing
assert resp.status_code in (400, 404, 422)
async def test_nonexistent_thread_returns_404(self) -> None:
"""Replay with a thread_id that has no checkpoints returns 404."""
cursor = _make_fake_cursor([])
conn = _FakeConnection([cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert "not found" in body["error"].lower()

View File

@@ -0,0 +1,159 @@
"""Integration tests for SessionManager + InterruptManager lifecycle.
These tests exercise the in-memory managers together, verifying the full
lifecycle of sessions and interrupts: creation, TTL sliding, interrupt
registration/resolution, and expired-interrupt cleanup.
No database required -- both managers are in-memory.
"""
from __future__ import annotations
import time
from unittest.mock import patch
import pytest
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
pytestmark = pytest.mark.integration
class TestSessionInterruptLifecycle:
"""Tests for the combined session + interrupt lifecycle."""
def test_create_session_register_interrupt_check_status(self) -> None:
"""Full lifecycle: create session, register interrupt, verify both states."""
sm = SessionManager(session_ttl_seconds=3600)
im = InterruptManager(ttl_seconds=300)
# Create a session
state = sm.touch("thread-1")
assert state.thread_id == "thread-1"
assert not state.has_pending_interrupt
assert not sm.is_expired("thread-1")
# Register an interrupt
record = im.register("thread-1", "cancel_order", {"order_id": "1042"})
sm.extend_for_interrupt("thread-1")
assert im.has_pending("thread-1")
session_state = sm.get_state("thread-1")
assert session_state is not None
assert session_state.has_pending_interrupt
# Session should not expire while interrupt is pending
assert not sm.is_expired("thread-1")
def test_interrupt_expiry_after_ttl(self) -> None:
"""Interrupt expires when TTL elapses, even if session is alive."""
im = InterruptManager(ttl_seconds=5)
record = im.register("thread-2", "refund", {"amount": 50})
assert im.has_pending("thread-2")
# Simulate time passing beyond TTL
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 10
assert not im.has_pending("thread-2")
status = im.check_status("thread-2")
assert status is not None
assert status.is_expired
assert status.remaining_seconds == 0.0
def test_interrupt_resolve_flow(self) -> None:
"""Resolving an interrupt removes it from pending and resets session."""
sm = SessionManager(session_ttl_seconds=3600)
im = InterruptManager(ttl_seconds=300)
sm.touch("thread-3")
im.register("thread-3", "delete_account", {"user_id": "u1"})
sm.extend_for_interrupt("thread-3")
# Verify pending state
assert im.has_pending("thread-3")
assert sm.get_state("thread-3").has_pending_interrupt
# Resolve
im.resolve("thread-3")
sm.resolve_interrupt("thread-3")
assert not im.has_pending("thread-3")
session_state = sm.get_state("thread-3")
assert session_state is not None
assert not session_state.has_pending_interrupt
def test_cleanup_expired_removes_old_interrupts(self) -> None:
"""cleanup_expired removes only expired interrupts, keeping active ones."""
im = InterruptManager(ttl_seconds=10)
# Register two interrupts at different times
old_record = im.register("thread-old", "action_old", {})
new_record = im.register("thread-new", "action_new", {})
# Simulate time where only old one expired
with patch("app.interrupt_manager.time") as mock_time:
# Move old record's creation to the past
im._interrupts["thread-old"] = old_record.__class__(
interrupt_id=old_record.interrupt_id,
thread_id=old_record.thread_id,
action=old_record.action,
params=old_record.params,
created_at=time.time() - 20,
ttl_seconds=old_record.ttl_seconds,
)
mock_time.time.return_value = time.time()
expired = im.cleanup_expired()
assert len(expired) == 1
assert expired[0].thread_id == "thread-old"
# New one should still be pending
assert im.has_pending("thread-new")
assert not im.has_pending("thread-old")
def test_session_ttl_sliding_window(self) -> None:
"""Touching a session resets the sliding window TTL."""
sm = SessionManager(session_ttl_seconds=3600)
state1 = sm.touch("thread-5")
first_activity = state1.last_activity
time.sleep(0.01)
state2 = sm.touch("thread-5")
second_activity = state2.last_activity
assert second_activity > first_activity
assert not sm.is_expired("thread-5")
def test_session_expires_after_ttl_without_activity(self) -> None:
"""Session expires when TTL passes without a touch or interrupt."""
sm = SessionManager(session_ttl_seconds=0)
sm.touch("thread-6")
# TTL is 0 so session is immediately expired
assert sm.is_expired("thread-6")
def test_pending_interrupt_prevents_session_expiry(self) -> None:
"""A session with pending interrupt does not expire even with TTL=0."""
sm = SessionManager(session_ttl_seconds=0)
sm.touch("thread-7")
sm.extend_for_interrupt("thread-7")
# Even with TTL=0, session should not expire because of pending interrupt
assert not sm.is_expired("thread-7")
def test_retry_prompt_for_expired_interrupt(self) -> None:
"""InterruptManager generates a retry prompt for expired interrupts."""
im = InterruptManager(ttl_seconds=300)
record = im.register("thread-8", "cancel_order", {"order_id": "1042"})
prompt = im.generate_retry_prompt(record)
assert prompt["type"] == "interrupt_expired"
assert prompt["thread_id"] == "thread-8"
assert "cancel_order" in prompt["action"]
assert "cancel_order" in prompt["message"]
assert "expired" in prompt["message"].lower()

View File

@@ -44,7 +44,7 @@ def _make_analytics_result() -> object:
) )
def _get_analytics(app: FastAPI, path: str = "/api/analytics", **patch_kwargs: object) -> object: def _get_analytics(app: FastAPI, path: str = "/api/v1/analytics", **patch_kwargs: object) -> object:
"""Helper: patch get_analytics, make request, return (response, mock).""" """Helper: patch get_analytics, make request, return (response, mock)."""
analytics_result = _make_analytics_result() analytics_result = _make_analytics_result()
with ( with (
@@ -84,7 +84,7 @@ class TestAnalyticsEndpoint:
def test_custom_range_7d(self) -> None: def test_custom_range_7d(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool() app.state.pool = _make_mock_pool()
resp, mock_ga = _get_analytics(app, "/api/analytics?range=7d") resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=7d")
assert resp.status_code == 200 assert resp.status_code == 200
mock_ga.assert_called_once() mock_ga.assert_called_once()
@@ -94,7 +94,7 @@ class TestAnalyticsEndpoint:
def test_custom_range_30d(self) -> None: def test_custom_range_30d(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool() app.state.pool = _make_mock_pool()
resp, mock_ga = _get_analytics(app, "/api/analytics?range=30d") resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=30d")
assert resp.status_code == 200 assert resp.status_code == 200
call_kwargs = mock_ga.call_args call_kwargs = mock_ga.call_args
@@ -107,7 +107,7 @@ class TestAnalyticsEndpoint:
app.state.pool = _make_mock_pool() app.state.pool = _make_mock_pool()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics?range=invalid") resp = client.get("/api/v1/analytics?range=invalid")
assert resp.status_code == 400 assert resp.status_code == 400
@@ -116,7 +116,7 @@ class TestAnalyticsEndpoint:
app.state.pool = _make_mock_pool() app.state.pool = _make_mock_pool()
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/analytics?range=7") resp = client.get("/api/v1/analytics?range=7")
assert resp.status_code == 400 assert resp.status_code == 400

View File

@@ -28,7 +28,7 @@ def client():
@pytest.fixture @pytest.fixture
def job_id(client): def job_id(client):
"""Create a job and return its ID.""" """Create a job and return its ID."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
assert response.status_code == 202 assert response.status_code == 202
return response.json()["job_id"] return response.json()["job_id"]
@@ -61,11 +61,11 @@ def job_with_classifications(client, job_id):
class TestImportEndpoint: class TestImportEndpoint:
"""Tests for POST /api/openapi/import.""" """Tests for POST /api/v1/openapi/import."""
def test_post_import_returns_job_id(self, client) -> None: def test_post_import_returns_job_id(self, client) -> None:
"""POST /import returns 202 with a job_id.""" """POST /import returns 202 with a job_id."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
assert response.status_code == 202 assert response.status_code == 202
data = response.json() data = response.json()
assert "job_id" in data assert "job_id" in data
@@ -73,38 +73,38 @@ class TestImportEndpoint:
def test_post_import_empty_url_returns_422(self, client) -> None: def test_post_import_empty_url_returns_422(self, client) -> None:
"""POST /import with empty URL returns 422 validation error.""" """POST /import with empty URL returns 422 validation error."""
response = client.post("/api/openapi/import", json={"url": ""}) response = client.post("/api/v1/openapi/import", json={"url": ""})
assert response.status_code == 422 assert response.status_code == 422
def test_post_import_missing_url_returns_422(self, client) -> None: def test_post_import_missing_url_returns_422(self, client) -> None:
"""POST /import with missing URL field returns 422.""" """POST /import with missing URL field returns 422."""
response = client.post("/api/openapi/import", json={}) response = client.post("/api/v1/openapi/import", json={})
assert response.status_code == 422 assert response.status_code == 422
def test_post_import_invalid_scheme_returns_422(self, client) -> None: def test_post_import_invalid_scheme_returns_422(self, client) -> None:
"""POST /import with non-http URL returns 422.""" """POST /import with non-http URL returns 422."""
response = client.post("/api/openapi/import", json={"url": "ftp://evil.com/spec"}) response = client.post("/api/v1/openapi/import", json={"url": "ftp://evil.com/spec"})
assert response.status_code == 422 assert response.status_code == 422
def test_post_import_returns_pending_status(self, client) -> None: def test_post_import_returns_pending_status(self, client) -> None:
"""Newly created job has pending status.""" """Newly created job has pending status."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
data = response.json() data = response.json()
assert data["status"] == "pending" assert data["status"] == "pending"
def test_post_import_returns_spec_url(self, client) -> None: def test_post_import_returns_spec_url(self, client) -> None:
"""Response includes the original spec URL.""" """Response includes the original spec URL."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
data = response.json() data = response.json()
assert data["spec_url"] == _SAMPLE_URL assert data["spec_url"] == _SAMPLE_URL
class TestGetJobEndpoint: class TestGetJobEndpoint:
"""Tests for GET /api/openapi/jobs/{job_id}.""" """Tests for GET /api/v1/openapi/jobs/{job_id}."""
def test_get_job_returns_status(self, client, job_id) -> None: def test_get_job_returns_status(self, client, job_id) -> None:
"""GET /jobs/{id} returns job status.""" """GET /jobs/{id} returns job status."""
response = client.get(f"/api/openapi/jobs/{job_id}") response = client.get(f"/api/v1/openapi/jobs/{job_id}")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "status" in data assert "status" in data
@@ -112,23 +112,23 @@ class TestGetJobEndpoint:
def test_get_unknown_job_returns_404(self, client) -> None: def test_get_unknown_job_returns_404(self, client) -> None:
"""GET /jobs/nonexistent returns 404.""" """GET /jobs/nonexistent returns 404."""
response = client.get("/api/openapi/jobs/nonexistent-id") response = client.get("/api/v1/openapi/jobs/nonexistent-id")
assert response.status_code == 404 assert response.status_code == 404
def test_get_job_includes_spec_url(self, client, job_id) -> None: def test_get_job_includes_spec_url(self, client, job_id) -> None:
"""Job response includes the spec URL.""" """Job response includes the spec URL."""
response = client.get(f"/api/openapi/jobs/{job_id}") response = client.get(f"/api/v1/openapi/jobs/{job_id}")
data = response.json() data = response.json()
assert data["spec_url"] == _SAMPLE_URL assert data["spec_url"] == _SAMPLE_URL
class TestGetClassificationsEndpoint: class TestGetClassificationsEndpoint:
"""Tests for GET /api/openapi/jobs/{job_id}/classifications.""" """Tests for GET /api/v1/openapi/jobs/{job_id}/classifications."""
def test_get_classifications_returns_list(self, client, job_with_classifications) -> None: def test_get_classifications_returns_list(self, client, job_with_classifications) -> None:
"""GET /classifications returns a list.""" """GET /classifications returns a list."""
response = client.get( response = client.get(
f"/api/openapi/jobs/{job_with_classifications}/classifications" f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -137,13 +137,13 @@ class TestGetClassificationsEndpoint:
def test_get_classifications_unknown_job_returns_404(self, client) -> None: def test_get_classifications_unknown_job_returns_404(self, client) -> None:
"""GET /classifications for unknown job returns 404.""" """GET /classifications for unknown job returns 404."""
response = client.get("/api/openapi/jobs/unknown/classifications") response = client.get("/api/v1/openapi/jobs/unknown/classifications")
assert response.status_code == 404 assert response.status_code == 404
def test_classification_has_expected_fields(self, client, job_with_classifications) -> None: def test_classification_has_expected_fields(self, client, job_with_classifications) -> None:
"""Each classification item has access_type and endpoint fields.""" """Each classification item has access_type and endpoint fields."""
response = client.get( response = client.get(
f"/api/openapi/jobs/{job_with_classifications}/classifications" f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
) )
item = response.json()[0] item = response.json()[0]
assert "access_type" in item assert "access_type" in item
@@ -152,12 +152,12 @@ class TestGetClassificationsEndpoint:
class TestUpdateClassificationEndpoint: class TestUpdateClassificationEndpoint:
"""Tests for PUT /api/openapi/jobs/{job_id}/classifications/{idx}.""" """Tests for PUT /api/v1/openapi/jobs/{job_id}/classifications/{idx}."""
def test_update_classification_succeeds(self, client, job_with_classifications) -> None: def test_update_classification_succeeds(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 updates the classification.""" """PUT /classifications/0 updates the classification."""
response = client.put( response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0", f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"}, json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -165,7 +165,7 @@ class TestUpdateClassificationEndpoint:
def test_update_unknown_job_returns_404(self, client) -> None: def test_update_unknown_job_returns_404(self, client) -> None:
"""PUT /classifications/0 for unknown job returns 404.""" """PUT /classifications/0 for unknown job returns 404."""
response = client.put( response = client.put(
"/api/openapi/jobs/unknown/classifications/0", "/api/v1/openapi/jobs/unknown/classifications/0",
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"}, json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -173,7 +173,7 @@ class TestUpdateClassificationEndpoint:
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None: def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid access_type returns 422.""" """PUT /classifications/0 with invalid access_type returns 422."""
response = client.put( response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0", f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"}, json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -181,7 +181,7 @@ class TestUpdateClassificationEndpoint:
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None: def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid agent_group returns 422.""" """PUT /classifications/0 with invalid agent_group returns 422."""
response = client.put( response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0", f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"}, json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -189,31 +189,31 @@ class TestUpdateClassificationEndpoint:
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None: def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
"""PUT /classifications/999 returns 404 for out-of-range index.""" """PUT /classifications/999 returns 404 for out-of-range index."""
response = client.put( response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/999", f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/999",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"}, json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"},
) )
assert response.status_code == 404 assert response.status_code == 404
class TestApproveEndpoint: class TestApproveEndpoint:
"""Tests for POST /api/openapi/jobs/{job_id}/approve.""" """Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
def test_approve_job_succeeds(self, client, job_with_classifications) -> None: def test_approve_job_succeeds(self, client, job_with_classifications) -> None:
"""POST /approve transitions job to approved status.""" """POST /approve transitions job to approved status."""
response = client.post( response = client.post(
f"/api/openapi/jobs/{job_with_classifications}/approve" f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
) )
assert response.status_code == 200 assert response.status_code == 200
def test_approve_unknown_job_returns_404(self, client) -> None: def test_approve_unknown_job_returns_404(self, client) -> None:
"""POST /approve for unknown job returns 404.""" """POST /approve for unknown job returns 404."""
response = client.post("/api/openapi/jobs/unknown/approve") response = client.post("/api/v1/openapi/jobs/unknown/approve")
assert response.status_code == 404 assert response.status_code == 404
def test_approve_returns_job_status(self, client, job_with_classifications) -> None: def test_approve_returns_job_status(self, client, job_with_classifications) -> None:
"""POST /approve returns updated job status.""" """POST /approve returns updated job status."""
response = client.post( response = client.post(
f"/api/openapi/jobs/{job_with_classifications}/approve" f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
) )
data = response.json() data = response.json()
assert "status" in data assert "status" in data

View File

@@ -5,9 +5,12 @@ from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.api_utils import envelope
pytestmark = pytest.mark.unit pytestmark = pytest.mark.unit
@@ -16,6 +19,14 @@ def _build_app() -> FastAPI:
app = FastAPI() app = FastAPI()
app.include_router(router) app.include_router(router)
@app.exception_handler(HTTPException)
async def _http_exc(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
return app return app
@@ -64,7 +75,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool([], count=0) app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
@@ -89,7 +100,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool(mock_rows, count=1) app.state.pool = _make_mock_pool(mock_rows, count=1)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
body = resp.json() body = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
data = body["data"] data = body["data"]
@@ -102,7 +113,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool([], count=0) app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations") resp = client.get("/api/v1/conversations")
assert resp.status_code == 200 assert resp.status_code == 200
def test_pagination_custom_params(self) -> None: def test_pagination_custom_params(self) -> None:
@@ -110,7 +121,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool([], count=0) app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations?page=2&per_page=10") resp = client.get("/api/v1/conversations?page=2&per_page=10")
assert resp.status_code == 200 assert resp.status_code == 200
def test_per_page_max_capped_at_100(self) -> None: def test_per_page_max_capped_at_100(self) -> None:
@@ -118,7 +129,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool([], count=0) app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/conversations?per_page=200") resp = client.get("/api/v1/conversations?per_page=200")
# FastAPI Query(le=100) rejects values > 100 # FastAPI Query(le=100) rejects values > 100
assert resp.status_code == 422 assert resp.status_code == 422
@@ -129,7 +140,7 @@ class TestGetReplay:
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([])
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/nonexistent-thread") resp = client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404 assert resp.status_code == 404
def test_returns_replay_page_for_existing_thread(self) -> None: def test_returns_replay_page_for_existing_thread(self) -> None:
@@ -149,7 +160,7 @@ class TestGetReplay:
app.state.pool = _make_mock_pool(mock_rows) app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/thread-123") resp = client.get("/api/v1/replay/thread-123")
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["success"] is True assert body["success"] is True
@@ -174,7 +185,7 @@ class TestGetReplay:
app.state.pool = _make_mock_pool(mock_rows) app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/t1?page=1&per_page=5") resp = client.get("/api/v1/replay/t1?page=1&per_page=5")
assert resp.status_code == 200 assert resp.status_code == 200
def test_error_response_has_envelope(self) -> None: def test_error_response_has_envelope(self) -> None:
@@ -182,16 +193,19 @@ class TestGetReplay:
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([])
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/missing") resp = client.get("/api/v1/replay/missing")
assert resp.status_code == 404 assert resp.status_code == 404
assert "detail" in resp.json() body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] is not None
def test_invalid_thread_id_returns_400(self) -> None: def test_invalid_thread_id_returns_400(self) -> None:
app = _build_app() app = _build_app()
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([])
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/id%20with%20spaces") resp = client.get("/api/v1/replay/id%20with%20spaces")
assert resp.status_code == 400 assert resp.status_code == 400
def test_thread_id_special_chars_returns_400(self) -> None: def test_thread_id_special_chars_returns_400(self) -> None:
@@ -199,5 +213,5 @@ class TestGetReplay:
app.state.pool = _make_mock_pool([]) app.state.pool = _make_mock_pool([])
with TestClient(app) as client: with TestClient(app) as client:
resp = client.get("/api/replay/id;DROP TABLE") resp = client.get("/api/v1/replay/id;DROP TABLE")
assert resp.status_code == 400 assert resp.status_code == 400

View File

@@ -0,0 +1,142 @@
"""Tests for standardized error response envelope format."""
from __future__ import annotations
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field
from app.api_utils import envelope
pytestmark = pytest.mark.unit
def _build_test_app() -> FastAPI:
"""Build a minimal FastAPI app with the standard exception handlers."""
app = FastAPI()
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
class ItemRequest(BaseModel):
name: str = Field(..., min_length=1)
count: int = Field(..., gt=0)
@app.get("/items/{item_id}")
def get_item(item_id: int) -> dict:
if item_id == 0:
raise HTTPException(status_code=400, detail="Invalid item ID")
if item_id == 999:
raise HTTPException(status_code=404, detail="Item not found")
if item_id == 401:
raise HTTPException(status_code=401, detail="Not authenticated")
return envelope({"id": item_id, "name": "test"})
@app.post("/items")
def create_item(item: ItemRequest) -> dict:
return envelope({"id": 1, "name": item.name})
@app.get("/crash")
def crash() -> dict:
msg = "unexpected failure"
raise RuntimeError(msg)
return app
class TestHttpExceptionEnvelope:
"""HTTPException responses use the standard envelope format."""
def test_400_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/0")
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Invalid item ID"
def test_404_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/999")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Item not found"
def test_401_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/401")
assert resp.status_code == 401
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Not authenticated"
class TestValidationErrorEnvelope:
"""Validation errors return 422 with envelope format."""
def test_validation_error_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post("/items", json={"name": "", "count": -1})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
assert len(body["error"]) > 0
class TestGeneralExceptionEnvelope:
"""Unhandled exceptions return 500 with safe envelope."""
def test_unhandled_exception_returns_500_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/crash")
assert resp.status_code == 500
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Internal server error"
class TestSuccessResponseUnchanged:
"""Success responses still work normally."""
def test_success_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app) as client:
resp = client.get("/items/42")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["id"] == 42
assert body["error"] is None

View File

@@ -0,0 +1,86 @@
"""Tests for the interrupt cleanup background loop in main.py."""
from __future__ import annotations
import asyncio
import logging
from unittest.mock import MagicMock, patch
import pytest
from app.main import _interrupt_cleanup_loop
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cleanup_loop_calls_cleanup_expired() -> None:
"""The loop should call cleanup_expired after each sleep interval."""
manager = MagicMock()
manager.cleanup_expired.return_value = ()
call_count = 0
original_sleep = asyncio.sleep
async def _fake_sleep(seconds: float) -> None:
nonlocal call_count
call_count += 1
if call_count >= 2:
raise asyncio.CancelledError
await original_sleep(0)
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
with pytest.raises(asyncio.CancelledError):
await _interrupt_cleanup_loop(manager, interval=60)
assert manager.cleanup_expired.call_count >= 1
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cleanup_loop_survives_exceptions() -> None:
"""The loop should not die when cleanup_expired raises an exception."""
manager = MagicMock()
manager.cleanup_expired.side_effect = [RuntimeError("db gone"), ()]
call_count = 0
original_sleep = asyncio.sleep
async def _fake_sleep(seconds: float) -> None:
nonlocal call_count
call_count += 1
if call_count >= 3:
raise asyncio.CancelledError
await original_sleep(0)
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
with pytest.raises(asyncio.CancelledError):
await _interrupt_cleanup_loop(manager, interval=60)
# Should have been called twice: once raising, once returning ()
assert manager.cleanup_expired.call_count == 2
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cleanup_loop_logs_expired_count(capsys: pytest.CaptureFixture[str]) -> None:
"""The loop should log when expired interrupts are found."""
fake_record = MagicMock()
manager = MagicMock()
manager.cleanup_expired.return_value = (fake_record, fake_record)
call_count = 0
original_sleep = asyncio.sleep
async def _fake_sleep(seconds: float) -> None:
nonlocal call_count
call_count += 1
if call_count >= 2:
raise asyncio.CancelledError
await original_sleep(0)
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
with pytest.raises(asyncio.CancelledError):
await _interrupt_cleanup_loop(manager, interval=60)
captured = capsys.readouterr()
assert "2 expired interrupt" in captured.out

View File

@@ -0,0 +1,20 @@
"""Tests for structured logging configuration."""
from __future__ import annotations
import pytest
from app.logging_config import configure_logging
pytestmark = pytest.mark.unit
def test_configure_logging_console_mode() -> None:
"""Console mode configures without error."""
configure_logging("console")
def test_configure_logging_json_mode() -> None:
"""JSON mode configures without error."""
configure_logging("json")

View File

@@ -36,7 +36,7 @@ class TestMainModule:
def test_health_route_registered(self) -> None: def test_health_route_registered(self) -> None:
routes = [r.path for r in app.routes if hasattr(r, "path")] routes = [r.path for r in app.routes if hasattr(r, "path")]
assert "/api/health" in routes assert "/api/v1/health" in routes
def test_app_version_is_0_5_0(self) -> None: def test_app_version_is_0_5_0(self) -> None:
assert app.version == "0.6.0" assert app.version == "0.6.0"

View File

@@ -41,7 +41,7 @@ services:
postgres: postgres:
condition: service_healthy condition: service_healthy
healthcheck: healthcheck:
test: ["CMD-SHELL", "curl -f http://localhost:8000/api/health || exit 1"] test: ["CMD-SHELL", "curl -f http://localhost:8000/api/v1/health || exit 1"]
interval: 10s interval: 10s
timeout: 5s timeout: 5s
retries: 5 retries: 5

View File

@@ -99,7 +99,12 @@ smart-support/
├── backend/ ├── backend/
│ ├── app/ │ ├── app/
│ │ ├── main.py # FastAPI + WebSocket 入口 │ │ ├── main.py # FastAPI + WebSocket 入口
│ │ ├── graph.py # LangGraph Supervisor 配置 │ │ ├── graph.py # LangGraph Supervisor 构建
│ │ ├── graph_context.py # GraphContext: 图 + 分类器 + 注册表的类型化封装
│ │ ├── ws_handler.py # WebSocket 消息分发 + 速率限制
│ │ ├── ws_context.py # WebSocketContext: WS 依赖包
│ │ ├── auth.py # API Key 认证中间件
│ │ ├── api_utils.py # 共享 API 响应工具 (envelope)
│ │ ├── agents/ # Agent 定义 + 工具绑定 │ │ ├── agents/ # Agent 定义 + 工具绑定
│ │ ├── registry.py # YAML Agent 注册表加载器 │ │ ├── registry.py # YAML Agent 注册表加载器
│ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成 │ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成
@@ -139,7 +144,11 @@ smart-support/
| 模块 | 职责 | | 模块 | 职责 |
|------|------| |------|------|
| main.py | 应用入口, WebSocket 端点, 静态文件服务 | | main.py | 应用入口, WebSocket 端点, 静态文件服务 |
| WebSocket Handler | 双向通信: 接收用户消息, 流式返回 token, 处理 interrupt 响应 | | auth.py | API Key 认证: 管理端点通过 `X-API-Key` header, WebSocket 通过 `?token=` query param |
| ws_handler.py | 双向通信: 接收用户消息, 流式返回 token, 处理 interrupt 响应 |
| graph_context.py | 类型化封装: 将编译后的图与分类器、注册表绑定, 替代猴子补丁 |
| ws_context.py | 依赖包: 将 WebSocket 处理所需的 9 个依赖打包为单一不可变对象 |
| api_utils.py | 共享响应格式: 统一的 `envelope()` 函数 |
### 2.3 Agent 编排层 (LangGraph) ### 2.3 Agent 编排层 (LangGraph)
@@ -427,6 +436,19 @@ CREATE INDEX idx_interrupts_ttl ON interrupts(ttl_expires_at)
WHERE status = 'pending'; WHERE status = 'pending';
``` ```
#### sessions (自定义 - 会话状态持久化)
```sql
-- 用于多 worker 部署的 PostgreSQL 会话状态管理
-- PgSessionManager 使用此表替代内存中的 dict
CREATE TABLE sessions (
thread_id TEXT PRIMARY KEY,
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
```
#### analytics_events (自定义 - 分析事件流) #### analytics_events (自定义 - 分析事件流)
```sql ```sql

View File

@@ -54,11 +54,19 @@ Set these in production (never commit secrets):
| `ANTHROPIC_API_KEY` | Yes* | LLM provider API key | | `ANTHROPIC_API_KEY` | Yes* | LLM provider API key |
| `LLM_PROVIDER` | Yes | `anthropic`, `openai`, or `google` | | `LLM_PROVIDER` | Yes | `anthropic`, `openai`, or `google` |
| `LLM_MODEL` | Yes | Model name for your provider | | `LLM_MODEL` | Yes | Model name for your provider |
| `ADMIN_API_KEY` | Recommended | API key for admin endpoints (analytics, replay, openapi, WS). Leave empty to disable auth (dev mode only) |
| `WEBHOOK_URL` | No | Escalation notification endpoint | | `WEBHOOK_URL` | No | Escalation notification endpoint |
| `SESSION_TTL_MINUTES` | No | Session timeout (default: 30) | | `SESSION_TTL_MINUTES` | No | Session timeout (default: 30) |
*Or `OPENAI_API_KEY` / `GOOGLE_API_KEY` depending on `LLM_PROVIDER`. *Or `OPENAI_API_KEY` / `GOOGLE_API_KEY` depending on `LLM_PROVIDER`.
### Authentication
When `ADMIN_API_KEY` is set, all admin REST endpoints require the `X-API-Key` header,
and WebSocket connections require a `?token=<key>` query parameter.
When unset or empty, authentication is disabled (suitable for local development only).
### HTTPS ### HTTPS
For production, place a reverse proxy (nginx, Caddy, or a load balancer) in For production, place a reverse proxy (nginx, Caddy, or a load balancer) in
@@ -87,10 +95,12 @@ cat backup.sql | docker compose exec -T postgres psql -U smart_support smart_sup
### Scaling ### Scaling
The backend is stateless (session state is in PostgreSQL via LangGraph's The backend supports multi-worker deployments. LangGraph session state is
PostgresSaver). You can run multiple backend replicas behind a load balancer. persisted in PostgreSQL via PostgresSaver. For full horizontal scaling, use
`PgSessionManager` and `PgInterruptManager` (instead of the default in-memory
managers) to share session and interrupt state across workers.
The WebSocket connections are session-specific. Use sticky sessions or a shared WebSocket connections are session-specific. Use sticky sessions or a shared
session backend if load balancing WebSockets across multiple instances. session backend if load balancing WebSockets across multiple instances.
## Manual / Development Setup ## Manual / Development Setup
@@ -139,7 +149,7 @@ GET /api/health
Response: Response:
```json ```json
{"status": "ok", "version": "0.5.0"} {"status": "ok", "version": "0.6.0"}
``` ```
### WebSocket health ### WebSocket health

View File

@@ -86,7 +86,21 @@ Content-Type: application/json
POST /api/openapi/jobs/{job_id}/approve POST /api/openapi/jobs/{job_id}/approve
``` ```
No request body. Changes the job status to `approved`. No request body. Generates tool code for each classified endpoint and produces
an agent YAML configuration. Response includes `generated_tools_count`.
Response:
```json
{
"job_id": "abc123",
"status": "approved",
"spec_url": "https://api.example.com/openapi.yaml",
"total_endpoints": 5,
"classified_count": 5,
"error_message": null,
"generated_tools_count": 5
}
```
## Access Type Classification ## Access Type Classification

View File

@@ -0,0 +1,76 @@
# Engineering Improvements -- Development Log
> Status: COMPLETED
> Branch: `eng/engineering-improvements`
> Date started: 2026-04-06
> Date completed: 2026-04-06
## What Was Built
### Phase 1: Quick Wins (no new deps)
1. **Interrupt Cleanup Background Task** -- Added asyncio background task in lifespan that calls `interrupt_manager.cleanup_expired()` every 60 seconds. Prevents unbounded memory growth from expired interrupts.
2. **API Versioning** -- All REST endpoints prefixed with `/api/v1/` (was `/api/`). Updated 4 router prefixes, Docker healthcheck, all frontend fetch URLs, and all test assertions. WebSocket `/ws` endpoint unchanged.
3. **Error Response Standardization** -- Added global exception handlers for `HTTPException`, `RequestValidationError`, and `Exception`. All error responses now use the same envelope format as success responses: `{"success": false, "data": null, "error": "..."}`.
### Phase 2: Medium Items (new deps)
4. **Alembic Database Migrations** -- Replaced inline DDL in `setup_app_tables()` with versioned Alembic migrations. Initial migration `001_initial_schema.py` captures all 4 tables + ALTER TABLE migration. `setup_app_tables()` preserved for tests. Production uses `run_alembic_migrations()`.
5. **Structured Logging** -- Replaced stdlib `logging.getLogger()` with `structlog.get_logger()` across 10 files. Added `logging_config.py` with console (dev) and JSON (production) modes. Configurable via `LOG_FORMAT` env var.
### Phase 3: Test Coverage
7. **Integration Tests (+30)** -- Created 5 new test files: analytics API, replay API, OpenAPI API, error responses, session/interrupt lifecycle. Uses httpx.AsyncClient with ASGITransport for full API layer testing.
8. **Frontend Tests (+57)** -- Created 12 new test files covering all components (ChatInput, ChatMessages, InterruptPrompt, ErrorBanner, NavBar, MetricCard, ReplayTimeline, AgentAction, Layout), pages (ChatPage, ReviewPage), and hooks (useWebSocket).
## Code Structure
### New files created
- `backend/app/logging_config.py` -- structlog configuration
- `backend/alembic.ini` -- Alembic config
- `backend/alembic/env.py` -- Migration environment
- `backend/alembic/versions/001_initial_schema.py` -- Initial migration
- `backend/tests/unit/test_interrupt_cleanup.py` (3 tests)
- `backend/tests/unit/test_error_responses.py` (6 tests)
- `backend/tests/unit/test_logging_config.py` (2 tests)
- `backend/tests/integration/test_analytics_api.py` (6 tests)
- `backend/tests/integration/test_replay_api.py` (6 tests)
- `backend/tests/integration/test_openapi_api.py` (5 tests)
- `backend/tests/integration/test_error_responses.py` (5 tests)
- `backend/tests/integration/test_session_interrupt_lifecycle.py` (8 tests)
- 12 frontend test files (57 tests total)
### Modified files
- `backend/app/main.py` -- cleanup task, exception handlers, alembic, structlog
- `backend/app/db.py` -- added run_alembic_migrations()
- `backend/app/config.py` -- added log_format setting
- `backend/pyproject.toml` -- added alembic, structlog deps
- 4 router files -- `/api/v1/` prefix
- 10 files -- structlog migration
- `docker-compose.yml` -- healthcheck URL
- `frontend/src/api.ts` -- `/api/v1/` URLs
- All existing test files -- API path updates + error envelope assertions
## Test Coverage
- Backend: 557 tests (was 516), 89.75% coverage
- Unit: ~490 tests
- Integration: ~60 tests
- E2E: ~7 tests
- Frontend: 80 tests (was 23), 16 test files (was 4)
## Deviations from Plan
- Redis rate limiting deferred (single-worker sufficient for now)
- ConversationTracker verified correct by design (pool per-method), skipped
- Coverage dropped slightly from 90.26% to 89.75% due to new alembic/logging modules with partial test coverage (still well above 80% threshold)
## Known Issues / Tech Debt
- Rate limiting remains process-global (needs Redis for multi-worker)
- Alembic migrations not tested against real PostgreSQL in CI (would need running DB)
- Frontend test coverage could be deeper (e.g., WebSocket reconnect edge cases)

View File

@@ -31,14 +31,14 @@ describe("fetchConversations", () => {
const result = await fetchConversations(); const result = await fetchConversations();
expect(result.conversations).toHaveLength(1); expect(result.conversations).toHaveLength(1);
expect(result.total).toBe(1); expect(result.total).toBe(1);
expect(mockFetch).toHaveBeenCalledWith("/api/conversations?page=1&per_page=20"); expect(mockFetch).toHaveBeenCalledWith("/api/v1/conversations?page=1&per_page=20");
}); });
it("passes custom page and perPage", async () => { it("passes custom page and perPage", async () => {
mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { conversations: [], total: 0, page: 2, per_page: 10 }, error: null })); mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { conversations: [], total: 0, page: 2, per_page: 10 }, error: null }));
await fetchConversations(2, 10); await fetchConversations(2, 10);
expect(mockFetch).toHaveBeenCalledWith("/api/conversations?page=2&per_page=10"); expect(mockFetch).toHaveBeenCalledWith("/api/v1/conversations?page=2&per_page=10");
}); });
it("throws on HTTP error", async () => { it("throws on HTTP error", async () => {
@@ -80,7 +80,7 @@ describe("fetchReplay", () => {
mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { thread_id: "a/b", total_steps: 0, page: 1, per_page: 20, steps: [] }, error: null })); mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { thread_id: "a/b", total_steps: 0, page: 1, per_page: 20, steps: [] }, error: null }));
await fetchReplay("a/b"); await fetchReplay("a/b");
expect(mockFetch).toHaveBeenCalledWith("/api/replay/a%2Fb?page=1&per_page=20"); expect(mockFetch).toHaveBeenCalledWith("/api/v1/replay/a%2Fb?page=1&per_page=20");
}); });
it("throws on HTTP error", async () => { it("throws on HTTP error", async () => {
@@ -112,6 +112,6 @@ describe("fetchAnalytics", () => {
mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { range: "7d" }, error: null })); mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { range: "7d" }, error: null }));
await fetchAnalytics(); await fetchAnalytics();
expect(mockFetch).toHaveBeenCalledWith("/api/analytics?range=7d"); expect(mockFetch).toHaveBeenCalledWith("/api/v1/analytics?range=7d");
}); });
}); });

View File

@@ -84,7 +84,7 @@ export async function fetchConversations(
perPage = 20 perPage = 20
): Promise<ConversationsPage> { ): Promise<ConversationsPage> {
return apiFetch<ConversationsPage>( return apiFetch<ConversationsPage>(
`/api/conversations?page=${page}&per_page=${perPage}` `/api/v1/conversations?page=${page}&per_page=${perPage}`
); );
} }
@@ -94,12 +94,12 @@ export async function fetchReplay(
perPage = 20 perPage = 20
): Promise<ReplayPage> { ): Promise<ReplayPage> {
return apiFetch<ReplayPage>( return apiFetch<ReplayPage>(
`/api/replay/${encodeURIComponent(threadId)}?page=${page}&per_page=${perPage}` `/api/v1/replay/${encodeURIComponent(threadId)}?page=${page}&per_page=${perPage}`
); );
} }
export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> { export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> {
return apiFetch<AnalyticsData>(`/api/analytics?range=${range}`); return apiFetch<AnalyticsData>(`/api/v1/analytics?range=${range}`);
} }
// -- OpenAPI import -- // -- OpenAPI import --
@@ -143,11 +143,11 @@ async function apiPost<T>(path: string, body: unknown): Promise<T> {
} }
export async function startImport(url: string): Promise<ImportJobResponse> { export async function startImport(url: string): Promise<ImportJobResponse> {
return apiPost<ImportJobResponse>("/api/openapi/import", { url }); return apiPost<ImportJobResponse>("/api/v1/openapi/import", { url });
} }
export async function fetchImportJob(jobId: string): Promise<ImportJobResponse> { export async function fetchImportJob(jobId: string): Promise<ImportJobResponse> {
const res = await fetch(`${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}`); const res = await fetch(`${API_BASE}/api/v1/openapi/jobs/${encodeURIComponent(jobId)}`);
if (!res.ok) { if (!res.ok) {
throw new Error(`API error ${res.status}: ${res.statusText}`); throw new Error(`API error ${res.status}: ${res.statusText}`);
} }
@@ -158,7 +158,7 @@ export async function fetchClassifications(
jobId: string jobId: string
): Promise<EndpointClassification[]> { ): Promise<EndpointClassification[]> {
const res = await fetch( const res = await fetch(
`${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}/classifications` `${API_BASE}/api/v1/openapi/jobs/${encodeURIComponent(jobId)}/classifications`
); );
if (!res.ok) { if (!res.ok) {
throw new Error(`API error ${res.status}: ${res.statusText}`); throw new Error(`API error ${res.status}: ${res.statusText}`);
@@ -168,7 +168,7 @@ export async function fetchClassifications(
export async function approveJob(jobId: string): Promise<ImportJobResponse> { export async function approveJob(jobId: string): Promise<ImportJobResponse> {
return apiPost<ImportJobResponse>( return apiPost<ImportJobResponse>(
`/api/openapi/jobs/${encodeURIComponent(jobId)}/approve`, `/api/v1/openapi/jobs/${encodeURIComponent(jobId)}/approve`,
{} {}
); );
} }

View File

@@ -0,0 +1,47 @@
import { describe, it, expect } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { AgentAction } from "./AgentAction";
import type { ToolAction } from "../types";
function makeAction(overrides: Partial<ToolAction> = {}): ToolAction {
return {
id: "action-1",
agent: "OrderAgent",
tool: "get_order",
args: { order_id: "ORD-100" },
timestamp: Date.now(),
...overrides,
};
}
describe("AgentAction", () => {
it("renders agent name and tool name", () => {
render(<AgentAction action={makeAction()} />);
expect(screen.getByText("OrderAgent")).toBeInTheDocument();
expect(screen.getByText("get_order")).toBeInTheDocument();
});
it("shows args and result when expanded", () => {
const action = makeAction({ result: { status: "shipped" } });
render(<AgentAction action={action} />);
// Click header to expand
fireEvent.click(screen.getByText("OrderAgent"));
expect(screen.getByText("Args:")).toBeInTheDocument();
expect(screen.getByText("Result:")).toBeInTheDocument();
expect(screen.getByText(/"order_id": "ORD-100"/)).toBeInTheDocument();
expect(screen.getByText(/"status": "shipped"/)).toBeInTheDocument();
});
it("does not show result section when result is undefined", () => {
render(<AgentAction action={makeAction()} />);
// Expand
fireEvent.click(screen.getByText("OrderAgent"));
expect(screen.getByText("Args:")).toBeInTheDocument();
expect(screen.queryByText("Result:")).not.toBeInTheDocument();
});
});

View File

@@ -0,0 +1,53 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { ChatInput } from "./ChatInput";
describe("ChatInput", () => {
it("renders input field and send button", () => {
render(<ChatInput onSend={vi.fn()} disabled={false} />);
expect(screen.getByPlaceholderText("Message Smart Support...")).toBeInTheDocument();
expect(screen.getByRole("button", { name: "Send Message" })).toBeInTheDocument();
});
it("calls onSend with trimmed content when form is submitted via Enter", () => {
const onSend = vi.fn();
render(<ChatInput onSend={onSend} disabled={false} />);
const input = screen.getByPlaceholderText("Message Smart Support...");
fireEvent.change(input, { target: { value: " Hello world " } });
fireEvent.keyDown(input, { key: "Enter" });
expect(onSend).toHaveBeenCalledWith("Hello world");
});
it("clears input after successful send", () => {
const onSend = vi.fn();
render(<ChatInput onSend={onSend} disabled={false} />);
const input = screen.getByPlaceholderText("Message Smart Support...") as HTMLInputElement;
fireEvent.change(input, { target: { value: "Test message" } });
fireEvent.keyDown(input, { key: "Enter" });
expect(input.value).toBe("");
});
it("shows disabled placeholder and disables input when disabled", () => {
render(<ChatInput onSend={vi.fn()} disabled={true} />);
const input = screen.getByPlaceholderText("Agent is working...") as HTMLInputElement;
expect(input.disabled).toBe(true);
expect(screen.getByRole("button", { name: "Send Message" })).toBeDisabled();
});
it("does not call onSend when input is empty or whitespace", () => {
const onSend = vi.fn();
render(<ChatInput onSend={onSend} disabled={false} />);
const input = screen.getByPlaceholderText("Message Smart Support...");
fireEvent.change(input, { target: { value: " " } });
fireEvent.keyDown(input, { key: "Enter" });
expect(onSend).not.toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,59 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen } from "@testing-library/react";
import { ChatMessages } from "./ChatMessages";
import type { ChatMessage } from "../types";
// Mock react-markdown to avoid complex rendering
vi.mock("react-markdown", () => ({
default: ({ children }: { children: string }) => <span>{children}</span>,
}));
describe("ChatMessages", () => {
it("renders welcome message when messages array is empty", () => {
render(<ChatMessages messages={[]} />);
expect(screen.getByText("Hello! How can I help you today?")).toBeInTheDocument();
expect(screen.getByText("Smart Support")).toBeInTheDocument();
});
it("renders user messages with correct sender label", () => {
const messages: ChatMessage[] = [
{ id: "1", sender: "user", content: "I need help", timestamp: Date.now() },
];
render(<ChatMessages messages={messages} />);
expect(screen.getByText("You")).toBeInTheDocument();
expect(screen.getByText("I need help")).toBeInTheDocument();
expect(screen.getByText("Me")).toBeInTheDocument();
});
it("renders agent messages with agent name", () => {
const messages: ChatMessage[] = [
{ id: "2", sender: "agent", agent: "OrderBot", content: "Sure, let me check.", timestamp: Date.now() },
];
render(<ChatMessages messages={messages} />);
expect(screen.getByText("OrderBot")).toBeInTheDocument();
expect(screen.getByText("Sure, let me check.")).toBeInTheDocument();
expect(screen.getByText("AI")).toBeInTheDocument();
});
it("shows streaming cursor for messages being streamed", () => {
const messages: ChatMessage[] = [
{ id: "3", sender: "agent", agent: "Bot", content: "Processing", timestamp: Date.now(), isStreaming: true },
];
render(<ChatMessages messages={messages} />);
expect(screen.getByText("|")).toBeInTheDocument();
expect(document.querySelector(".cursor-blink")).toBeTruthy();
});
it("shows fallback agent label when agent field is missing", () => {
const messages: ChatMessage[] = [
{ id: "4", sender: "agent", content: "Generic response", timestamp: Date.now() },
];
render(<ChatMessages messages={messages} />);
expect(screen.getByText("Agent")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,33 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { ErrorBanner } from "./ErrorBanner";
describe("ErrorBanner", () => {
it("returns null when status is connected", () => {
const { container } = render(<ErrorBanner status="connected" />);
expect(container.innerHTML).toBe("");
});
it("shows disconnection message when status is disconnected", () => {
render(<ErrorBanner status="disconnected" onReconnect={vi.fn()} />);
expect(screen.getByText("Disconnected from server. Retrying...")).toBeInTheDocument();
expect(screen.getByRole("alert")).toBeInTheDocument();
});
it("shows connecting message when status is connecting", () => {
render(<ErrorBanner status="connecting" />);
expect(screen.getByText("Connecting to server...")).toBeInTheDocument();
// No reconnect button while connecting
expect(screen.queryByText("Reconnect")).not.toBeInTheDocument();
});
it("calls onReconnect when reconnect button is clicked", () => {
const onReconnect = vi.fn();
render(<ErrorBanner status="disconnected" onReconnect={onReconnect} />);
fireEvent.click(screen.getByText("Reconnect"));
expect(onReconnect).toHaveBeenCalledTimes(1);
});
});

View File

@@ -0,0 +1,58 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { InterruptPrompt } from "./InterruptPrompt";
import type { InterruptMessage } from "../types";
describe("InterruptPrompt", () => {
const baseInterrupt: InterruptMessage = {
type: "interrupt",
thread_id: "t1",
action: "cancel_order",
params: {},
};
it("renders action name and approval title", () => {
render(<InterruptPrompt interrupt={baseInterrupt} onRespond={vi.fn()} />);
expect(screen.getByText("Action Requires Approval")).toBeInTheDocument();
expect(screen.getByText("cancel_order")).toBeInTheDocument();
});
it("calls onRespond with true when Approve button is clicked", () => {
const onRespond = vi.fn();
render(<InterruptPrompt interrupt={baseInterrupt} onRespond={onRespond} />);
fireEvent.click(screen.getByText("Approve Action"));
expect(onRespond).toHaveBeenCalledWith(true);
});
it("calls onRespond with false when Reject button is clicked", () => {
const onRespond = vi.fn();
render(<InterruptPrompt interrupt={baseInterrupt} onRespond={onRespond} />);
fireEvent.click(screen.getByText("Reject & Escalate"));
expect(onRespond).toHaveBeenCalledWith(false);
});
it("displays order_id parameter when present", () => {
const interrupt: InterruptMessage = {
...baseInterrupt,
params: { order_id: "ORD-12345" },
};
render(<InterruptPrompt interrupt={interrupt} onRespond={vi.fn()} />);
expect(screen.getByText("Target Order ID")).toBeInTheDocument();
expect(screen.getByText("ORD-12345")).toBeInTheDocument();
});
it("displays message parameter when present", () => {
const interrupt: InterruptMessage = {
...baseInterrupt,
params: { message: "This will refund $50" },
};
render(<InterruptPrompt interrupt={interrupt} onRespond={vi.fn()} />);
expect(screen.getByText("Detail Message")).toBeInTheDocument();
expect(screen.getByText("This will refund $50")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,39 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen } from "@testing-library/react";
import { MemoryRouter, Routes, Route } from "react-router-dom";
import { Layout } from "./Layout";
// Mock NavBar to simplify layout tests
vi.mock("./NavBar", () => ({
NavBar: () => <nav data-testid="navbar">NavBar</nav>,
}));
function renderLayout(path = "/") {
return render(
<MemoryRouter initialEntries={[path]}>
<Routes>
<Route element={<Layout />}>
<Route path="/" element={<div>Home Content</div>} />
<Route path="/dashboard" element={<div>Dashboard Content</div>} />
</Route>
</Routes>
</MemoryRouter>
);
}
describe("Layout", () => {
it("renders NavBar component", () => {
renderLayout();
expect(screen.getByTestId("navbar")).toBeInTheDocument();
});
it("renders child route content via Outlet", () => {
renderLayout("/");
expect(screen.getByText("Home Content")).toBeInTheDocument();
});
it("renders correct content for different routes", () => {
renderLayout("/dashboard");
expect(screen.getByText("Dashboard Content")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,28 @@
import { describe, it, expect } from "vitest";
import { render, screen } from "@testing-library/react";
import { MetricCard } from "./MetricCard";
describe("MetricCard", () => {
it("renders label and value", () => {
render(<MetricCard label="Total Users" value={42} />);
expect(screen.getByText("Total Users")).toBeInTheDocument();
expect(screen.getByText("42")).toBeInTheDocument();
});
it("renders with unit prefix and suffix", () => {
render(<MetricCard label="Cost" value="3.50" unit="$" suffix="/mo" />);
expect(screen.getByText("Cost")).toBeInTheDocument();
expect(screen.getByText("$")).toBeInTheDocument();
expect(screen.getByText("3.50")).toBeInTheDocument();
expect(screen.getByText("/mo")).toBeInTheDocument();
});
it("handles zero value correctly", () => {
render(<MetricCard label="Errors" value={0} />);
expect(screen.getByText("Errors")).toBeInTheDocument();
expect(screen.getByText("0")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,54 @@
import { describe, it, expect } from "vitest";
import { render, screen } from "@testing-library/react";
import { MemoryRouter } from "react-router-dom";
import { NavBar } from "./NavBar";
function renderNavBar(initialPath = "/") {
return render(
<MemoryRouter initialEntries={[initialPath]}>
<NavBar />
</MemoryRouter>
);
}
describe("NavBar", () => {
it("renders all navigation links", () => {
renderNavBar();
expect(screen.getByText("Dashboard")).toBeInTheDocument();
expect(screen.getByText("Inbox")).toBeInTheDocument();
expect(screen.getByText("Conversation Replay")).toBeInTheDocument();
expect(screen.getByText("Agents & Tools")).toBeInTheDocument();
});
it("navigation links point to correct routes", () => {
renderNavBar();
const dashboardLink = screen.getByText("Dashboard").closest("a");
expect(dashboardLink).toHaveAttribute("href", "/dashboard");
const inboxLink = screen.getByText("Inbox").closest("a");
expect(inboxLink).toHaveAttribute("href", "/");
const replayLink = screen.getByText("Conversation Replay").closest("a");
expect(replayLink).toHaveAttribute("href", "/replay");
const reviewLink = screen.getByText("Agents & Tools").closest("a");
expect(reviewLink).toHaveAttribute("href", "/review");
});
it("active link has active class when on matching route", () => {
renderNavBar("/dashboard");
const dashboardLink = screen.getByText("Dashboard").closest("a");
expect(dashboardLink?.className).toContain("active");
const inboxLink = screen.getByText("Inbox").closest("a");
expect(inboxLink?.className).not.toContain("active");
});
it("renders brand name", () => {
renderNavBar();
expect(screen.getByText("Nexus AI")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,69 @@
import { describe, it, expect } from "vitest";
import { render, screen, fireEvent } from "@testing-library/react";
import { ReplayTimeline } from "./ReplayTimeline";
import type { ReplayStep } from "../api";
function makeStep(overrides: Partial<ReplayStep> = {}): ReplayStep {
return {
step: 1,
type: "message",
content: "Hello",
agent: null,
tool: null,
params: null,
result: null,
timestamp: "2026-04-01T12:00:00Z",
...overrides,
};
}
describe("ReplayTimeline", () => {
it("returns null when steps array is empty", () => {
const { container } = render(<ReplayTimeline steps={[]} />);
expect(container.innerHTML).toBe("");
});
it("renders a list of steps with type badges", () => {
const steps = [
makeStep({ step: 1, type: "message", content: "User said hi" }),
makeStep({ step: 2, type: "tool_call", content: "Calling API", agent: "OrderBot", tool: "get_order" }),
];
render(<ReplayTimeline steps={steps} />);
expect(screen.getByText("message")).toBeInTheDocument();
expect(screen.getByText("tool call")).toBeInTheDocument();
expect(screen.getByText("User said hi")).toBeInTheDocument();
expect(screen.getByText("OrderBot")).toBeInTheDocument();
expect(screen.getByText("get_order()")).toBeInTheDocument();
});
it("expands step details when View JSON Payload button is clicked", () => {
const steps = [
makeStep({
step: 1,
type: "tool_call",
params: { order_id: "123" },
result: { status: "ok" },
}),
];
render(<ReplayTimeline steps={steps} />);
const expandButton = screen.getByText("View JSON Payload", { exact: false });
expect(expandButton).toBeInTheDocument();
fireEvent.click(expandButton);
// After expanding, the JSON payload should be visible
expect(screen.getByText("Hide JSON Payload", { exact: false })).toBeInTheDocument();
expect(screen.getByText(/"order_id": "123"/)).toBeInTheDocument();
});
it("does not show expand button when step has no params or result", () => {
const steps = [
makeStep({ step: 1, type: "message", params: null, result: null }),
];
render(<ReplayTimeline steps={steps} />);
expect(screen.queryByText("View JSON Payload", { exact: false })).not.toBeInTheDocument();
});
});

View File

@@ -0,0 +1,221 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { renderHook, act } from "@testing-library/react";
import { useWebSocket } from "./useWebSocket";
// Mock sessionStorage
const mockSessionStorage: Record<string, string> = {};
vi.stubGlobal("sessionStorage", {
getItem: (key: string) => mockSessionStorage[key] ?? null,
setItem: (key: string, value: string) => {
mockSessionStorage[key] = value;
},
});
// Mock crypto.randomUUID
vi.stubGlobal("crypto", { randomUUID: () => "test-uuid-1234" });
// Mock WebSocket
class MockWebSocket {
static OPEN = 1;
static CLOSED = 3;
static instances: MockWebSocket[] = [];
url: string;
readyState = 0;
onopen: (() => void) | null = null;
onclose: (() => void) | null = null;
onmessage: ((event: { data: string }) => void) | null = null;
onerror: (() => void) | null = null;
send = vi.fn();
close = vi.fn().mockImplementation(() => {
this.readyState = MockWebSocket.CLOSED;
// Trigger onclose asynchronously like real WebSocket
setTimeout(() => this.onclose?.(), 0);
});
constructor(url: string) {
this.url = url;
MockWebSocket.instances.push(this);
}
simulateOpen() {
this.readyState = MockWebSocket.OPEN;
this.onopen?.();
}
simulateMessage(data: unknown) {
this.onmessage?.({ data: JSON.stringify(data) });
}
simulateClose() {
this.readyState = MockWebSocket.CLOSED;
this.onclose?.();
}
simulateError() {
this.onerror?.();
}
}
vi.stubGlobal("WebSocket", MockWebSocket);
beforeEach(() => {
MockWebSocket.instances = [];
delete mockSessionStorage["smart_support_thread_id"];
vi.useFakeTimers();
});
afterEach(() => {
vi.useRealTimers();
});
describe("useWebSocket", () => {
it("establishes connection with correct URL on mount", () => {
const onMessage = vi.fn();
renderHook(() => useWebSocket(onMessage));
expect(MockWebSocket.instances).toHaveLength(1);
expect(MockWebSocket.instances[0].url).toContain("/ws");
});
it("sets status to connected when WebSocket opens", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
expect(result.current.status).toBe("connecting");
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
expect(result.current.status).toBe("connected");
});
it("parses incoming JSON messages and dispatches to callback", () => {
const onMessage = vi.fn();
renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
const serverMsg = { type: "token", agent: "bot", content: "Hello" };
act(() => {
MockWebSocket.instances[0].simulateMessage(serverMsg);
});
expect(onMessage).toHaveBeenCalledWith(serverMsg);
});
it("sends JSON through WebSocket via sendMessage", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
act(() => {
result.current.sendMessage("Hi there");
});
expect(MockWebSocket.instances[0].send).toHaveBeenCalledTimes(1);
const sent = JSON.parse(MockWebSocket.instances[0].send.mock.calls[0][0]);
expect(sent.type).toBe("message");
expect(sent.content).toBe("Hi there");
expect(sent.thread_id).toBeDefined();
});
it("calls onDisconnect when WebSocket closes", () => {
const onMessage = vi.fn();
const onDisconnect = vi.fn();
renderHook(() => useWebSocket(onMessage, { onDisconnect }));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
act(() => {
MockWebSocket.instances[0].simulateClose();
});
expect(onDisconnect).toHaveBeenCalledTimes(1);
});
it("sets status to disconnected on close and attempts reconnect", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
act(() => {
MockWebSocket.instances[0].simulateClose();
});
expect(result.current.status).toBe("disconnected");
// After timeout, a new WebSocket should be created (reconnect attempt)
act(() => {
vi.advanceTimersByTime(1500);
});
expect(MockWebSocket.instances.length).toBeGreaterThanOrEqual(2);
});
it("closes WebSocket on error event", () => {
const onMessage = vi.fn();
renderHook(() => useWebSocket(onMessage));
const ws = MockWebSocket.instances[0];
act(() => {
ws.simulateError();
});
expect(ws.close).toHaveBeenCalledTimes(1);
});
it("reconnect resets retries and creates a new connection", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
const wsBeforeReconnect = MockWebSocket.instances[0];
act(() => {
result.current.reconnect();
});
// The old socket should have been closed
expect(wsBeforeReconnect.close).toHaveBeenCalled();
// Let the close callback fire and reconnect timer run
act(() => {
vi.advanceTimersByTime(100);
});
// A new WebSocket should have been created
expect(MockWebSocket.instances.length).toBeGreaterThan(1);
});
it("sends interrupt response with approved flag", () => {
const onMessage = vi.fn();
const { result } = renderHook(() => useWebSocket(onMessage));
act(() => {
MockWebSocket.instances[0].simulateOpen();
});
act(() => {
result.current.sendInterruptResponse(true);
});
const sent = JSON.parse(MockWebSocket.instances[0].send.mock.calls[0][0]);
expect(sent.type).toBe("interrupt_response");
expect(sent.approved).toBe(true);
});
});

View File

@@ -0,0 +1,106 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, fireEvent, waitFor, act } from "@testing-library/react";
import { ChatPage } from "./ChatPage";
// Mock react-markdown
vi.mock("react-markdown", () => ({
default: ({ children }: { children: string }) => <span>{children}</span>,
}));
// Mock crypto.randomUUID for stable IDs
vi.stubGlobal("crypto", { randomUUID: () => `uuid-${Date.now()}-${Math.random()}` });
// Capture the onMessage callback from the hook
let capturedOnMessage: ((msg: unknown) => void) | null = null;
const mockSendMessage = vi.fn();
const mockSendInterruptResponse = vi.fn();
const mockReconnect = vi.fn();
let mockStatus = "connected";
vi.mock("../hooks/useWebSocket", () => ({
useWebSocket: (onMessage: (msg: unknown) => void) => {
capturedOnMessage = onMessage;
return {
status: mockStatus,
threadId: "test-thread",
sendMessage: mockSendMessage,
sendInterruptResponse: mockSendInterruptResponse,
reconnect: mockReconnect,
};
},
}));
beforeEach(() => {
capturedOnMessage = null;
mockSendMessage.mockReset();
mockSendInterruptResponse.mockReset();
mockReconnect.mockReset();
mockStatus = "connected";
});
describe("ChatPage", () => {
it("renders chat interface with input field and header", () => {
render(<ChatPage />);
expect(screen.getByText("Inbox")).toBeInTheDocument();
expect(screen.getByPlaceholderText("Message Smart Support...")).toBeInTheDocument();
expect(screen.getByRole("button", { name: "Send Message" })).toBeInTheDocument();
});
it("user can type and submit a message", () => {
render(<ChatPage />);
const input = screen.getByPlaceholderText("Message Smart Support...");
fireEvent.change(input, { target: { value: "Hello bot" } });
fireEvent.keyDown(input, { key: "Enter" });
expect(mockSendMessage).toHaveBeenCalledWith("Hello bot");
expect(screen.getByText("Hello bot")).toBeInTheDocument();
expect(screen.getByText("You")).toBeInTheDocument();
});
it("displays streaming tokens as they arrive", () => {
render(<ChatPage />);
act(() => {
capturedOnMessage?.({ type: "token", agent: "Bot", content: "Hello " });
});
act(() => {
capturedOnMessage?.({ type: "token", agent: "Bot", content: "world" });
});
expect(screen.getByText("Hello world")).toBeInTheDocument();
});
it("shows interrupt prompt when interrupt message received", () => {
render(<ChatPage />);
act(() => {
capturedOnMessage?.({
type: "interrupt",
thread_id: "t1",
action: "cancel_order",
params: { order_id: "ORD-999" },
});
});
expect(screen.getByText("Action Requires Approval")).toBeInTheDocument();
expect(screen.getByText("cancel_order")).toBeInTheDocument();
});
it("shows error message when server sends error", () => {
render(<ChatPage />);
act(() => {
capturedOnMessage?.({ type: "error", message: "Something went wrong" });
});
expect(screen.getByText("Error: Something went wrong")).toBeInTheDocument();
});
it("renders welcome message in empty state", () => {
render(<ChatPage />);
expect(screen.getByText("Hello! How can I help you today?")).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,200 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
import { ReviewPage } from "./ReviewPage";
vi.mock("../api", () => ({
startImport: vi.fn(),
fetchImportJob: vi.fn(),
fetchClassifications: vi.fn(),
approveJob: vi.fn(),
}));
import { startImport, fetchImportJob, fetchClassifications, approveJob } from "../api";
const mockStartImport = vi.mocked(startImport);
const mockFetchImportJob = vi.mocked(fetchImportJob);
const mockFetchClassifications = vi.mocked(fetchClassifications);
const mockApproveJob = vi.mocked(approveJob);
beforeEach(() => {
mockStartImport.mockReset();
mockFetchImportJob.mockReset();
mockFetchClassifications.mockReset();
mockApproveJob.mockReset();
});
describe("ReviewPage", () => {
it("renders the OpenAPI URL input form", () => {
render(<ReviewPage />);
expect(screen.getByText("Agents & Tools Registry")).toBeInTheDocument();
expect(screen.getByPlaceholderText("https://example.com/openapi.yaml")).toBeInTheDocument();
expect(screen.getByText("Scan Tools")).toBeInTheDocument();
});
it("submit form triggers API call with entered URL", async () => {
mockStartImport.mockResolvedValue({
job_id: "job-1",
status: "processing",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 5,
classified_count: 0,
error_message: null,
});
mockFetchImportJob.mockResolvedValue({
job_id: "job-1",
status: "processing",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 5,
classified_count: 0,
error_message: null,
});
render(<ReviewPage />);
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(mockStartImport).toHaveBeenCalledWith("https://example.com/openapi.yaml");
});
});
it("shows loading state during import", async () => {
mockStartImport.mockReturnValue(new Promise(() => {})); // never resolves
render(<ReviewPage />);
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://api.test.com/spec.json" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(screen.getByText("Importing...")).toBeInTheDocument();
});
});
it("displays classification results after job completes", async () => {
mockStartImport.mockResolvedValue({
job_id: "job-1",
status: "done",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 2,
classified_count: 2,
error_message: null,
});
mockFetchImportJob.mockResolvedValue({
job_id: "job-1",
status: "done",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 2,
classified_count: 2,
error_message: null,
});
mockFetchClassifications.mockResolvedValue([
{
index: 0,
access_type: "read",
needs_interrupt: false,
agent_group: "OrderAgent",
confidence: 0.95,
customer_params: [],
endpoint: { path: "/orders", method: "get", operation_id: "getOrders", summary: "List orders", description: "" },
},
{
index: 1,
access_type: "write",
needs_interrupt: true,
agent_group: "OrderAgent",
confidence: 0.9,
customer_params: ["order_id"],
endpoint: { path: "/orders/{id}/cancel", method: "post", operation_id: "cancelOrder", summary: "Cancel an order", description: "" },
},
]);
render(<ReviewPage />);
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(screen.getByText("Assigned Capabilities (2)")).toBeInTheDocument();
});
expect(screen.getByText("OrderAgent")).toBeInTheDocument();
expect(screen.getByText("/orders")).toBeInTheDocument();
expect(screen.getByText("List orders")).toBeInTheDocument();
});
it("shows error on API failure", async () => {
mockStartImport.mockRejectedValue(new Error("Network timeout"));
render(<ReviewPage />);
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(screen.getByText("Error: Network timeout")).toBeInTheDocument();
});
});
it("shows success message after approval", async () => {
// Set up initial state with classifications
mockStartImport.mockResolvedValue({
job_id: "job-1",
status: "done",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 1,
classified_count: 1,
error_message: null,
});
mockFetchImportJob.mockResolvedValue({
job_id: "job-1",
status: "done",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 1,
classified_count: 1,
error_message: null,
});
mockFetchClassifications.mockResolvedValue([
{
index: 0,
access_type: "read",
needs_interrupt: false,
agent_group: "TestAgent",
confidence: 0.9,
customer_params: [],
endpoint: { path: "/test", method: "get", operation_id: "test", summary: "Test endpoint", description: "" },
},
]);
mockApproveJob.mockResolvedValue({
job_id: "job-1",
status: "approved",
spec_url: "https://example.com/openapi.yaml",
total_endpoints: 1,
classified_count: 1,
error_message: null,
generated_tools_count: 3,
});
render(<ReviewPage />);
// Import first
const input = screen.getByPlaceholderText("https://example.com/openapi.yaml");
fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } });
fireEvent.click(screen.getByText("Scan Tools"));
await waitFor(() => {
expect(screen.getByText("Save Configuration")).toBeInTheDocument();
});
// Approve
fireEvent.click(screen.getByText("Save Configuration"));
await waitFor(() => {
expect(screen.getByText("Configuration saved. 3 tools generated.")).toBeInTheDocument();
});
});
});