From f0699436c5b106283f68964b24747b0548f8deea Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Mon, 6 Apr 2026 23:19:29 +0200 Subject: [PATCH] refactor: engineering improvements -- API versioning, structured logging, Alembic, error standardization, test coverage - API versioning: all REST endpoints prefixed with /api/v1/ - Structured logging: replaced stdlib logging with structlog (console/JSON modes) - Alembic migrations: versioned DB schema with initial migration - Error standardization: global exception handlers for consistent envelope format - Interrupt cleanup: asyncio background task for expired interrupt removal - Integration tests: +30 tests (analytics, replay, openapi, error, session APIs) - Frontend tests: +57 tests (all components, pages, useWebSocket hook) - Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files --- .env.example | 4 + CLAUDE.md | 57 +++-- README.md | 33 +-- backend/alembic.ini | 149 ++++++++++++ backend/alembic/README | 1 + backend/alembic/env.py | 67 ++++++ backend/alembic/script.py.mako | 28 +++ .../alembic/versions/001_initial_schema.py | 92 ++++++++ backend/app/analytics/api.py | 2 +- backend/app/auth.py | 4 +- backend/app/config.py | 2 + backend/app/db.py | 12 + backend/app/escalation.py | 4 +- backend/app/graph.py | 5 +- backend/app/intent.py | 5 +- backend/app/logging_config.py | 57 +++++ backend/app/main.py | 78 ++++++- backend/app/openapi/classifier.py | 5 +- backend/app/openapi/importer.py | 5 +- backend/app/openapi/review_api.py | 6 +- backend/app/replay/api.py | 2 +- backend/app/replay/transformer.py | 4 +- backend/app/ws_handler.py | 5 +- backend/pyproject.toml | 2 + backend/tests/e2e/conftest.py | 2 +- backend/tests/e2e/test_chat_flows.py | 2 +- backend/tests/e2e/test_openapi_import.py | 18 +- backend/tests/e2e/test_replay_analytics.py | 18 +- .../tests/integration/test_analytics_api.py | 183 +++++++++++++++ .../tests/integration/test_error_responses.py | 128 ++++++++++ backend/tests/integration/test_openapi_api.py | 164 +++++++++++++ backend/tests/integration/test_replay_api.py | 213 +++++++++++++++++ .../test_session_interrupt_lifecycle.py | 159 +++++++++++++ backend/tests/unit/analytics/test_api.py | 10 +- backend/tests/unit/openapi/test_review_api.py | 52 ++--- backend/tests/unit/replay/test_api.py | 40 ++-- backend/tests/unit/test_error_responses.py | 142 +++++++++++ backend/tests/unit/test_interrupt_cleanup.py | 86 +++++++ backend/tests/unit/test_logging_config.py | 20 ++ backend/tests/unit/test_main.py | 2 +- docker-compose.yml | 2 +- docs/ARCHITECTURE.md | 26 ++- docs/deployment.md | 18 +- docs/openapi-import-guide.md | 16 +- docs/phases/eng-improvements-dev-log.md | 76 ++++++ frontend/src/api.test.ts | 8 +- frontend/src/api.ts | 14 +- frontend/src/components/AgentAction.test.tsx | 47 ++++ frontend/src/components/ChatInput.test.tsx | 53 +++++ frontend/src/components/ChatMessages.test.tsx | 59 +++++ frontend/src/components/ErrorBanner.test.tsx | 33 +++ .../src/components/InterruptPrompt.test.tsx | 58 +++++ frontend/src/components/Layout.test.tsx | 39 ++++ frontend/src/components/MetricCard.test.tsx | 28 +++ frontend/src/components/NavBar.test.tsx | 54 +++++ .../src/components/ReplayTimeline.test.tsx | 69 ++++++ frontend/src/hooks/useWebSocket.test.ts | 221 ++++++++++++++++++ frontend/src/pages/ChatPage.test.tsx | 106 +++++++++ frontend/src/pages/ReviewPage.test.tsx | 200 ++++++++++++++++ 59 files changed, 2846 insertions(+), 149 deletions(-) create mode 100644 backend/alembic.ini create mode 100644 backend/alembic/README create mode 100644 backend/alembic/env.py create mode 100644 backend/alembic/script.py.mako create mode 100644 backend/alembic/versions/001_initial_schema.py create mode 100644 backend/app/logging_config.py create mode 100644 backend/tests/integration/test_analytics_api.py create mode 100644 backend/tests/integration/test_error_responses.py create mode 100644 backend/tests/integration/test_openapi_api.py create mode 100644 backend/tests/integration/test_replay_api.py create mode 100644 backend/tests/integration/test_session_interrupt_lifecycle.py create mode 100644 backend/tests/unit/test_error_responses.py create mode 100644 backend/tests/unit/test_interrupt_cleanup.py create mode 100644 backend/tests/unit/test_logging_config.py create mode 100644 docs/phases/eng-improvements-dev-log.md create mode 100644 frontend/src/components/AgentAction.test.tsx create mode 100644 frontend/src/components/ChatInput.test.tsx create mode 100644 frontend/src/components/ChatMessages.test.tsx create mode 100644 frontend/src/components/ErrorBanner.test.tsx create mode 100644 frontend/src/components/InterruptPrompt.test.tsx create mode 100644 frontend/src/components/Layout.test.tsx create mode 100644 frontend/src/components/MetricCard.test.tsx create mode 100644 frontend/src/components/NavBar.test.tsx create mode 100644 frontend/src/components/ReplayTimeline.test.tsx create mode 100644 frontend/src/hooks/useWebSocket.test.ts create mode 100644 frontend/src/pages/ChatPage.test.tsx create mode 100644 frontend/src/pages/ReviewPage.test.tsx diff --git a/.env.example b/.env.example index 9acf1e9..6dabb1d 100644 --- a/.env.example +++ b/.env.example @@ -26,6 +26,10 @@ WEBHOOK_URL= SESSION_TTL_MINUTES=30 INTERRUPT_TTL_MINUTES=30 +# Optional: API key for admin endpoints (analytics, replay, openapi, websocket) +# Leave empty to disable authentication (dev mode) +ADMIN_API_KEY= + # Optional: load a named agent template instead of agents.yaml # Available templates: ecommerce, saas, generic TEMPLATE_NAME= diff --git a/CLAUDE.md b/CLAUDE.md index f4fd077..4beacd7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -30,7 +30,7 @@ pytest --cov=app --cov-report=term-missing # - If any test fails, fix it before starting the new phase # 3. Create checkpoint to snapshot the starting state -/everything-claude-code:checkpoint create "phase-name" +/ecc:checkpoint create "phase-name" # 4. Create the phase branch git checkout main @@ -50,25 +50,32 @@ git checkout -b phase-{N}/{short-description} 3. Identify all tasks, acceptance criteria, and dependencies for this phase 4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5) -### Step 2: Develop Using Orchestrate Skill +### Step 2: Develop Using ECC Skills -Route to the correct orchestration mode based on work type: +Route to the correct skill based on work type: -| Work Type | Skill Command | -|-----------|---------------| -| New feature | `/everything-claude-code:orchestrate feature` | -| Bug fix | `/everything-claude-code:orchestrate bugfix` | -| Refactor | `/everything-claude-code:orchestrate refactor` | +| Work Type | Skill Command | What It Does | +|-----------|---------------|--------------| +| New feature | `/ecc:feature-dev ` | Discovery -> Exploration -> Architecture -> TDD -> Review -> Summary | +| Bug fix | `/ecc:tdd` then `/ecc:code-review` | RED -> GREEN -> REFACTOR cycle, then review | +| Refactor | `/ecc:plan` then `/ecc:tdd` then `/ecc:code-review` | Plan refactor scope, TDD, review | +| Security-sensitive | Add `/ecc:security-review` after code-review | Auth, payments, user input, external APIs | +| Final verification | `/ecc:verify` | Build + tests + lint + coverage + security scan | -ALWAYS use the appropriate orchestrate skill. Never develop without it. - -A single phase may contain mixed work types (e.g., Phase 5 has feature + bugfix + refactor). Call the orchestrate skill **per sub-task** with the matching mode. Example: +A single phase may contain mixed work types. Call the appropriate skill **per sub-task**: ``` -# Within Phase 5: -/everything-claude-code:orchestrate feature # for demo script -/everything-claude-code:orchestrate bugfix # for error handling fixes -/everything-claude-code:orchestrate refactor # for code cleanup +# Within a phase: +/ecc:feature-dev "demo script" # for new features +/ecc:tdd # for bug fixes (write failing test, then fix) +/ecc:plan "consolidate error handling" # for refactors (plan first, then TDD) +``` + +For full multi-phase autonomous execution, use GSD: + +``` +/gsd:autonomous # execute all remaining phases +/gsd:execute-phase 6 # execute a specific phase ``` ### Step 3: Module Independence (CRITICAL) @@ -171,10 +178,10 @@ After all development and testing, run verification in this exact order: ``` # 1. Run the verification skill -- must pass -/everything-claude-code:verify +/ecc:verify # 2. Verify the checkpoint -- validates all phase deliverables -/everything-claude-code:checkpoint verify "phase-name" +/ecc:checkpoint verify "phase-name" ``` The checkpoint verify validates: @@ -222,11 +229,11 @@ git push origin main --tags All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy. A checkpoint includes: -- `/everything-claude-code:checkpoint create` at phase start -- `/everything-claude-code:checkpoint verify` at phase end +- `/ecc:checkpoint create` at phase start +- `/ecc:checkpoint verify` at phase end - All tests passing (80%+ coverage) - Phase dev log written and linked -- `/everything-claude-code:verify` passed +- `/ecc:verify` passed - Git tag `checkpoint/phase-{N}` created - Phase marked COMPLETED in four locations - Branch merged to main @@ -264,7 +271,7 @@ This project inherits from `~/.claude/rules/`. CLAUDE.md only contains project-s ### Hooks (ECC Plugin -- No Custom Hooks) -All hooks come from the ECC plugin (`everything-claude-code`). No project-level hooks in `.claude/settings.local.json`. +All hooks come from the ECC plugin (`ecc`). No project-level hooks in `.claude/settings.local.json`. | ECC Hook | Type | What It Does | |----------|------|-------------| @@ -290,7 +297,7 @@ Controlled by `ECC_HOOK_PROFILE` env var in `~/.claude/settings.json` (currently - Architecture doc: `docs/ARCHITECTURE.md` - Phase dev logs: `docs/phases/phase-{N}-dev-log.md` - Test command: `pytest --cov=app --cov-report=term-missing` -- **Phase start:** `/everything-claude-code:checkpoint create "phase-name"` -- **Phase end:** `/everything-claude-code:checkpoint verify "phase-name"` -- Verify command: `/everything-claude-code:verify` -- Orchestrate: `/everything-claude-code:orchestrate {feature|bugfix|refactor}` +- **Phase start:** `/ecc:checkpoint create "phase-name"` +- **Phase end:** `/ecc:checkpoint verify "phase-name"` +- Verify command: `/ecc:verify` +- Orchestrate: `/ecc:orchestrate {feature|bugfix|refactor}` diff --git a/README.md b/README.md index c2123eb..eb5182e 100644 --- a/README.md +++ b/README.md @@ -99,8 +99,12 @@ smart-support/ ├── backend/ │ ├── app/ │ │ ├── main.py # FastAPI + WebSocket entry point -│ │ ├── graph.py # LangGraph Supervisor +│ │ ├── graph.py # LangGraph Supervisor construction +│ │ ├── graph_context.py # Typed wrapper for graph + classifier + registry │ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting +│ │ ├── ws_context.py # WebSocket dependency bundle +│ │ ├── auth.py # API key authentication middleware +│ │ ├── api_utils.py # Shared API response helpers │ │ ├── safety.py # Confirmation rules + MCP error taxonomy │ │ ├── agents/ # Agent definitions and tools │ │ ├── registry.py # YAML agent registry loader @@ -124,18 +128,21 @@ smart-support/ ## API Endpoints -| Method | Path | Description | -|--------|------|-------------| -| WS | `/ws` | Main WebSocket chat endpoint | -| GET | `/api/health` | Health check | -| GET | `/api/conversations` | List conversations (paginated) | -| GET | `/api/replay/{thread_id}` | Replay conversation steps (paginated) | -| GET | `/api/analytics` | Analytics summary (`?range=7d`) | -| POST | `/api/openapi/import` | Start OpenAPI import job | -| GET | `/api/openapi/jobs/{id}` | Check import job status | -| GET | `/api/openapi/jobs/{id}/classifications` | Get endpoint classifications | -| PUT | `/api/openapi/jobs/{id}/classifications/{idx}` | Update a classification | -| POST | `/api/openapi/jobs/{id}/approve` | Approve and generate tools | +| Method | Path | Auth | Description | +|--------|------|------|-------------| +| WS | `/ws` | Token | Main WebSocket chat endpoint (`?token=`) | +| GET | `/api/health` | No | Health check | +| GET | `/api/conversations` | API Key | List conversations (paginated) | +| GET | `/api/replay/{thread_id}` | API Key | Replay conversation steps (paginated) | +| GET | `/api/analytics` | API Key | Analytics summary (`?range=7d`) | +| POST | `/api/openapi/import` | API Key | Start OpenAPI import job | +| GET | `/api/openapi/jobs/{id}` | API Key | Check import job status | +| GET | `/api/openapi/jobs/{id}/classifications` | API Key | Get endpoint classifications | +| PUT | `/api/openapi/jobs/{id}/classifications/{idx}` | API Key | Update a classification | +| POST | `/api/openapi/jobs/{id}/approve` | API Key | Approve and generate tools | + +Authentication is controlled by the `ADMIN_API_KEY` environment variable. +API Key endpoints require the `X-API-Key` header. When `ADMIN_API_KEY` is unset, auth is disabled. ## Running Tests diff --git a/backend/alembic.ini b/backend/alembic.ini new file mode 100644 index 0000000..f0da7e5 --- /dev/null +++ b/backend/alembic.ini @@ -0,0 +1,149 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s +# Or organize into date-based subdirectories (requires recursive_version_locations = true) +# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +sqlalchemy.url = + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/backend/alembic/README b/backend/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/backend/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/backend/alembic/env.py b/backend/alembic/env.py new file mode 100644 index 0000000..5fa1c72 --- /dev/null +++ b/backend/alembic/env.py @@ -0,0 +1,67 @@ +"""Alembic environment configuration for smart-support.""" + +from __future__ import annotations + +import os +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool + +from alembic import context + +config = context.config + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# No SQLAlchemy ORM models -- we use raw DDL migrations +target_metadata = None + + +def _get_url() -> str: + """Read DATABASE_URL from environment, falling back to alembic.ini.""" + return os.environ.get("DATABASE_URL", "") or config.get_main_option( + "sqlalchemy.url", "" + ) + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + Configures the context with just a URL so that an Engine + is not required. + """ + url = _get_url() + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode with a live database connection.""" + configuration = config.get_section(config.config_ini_section, {}) + configuration["sqlalchemy.url"] = _get_url() + + connectable = engine_from_config( + configuration, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/backend/alembic/script.py.mako b/backend/alembic/script.py.mako new file mode 100644 index 0000000..1101630 --- /dev/null +++ b/backend/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/backend/alembic/versions/001_initial_schema.py b/backend/alembic/versions/001_initial_schema.py new file mode 100644 index 0000000..ec22402 --- /dev/null +++ b/backend/alembic/versions/001_initial_schema.py @@ -0,0 +1,92 @@ +"""Initial schema -- all application tables. + +Revision ID: a1b2c3d4e5f6 +Revises: +Create Date: 2026-04-06 +""" + +from __future__ import annotations + +from alembic import op + +revision: str = "a1b2c3d4e5f6" +down_revision: str | None = None +branch_labels: tuple[str, ...] | None = None +depends_on: tuple[str, ...] | None = None + + +def upgrade() -> None: + op.execute( + """ + CREATE TABLE IF NOT EXISTS conversations ( + thread_id TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(), + total_tokens INTEGER NOT NULL DEFAULT 0, + total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0, + status TEXT NOT NULL DEFAULT 'active' + ) + """ + ) + + op.execute( + """ + CREATE TABLE IF NOT EXISTS active_interrupts ( + interrupt_id TEXT PRIMARY KEY, + thread_id TEXT NOT NULL REFERENCES conversations(thread_id), + action TEXT NOT NULL, + params JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + resolved_at TIMESTAMPTZ, + resolution TEXT + ) + """ + ) + + op.execute( + """ + CREATE TABLE IF NOT EXISTS sessions ( + thread_id TEXT PRIMARY KEY, + last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(), + has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """ + ) + + op.execute( + """ + CREATE TABLE IF NOT EXISTS analytics_events ( + id BIGSERIAL PRIMARY KEY, + thread_id TEXT NOT NULL, + event_type TEXT NOT NULL, + agent_name TEXT, + tool_name TEXT, + tokens_used INTEGER NOT NULL DEFAULT 0, + cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0, + duration_ms INTEGER, + success BOOLEAN, + error_message TEXT, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """ + ) + + # Migration columns added in Phase 4 + op.execute( + """ + ALTER TABLE conversations + ADD COLUMN IF NOT EXISTS resolution_type TEXT, + ADD COLUMN IF NOT EXISTS agents_used TEXT[], + ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ + """ + ) + + +def downgrade() -> None: + op.execute("DROP TABLE IF EXISTS analytics_events") + op.execute("DROP TABLE IF EXISTS sessions") + op.execute("DROP TABLE IF EXISTS active_interrupts") + op.execute("DROP TABLE IF EXISTS conversations") diff --git a/backend/app/analytics/api.py b/backend/app/analytics/api.py index 7f7d890..50037de 100644 --- a/backend/app/analytics/api.py +++ b/backend/app/analytics/api.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool router = APIRouter( - prefix="/api/analytics", + prefix="/api/v1/analytics", tags=["analytics"], dependencies=[Depends(require_admin_api_key)], ) diff --git a/backend/app/auth.py b/backend/app/auth.py index 3a186d4..d2f55be 100644 --- a/backend/app/auth.py +++ b/backend/app/auth.py @@ -2,14 +2,14 @@ from __future__ import annotations -import logging import secrets from typing import Annotated +import structlog from fastapi import Depends, HTTPException, Query, Request, WebSocket, status from fastapi.security import APIKeyHeader -logger = logging.getLogger(__name__) +logger = structlog.get_logger() _API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False) diff --git a/backend/app/config.py b/backend/app/config.py index 6857506..0271a09 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -32,6 +32,8 @@ class Settings(BaseSettings): template_name: str = "" + log_format: str = "console" # "console" for dev, "json" for production + admin_api_key: str = "" anthropic_api_key: str = "" diff --git a/backend/app/db.py b/backend/app/db.py index 01aa95e..244ada0 100644 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -2,6 +2,7 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver @@ -88,6 +89,17 @@ async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver: return checkpointer +def run_alembic_migrations(database_url: str) -> None: + """Run Alembic migrations to head.""" + from alembic.config import Config + + from alembic import command + + alembic_cfg = Config(str(Path(__file__).parent.parent / "alembic.ini")) + alembic_cfg.set_main_option("sqlalchemy.url", database_url) + command.upgrade(alembic_cfg, "head") + + async def setup_app_tables(pool: AsyncConnectionPool) -> None: """Create application-specific tables and apply migrations.""" async with pool.connection() as conn: diff --git a/backend/app/escalation.py b/backend/app/escalation.py index 0fdb858..07868d2 100644 --- a/backend/app/escalation.py +++ b/backend/app/escalation.py @@ -3,14 +3,14 @@ from __future__ import annotations import asyncio -import logging from dataclasses import dataclass from typing import Protocol import httpx +import structlog from pydantic import BaseModel -logger = logging.getLogger(__name__) +logger = structlog.get_logger() class EscalationPayload(BaseModel, frozen=True): diff --git a/backend/app/graph.py b/backend/app/graph.py index ee11b43..704b399 100644 --- a/backend/app/graph.py +++ b/backend/app/graph.py @@ -2,7 +2,6 @@ from __future__ import annotations -import logging from typing import TYPE_CHECKING from langchain.agents import create_agent @@ -18,7 +17,9 @@ if TYPE_CHECKING: from app.intent import IntentClassifier from app.registry import AgentRegistry -logger = logging.getLogger(__name__) +import structlog + +logger = structlog.get_logger() SUPERVISOR_PROMPT = ( "You are a customer support supervisor. " diff --git a/backend/app/intent.py b/backend/app/intent.py index f39bc97..2dce799 100644 --- a/backend/app/intent.py +++ b/backend/app/intent.py @@ -2,7 +2,6 @@ from __future__ import annotations -import logging from typing import TYPE_CHECKING, Protocol from pydantic import BaseModel @@ -12,7 +11,9 @@ if TYPE_CHECKING: from app.registry import AgentConfig -logger = logging.getLogger(__name__) +import structlog + +logger = structlog.get_logger() CLASSIFICATION_PROMPT = ( "You are an intent classifier for a customer support system.\n" diff --git a/backend/app/logging_config.py b/backend/app/logging_config.py new file mode 100644 index 0000000..ea08949 --- /dev/null +++ b/backend/app/logging_config.py @@ -0,0 +1,57 @@ +"""Structured logging configuration using structlog.""" + +from __future__ import annotations + +import logging +import sys + +import structlog + + +def configure_logging(log_format: str = "console") -> None: + """Configure structlog with stdlib integration. + + Args: + log_format: "console" for human-readable dev output, + "json" for machine-parseable production output. + """ + shared_processors: list[structlog.types.Processor] = [ + structlog.contextvars.merge_contextvars, + structlog.stdlib.filter_by_level, + structlog.stdlib.add_logger_name, + structlog.stdlib.add_log_level, + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + structlog.processors.UnicodeDecoder(), + ] + + if log_format == "json": + renderer: structlog.types.Processor = structlog.processors.JSONRenderer() + else: + renderer = structlog.dev.ConsoleRenderer() + + structlog.configure( + processors=[ + *shared_processors, + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, + ], + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + cache_logger_on_first_use=True, + ) + + formatter = structlog.stdlib.ProcessorFormatter( + processors=[ + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + renderer, + ], + ) + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + + root_logger = logging.getLogger() + root_logger.handlers.clear() + root_logger.addHandler(handler) + root_logger.setLevel(logging.INFO) diff --git a/backend/app/main.py b/backend/app/main.py index a69f92e..469a96b 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -2,25 +2,30 @@ from __future__ import annotations -import logging +import asyncio +import contextlib from contextlib import asynccontextmanager from pathlib import Path from typing import TYPE_CHECKING -from fastapi import Depends, FastAPI, Query, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from app.analytics.api import router as analytics_router from app.analytics.event_recorder import PostgresAnalyticsRecorder +from app.api_utils import envelope from app.callbacks import TokenUsageCallbackHandler from app.config import Settings from app.conversation_tracker import PostgresConversationTracker -from app.db import create_checkpointer, create_pool, setup_app_tables +from app.db import create_checkpointer, create_pool, run_alembic_migrations from app.escalation import NoOpEscalator, WebhookEscalator from app.graph import build_graph from app.intent import LLMIntentClassifier from app.interrupt_manager import InterruptManager from app.llm import create_llm +from app.logging_config import configure_logging from app.openapi.review_api import router as openapi_router from app.registry import AgentRegistry from app.replay.api import router as replay_router @@ -31,19 +36,44 @@ from app.ws_handler import dispatch_message if TYPE_CHECKING: from collections.abc import AsyncGenerator -logger = logging.getLogger(__name__) +import structlog + +logger = structlog.get_logger() AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml" FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist" +async def _interrupt_cleanup_loop( + interrupt_manager: InterruptManager, + interval: int = 60, +) -> None: + """Periodically remove expired interrupts in the background. + + Runs until cancelled. Catches all exceptions to prevent the task + from dying unexpectedly. + """ + while True: + await asyncio.sleep(interval) + try: + expired = interrupt_manager.cleanup_expired() + if expired: + logger.info( + "Cleaned up %d expired interrupt(s)", + len(expired), + ) + except Exception: + logger.exception("Error during interrupt cleanup") + + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: settings = Settings() + configure_logging(settings.log_format) pool = await create_pool(settings) checkpointer = await create_checkpointer(pool) - await setup_app_tables(pool) + run_alembic_migrations(settings.database_url) # Load agents from template or default YAML if settings.template_name: @@ -89,8 +119,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: settings.template_name or "(default)", ) + cleanup_task = asyncio.create_task( + _interrupt_cleanup_loop(interrupt_manager), + ) + yield + cleanup_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await cleanup_task + await pool.close() @@ -103,7 +141,35 @@ app.include_router(replay_router) app.include_router(analytics_router) -@app.get("/api/health") +@app.exception_handler(HTTPException) +async def http_exception_handler(request, exc): # type: ignore[no-untyped-def] + """Wrap HTTPException in standard envelope format.""" + return JSONResponse( + status_code=exc.status_code, + content=envelope(None, success=False, error=exc.detail), + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def] + """Wrap validation errors in standard envelope format.""" + return JSONResponse( + status_code=422, + content=envelope(None, success=False, error=str(exc)), + ) + + +@app.exception_handler(Exception) +async def general_exception_handler(request, exc): # type: ignore[no-untyped-def] + """Catch-all handler -- never leak stack traces.""" + logger.exception("Unhandled exception: %s", exc) + return JSONResponse( + status_code=500, + content=envelope(None, success=False, error="Internal server error"), + ) + + +@app.get("/api/v1/health") def health_check() -> dict: """Health check endpoint for load balancers and monitoring.""" return {"status": "ok", "version": _VERSION} diff --git a/backend/app/openapi/classifier.py b/backend/app/openapi/classifier.py index 6ad103a..369e04d 100644 --- a/backend/app/openapi/classifier.py +++ b/backend/app/openapi/classifier.py @@ -8,13 +8,14 @@ classifier and an LLM-backed classifier with heuristic fallback. from __future__ import annotations import json -import logging import re from typing import Protocol +import structlog + from app.openapi.models import ClassificationResult, EndpointInfo -logger = logging.getLogger(__name__) +logger = structlog.get_logger() _WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) _INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) diff --git a/backend/app/openapi/importer.py b/backend/app/openapi/importer.py index 521ca9d..61beb37 100644 --- a/backend/app/openapi/importer.py +++ b/backend/app/openapi/importer.py @@ -6,10 +6,11 @@ Each stage updates the job status and calls the on_progress callback. from __future__ import annotations -import logging from collections.abc import Callable from dataclasses import replace +import structlog + from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier from app.openapi.fetcher import fetch_spec from app.openapi.models import ImportJob @@ -17,7 +18,7 @@ from app.openapi.parser import parse_endpoints from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy from app.openapi.validator import validate_spec -logger = logging.getLogger(__name__) +logger = structlog.get_logger() ProgressCallback = Callable[[str, ImportJob], None] | None diff --git a/backend/app/openapi/review_api.py b/backend/app/openapi/review_api.py index 6a8a452..b6871cf 100644 --- a/backend/app/openapi/review_api.py +++ b/backend/app/openapi/review_api.py @@ -10,11 +10,11 @@ Exposes endpoints for: from __future__ import annotations import asyncio -import logging import re import uuid from typing import Literal +import structlog from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from pydantic import BaseModel, field_validator @@ -23,10 +23,10 @@ from app.openapi.generator import generate_agent_yaml, generate_tool_code from app.openapi.importer import ImportOrchestrator from app.openapi.models import ClassificationResult, ImportJob -logger = logging.getLogger(__name__) +logger = structlog.get_logger() router = APIRouter( - prefix="/api/openapi", + prefix="/api/v1/openapi", tags=["openapi"], dependencies=[Depends(require_admin_api_key)], ) diff --git a/backend/app/replay/api.py b/backend/app/replay/api.py index 7f1d94e..d071a75 100644 --- a/backend/app/replay/api.py +++ b/backend/app/replay/api.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from psycopg_pool import AsyncConnectionPool router = APIRouter( - prefix="/api", + prefix="/api/v1", tags=["replay"], dependencies=[Depends(require_admin_api_key)], ) diff --git a/backend/app/replay/transformer.py b/backend/app/replay/transformer.py index 1fafa4e..f388ef3 100644 --- a/backend/app/replay/transformer.py +++ b/backend/app/replay/transformer.py @@ -2,11 +2,11 @@ from __future__ import annotations -import logging +import structlog from app.replay.models import ReplayStep, StepType -logger = logging.getLogger(__name__) +logger = structlog.get_logger() _EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z" diff --git a/backend/app/ws_handler.py b/backend/app/ws_handler.py index 9eaa29d..70430db 100644 --- a/backend/app/ws_handler.py +++ b/backend/app/ws_handler.py @@ -3,7 +3,6 @@ from __future__ import annotations import json -import logging import re import time from collections import defaultdict @@ -21,7 +20,9 @@ if TYPE_CHECKING: from app.session_manager import SessionManager from app.ws_context import WebSocketContext -logger = logging.getLogger(__name__) +import structlog + +logger = structlog.get_logger() MAX_MESSAGE_SIZE = 32_768 # 32 KB MAX_CONTENT_LENGTH = 10_000 # characters diff --git a/backend/pyproject.toml b/backend/pyproject.toml index ab47c93..08825fc 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "python-dotenv>=1.0,<2.0", "httpx>=0.28,<1.0", "openapi-spec-validator>=0.7,<1.0", + "alembic>=1.13,<2.0", + "structlog>=24.0,<26.0", ] [project.optional-dependencies] diff --git a/backend/tests/e2e/conftest.py b/backend/tests/e2e/conftest.py index 5bc87c1..cf13463 100644 --- a/backend/tests/e2e/conftest.py +++ b/backend/tests/e2e/conftest.py @@ -174,7 +174,7 @@ def create_e2e_app( app.state.analytics_recorder = AsyncMock() app.state.conversation_tracker = AsyncMock() - @app.get("/api/health") + @app.get("/api/v1/health") def health_check() -> dict: return {"status": "ok", "version": "test"} diff --git a/backend/tests/e2e/test_chat_flows.py b/backend/tests/e2e/test_chat_flows.py index 06f578d..567bdc7 100644 --- a/backend/tests/e2e/test_chat_flows.py +++ b/backend/tests/e2e/test_chat_flows.py @@ -341,7 +341,7 @@ class TestChatEdgeCases: def test_health_endpoint(self) -> None: app = create_e2e_app() with TestClient(app) as client: - resp = client.get("/api/health") + resp = client.get("/api/v1/health") assert resp.status_code == 200 assert resp.json()["status"] == "ok" diff --git a/backend/tests/e2e/test_openapi_import.py b/backend/tests/e2e/test_openapi_import.py index ceebb00..011f257 100644 --- a/backend/tests/e2e/test_openapi_import.py +++ b/backend/tests/e2e/test_openapi_import.py @@ -62,7 +62,7 @@ class TestFlow5OpenAPIImport: with TestClient(app) as client: # Step 1: Start import job resp = client.post( - "/api/openapi/import", + "/api/v1/openapi/import", json={"url": "https://api.example.com/openapi.json"}, ) assert resp.status_code == 202 @@ -71,7 +71,7 @@ class TestFlow5OpenAPIImport: job_id = body["job_id"] # Step 2: Check job status (still pending since background task hasn't run) - resp = client.get(f"/api/openapi/jobs/{job_id}") + resp = client.get(f"/api/v1/openapi/jobs/{job_id}") assert resp.status_code == 200 assert resp.json()["job_id"] == job_id @@ -99,7 +99,7 @@ class TestFlow5OpenAPIImport: with TestClient(app) as client: # Step 1: Get classifications - resp = client.get(f"/api/openapi/jobs/{job_id}/classifications") + resp = client.get(f"/api/v1/openapi/jobs/{job_id}/classifications") assert resp.status_code == 200 classifications = resp.json() assert len(classifications) == 2 @@ -118,7 +118,7 @@ class TestFlow5OpenAPIImport: # Step 2: Update a classification resp = client.put( - f"/api/openapi/jobs/{job_id}/classifications/0", + f"/api/v1/openapi/jobs/{job_id}/classifications/0", json={ "access_type": "write", "needs_interrupt": True, @@ -132,7 +132,7 @@ class TestFlow5OpenAPIImport: assert updated["agent_group"] == "order_actions" # Step 3: Approve the job - resp = client.post(f"/api/openapi/jobs/{job_id}/approve") + resp = client.post(f"/api/v1/openapi/jobs/{job_id}/approve") assert resp.status_code == 200 assert resp.json()["status"] == "approved" @@ -140,14 +140,14 @@ class TestFlow5OpenAPIImport: app = create_e2e_app() with TestClient(app) as client: - resp = client.get("/api/openapi/jobs/nonexistent") + resp = client.get("/api/v1/openapi/jobs/nonexistent") assert resp.status_code == 404 def test_import_invalid_url_returns_422(self) -> None: app = create_e2e_app() with TestClient(app) as client: - resp = client.post("/api/openapi/import", json={"url": "not-a-url"}) + resp = client.post("/api/v1/openapi/import", json={"url": "not-a-url"}) assert resp.status_code == 422 def test_classification_index_out_of_range(self) -> None: @@ -166,7 +166,7 @@ class TestFlow5OpenAPIImport: with TestClient(app) as client: resp = client.put( - f"/api/openapi/jobs/{job_id}/classifications/99", + f"/api/v1/openapi/jobs/{job_id}/classifications/99", json={ "access_type": "read", "needs_interrupt": False, @@ -191,7 +191,7 @@ class TestFlow5OpenAPIImport: with TestClient(app) as client: resp = client.put( - f"/api/openapi/jobs/{job_id}/classifications/0", + f"/api/v1/openapi/jobs/{job_id}/classifications/0", json={ "access_type": "read", "needs_interrupt": False, diff --git a/backend/tests/e2e/test_replay_analytics.py b/backend/tests/e2e/test_replay_analytics.py index 82a6c9f..b7aa103 100644 --- a/backend/tests/e2e/test_replay_analytics.py +++ b/backend/tests/e2e/test_replay_analytics.py @@ -98,7 +98,7 @@ class TestFlow6ReplayConversation: app = create_e2e_app(pool=pool) with TestClient(app) as client: - resp = client.get("/api/conversations") + resp = client.get("/api/v1/conversations") assert resp.status_code == 200 body = resp.json() assert body["success"] is True @@ -124,7 +124,7 @@ class TestFlow6ReplayConversation: app = create_e2e_app(pool=pool) with TestClient(app) as client: - resp = client.get("/api/conversations", params={"page": 1, "per_page": 2}) + resp = client.get("/api/v1/conversations", params={"page": 1, "per_page": 2}) assert resp.status_code == 200 body = resp.json() assert body["success"] is True @@ -139,7 +139,7 @@ class TestFlow6ReplayConversation: app = create_e2e_app(pool=pool) with TestClient(app) as client: - resp = client.get("/api/replay/nonexistent-thread") + resp = client.get("/api/v1/replay/nonexistent-thread") assert resp.status_code == 404 def test_replay_invalid_thread_id_format(self) -> None: @@ -147,7 +147,7 @@ class TestFlow6ReplayConversation: with TestClient(app) as client: # Thread ID with special chars fails regex validation - resp = client.get("/api/replay/invalid%20thread%21%40") + resp = client.get("/api/v1/replay/invalid%20thread%21%40") assert resp.status_code == 400 @@ -158,21 +158,21 @@ class TestAnalyticsDashboard: app = create_e2e_app() with TestClient(app) as client: - resp = client.get("/api/analytics", params={"range": "invalid"}) + resp = client.get("/api/v1/analytics", params={"range": "invalid"}) assert resp.status_code == 400 def test_analytics_range_too_large(self) -> None: app = create_e2e_app() with TestClient(app) as client: - resp = client.get("/api/analytics", params={"range": "999d"}) + resp = client.get("/api/v1/analytics", params={"range": "999d"}) assert resp.status_code == 400 def test_analytics_range_zero_rejected(self) -> None: app = create_e2e_app() with TestClient(app) as client: - resp = client.get("/api/analytics", params={"range": "0d"}) + resp = client.get("/api/v1/analytics", params={"range": "0d"}) assert resp.status_code == 400 @@ -216,7 +216,7 @@ class TestFullUserJourney: assert any(m["type"] == "message_complete" for m in messages) # Step 2: Check conversations endpoint - resp = client.get("/api/conversations") + resp = client.get("/api/v1/conversations") assert resp.status_code == 200 body = resp.json() assert body["success"] is True @@ -226,5 +226,5 @@ class TestFullUserJourney: ) # Step 3: Health check still works - resp = client.get("/api/health") + resp = client.get("/api/v1/health") assert resp.status_code == 200 diff --git a/backend/tests/integration/test_analytics_api.py b/backend/tests/integration/test_analytics_api.py new file mode 100644 index 0000000..291ff69 --- /dev/null +++ b/backend/tests/integration/test_analytics_api.py @@ -0,0 +1,183 @@ +"""Integration tests for the /api/v1/analytics endpoint. + +Tests the full API layer (routing, parameter validation, serialization, +error handling) with a mocked database pool. +""" + +from __future__ import annotations + +from dataclasses import asdict +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +from app.analytics.models import AnalyticsResult, InterruptStats + +pytestmark = pytest.mark.integration + +_SAMPLE_RESULT = AnalyticsResult( + range="7d", + total_conversations=42, + resolution_rate=0.85, + escalation_rate=0.05, + avg_turns_per_conversation=3.2, + avg_cost_per_conversation_usd=0.012, + agent_usage=(), + interrupt_stats=InterruptStats(total=10, approved=7, rejected=2, expired=1), +) + + +def _build_app(): + """Build a minimal FastAPI app with the analytics router and mocked deps.""" + from fastapi import FastAPI + from fastapi.exceptions import RequestValidationError + from fastapi.responses import JSONResponse + + from app.analytics.api import router as analytics_router + from app.api_utils import envelope + + test_app = FastAPI() + test_app.include_router(analytics_router) + + @test_app.exception_handler(Exception) + async def _catch_all(request, exc): + return JSONResponse( + status_code=500, + content=envelope(None, success=False, error="Internal server error"), + ) + + from fastapi import HTTPException + + @test_app.exception_handler(HTTPException) + async def _http_exc(request, exc): + return JSONResponse( + status_code=exc.status_code, + content=envelope(None, success=False, error=exc.detail), + ) + + @test_app.exception_handler(RequestValidationError) + async def _validation_exc(request, exc): + return JSONResponse( + status_code=422, + content=envelope(None, success=False, error=str(exc)), + ) + + # No admin_api_key set -> auth is skipped + test_app.state.settings = MagicMock(admin_api_key="") + test_app.state.pool = MagicMock() + + return test_app + + +class TestAnalyticsValidRange: + """Test analytics endpoint with valid range parameters.""" + + async def test_valid_range_7d_returns_envelope(self) -> None: + """GET /api/v1/analytics?range=7d returns success envelope with data.""" + test_app = _build_app() + with patch( + "app.analytics.api.get_analytics", + new_callable=AsyncMock, + return_value=_SAMPLE_RESULT, + ): + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/analytics", params={"range": "7d"}) + + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + assert body["error"] is None + assert body["data"]["total_conversations"] == 42 + assert body["data"]["resolution_rate"] == 0.85 + + async def test_default_range_returns_success(self) -> None: + """GET /api/v1/analytics with no range param defaults to 7d.""" + test_app = _build_app() + with patch( + "app.analytics.api.get_analytics", + new_callable=AsyncMock, + return_value=_SAMPLE_RESULT, + ) as mock_get: + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/analytics") + + assert resp.status_code == 200 + # Verify default range of 7 days was passed + mock_get.assert_called_once() + call_args = mock_get.call_args + assert call_args[1].get("range_days", call_args[0][1] if len(call_args[0]) > 1 else None) in (7, None) or call_args[0][1] == 7 + + async def test_large_range_365d_works(self) -> None: + """GET /api/v1/analytics?range=365d is accepted (max boundary).""" + test_app = _build_app() + result = AnalyticsResult( + range="365d", + total_conversations=1000, + resolution_rate=0.9, + escalation_rate=0.02, + avg_turns_per_conversation=4.0, + avg_cost_per_conversation_usd=0.01, + agent_usage=(), + interrupt_stats=InterruptStats(), + ) + with patch( + "app.analytics.api.get_analytics", + new_callable=AsyncMock, + return_value=result, + ): + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/analytics", params={"range": "365d"}) + + assert resp.status_code == 200 + assert resp.json()["success"] is True + + +class TestAnalyticsInvalidRange: + """Test analytics endpoint with invalid range parameters.""" + + async def test_invalid_range_format_returns_400(self) -> None: + """GET /api/v1/analytics?range=abc returns 400 error envelope.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/analytics", params={"range": "abc"}) + + assert resp.status_code == 400 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert "Invalid range format" in body["error"] + + async def test_zero_day_range_returns_400(self) -> None: + """GET /api/v1/analytics?range=0d returns 400 because 0 is below minimum.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/analytics", params={"range": "0d"}) + + assert resp.status_code == 400 + body = resp.json() + assert body["success"] is False + assert "between 1 and 365" in body["error"] + + async def test_range_exceeding_max_returns_400(self) -> None: + """GET /api/v1/analytics?range=999d returns 400 because it exceeds 365.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/analytics", params={"range": "999d"}) + + assert resp.status_code == 400 + body = resp.json() + assert body["success"] is False + assert "between 1 and 365" in body["error"] diff --git a/backend/tests/integration/test_error_responses.py b/backend/tests/integration/test_error_responses.py new file mode 100644 index 0000000..137c566 --- /dev/null +++ b/backend/tests/integration/test_error_responses.py @@ -0,0 +1,128 @@ +"""Integration tests for global error handling and envelope format consistency. + +Tests that all error responses from the FastAPI app conform to the +standard envelope: {"success": false, "data": null, "error": "..."}. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from httpx import ASGITransport, AsyncClient + +pytestmark = pytest.mark.integration + + +def _build_app(): + """Build the actual FastAPI app with exception handlers but mocked state.""" + from fastapi import FastAPI, HTTPException + from fastapi.exceptions import RequestValidationError + from fastapi.responses import JSONResponse + + from app.analytics.api import router as analytics_router + from app.api_utils import envelope + from app.replay.api import router as replay_router + + test_app = FastAPI() + test_app.include_router(analytics_router) + test_app.include_router(replay_router) + + @test_app.exception_handler(HTTPException) + async def _http_exc(request, exc): + return JSONResponse( + status_code=exc.status_code, + content=envelope(None, success=False, error=exc.detail), + ) + + @test_app.exception_handler(RequestValidationError) + async def _validation_exc(request, exc): + return JSONResponse( + status_code=422, + content=envelope(None, success=False, error=str(exc)), + ) + + @test_app.exception_handler(Exception) + async def _catch_all(request, exc): + return JSONResponse( + status_code=500, + content=envelope(None, success=False, error="Internal server error"), + ) + + @test_app.get("/api/v1/health") + def health_check(): + return {"status": "ok", "version": "0.6.0"} + + test_app.state.settings = MagicMock(admin_api_key="") + test_app.state.pool = MagicMock() + + return test_app + + +class TestEnvelopeFormat: + """Tests that error responses consistently follow envelope format.""" + + async def test_http_400_produces_envelope(self) -> None: + """A 400 error returns standard envelope with success=false.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/analytics", params={"range": "invalid"}) + + assert resp.status_code == 400 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert isinstance(body["error"], str) + assert len(body["error"]) > 0 + + async def test_validation_error_produces_422_envelope(self) -> None: + """Invalid query param type returns 422 with envelope format.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + # page must be >= 1; passing 0 triggers validation error + resp = await client.get("/api/v1/conversations", params={"page": 0}) + + assert resp.status_code == 422 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert isinstance(body["error"], str) + + async def test_all_error_fields_present(self) -> None: + """Error envelope contains exactly success, data, and error keys.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/analytics", params={"range": "bad"}) + + body = resp.json() + assert set(body.keys()) == {"success", "data", "error"} + + async def test_health_endpoint_returns_200(self) -> None: + """Health check returns 200 with status ok.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/health") + + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert "version" in body + + async def test_unknown_endpoint_returns_404(self) -> None: + """Requesting a non-existent path returns 404.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/nonexistent-path") + + # FastAPI returns 404 for unknown routes; may or may not be wrapped + assert resp.status_code == 404 diff --git a/backend/tests/integration/test_openapi_api.py b/backend/tests/integration/test_openapi_api.py new file mode 100644 index 0000000..cde31bf --- /dev/null +++ b/backend/tests/integration/test_openapi_api.py @@ -0,0 +1,164 @@ +"""Integration tests for /api/v1/openapi/ endpoints. + +Tests the full API layer for the OpenAPI import review workflow, +including job creation, status retrieval, classification updates, +and approval triggering. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +pytestmark = pytest.mark.integration + + +def _build_app(): + """Build a minimal FastAPI app with the openapi router and mocked deps.""" + from fastapi import FastAPI, HTTPException + from fastapi.exceptions import RequestValidationError + from fastapi.responses import JSONResponse + + from app.api_utils import envelope + from app.openapi.review_api import router as openapi_router + + test_app = FastAPI() + test_app.include_router(openapi_router) + + @test_app.exception_handler(HTTPException) + async def _http_exc(request, exc): + return JSONResponse( + status_code=exc.status_code, + content=envelope(None, success=False, error=exc.detail), + ) + + @test_app.exception_handler(RequestValidationError) + async def _validation_exc(request, exc): + return JSONResponse( + status_code=422, + content=envelope(None, success=False, error=str(exc)), + ) + + test_app.state.settings = MagicMock(admin_api_key="") + + return test_app + + +@pytest.fixture(autouse=True) +def _clear_job_store(): + """Clear the in-memory job store between tests.""" + from app.openapi.review_api import _job_store + + _job_store.clear() + yield + _job_store.clear() + + +class TestImportEndpoint: + """Tests for POST /api/v1/openapi/import.""" + + async def test_import_returns_202_with_job_id(self) -> None: + """Starting an import returns 202 with a job_id.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.post( + "/api/v1/openapi/import", + json={"url": "https://example.com/api/spec.json"}, + ) + + assert resp.status_code == 202 + body = resp.json() + assert "job_id" in body + assert body["status"] == "pending" + assert body["spec_url"] == "https://example.com/api/spec.json" + + async def test_import_invalid_url_returns_422(self) -> None: + """POST with invalid URL (no http/https) returns 422.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.post( + "/api/v1/openapi/import", + json={"url": "ftp://example.com/spec.json"}, + ) + + assert resp.status_code == 422 + body = resp.json() + assert body["success"] is False + + +class TestJobStatusEndpoint: + """Tests for GET /api/v1/openapi/jobs/{job_id}.""" + + async def test_get_existing_job_returns_status(self) -> None: + """Retrieving an existing job returns its status.""" + from app.openapi.review_api import _job_store + + _job_store["test-job-1"] = { + "job_id": "test-job-1", + "status": "done", + "spec_url": "https://example.com/spec.json", + "total_endpoints": 5, + "classified_count": 5, + "error_message": None, + "classifications": [], + } + + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/openapi/jobs/test-job-1") + + assert resp.status_code == 200 + body = resp.json() + assert body["job_id"] == "test-job-1" + assert body["status"] == "done" + assert body["total_endpoints"] == 5 + + async def test_get_unknown_job_returns_404(self) -> None: + """Retrieving a non-existent job returns 404 error envelope.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/openapi/jobs/unknown-id-999") + + assert resp.status_code == 404 + body = resp.json() + assert body["success"] is False + assert "not found" in body["error"].lower() + + +class TestApproveEndpoint: + """Tests for POST /api/v1/openapi/jobs/{job_id}/approve.""" + + async def test_approve_with_no_classifications_returns_400(self) -> None: + """Approving a job with no classifications returns 400.""" + from app.openapi.review_api import _job_store + + _job_store["empty-job"] = { + "job_id": "empty-job", + "status": "done", + "spec_url": "https://example.com/spec.json", + "total_endpoints": 0, + "classified_count": 0, + "error_message": None, + "classifications": [], + } + + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.post("/api/v1/openapi/jobs/empty-job/approve") + + assert resp.status_code == 400 + body = resp.json() + assert body["success"] is False + assert "no classifications" in body["error"].lower() diff --git a/backend/tests/integration/test_replay_api.py b/backend/tests/integration/test_replay_api.py new file mode 100644 index 0000000..f4636d3 --- /dev/null +++ b/backend/tests/integration/test_replay_api.py @@ -0,0 +1,213 @@ +"""Integration tests for /api/v1/conversations and /api/v1/replay/{thread_id}. + +Tests the full API layer with a mocked database pool, verifying routing, +serialization, pagination, and error handling in envelope format. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from httpx import ASGITransport, AsyncClient + +pytestmark = pytest.mark.integration + + +def _make_fake_cursor(rows, *, fetchone_value=None): + """Build a fake async cursor returning the given rows on fetchall.""" + cursor = AsyncMock() + cursor.fetchall = AsyncMock(return_value=rows) + if fetchone_value is not None: + cursor.fetchone = AsyncMock(return_value=fetchone_value) + return cursor + + +class _FakeConnection: + """Fake async connection that returns pre-configured cursors in order.""" + + def __init__(self, cursors: list) -> None: + self._cursors = list(cursors) + self._idx = 0 + + async def execute(self, sql, params=None): + cursor = self._cursors[self._idx] + self._idx += 1 + return cursor + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +class _FakePool: + """Fake connection pool that yields a fake connection.""" + + def __init__(self, conn: _FakeConnection) -> None: + self._conn = conn + + def connection(self): + return self._conn + + +def _build_app(pool=None): + """Build a minimal FastAPI app with the replay router and mocked deps.""" + from fastapi import FastAPI, HTTPException + from fastapi.exceptions import RequestValidationError + from fastapi.responses import JSONResponse + + from app.api_utils import envelope + from app.replay.api import router as replay_router + + test_app = FastAPI() + test_app.include_router(replay_router) + + @test_app.exception_handler(HTTPException) + async def _http_exc(request, exc): + return JSONResponse( + status_code=exc.status_code, + content=envelope(None, success=False, error=exc.detail), + ) + + @test_app.exception_handler(RequestValidationError) + async def _validation_exc(request, exc): + return JSONResponse( + status_code=422, + content=envelope(None, success=False, error=str(exc)), + ) + + test_app.state.settings = MagicMock(admin_api_key="") + test_app.state.pool = pool or MagicMock() + + return test_app + + +class TestListConversations: + """Tests for GET /api/v1/conversations endpoint.""" + + async def test_returns_paginated_envelope(self) -> None: + """Conversations list returns envelope with pagination metadata.""" + count_cursor = _make_fake_cursor([], fetchone_value=(3,)) + rows = [ + {"thread_id": "t1", "created_at": "2026-01-01", "last_activity": "2026-01-01", + "status": "active", "total_tokens": 100, "total_cost_usd": 0.01}, + {"thread_id": "t2", "created_at": "2026-01-02", "last_activity": "2026-01-02", + "status": "resolved", "total_tokens": 200, "total_cost_usd": 0.02}, + ] + list_cursor = _make_fake_cursor(rows) + conn = _FakeConnection([count_cursor, list_cursor]) + pool = _FakePool(conn) + test_app = _build_app(pool) + + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/conversations") + + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + assert body["data"]["total"] == 3 + assert len(body["data"]["conversations"]) == 2 + assert body["data"]["page"] == 1 + assert body["data"]["per_page"] == 20 + + async def test_custom_page_and_per_page(self) -> None: + """Custom page/per_page params are reflected in the response.""" + count_cursor = _make_fake_cursor([], fetchone_value=(50,)) + list_cursor = _make_fake_cursor([]) + conn = _FakeConnection([count_cursor, list_cursor]) + pool = _FakePool(conn) + test_app = _build_app(pool) + + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/conversations", params={"page": 3, "per_page": 10}) + + assert resp.status_code == 200 + body = resp.json() + assert body["data"]["page"] == 3 + assert body["data"]["per_page"] == 10 + + async def test_invalid_page_returns_422(self) -> None: + """page=0 violates ge=1 constraint and returns 422 error envelope.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/conversations", params={"page": 0}) + + assert resp.status_code == 422 + body = resp.json() + assert body["success"] is False + + +class TestReplayEndpoint: + """Tests for GET /api/v1/replay/{thread_id} endpoint.""" + + async def test_valid_thread_returns_timeline(self) -> None: + """Replay with valid thread_id returns steps in envelope format.""" + checkpoint_rows = [ + { + "thread_id": "abc123", + "checkpoint_id": "cp1", + "checkpoint": { + "channel_values": { + "messages": [ + {"type": "human", "content": "Hello", "created_at": "2026-01-01T00:00:00Z"}, + {"type": "ai", "content": "Hi there!", "created_at": "2026-01-01T00:00:01Z"}, + ] + } + }, + "metadata": {}, + } + ] + cursor = _make_fake_cursor(checkpoint_rows) + conn = _FakeConnection([cursor]) + pool = _FakePool(conn) + test_app = _build_app(pool) + + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/replay/abc123") + + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + assert body["data"]["thread_id"] == "abc123" + assert body["data"]["total_steps"] == 2 + assert len(body["data"]["steps"]) == 2 + assert body["data"]["steps"][0]["type"] == "user_message" + assert body["data"]["steps"][1]["type"] == "agent_response" + + async def test_invalid_thread_id_format_returns_400(self) -> None: + """Thread IDs with path traversal characters are rejected with 400.""" + test_app = _build_app() + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/replay/../../etc/passwd") + + # FastAPI may return 400 from our handler or 404 from routing + assert resp.status_code in (400, 404, 422) + + async def test_nonexistent_thread_returns_404(self) -> None: + """Replay with a thread_id that has no checkpoints returns 404.""" + cursor = _make_fake_cursor([]) + conn = _FakeConnection([cursor]) + pool = _FakePool(conn) + test_app = _build_app(pool) + + async with AsyncClient( + transport=ASGITransport(app=test_app), base_url="http://test" + ) as client: + resp = await client.get("/api/v1/replay/nonexistent-thread") + + assert resp.status_code == 404 + body = resp.json() + assert body["success"] is False + assert "not found" in body["error"].lower() diff --git a/backend/tests/integration/test_session_interrupt_lifecycle.py b/backend/tests/integration/test_session_interrupt_lifecycle.py new file mode 100644 index 0000000..4b4835c --- /dev/null +++ b/backend/tests/integration/test_session_interrupt_lifecycle.py @@ -0,0 +1,159 @@ +"""Integration tests for SessionManager + InterruptManager lifecycle. + +These tests exercise the in-memory managers together, verifying the full +lifecycle of sessions and interrupts: creation, TTL sliding, interrupt +registration/resolution, and expired-interrupt cleanup. + +No database required -- both managers are in-memory. +""" + +from __future__ import annotations + +import time +from unittest.mock import patch + +import pytest + +from app.interrupt_manager import InterruptManager +from app.session_manager import SessionManager + +pytestmark = pytest.mark.integration + + +class TestSessionInterruptLifecycle: + """Tests for the combined session + interrupt lifecycle.""" + + def test_create_session_register_interrupt_check_status(self) -> None: + """Full lifecycle: create session, register interrupt, verify both states.""" + sm = SessionManager(session_ttl_seconds=3600) + im = InterruptManager(ttl_seconds=300) + + # Create a session + state = sm.touch("thread-1") + assert state.thread_id == "thread-1" + assert not state.has_pending_interrupt + assert not sm.is_expired("thread-1") + + # Register an interrupt + record = im.register("thread-1", "cancel_order", {"order_id": "1042"}) + sm.extend_for_interrupt("thread-1") + + assert im.has_pending("thread-1") + session_state = sm.get_state("thread-1") + assert session_state is not None + assert session_state.has_pending_interrupt + + # Session should not expire while interrupt is pending + assert not sm.is_expired("thread-1") + + def test_interrupt_expiry_after_ttl(self) -> None: + """Interrupt expires when TTL elapses, even if session is alive.""" + im = InterruptManager(ttl_seconds=5) + + record = im.register("thread-2", "refund", {"amount": 50}) + assert im.has_pending("thread-2") + + # Simulate time passing beyond TTL + with patch("app.interrupt_manager.time") as mock_time: + mock_time.time.return_value = record.created_at + 10 + assert not im.has_pending("thread-2") + + status = im.check_status("thread-2") + assert status is not None + assert status.is_expired + assert status.remaining_seconds == 0.0 + + def test_interrupt_resolve_flow(self) -> None: + """Resolving an interrupt removes it from pending and resets session.""" + sm = SessionManager(session_ttl_seconds=3600) + im = InterruptManager(ttl_seconds=300) + + sm.touch("thread-3") + im.register("thread-3", "delete_account", {"user_id": "u1"}) + sm.extend_for_interrupt("thread-3") + + # Verify pending state + assert im.has_pending("thread-3") + assert sm.get_state("thread-3").has_pending_interrupt + + # Resolve + im.resolve("thread-3") + sm.resolve_interrupt("thread-3") + + assert not im.has_pending("thread-3") + session_state = sm.get_state("thread-3") + assert session_state is not None + assert not session_state.has_pending_interrupt + + def test_cleanup_expired_removes_old_interrupts(self) -> None: + """cleanup_expired removes only expired interrupts, keeping active ones.""" + im = InterruptManager(ttl_seconds=10) + + # Register two interrupts at different times + old_record = im.register("thread-old", "action_old", {}) + new_record = im.register("thread-new", "action_new", {}) + + # Simulate time where only old one expired + with patch("app.interrupt_manager.time") as mock_time: + # Move old record's creation to the past + im._interrupts["thread-old"] = old_record.__class__( + interrupt_id=old_record.interrupt_id, + thread_id=old_record.thread_id, + action=old_record.action, + params=old_record.params, + created_at=time.time() - 20, + ttl_seconds=old_record.ttl_seconds, + ) + mock_time.time.return_value = time.time() + + expired = im.cleanup_expired() + assert len(expired) == 1 + assert expired[0].thread_id == "thread-old" + + # New one should still be pending + assert im.has_pending("thread-new") + assert not im.has_pending("thread-old") + + def test_session_ttl_sliding_window(self) -> None: + """Touching a session resets the sliding window TTL.""" + sm = SessionManager(session_ttl_seconds=3600) + + state1 = sm.touch("thread-5") + first_activity = state1.last_activity + + time.sleep(0.01) + state2 = sm.touch("thread-5") + second_activity = state2.last_activity + + assert second_activity > first_activity + assert not sm.is_expired("thread-5") + + def test_session_expires_after_ttl_without_activity(self) -> None: + """Session expires when TTL passes without a touch or interrupt.""" + sm = SessionManager(session_ttl_seconds=0) + sm.touch("thread-6") + + # TTL is 0 so session is immediately expired + assert sm.is_expired("thread-6") + + def test_pending_interrupt_prevents_session_expiry(self) -> None: + """A session with pending interrupt does not expire even with TTL=0.""" + sm = SessionManager(session_ttl_seconds=0) + sm.touch("thread-7") + sm.extend_for_interrupt("thread-7") + + # Even with TTL=0, session should not expire because of pending interrupt + assert not sm.is_expired("thread-7") + + def test_retry_prompt_for_expired_interrupt(self) -> None: + """InterruptManager generates a retry prompt for expired interrupts.""" + im = InterruptManager(ttl_seconds=300) + record = im.register("thread-8", "cancel_order", {"order_id": "1042"}) + + prompt = im.generate_retry_prompt(record) + + assert prompt["type"] == "interrupt_expired" + assert prompt["thread_id"] == "thread-8" + assert "cancel_order" in prompt["action"] + assert "cancel_order" in prompt["message"] + assert "expired" in prompt["message"].lower() diff --git a/backend/tests/unit/analytics/test_api.py b/backend/tests/unit/analytics/test_api.py index 5a74411..a23b239 100644 --- a/backend/tests/unit/analytics/test_api.py +++ b/backend/tests/unit/analytics/test_api.py @@ -44,7 +44,7 @@ def _make_analytics_result() -> object: ) -def _get_analytics(app: FastAPI, path: str = "/api/analytics", **patch_kwargs: object) -> object: +def _get_analytics(app: FastAPI, path: str = "/api/v1/analytics", **patch_kwargs: object) -> object: """Helper: patch get_analytics, make request, return (response, mock).""" analytics_result = _make_analytics_result() with ( @@ -84,7 +84,7 @@ class TestAnalyticsEndpoint: def test_custom_range_7d(self) -> None: app = _build_app() app.state.pool = _make_mock_pool() - resp, mock_ga = _get_analytics(app, "/api/analytics?range=7d") + resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=7d") assert resp.status_code == 200 mock_ga.assert_called_once() @@ -94,7 +94,7 @@ class TestAnalyticsEndpoint: def test_custom_range_30d(self) -> None: app = _build_app() app.state.pool = _make_mock_pool() - resp, mock_ga = _get_analytics(app, "/api/analytics?range=30d") + resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=30d") assert resp.status_code == 200 call_kwargs = mock_ga.call_args @@ -107,7 +107,7 @@ class TestAnalyticsEndpoint: app.state.pool = _make_mock_pool() with TestClient(app) as client: - resp = client.get("/api/analytics?range=invalid") + resp = client.get("/api/v1/analytics?range=invalid") assert resp.status_code == 400 @@ -116,7 +116,7 @@ class TestAnalyticsEndpoint: app.state.pool = _make_mock_pool() with TestClient(app) as client: - resp = client.get("/api/analytics?range=7") + resp = client.get("/api/v1/analytics?range=7") assert resp.status_code == 400 diff --git a/backend/tests/unit/openapi/test_review_api.py b/backend/tests/unit/openapi/test_review_api.py index c7cc4df..fee25ad 100644 --- a/backend/tests/unit/openapi/test_review_api.py +++ b/backend/tests/unit/openapi/test_review_api.py @@ -28,7 +28,7 @@ def client(): @pytest.fixture def job_id(client): """Create a job and return its ID.""" - response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) + response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL}) assert response.status_code == 202 return response.json()["job_id"] @@ -61,11 +61,11 @@ def job_with_classifications(client, job_id): class TestImportEndpoint: - """Tests for POST /api/openapi/import.""" + """Tests for POST /api/v1/openapi/import.""" def test_post_import_returns_job_id(self, client) -> None: """POST /import returns 202 with a job_id.""" - response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) + response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL}) assert response.status_code == 202 data = response.json() assert "job_id" in data @@ -73,38 +73,38 @@ class TestImportEndpoint: def test_post_import_empty_url_returns_422(self, client) -> None: """POST /import with empty URL returns 422 validation error.""" - response = client.post("/api/openapi/import", json={"url": ""}) + response = client.post("/api/v1/openapi/import", json={"url": ""}) assert response.status_code == 422 def test_post_import_missing_url_returns_422(self, client) -> None: """POST /import with missing URL field returns 422.""" - response = client.post("/api/openapi/import", json={}) + response = client.post("/api/v1/openapi/import", json={}) assert response.status_code == 422 def test_post_import_invalid_scheme_returns_422(self, client) -> None: """POST /import with non-http URL returns 422.""" - response = client.post("/api/openapi/import", json={"url": "ftp://evil.com/spec"}) + response = client.post("/api/v1/openapi/import", json={"url": "ftp://evil.com/spec"}) assert response.status_code == 422 def test_post_import_returns_pending_status(self, client) -> None: """Newly created job has pending status.""" - response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) + response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL}) data = response.json() assert data["status"] == "pending" def test_post_import_returns_spec_url(self, client) -> None: """Response includes the original spec URL.""" - response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL}) + response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL}) data = response.json() assert data["spec_url"] == _SAMPLE_URL class TestGetJobEndpoint: - """Tests for GET /api/openapi/jobs/{job_id}.""" + """Tests for GET /api/v1/openapi/jobs/{job_id}.""" def test_get_job_returns_status(self, client, job_id) -> None: """GET /jobs/{id} returns job status.""" - response = client.get(f"/api/openapi/jobs/{job_id}") + response = client.get(f"/api/v1/openapi/jobs/{job_id}") assert response.status_code == 200 data = response.json() assert "status" in data @@ -112,23 +112,23 @@ class TestGetJobEndpoint: def test_get_unknown_job_returns_404(self, client) -> None: """GET /jobs/nonexistent returns 404.""" - response = client.get("/api/openapi/jobs/nonexistent-id") + response = client.get("/api/v1/openapi/jobs/nonexistent-id") assert response.status_code == 404 def test_get_job_includes_spec_url(self, client, job_id) -> None: """Job response includes the spec URL.""" - response = client.get(f"/api/openapi/jobs/{job_id}") + response = client.get(f"/api/v1/openapi/jobs/{job_id}") data = response.json() assert data["spec_url"] == _SAMPLE_URL class TestGetClassificationsEndpoint: - """Tests for GET /api/openapi/jobs/{job_id}/classifications.""" + """Tests for GET /api/v1/openapi/jobs/{job_id}/classifications.""" def test_get_classifications_returns_list(self, client, job_with_classifications) -> None: """GET /classifications returns a list.""" response = client.get( - f"/api/openapi/jobs/{job_with_classifications}/classifications" + f"/api/v1/openapi/jobs/{job_with_classifications}/classifications" ) assert response.status_code == 200 data = response.json() @@ -137,13 +137,13 @@ class TestGetClassificationsEndpoint: def test_get_classifications_unknown_job_returns_404(self, client) -> None: """GET /classifications for unknown job returns 404.""" - response = client.get("/api/openapi/jobs/unknown/classifications") + response = client.get("/api/v1/openapi/jobs/unknown/classifications") assert response.status_code == 404 def test_classification_has_expected_fields(self, client, job_with_classifications) -> None: """Each classification item has access_type and endpoint fields.""" response = client.get( - f"/api/openapi/jobs/{job_with_classifications}/classifications" + f"/api/v1/openapi/jobs/{job_with_classifications}/classifications" ) item = response.json()[0] assert "access_type" in item @@ -152,12 +152,12 @@ class TestGetClassificationsEndpoint: class TestUpdateClassificationEndpoint: - """Tests for PUT /api/openapi/jobs/{job_id}/classifications/{idx}.""" + """Tests for PUT /api/v1/openapi/jobs/{job_id}/classifications/{idx}.""" def test_update_classification_succeeds(self, client, job_with_classifications) -> None: """PUT /classifications/0 updates the classification.""" response = client.put( - f"/api/openapi/jobs/{job_with_classifications}/classifications/0", + f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0", json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"}, ) assert response.status_code == 200 @@ -165,7 +165,7 @@ class TestUpdateClassificationEndpoint: def test_update_unknown_job_returns_404(self, client) -> None: """PUT /classifications/0 for unknown job returns 404.""" response = client.put( - "/api/openapi/jobs/unknown/classifications/0", + "/api/v1/openapi/jobs/unknown/classifications/0", json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"}, ) assert response.status_code == 404 @@ -173,7 +173,7 @@ class TestUpdateClassificationEndpoint: def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None: """PUT /classifications/0 with invalid access_type returns 422.""" response = client.put( - f"/api/openapi/jobs/{job_with_classifications}/classifications/0", + f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0", json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"}, ) assert response.status_code == 422 @@ -181,7 +181,7 @@ class TestUpdateClassificationEndpoint: def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None: """PUT /classifications/0 with invalid agent_group returns 422.""" response = client.put( - f"/api/openapi/jobs/{job_with_classifications}/classifications/0", + f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0", json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"}, ) assert response.status_code == 422 @@ -189,31 +189,31 @@ class TestUpdateClassificationEndpoint: def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None: """PUT /classifications/999 returns 404 for out-of-range index.""" response = client.put( - f"/api/openapi/jobs/{job_with_classifications}/classifications/999", + f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/999", json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"}, ) assert response.status_code == 404 class TestApproveEndpoint: - """Tests for POST /api/openapi/jobs/{job_id}/approve.""" + """Tests for POST /api/v1/openapi/jobs/{job_id}/approve.""" def test_approve_job_succeeds(self, client, job_with_classifications) -> None: """POST /approve transitions job to approved status.""" response = client.post( - f"/api/openapi/jobs/{job_with_classifications}/approve" + f"/api/v1/openapi/jobs/{job_with_classifications}/approve" ) assert response.status_code == 200 def test_approve_unknown_job_returns_404(self, client) -> None: """POST /approve for unknown job returns 404.""" - response = client.post("/api/openapi/jobs/unknown/approve") + response = client.post("/api/v1/openapi/jobs/unknown/approve") assert response.status_code == 404 def test_approve_returns_job_status(self, client, job_with_classifications) -> None: """POST /approve returns updated job status.""" response = client.post( - f"/api/openapi/jobs/{job_with_classifications}/approve" + f"/api/v1/openapi/jobs/{job_with_classifications}/approve" ) data = response.json() assert "status" in data diff --git a/backend/tests/unit/replay/test_api.py b/backend/tests/unit/replay/test_api.py index 1acf55e..50900bf 100644 --- a/backend/tests/unit/replay/test_api.py +++ b/backend/tests/unit/replay/test_api.py @@ -5,9 +5,12 @@ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock import pytest -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse from fastapi.testclient import TestClient +from app.api_utils import envelope + pytestmark = pytest.mark.unit @@ -16,6 +19,14 @@ def _build_app() -> FastAPI: app = FastAPI() app.include_router(router) + + @app.exception_handler(HTTPException) + async def _http_exc(request, exc): # type: ignore[no-untyped-def] + return JSONResponse( + status_code=exc.status_code, + content=envelope(None, success=False, error=exc.detail), + ) + return app @@ -64,7 +75,7 @@ class TestListConversations: app.state.pool = _make_mock_pool([], count=0) with TestClient(app) as client: - resp = client.get("/api/conversations") + resp = client.get("/api/v1/conversations") assert resp.status_code == 200 body = resp.json() assert body["success"] is True @@ -89,7 +100,7 @@ class TestListConversations: app.state.pool = _make_mock_pool(mock_rows, count=1) with TestClient(app) as client: - resp = client.get("/api/conversations") + resp = client.get("/api/v1/conversations") body = resp.json() assert resp.status_code == 200 data = body["data"] @@ -102,7 +113,7 @@ class TestListConversations: app.state.pool = _make_mock_pool([], count=0) with TestClient(app) as client: - resp = client.get("/api/conversations") + resp = client.get("/api/v1/conversations") assert resp.status_code == 200 def test_pagination_custom_params(self) -> None: @@ -110,7 +121,7 @@ class TestListConversations: app.state.pool = _make_mock_pool([], count=0) with TestClient(app) as client: - resp = client.get("/api/conversations?page=2&per_page=10") + resp = client.get("/api/v1/conversations?page=2&per_page=10") assert resp.status_code == 200 def test_per_page_max_capped_at_100(self) -> None: @@ -118,7 +129,7 @@ class TestListConversations: app.state.pool = _make_mock_pool([], count=0) with TestClient(app) as client: - resp = client.get("/api/conversations?per_page=200") + resp = client.get("/api/v1/conversations?per_page=200") # FastAPI Query(le=100) rejects values > 100 assert resp.status_code == 422 @@ -129,7 +140,7 @@ class TestGetReplay: app.state.pool = _make_mock_pool([]) with TestClient(app) as client: - resp = client.get("/api/replay/nonexistent-thread") + resp = client.get("/api/v1/replay/nonexistent-thread") assert resp.status_code == 404 def test_returns_replay_page_for_existing_thread(self) -> None: @@ -149,7 +160,7 @@ class TestGetReplay: app.state.pool = _make_mock_pool(mock_rows) with TestClient(app) as client: - resp = client.get("/api/replay/thread-123") + resp = client.get("/api/v1/replay/thread-123") assert resp.status_code == 200 body = resp.json() assert body["success"] is True @@ -174,7 +185,7 @@ class TestGetReplay: app.state.pool = _make_mock_pool(mock_rows) with TestClient(app) as client: - resp = client.get("/api/replay/t1?page=1&per_page=5") + resp = client.get("/api/v1/replay/t1?page=1&per_page=5") assert resp.status_code == 200 def test_error_response_has_envelope(self) -> None: @@ -182,16 +193,19 @@ class TestGetReplay: app.state.pool = _make_mock_pool([]) with TestClient(app) as client: - resp = client.get("/api/replay/missing") + resp = client.get("/api/v1/replay/missing") assert resp.status_code == 404 - assert "detail" in resp.json() + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert body["error"] is not None def test_invalid_thread_id_returns_400(self) -> None: app = _build_app() app.state.pool = _make_mock_pool([]) with TestClient(app) as client: - resp = client.get("/api/replay/id%20with%20spaces") + resp = client.get("/api/v1/replay/id%20with%20spaces") assert resp.status_code == 400 def test_thread_id_special_chars_returns_400(self) -> None: @@ -199,5 +213,5 @@ class TestGetReplay: app.state.pool = _make_mock_pool([]) with TestClient(app) as client: - resp = client.get("/api/replay/id;DROP TABLE") + resp = client.get("/api/v1/replay/id;DROP TABLE") assert resp.status_code == 400 diff --git a/backend/tests/unit/test_error_responses.py b/backend/tests/unit/test_error_responses.py new file mode 100644 index 0000000..e30e7ff --- /dev/null +++ b/backend/tests/unit/test_error_responses.py @@ -0,0 +1,142 @@ +"""Tests for standardized error response envelope format.""" + +from __future__ import annotations + +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient +from pydantic import BaseModel, Field + +from app.api_utils import envelope + +pytestmark = pytest.mark.unit + + +def _build_test_app() -> FastAPI: + """Build a minimal FastAPI app with the standard exception handlers.""" + app = FastAPI() + + @app.exception_handler(HTTPException) + async def http_exception_handler(request, exc): # type: ignore[no-untyped-def] + return JSONResponse( + status_code=exc.status_code, + content=envelope(None, success=False, error=exc.detail), + ) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def] + return JSONResponse( + status_code=422, + content=envelope(None, success=False, error=str(exc)), + ) + + @app.exception_handler(Exception) + async def general_exception_handler(request, exc): # type: ignore[no-untyped-def] + return JSONResponse( + status_code=500, + content=envelope(None, success=False, error="Internal server error"), + ) + + class ItemRequest(BaseModel): + name: str = Field(..., min_length=1) + count: int = Field(..., gt=0) + + @app.get("/items/{item_id}") + def get_item(item_id: int) -> dict: + if item_id == 0: + raise HTTPException(status_code=400, detail="Invalid item ID") + if item_id == 999: + raise HTTPException(status_code=404, detail="Item not found") + if item_id == 401: + raise HTTPException(status_code=401, detail="Not authenticated") + return envelope({"id": item_id, "name": "test"}) + + @app.post("/items") + def create_item(item: ItemRequest) -> dict: + return envelope({"id": 1, "name": item.name}) + + @app.get("/crash") + def crash() -> dict: + msg = "unexpected failure" + raise RuntimeError(msg) + + return app + + +class TestHttpExceptionEnvelope: + """HTTPException responses use the standard envelope format.""" + + def test_400_returns_envelope(self) -> None: + app = _build_test_app() + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.get("/items/0") + assert resp.status_code == 400 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert body["error"] == "Invalid item ID" + + def test_404_returns_envelope(self) -> None: + app = _build_test_app() + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.get("/items/999") + assert resp.status_code == 404 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert body["error"] == "Item not found" + + def test_401_returns_envelope(self) -> None: + app = _build_test_app() + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.get("/items/401") + assert resp.status_code == 401 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert body["error"] == "Not authenticated" + + +class TestValidationErrorEnvelope: + """Validation errors return 422 with envelope format.""" + + def test_validation_error_returns_envelope(self) -> None: + app = _build_test_app() + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.post("/items", json={"name": "", "count": -1}) + assert resp.status_code == 422 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert isinstance(body["error"], str) + assert len(body["error"]) > 0 + + +class TestGeneralExceptionEnvelope: + """Unhandled exceptions return 500 with safe envelope.""" + + def test_unhandled_exception_returns_500_envelope(self) -> None: + app = _build_test_app() + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.get("/crash") + assert resp.status_code == 500 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert body["error"] == "Internal server error" + + +class TestSuccessResponseUnchanged: + """Success responses still work normally.""" + + def test_success_returns_envelope(self) -> None: + app = _build_test_app() + with TestClient(app) as client: + resp = client.get("/items/42") + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + assert body["data"]["id"] == 42 + assert body["error"] is None diff --git a/backend/tests/unit/test_interrupt_cleanup.py b/backend/tests/unit/test_interrupt_cleanup.py new file mode 100644 index 0000000..6a45303 --- /dev/null +++ b/backend/tests/unit/test_interrupt_cleanup.py @@ -0,0 +1,86 @@ +"""Tests for the interrupt cleanup background loop in main.py.""" + +from __future__ import annotations + +import asyncio +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from app.main import _interrupt_cleanup_loop + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_cleanup_loop_calls_cleanup_expired() -> None: + """The loop should call cleanup_expired after each sleep interval.""" + manager = MagicMock() + manager.cleanup_expired.return_value = () + + call_count = 0 + original_sleep = asyncio.sleep + + async def _fake_sleep(seconds: float) -> None: + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise asyncio.CancelledError + await original_sleep(0) + + with patch("app.main.asyncio.sleep", side_effect=_fake_sleep): + with pytest.raises(asyncio.CancelledError): + await _interrupt_cleanup_loop(manager, interval=60) + + assert manager.cleanup_expired.call_count >= 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_cleanup_loop_survives_exceptions() -> None: + """The loop should not die when cleanup_expired raises an exception.""" + manager = MagicMock() + manager.cleanup_expired.side_effect = [RuntimeError("db gone"), ()] + + call_count = 0 + original_sleep = asyncio.sleep + + async def _fake_sleep(seconds: float) -> None: + nonlocal call_count + call_count += 1 + if call_count >= 3: + raise asyncio.CancelledError + await original_sleep(0) + + with patch("app.main.asyncio.sleep", side_effect=_fake_sleep): + with pytest.raises(asyncio.CancelledError): + await _interrupt_cleanup_loop(manager, interval=60) + + # Should have been called twice: once raising, once returning () + assert manager.cleanup_expired.call_count == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_cleanup_loop_logs_expired_count(capsys: pytest.CaptureFixture[str]) -> None: + """The loop should log when expired interrupts are found.""" + fake_record = MagicMock() + manager = MagicMock() + manager.cleanup_expired.return_value = (fake_record, fake_record) + + call_count = 0 + original_sleep = asyncio.sleep + + async def _fake_sleep(seconds: float) -> None: + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise asyncio.CancelledError + await original_sleep(0) + + with patch("app.main.asyncio.sleep", side_effect=_fake_sleep): + with pytest.raises(asyncio.CancelledError): + await _interrupt_cleanup_loop(manager, interval=60) + + captured = capsys.readouterr() + assert "2 expired interrupt" in captured.out diff --git a/backend/tests/unit/test_logging_config.py b/backend/tests/unit/test_logging_config.py new file mode 100644 index 0000000..cac293a --- /dev/null +++ b/backend/tests/unit/test_logging_config.py @@ -0,0 +1,20 @@ +"""Tests for structured logging configuration.""" + +from __future__ import annotations + +import pytest + +from app.logging_config import configure_logging + + +pytestmark = pytest.mark.unit + + +def test_configure_logging_console_mode() -> None: + """Console mode configures without error.""" + configure_logging("console") + + +def test_configure_logging_json_mode() -> None: + """JSON mode configures without error.""" + configure_logging("json") diff --git a/backend/tests/unit/test_main.py b/backend/tests/unit/test_main.py index 8a02f53..b13c83e 100644 --- a/backend/tests/unit/test_main.py +++ b/backend/tests/unit/test_main.py @@ -36,7 +36,7 @@ class TestMainModule: def test_health_route_registered(self) -> None: routes = [r.path for r in app.routes if hasattr(r, "path")] - assert "/api/health" in routes + assert "/api/v1/health" in routes def test_app_version_is_0_5_0(self) -> None: assert app.version == "0.6.0" diff --git a/docker-compose.yml b/docker-compose.yml index 7aa5daf..41b8e11 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -41,7 +41,7 @@ services: postgres: condition: service_healthy healthcheck: - test: ["CMD-SHELL", "curl -f http://localhost:8000/api/health || exit 1"] + test: ["CMD-SHELL", "curl -f http://localhost:8000/api/v1/health || exit 1"] interval: 10s timeout: 5s retries: 5 diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 2ceba77..1a8cc72 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -99,7 +99,12 @@ smart-support/ ├── backend/ │ ├── app/ │ │ ├── main.py # FastAPI + WebSocket 入口 -│ │ ├── graph.py # LangGraph Supervisor 配置 +│ │ ├── graph.py # LangGraph Supervisor 构建 +│ │ ├── graph_context.py # GraphContext: 图 + 分类器 + 注册表的类型化封装 +│ │ ├── ws_handler.py # WebSocket 消息分发 + 速率限制 +│ │ ├── ws_context.py # WebSocketContext: WS 依赖包 +│ │ ├── auth.py # API Key 认证中间件 +│ │ ├── api_utils.py # 共享 API 响应工具 (envelope) │ │ ├── agents/ # Agent 定义 + 工具绑定 │ │ ├── registry.py # YAML Agent 注册表加载器 │ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成 @@ -139,7 +144,11 @@ smart-support/ | 模块 | 职责 | |------|------| | main.py | 应用入口, WebSocket 端点, 静态文件服务 | -| WebSocket Handler | 双向通信: 接收用户消息, 流式返回 token, 处理 interrupt 响应 | +| auth.py | API Key 认证: 管理端点通过 `X-API-Key` header, WebSocket 通过 `?token=` query param | +| ws_handler.py | 双向通信: 接收用户消息, 流式返回 token, 处理 interrupt 响应 | +| graph_context.py | 类型化封装: 将编译后的图与分类器、注册表绑定, 替代猴子补丁 | +| ws_context.py | 依赖包: 将 WebSocket 处理所需的 9 个依赖打包为单一不可变对象 | +| api_utils.py | 共享响应格式: 统一的 `envelope()` 函数 | ### 2.3 Agent 编排层 (LangGraph) @@ -427,6 +436,19 @@ CREATE INDEX idx_interrupts_ttl ON interrupts(ttl_expires_at) WHERE status = 'pending'; ``` +#### sessions (自定义 - 会话状态持久化) + +```sql +-- 用于多 worker 部署的 PostgreSQL 会话状态管理 +-- PgSessionManager 使用此表替代内存中的 dict +CREATE TABLE sessions ( + thread_id TEXT PRIMARY KEY, + last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(), + has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +``` + #### analytics_events (自定义 - 分析事件流) ```sql diff --git a/docs/deployment.md b/docs/deployment.md index f3a5da5..760fbaa 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -54,11 +54,19 @@ Set these in production (never commit secrets): | `ANTHROPIC_API_KEY` | Yes* | LLM provider API key | | `LLM_PROVIDER` | Yes | `anthropic`, `openai`, or `google` | | `LLM_MODEL` | Yes | Model name for your provider | +| `ADMIN_API_KEY` | Recommended | API key for admin endpoints (analytics, replay, openapi, WS). Leave empty to disable auth (dev mode only) | | `WEBHOOK_URL` | No | Escalation notification endpoint | | `SESSION_TTL_MINUTES` | No | Session timeout (default: 30) | *Or `OPENAI_API_KEY` / `GOOGLE_API_KEY` depending on `LLM_PROVIDER`. +### Authentication + +When `ADMIN_API_KEY` is set, all admin REST endpoints require the `X-API-Key` header, +and WebSocket connections require a `?token=` query parameter. + +When unset or empty, authentication is disabled (suitable for local development only). + ### HTTPS For production, place a reverse proxy (nginx, Caddy, or a load balancer) in @@ -87,10 +95,12 @@ cat backup.sql | docker compose exec -T postgres psql -U smart_support smart_sup ### Scaling -The backend is stateless (session state is in PostgreSQL via LangGraph's -PostgresSaver). You can run multiple backend replicas behind a load balancer. +The backend supports multi-worker deployments. LangGraph session state is +persisted in PostgreSQL via PostgresSaver. For full horizontal scaling, use +`PgSessionManager` and `PgInterruptManager` (instead of the default in-memory +managers) to share session and interrupt state across workers. -The WebSocket connections are session-specific. Use sticky sessions or a shared +WebSocket connections are session-specific. Use sticky sessions or a shared session backend if load balancing WebSockets across multiple instances. ## Manual / Development Setup @@ -139,7 +149,7 @@ GET /api/health Response: ```json -{"status": "ok", "version": "0.5.0"} +{"status": "ok", "version": "0.6.0"} ``` ### WebSocket health diff --git a/docs/openapi-import-guide.md b/docs/openapi-import-guide.md index 7f85cef..27d0ec5 100644 --- a/docs/openapi-import-guide.md +++ b/docs/openapi-import-guide.md @@ -86,7 +86,21 @@ Content-Type: application/json POST /api/openapi/jobs/{job_id}/approve ``` -No request body. Changes the job status to `approved`. +No request body. Generates tool code for each classified endpoint and produces +an agent YAML configuration. Response includes `generated_tools_count`. + +Response: +```json +{ + "job_id": "abc123", + "status": "approved", + "spec_url": "https://api.example.com/openapi.yaml", + "total_endpoints": 5, + "classified_count": 5, + "error_message": null, + "generated_tools_count": 5 +} +``` ## Access Type Classification diff --git a/docs/phases/eng-improvements-dev-log.md b/docs/phases/eng-improvements-dev-log.md new file mode 100644 index 0000000..1e12cb2 --- /dev/null +++ b/docs/phases/eng-improvements-dev-log.md @@ -0,0 +1,76 @@ +# Engineering Improvements -- Development Log + +> Status: COMPLETED +> Branch: `eng/engineering-improvements` +> Date started: 2026-04-06 +> Date completed: 2026-04-06 + +## What Was Built + +### Phase 1: Quick Wins (no new deps) + +1. **Interrupt Cleanup Background Task** -- Added asyncio background task in lifespan that calls `interrupt_manager.cleanup_expired()` every 60 seconds. Prevents unbounded memory growth from expired interrupts. + +2. **API Versioning** -- All REST endpoints prefixed with `/api/v1/` (was `/api/`). Updated 4 router prefixes, Docker healthcheck, all frontend fetch URLs, and all test assertions. WebSocket `/ws` endpoint unchanged. + +3. **Error Response Standardization** -- Added global exception handlers for `HTTPException`, `RequestValidationError`, and `Exception`. All error responses now use the same envelope format as success responses: `{"success": false, "data": null, "error": "..."}`. + +### Phase 2: Medium Items (new deps) + +4. **Alembic Database Migrations** -- Replaced inline DDL in `setup_app_tables()` with versioned Alembic migrations. Initial migration `001_initial_schema.py` captures all 4 tables + ALTER TABLE migration. `setup_app_tables()` preserved for tests. Production uses `run_alembic_migrations()`. + +5. **Structured Logging** -- Replaced stdlib `logging.getLogger()` with `structlog.get_logger()` across 10 files. Added `logging_config.py` with console (dev) and JSON (production) modes. Configurable via `LOG_FORMAT` env var. + +### Phase 3: Test Coverage + +7. **Integration Tests (+30)** -- Created 5 new test files: analytics API, replay API, OpenAPI API, error responses, session/interrupt lifecycle. Uses httpx.AsyncClient with ASGITransport for full API layer testing. + +8. **Frontend Tests (+57)** -- Created 12 new test files covering all components (ChatInput, ChatMessages, InterruptPrompt, ErrorBanner, NavBar, MetricCard, ReplayTimeline, AgentAction, Layout), pages (ChatPage, ReviewPage), and hooks (useWebSocket). + +## Code Structure + +### New files created +- `backend/app/logging_config.py` -- structlog configuration +- `backend/alembic.ini` -- Alembic config +- `backend/alembic/env.py` -- Migration environment +- `backend/alembic/versions/001_initial_schema.py` -- Initial migration +- `backend/tests/unit/test_interrupt_cleanup.py` (3 tests) +- `backend/tests/unit/test_error_responses.py` (6 tests) +- `backend/tests/unit/test_logging_config.py` (2 tests) +- `backend/tests/integration/test_analytics_api.py` (6 tests) +- `backend/tests/integration/test_replay_api.py` (6 tests) +- `backend/tests/integration/test_openapi_api.py` (5 tests) +- `backend/tests/integration/test_error_responses.py` (5 tests) +- `backend/tests/integration/test_session_interrupt_lifecycle.py` (8 tests) +- 12 frontend test files (57 tests total) + +### Modified files +- `backend/app/main.py` -- cleanup task, exception handlers, alembic, structlog +- `backend/app/db.py` -- added run_alembic_migrations() +- `backend/app/config.py` -- added log_format setting +- `backend/pyproject.toml` -- added alembic, structlog deps +- 4 router files -- `/api/v1/` prefix +- 10 files -- structlog migration +- `docker-compose.yml` -- healthcheck URL +- `frontend/src/api.ts` -- `/api/v1/` URLs +- All existing test files -- API path updates + error envelope assertions + +## Test Coverage + +- Backend: 557 tests (was 516), 89.75% coverage + - Unit: ~490 tests + - Integration: ~60 tests + - E2E: ~7 tests +- Frontend: 80 tests (was 23), 16 test files (was 4) + +## Deviations from Plan + +- Redis rate limiting deferred (single-worker sufficient for now) +- ConversationTracker verified correct by design (pool per-method), skipped +- Coverage dropped slightly from 90.26% to 89.75% due to new alembic/logging modules with partial test coverage (still well above 80% threshold) + +## Known Issues / Tech Debt + +- Rate limiting remains process-global (needs Redis for multi-worker) +- Alembic migrations not tested against real PostgreSQL in CI (would need running DB) +- Frontend test coverage could be deeper (e.g., WebSocket reconnect edge cases) diff --git a/frontend/src/api.test.ts b/frontend/src/api.test.ts index da6859f..dc1d4bf 100644 --- a/frontend/src/api.test.ts +++ b/frontend/src/api.test.ts @@ -31,14 +31,14 @@ describe("fetchConversations", () => { const result = await fetchConversations(); expect(result.conversations).toHaveLength(1); expect(result.total).toBe(1); - expect(mockFetch).toHaveBeenCalledWith("/api/conversations?page=1&per_page=20"); + expect(mockFetch).toHaveBeenCalledWith("/api/v1/conversations?page=1&per_page=20"); }); it("passes custom page and perPage", async () => { mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { conversations: [], total: 0, page: 2, per_page: 10 }, error: null })); await fetchConversations(2, 10); - expect(mockFetch).toHaveBeenCalledWith("/api/conversations?page=2&per_page=10"); + expect(mockFetch).toHaveBeenCalledWith("/api/v1/conversations?page=2&per_page=10"); }); it("throws on HTTP error", async () => { @@ -80,7 +80,7 @@ describe("fetchReplay", () => { mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { thread_id: "a/b", total_steps: 0, page: 1, per_page: 20, steps: [] }, error: null })); await fetchReplay("a/b"); - expect(mockFetch).toHaveBeenCalledWith("/api/replay/a%2Fb?page=1&per_page=20"); + expect(mockFetch).toHaveBeenCalledWith("/api/v1/replay/a%2Fb?page=1&per_page=20"); }); it("throws on HTTP error", async () => { @@ -112,6 +112,6 @@ describe("fetchAnalytics", () => { mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { range: "7d" }, error: null })); await fetchAnalytics(); - expect(mockFetch).toHaveBeenCalledWith("/api/analytics?range=7d"); + expect(mockFetch).toHaveBeenCalledWith("/api/v1/analytics?range=7d"); }); }); diff --git a/frontend/src/api.ts b/frontend/src/api.ts index b49263d..2e9dc49 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -84,7 +84,7 @@ export async function fetchConversations( perPage = 20 ): Promise { return apiFetch( - `/api/conversations?page=${page}&per_page=${perPage}` + `/api/v1/conversations?page=${page}&per_page=${perPage}` ); } @@ -94,12 +94,12 @@ export async function fetchReplay( perPage = 20 ): Promise { return apiFetch( - `/api/replay/${encodeURIComponent(threadId)}?page=${page}&per_page=${perPage}` + `/api/v1/replay/${encodeURIComponent(threadId)}?page=${page}&per_page=${perPage}` ); } export async function fetchAnalytics(range = "7d"): Promise { - return apiFetch(`/api/analytics?range=${range}`); + return apiFetch(`/api/v1/analytics?range=${range}`); } // -- OpenAPI import -- @@ -143,11 +143,11 @@ async function apiPost(path: string, body: unknown): Promise { } export async function startImport(url: string): Promise { - return apiPost("/api/openapi/import", { url }); + return apiPost("/api/v1/openapi/import", { url }); } export async function fetchImportJob(jobId: string): Promise { - const res = await fetch(`${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}`); + const res = await fetch(`${API_BASE}/api/v1/openapi/jobs/${encodeURIComponent(jobId)}`); if (!res.ok) { throw new Error(`API error ${res.status}: ${res.statusText}`); } @@ -158,7 +158,7 @@ export async function fetchClassifications( jobId: string ): Promise { const res = await fetch( - `${API_BASE}/api/openapi/jobs/${encodeURIComponent(jobId)}/classifications` + `${API_BASE}/api/v1/openapi/jobs/${encodeURIComponent(jobId)}/classifications` ); if (!res.ok) { throw new Error(`API error ${res.status}: ${res.statusText}`); @@ -168,7 +168,7 @@ export async function fetchClassifications( export async function approveJob(jobId: string): Promise { return apiPost( - `/api/openapi/jobs/${encodeURIComponent(jobId)}/approve`, + `/api/v1/openapi/jobs/${encodeURIComponent(jobId)}/approve`, {} ); } diff --git a/frontend/src/components/AgentAction.test.tsx b/frontend/src/components/AgentAction.test.tsx new file mode 100644 index 0000000..b895139 --- /dev/null +++ b/frontend/src/components/AgentAction.test.tsx @@ -0,0 +1,47 @@ +import { describe, it, expect } from "vitest"; +import { render, screen, fireEvent } from "@testing-library/react"; +import { AgentAction } from "./AgentAction"; +import type { ToolAction } from "../types"; + +function makeAction(overrides: Partial = {}): ToolAction { + return { + id: "action-1", + agent: "OrderAgent", + tool: "get_order", + args: { order_id: "ORD-100" }, + timestamp: Date.now(), + ...overrides, + }; +} + +describe("AgentAction", () => { + it("renders agent name and tool name", () => { + render(); + + expect(screen.getByText("OrderAgent")).toBeInTheDocument(); + expect(screen.getByText("get_order")).toBeInTheDocument(); + }); + + it("shows args and result when expanded", () => { + const action = makeAction({ result: { status: "shipped" } }); + render(); + + // Click header to expand + fireEvent.click(screen.getByText("OrderAgent")); + + expect(screen.getByText("Args:")).toBeInTheDocument(); + expect(screen.getByText("Result:")).toBeInTheDocument(); + expect(screen.getByText(/"order_id": "ORD-100"/)).toBeInTheDocument(); + expect(screen.getByText(/"status": "shipped"/)).toBeInTheDocument(); + }); + + it("does not show result section when result is undefined", () => { + render(); + + // Expand + fireEvent.click(screen.getByText("OrderAgent")); + + expect(screen.getByText("Args:")).toBeInTheDocument(); + expect(screen.queryByText("Result:")).not.toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/ChatInput.test.tsx b/frontend/src/components/ChatInput.test.tsx new file mode 100644 index 0000000..c3b40a7 --- /dev/null +++ b/frontend/src/components/ChatInput.test.tsx @@ -0,0 +1,53 @@ +import { describe, it, expect, vi } from "vitest"; +import { render, screen, fireEvent } from "@testing-library/react"; +import { ChatInput } from "./ChatInput"; + +describe("ChatInput", () => { + it("renders input field and send button", () => { + render(); + + expect(screen.getByPlaceholderText("Message Smart Support...")).toBeInTheDocument(); + expect(screen.getByRole("button", { name: "Send Message" })).toBeInTheDocument(); + }); + + it("calls onSend with trimmed content when form is submitted via Enter", () => { + const onSend = vi.fn(); + render(); + + const input = screen.getByPlaceholderText("Message Smart Support..."); + fireEvent.change(input, { target: { value: " Hello world " } }); + fireEvent.keyDown(input, { key: "Enter" }); + + expect(onSend).toHaveBeenCalledWith("Hello world"); + }); + + it("clears input after successful send", () => { + const onSend = vi.fn(); + render(); + + const input = screen.getByPlaceholderText("Message Smart Support...") as HTMLInputElement; + fireEvent.change(input, { target: { value: "Test message" } }); + fireEvent.keyDown(input, { key: "Enter" }); + + expect(input.value).toBe(""); + }); + + it("shows disabled placeholder and disables input when disabled", () => { + render(); + + const input = screen.getByPlaceholderText("Agent is working...") as HTMLInputElement; + expect(input.disabled).toBe(true); + expect(screen.getByRole("button", { name: "Send Message" })).toBeDisabled(); + }); + + it("does not call onSend when input is empty or whitespace", () => { + const onSend = vi.fn(); + render(); + + const input = screen.getByPlaceholderText("Message Smart Support..."); + fireEvent.change(input, { target: { value: " " } }); + fireEvent.keyDown(input, { key: "Enter" }); + + expect(onSend).not.toHaveBeenCalled(); + }); +}); diff --git a/frontend/src/components/ChatMessages.test.tsx b/frontend/src/components/ChatMessages.test.tsx new file mode 100644 index 0000000..6d92e74 --- /dev/null +++ b/frontend/src/components/ChatMessages.test.tsx @@ -0,0 +1,59 @@ +import { describe, it, expect, vi } from "vitest"; +import { render, screen } from "@testing-library/react"; +import { ChatMessages } from "./ChatMessages"; +import type { ChatMessage } from "../types"; + +// Mock react-markdown to avoid complex rendering +vi.mock("react-markdown", () => ({ + default: ({ children }: { children: string }) => {children}, +})); + +describe("ChatMessages", () => { + it("renders welcome message when messages array is empty", () => { + render(); + + expect(screen.getByText("Hello! How can I help you today?")).toBeInTheDocument(); + expect(screen.getByText("Smart Support")).toBeInTheDocument(); + }); + + it("renders user messages with correct sender label", () => { + const messages: ChatMessage[] = [ + { id: "1", sender: "user", content: "I need help", timestamp: Date.now() }, + ]; + render(); + + expect(screen.getByText("You")).toBeInTheDocument(); + expect(screen.getByText("I need help")).toBeInTheDocument(); + expect(screen.getByText("Me")).toBeInTheDocument(); + }); + + it("renders agent messages with agent name", () => { + const messages: ChatMessage[] = [ + { id: "2", sender: "agent", agent: "OrderBot", content: "Sure, let me check.", timestamp: Date.now() }, + ]; + render(); + + expect(screen.getByText("OrderBot")).toBeInTheDocument(); + expect(screen.getByText("Sure, let me check.")).toBeInTheDocument(); + expect(screen.getByText("AI")).toBeInTheDocument(); + }); + + it("shows streaming cursor for messages being streamed", () => { + const messages: ChatMessage[] = [ + { id: "3", sender: "agent", agent: "Bot", content: "Processing", timestamp: Date.now(), isStreaming: true }, + ]; + render(); + + expect(screen.getByText("|")).toBeInTheDocument(); + expect(document.querySelector(".cursor-blink")).toBeTruthy(); + }); + + it("shows fallback agent label when agent field is missing", () => { + const messages: ChatMessage[] = [ + { id: "4", sender: "agent", content: "Generic response", timestamp: Date.now() }, + ]; + render(); + + expect(screen.getByText("Agent")).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/ErrorBanner.test.tsx b/frontend/src/components/ErrorBanner.test.tsx new file mode 100644 index 0000000..570be20 --- /dev/null +++ b/frontend/src/components/ErrorBanner.test.tsx @@ -0,0 +1,33 @@ +import { describe, it, expect, vi } from "vitest"; +import { render, screen, fireEvent } from "@testing-library/react"; +import { ErrorBanner } from "./ErrorBanner"; + +describe("ErrorBanner", () => { + it("returns null when status is connected", () => { + const { container } = render(); + expect(container.innerHTML).toBe(""); + }); + + it("shows disconnection message when status is disconnected", () => { + render(); + + expect(screen.getByText("Disconnected from server. Retrying...")).toBeInTheDocument(); + expect(screen.getByRole("alert")).toBeInTheDocument(); + }); + + it("shows connecting message when status is connecting", () => { + render(); + + expect(screen.getByText("Connecting to server...")).toBeInTheDocument(); + // No reconnect button while connecting + expect(screen.queryByText("Reconnect")).not.toBeInTheDocument(); + }); + + it("calls onReconnect when reconnect button is clicked", () => { + const onReconnect = vi.fn(); + render(); + + fireEvent.click(screen.getByText("Reconnect")); + expect(onReconnect).toHaveBeenCalledTimes(1); + }); +}); diff --git a/frontend/src/components/InterruptPrompt.test.tsx b/frontend/src/components/InterruptPrompt.test.tsx new file mode 100644 index 0000000..9b9d841 --- /dev/null +++ b/frontend/src/components/InterruptPrompt.test.tsx @@ -0,0 +1,58 @@ +import { describe, it, expect, vi } from "vitest"; +import { render, screen, fireEvent } from "@testing-library/react"; +import { InterruptPrompt } from "./InterruptPrompt"; +import type { InterruptMessage } from "../types"; + +describe("InterruptPrompt", () => { + const baseInterrupt: InterruptMessage = { + type: "interrupt", + thread_id: "t1", + action: "cancel_order", + params: {}, + }; + + it("renders action name and approval title", () => { + render(); + + expect(screen.getByText("Action Requires Approval")).toBeInTheDocument(); + expect(screen.getByText("cancel_order")).toBeInTheDocument(); + }); + + it("calls onRespond with true when Approve button is clicked", () => { + const onRespond = vi.fn(); + render(); + + fireEvent.click(screen.getByText("Approve Action")); + expect(onRespond).toHaveBeenCalledWith(true); + }); + + it("calls onRespond with false when Reject button is clicked", () => { + const onRespond = vi.fn(); + render(); + + fireEvent.click(screen.getByText("Reject & Escalate")); + expect(onRespond).toHaveBeenCalledWith(false); + }); + + it("displays order_id parameter when present", () => { + const interrupt: InterruptMessage = { + ...baseInterrupt, + params: { order_id: "ORD-12345" }, + }; + render(); + + expect(screen.getByText("Target Order ID")).toBeInTheDocument(); + expect(screen.getByText("ORD-12345")).toBeInTheDocument(); + }); + + it("displays message parameter when present", () => { + const interrupt: InterruptMessage = { + ...baseInterrupt, + params: { message: "This will refund $50" }, + }; + render(); + + expect(screen.getByText("Detail Message")).toBeInTheDocument(); + expect(screen.getByText("This will refund $50")).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/Layout.test.tsx b/frontend/src/components/Layout.test.tsx new file mode 100644 index 0000000..d51bd05 --- /dev/null +++ b/frontend/src/components/Layout.test.tsx @@ -0,0 +1,39 @@ +import { describe, it, expect, vi } from "vitest"; +import { render, screen } from "@testing-library/react"; +import { MemoryRouter, Routes, Route } from "react-router-dom"; +import { Layout } from "./Layout"; + +// Mock NavBar to simplify layout tests +vi.mock("./NavBar", () => ({ + NavBar: () => , +})); + +function renderLayout(path = "/") { + return render( + + + }> + Home Content} /> + Dashboard Content} /> + + + + ); +} + +describe("Layout", () => { + it("renders NavBar component", () => { + renderLayout(); + expect(screen.getByTestId("navbar")).toBeInTheDocument(); + }); + + it("renders child route content via Outlet", () => { + renderLayout("/"); + expect(screen.getByText("Home Content")).toBeInTheDocument(); + }); + + it("renders correct content for different routes", () => { + renderLayout("/dashboard"); + expect(screen.getByText("Dashboard Content")).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/MetricCard.test.tsx b/frontend/src/components/MetricCard.test.tsx new file mode 100644 index 0000000..5fc2d1d --- /dev/null +++ b/frontend/src/components/MetricCard.test.tsx @@ -0,0 +1,28 @@ +import { describe, it, expect } from "vitest"; +import { render, screen } from "@testing-library/react"; +import { MetricCard } from "./MetricCard"; + +describe("MetricCard", () => { + it("renders label and value", () => { + render(); + + expect(screen.getByText("Total Users")).toBeInTheDocument(); + expect(screen.getByText("42")).toBeInTheDocument(); + }); + + it("renders with unit prefix and suffix", () => { + render(); + + expect(screen.getByText("Cost")).toBeInTheDocument(); + expect(screen.getByText("$")).toBeInTheDocument(); + expect(screen.getByText("3.50")).toBeInTheDocument(); + expect(screen.getByText("/mo")).toBeInTheDocument(); + }); + + it("handles zero value correctly", () => { + render(); + + expect(screen.getByText("Errors")).toBeInTheDocument(); + expect(screen.getByText("0")).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/NavBar.test.tsx b/frontend/src/components/NavBar.test.tsx new file mode 100644 index 0000000..8bd9606 --- /dev/null +++ b/frontend/src/components/NavBar.test.tsx @@ -0,0 +1,54 @@ +import { describe, it, expect } from "vitest"; +import { render, screen } from "@testing-library/react"; +import { MemoryRouter } from "react-router-dom"; +import { NavBar } from "./NavBar"; + +function renderNavBar(initialPath = "/") { + return render( + + + + ); +} + +describe("NavBar", () => { + it("renders all navigation links", () => { + renderNavBar(); + + expect(screen.getByText("Dashboard")).toBeInTheDocument(); + expect(screen.getByText("Inbox")).toBeInTheDocument(); + expect(screen.getByText("Conversation Replay")).toBeInTheDocument(); + expect(screen.getByText("Agents & Tools")).toBeInTheDocument(); + }); + + it("navigation links point to correct routes", () => { + renderNavBar(); + + const dashboardLink = screen.getByText("Dashboard").closest("a"); + expect(dashboardLink).toHaveAttribute("href", "/dashboard"); + + const inboxLink = screen.getByText("Inbox").closest("a"); + expect(inboxLink).toHaveAttribute("href", "/"); + + const replayLink = screen.getByText("Conversation Replay").closest("a"); + expect(replayLink).toHaveAttribute("href", "/replay"); + + const reviewLink = screen.getByText("Agents & Tools").closest("a"); + expect(reviewLink).toHaveAttribute("href", "/review"); + }); + + it("active link has active class when on matching route", () => { + renderNavBar("/dashboard"); + + const dashboardLink = screen.getByText("Dashboard").closest("a"); + expect(dashboardLink?.className).toContain("active"); + + const inboxLink = screen.getByText("Inbox").closest("a"); + expect(inboxLink?.className).not.toContain("active"); + }); + + it("renders brand name", () => { + renderNavBar(); + expect(screen.getByText("Nexus AI")).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/ReplayTimeline.test.tsx b/frontend/src/components/ReplayTimeline.test.tsx new file mode 100644 index 0000000..2f62d86 --- /dev/null +++ b/frontend/src/components/ReplayTimeline.test.tsx @@ -0,0 +1,69 @@ +import { describe, it, expect } from "vitest"; +import { render, screen, fireEvent } from "@testing-library/react"; +import { ReplayTimeline } from "./ReplayTimeline"; +import type { ReplayStep } from "../api"; + +function makeStep(overrides: Partial = {}): ReplayStep { + return { + step: 1, + type: "message", + content: "Hello", + agent: null, + tool: null, + params: null, + result: null, + timestamp: "2026-04-01T12:00:00Z", + ...overrides, + }; +} + +describe("ReplayTimeline", () => { + it("returns null when steps array is empty", () => { + const { container } = render(); + expect(container.innerHTML).toBe(""); + }); + + it("renders a list of steps with type badges", () => { + const steps = [ + makeStep({ step: 1, type: "message", content: "User said hi" }), + makeStep({ step: 2, type: "tool_call", content: "Calling API", agent: "OrderBot", tool: "get_order" }), + ]; + render(); + + expect(screen.getByText("message")).toBeInTheDocument(); + expect(screen.getByText("tool call")).toBeInTheDocument(); + expect(screen.getByText("User said hi")).toBeInTheDocument(); + expect(screen.getByText("OrderBot")).toBeInTheDocument(); + expect(screen.getByText("get_order()")).toBeInTheDocument(); + }); + + it("expands step details when View JSON Payload button is clicked", () => { + const steps = [ + makeStep({ + step: 1, + type: "tool_call", + params: { order_id: "123" }, + result: { status: "ok" }, + }), + ]; + render(); + + const expandButton = screen.getByText("View JSON Payload", { exact: false }); + expect(expandButton).toBeInTheDocument(); + + fireEvent.click(expandButton); + + // After expanding, the JSON payload should be visible + expect(screen.getByText("Hide JSON Payload", { exact: false })).toBeInTheDocument(); + expect(screen.getByText(/"order_id": "123"/)).toBeInTheDocument(); + }); + + it("does not show expand button when step has no params or result", () => { + const steps = [ + makeStep({ step: 1, type: "message", params: null, result: null }), + ]; + render(); + + expect(screen.queryByText("View JSON Payload", { exact: false })).not.toBeInTheDocument(); + }); +}); diff --git a/frontend/src/hooks/useWebSocket.test.ts b/frontend/src/hooks/useWebSocket.test.ts new file mode 100644 index 0000000..707b64d --- /dev/null +++ b/frontend/src/hooks/useWebSocket.test.ts @@ -0,0 +1,221 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { renderHook, act } from "@testing-library/react"; +import { useWebSocket } from "./useWebSocket"; + +// Mock sessionStorage +const mockSessionStorage: Record = {}; +vi.stubGlobal("sessionStorage", { + getItem: (key: string) => mockSessionStorage[key] ?? null, + setItem: (key: string, value: string) => { + mockSessionStorage[key] = value; + }, +}); + +// Mock crypto.randomUUID +vi.stubGlobal("crypto", { randomUUID: () => "test-uuid-1234" }); + +// Mock WebSocket +class MockWebSocket { + static OPEN = 1; + static CLOSED = 3; + static instances: MockWebSocket[] = []; + + url: string; + readyState = 0; + onopen: (() => void) | null = null; + onclose: (() => void) | null = null; + onmessage: ((event: { data: string }) => void) | null = null; + onerror: (() => void) | null = null; + send = vi.fn(); + close = vi.fn().mockImplementation(() => { + this.readyState = MockWebSocket.CLOSED; + // Trigger onclose asynchronously like real WebSocket + setTimeout(() => this.onclose?.(), 0); + }); + + constructor(url: string) { + this.url = url; + MockWebSocket.instances.push(this); + } + + simulateOpen() { + this.readyState = MockWebSocket.OPEN; + this.onopen?.(); + } + + simulateMessage(data: unknown) { + this.onmessage?.({ data: JSON.stringify(data) }); + } + + simulateClose() { + this.readyState = MockWebSocket.CLOSED; + this.onclose?.(); + } + + simulateError() { + this.onerror?.(); + } +} + +vi.stubGlobal("WebSocket", MockWebSocket); + +beforeEach(() => { + MockWebSocket.instances = []; + delete mockSessionStorage["smart_support_thread_id"]; + vi.useFakeTimers(); +}); + +afterEach(() => { + vi.useRealTimers(); +}); + +describe("useWebSocket", () => { + it("establishes connection with correct URL on mount", () => { + const onMessage = vi.fn(); + renderHook(() => useWebSocket(onMessage)); + + expect(MockWebSocket.instances).toHaveLength(1); + expect(MockWebSocket.instances[0].url).toContain("/ws"); + }); + + it("sets status to connected when WebSocket opens", () => { + const onMessage = vi.fn(); + const { result } = renderHook(() => useWebSocket(onMessage)); + + expect(result.current.status).toBe("connecting"); + + act(() => { + MockWebSocket.instances[0].simulateOpen(); + }); + + expect(result.current.status).toBe("connected"); + }); + + it("parses incoming JSON messages and dispatches to callback", () => { + const onMessage = vi.fn(); + renderHook(() => useWebSocket(onMessage)); + + act(() => { + MockWebSocket.instances[0].simulateOpen(); + }); + + const serverMsg = { type: "token", agent: "bot", content: "Hello" }; + act(() => { + MockWebSocket.instances[0].simulateMessage(serverMsg); + }); + + expect(onMessage).toHaveBeenCalledWith(serverMsg); + }); + + it("sends JSON through WebSocket via sendMessage", () => { + const onMessage = vi.fn(); + const { result } = renderHook(() => useWebSocket(onMessage)); + + act(() => { + MockWebSocket.instances[0].simulateOpen(); + }); + + act(() => { + result.current.sendMessage("Hi there"); + }); + + expect(MockWebSocket.instances[0].send).toHaveBeenCalledTimes(1); + const sent = JSON.parse(MockWebSocket.instances[0].send.mock.calls[0][0]); + expect(sent.type).toBe("message"); + expect(sent.content).toBe("Hi there"); + expect(sent.thread_id).toBeDefined(); + }); + + it("calls onDisconnect when WebSocket closes", () => { + const onMessage = vi.fn(); + const onDisconnect = vi.fn(); + renderHook(() => useWebSocket(onMessage, { onDisconnect })); + + act(() => { + MockWebSocket.instances[0].simulateOpen(); + }); + + act(() => { + MockWebSocket.instances[0].simulateClose(); + }); + + expect(onDisconnect).toHaveBeenCalledTimes(1); + }); + + it("sets status to disconnected on close and attempts reconnect", () => { + const onMessage = vi.fn(); + const { result } = renderHook(() => useWebSocket(onMessage)); + + act(() => { + MockWebSocket.instances[0].simulateOpen(); + }); + + act(() => { + MockWebSocket.instances[0].simulateClose(); + }); + + expect(result.current.status).toBe("disconnected"); + + // After timeout, a new WebSocket should be created (reconnect attempt) + act(() => { + vi.advanceTimersByTime(1500); + }); + + expect(MockWebSocket.instances.length).toBeGreaterThanOrEqual(2); + }); + + it("closes WebSocket on error event", () => { + const onMessage = vi.fn(); + renderHook(() => useWebSocket(onMessage)); + + const ws = MockWebSocket.instances[0]; + act(() => { + ws.simulateError(); + }); + + expect(ws.close).toHaveBeenCalledTimes(1); + }); + + it("reconnect resets retries and creates a new connection", () => { + const onMessage = vi.fn(); + const { result } = renderHook(() => useWebSocket(onMessage)); + + act(() => { + MockWebSocket.instances[0].simulateOpen(); + }); + + const wsBeforeReconnect = MockWebSocket.instances[0]; + + act(() => { + result.current.reconnect(); + }); + + // The old socket should have been closed + expect(wsBeforeReconnect.close).toHaveBeenCalled(); + + // Let the close callback fire and reconnect timer run + act(() => { + vi.advanceTimersByTime(100); + }); + + // A new WebSocket should have been created + expect(MockWebSocket.instances.length).toBeGreaterThan(1); + }); + + it("sends interrupt response with approved flag", () => { + const onMessage = vi.fn(); + const { result } = renderHook(() => useWebSocket(onMessage)); + + act(() => { + MockWebSocket.instances[0].simulateOpen(); + }); + + act(() => { + result.current.sendInterruptResponse(true); + }); + + const sent = JSON.parse(MockWebSocket.instances[0].send.mock.calls[0][0]); + expect(sent.type).toBe("interrupt_response"); + expect(sent.approved).toBe(true); + }); +}); diff --git a/frontend/src/pages/ChatPage.test.tsx b/frontend/src/pages/ChatPage.test.tsx new file mode 100644 index 0000000..06caddc --- /dev/null +++ b/frontend/src/pages/ChatPage.test.tsx @@ -0,0 +1,106 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { render, screen, fireEvent, waitFor, act } from "@testing-library/react"; +import { ChatPage } from "./ChatPage"; + +// Mock react-markdown +vi.mock("react-markdown", () => ({ + default: ({ children }: { children: string }) => {children}, +})); + +// Mock crypto.randomUUID for stable IDs +vi.stubGlobal("crypto", { randomUUID: () => `uuid-${Date.now()}-${Math.random()}` }); + +// Capture the onMessage callback from the hook +let capturedOnMessage: ((msg: unknown) => void) | null = null; +const mockSendMessage = vi.fn(); +const mockSendInterruptResponse = vi.fn(); +const mockReconnect = vi.fn(); +let mockStatus = "connected"; + +vi.mock("../hooks/useWebSocket", () => ({ + useWebSocket: (onMessage: (msg: unknown) => void) => { + capturedOnMessage = onMessage; + return { + status: mockStatus, + threadId: "test-thread", + sendMessage: mockSendMessage, + sendInterruptResponse: mockSendInterruptResponse, + reconnect: mockReconnect, + }; + }, +})); + +beforeEach(() => { + capturedOnMessage = null; + mockSendMessage.mockReset(); + mockSendInterruptResponse.mockReset(); + mockReconnect.mockReset(); + mockStatus = "connected"; +}); + +describe("ChatPage", () => { + it("renders chat interface with input field and header", () => { + render(); + + expect(screen.getByText("Inbox")).toBeInTheDocument(); + expect(screen.getByPlaceholderText("Message Smart Support...")).toBeInTheDocument(); + expect(screen.getByRole("button", { name: "Send Message" })).toBeInTheDocument(); + }); + + it("user can type and submit a message", () => { + render(); + + const input = screen.getByPlaceholderText("Message Smart Support..."); + fireEvent.change(input, { target: { value: "Hello bot" } }); + fireEvent.keyDown(input, { key: "Enter" }); + + expect(mockSendMessage).toHaveBeenCalledWith("Hello bot"); + expect(screen.getByText("Hello bot")).toBeInTheDocument(); + expect(screen.getByText("You")).toBeInTheDocument(); + }); + + it("displays streaming tokens as they arrive", () => { + render(); + + act(() => { + capturedOnMessage?.({ type: "token", agent: "Bot", content: "Hello " }); + }); + act(() => { + capturedOnMessage?.({ type: "token", agent: "Bot", content: "world" }); + }); + + expect(screen.getByText("Hello world")).toBeInTheDocument(); + }); + + it("shows interrupt prompt when interrupt message received", () => { + render(); + + act(() => { + capturedOnMessage?.({ + type: "interrupt", + thread_id: "t1", + action: "cancel_order", + params: { order_id: "ORD-999" }, + }); + }); + + expect(screen.getByText("Action Requires Approval")).toBeInTheDocument(); + expect(screen.getByText("cancel_order")).toBeInTheDocument(); + }); + + it("shows error message when server sends error", () => { + render(); + + act(() => { + capturedOnMessage?.({ type: "error", message: "Something went wrong" }); + }); + + expect(screen.getByText("Error: Something went wrong")).toBeInTheDocument(); + }); + + it("renders welcome message in empty state", () => { + render(); + + expect(screen.getByText("Hello! How can I help you today?")).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/pages/ReviewPage.test.tsx b/frontend/src/pages/ReviewPage.test.tsx new file mode 100644 index 0000000..57437b8 --- /dev/null +++ b/frontend/src/pages/ReviewPage.test.tsx @@ -0,0 +1,200 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { render, screen, fireEvent, waitFor } from "@testing-library/react"; +import { ReviewPage } from "./ReviewPage"; + +vi.mock("../api", () => ({ + startImport: vi.fn(), + fetchImportJob: vi.fn(), + fetchClassifications: vi.fn(), + approveJob: vi.fn(), +})); + +import { startImport, fetchImportJob, fetchClassifications, approveJob } from "../api"; +const mockStartImport = vi.mocked(startImport); +const mockFetchImportJob = vi.mocked(fetchImportJob); +const mockFetchClassifications = vi.mocked(fetchClassifications); +const mockApproveJob = vi.mocked(approveJob); + +beforeEach(() => { + mockStartImport.mockReset(); + mockFetchImportJob.mockReset(); + mockFetchClassifications.mockReset(); + mockApproveJob.mockReset(); +}); + +describe("ReviewPage", () => { + it("renders the OpenAPI URL input form", () => { + render(); + + expect(screen.getByText("Agents & Tools Registry")).toBeInTheDocument(); + expect(screen.getByPlaceholderText("https://example.com/openapi.yaml")).toBeInTheDocument(); + expect(screen.getByText("Scan Tools")).toBeInTheDocument(); + }); + + it("submit form triggers API call with entered URL", async () => { + mockStartImport.mockResolvedValue({ + job_id: "job-1", + status: "processing", + spec_url: "https://example.com/openapi.yaml", + total_endpoints: 5, + classified_count: 0, + error_message: null, + }); + mockFetchImportJob.mockResolvedValue({ + job_id: "job-1", + status: "processing", + spec_url: "https://example.com/openapi.yaml", + total_endpoints: 5, + classified_count: 0, + error_message: null, + }); + + render(); + + const input = screen.getByPlaceholderText("https://example.com/openapi.yaml"); + fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } }); + fireEvent.click(screen.getByText("Scan Tools")); + + await waitFor(() => { + expect(mockStartImport).toHaveBeenCalledWith("https://example.com/openapi.yaml"); + }); + }); + + it("shows loading state during import", async () => { + mockStartImport.mockReturnValue(new Promise(() => {})); // never resolves + + render(); + + const input = screen.getByPlaceholderText("https://example.com/openapi.yaml"); + fireEvent.change(input, { target: { value: "https://api.test.com/spec.json" } }); + fireEvent.click(screen.getByText("Scan Tools")); + + await waitFor(() => { + expect(screen.getByText("Importing...")).toBeInTheDocument(); + }); + }); + + it("displays classification results after job completes", async () => { + mockStartImport.mockResolvedValue({ + job_id: "job-1", + status: "done", + spec_url: "https://example.com/openapi.yaml", + total_endpoints: 2, + classified_count: 2, + error_message: null, + }); + mockFetchImportJob.mockResolvedValue({ + job_id: "job-1", + status: "done", + spec_url: "https://example.com/openapi.yaml", + total_endpoints: 2, + classified_count: 2, + error_message: null, + }); + mockFetchClassifications.mockResolvedValue([ + { + index: 0, + access_type: "read", + needs_interrupt: false, + agent_group: "OrderAgent", + confidence: 0.95, + customer_params: [], + endpoint: { path: "/orders", method: "get", operation_id: "getOrders", summary: "List orders", description: "" }, + }, + { + index: 1, + access_type: "write", + needs_interrupt: true, + agent_group: "OrderAgent", + confidence: 0.9, + customer_params: ["order_id"], + endpoint: { path: "/orders/{id}/cancel", method: "post", operation_id: "cancelOrder", summary: "Cancel an order", description: "" }, + }, + ]); + + render(); + + const input = screen.getByPlaceholderText("https://example.com/openapi.yaml"); + fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } }); + fireEvent.click(screen.getByText("Scan Tools")); + + await waitFor(() => { + expect(screen.getByText("Assigned Capabilities (2)")).toBeInTheDocument(); + }); + expect(screen.getByText("OrderAgent")).toBeInTheDocument(); + expect(screen.getByText("/orders")).toBeInTheDocument(); + expect(screen.getByText("List orders")).toBeInTheDocument(); + }); + + it("shows error on API failure", async () => { + mockStartImport.mockRejectedValue(new Error("Network timeout")); + + render(); + + const input = screen.getByPlaceholderText("https://example.com/openapi.yaml"); + fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } }); + fireEvent.click(screen.getByText("Scan Tools")); + + await waitFor(() => { + expect(screen.getByText("Error: Network timeout")).toBeInTheDocument(); + }); + }); + + it("shows success message after approval", async () => { + // Set up initial state with classifications + mockStartImport.mockResolvedValue({ + job_id: "job-1", + status: "done", + spec_url: "https://example.com/openapi.yaml", + total_endpoints: 1, + classified_count: 1, + error_message: null, + }); + mockFetchImportJob.mockResolvedValue({ + job_id: "job-1", + status: "done", + spec_url: "https://example.com/openapi.yaml", + total_endpoints: 1, + classified_count: 1, + error_message: null, + }); + mockFetchClassifications.mockResolvedValue([ + { + index: 0, + access_type: "read", + needs_interrupt: false, + agent_group: "TestAgent", + confidence: 0.9, + customer_params: [], + endpoint: { path: "/test", method: "get", operation_id: "test", summary: "Test endpoint", description: "" }, + }, + ]); + mockApproveJob.mockResolvedValue({ + job_id: "job-1", + status: "approved", + spec_url: "https://example.com/openapi.yaml", + total_endpoints: 1, + classified_count: 1, + error_message: null, + generated_tools_count: 3, + }); + + render(); + + // Import first + const input = screen.getByPlaceholderText("https://example.com/openapi.yaml"); + fireEvent.change(input, { target: { value: "https://example.com/openapi.yaml" } }); + fireEvent.click(screen.getByText("Scan Tools")); + + await waitFor(() => { + expect(screen.getByText("Save Configuration")).toBeInTheDocument(); + }); + + // Approve + fireEvent.click(screen.getByText("Save Configuration")); + + await waitFor(() => { + expect(screen.getByText("Configuration saved. 3 tools generated.")).toBeInTheDocument(); + }); + }); +});