refactor: engineering improvements -- API versioning, structured logging, Alembic, error standardization, test coverage

- API versioning: all REST endpoints prefixed with /api/v1/
- Structured logging: replaced stdlib logging with structlog (console/JSON modes)
- Alembic migrations: versioned DB schema with initial migration
- Error standardization: global exception handlers for consistent envelope format
- Interrupt cleanup: asyncio background task for expired interrupt removal
- Integration tests: +30 tests (analytics, replay, openapi, error, session APIs)
- Frontend tests: +57 tests (all components, pages, useWebSocket hook)
- Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
This commit is contained in:
Yaojia Wang
2026-04-06 23:19:29 +02:00
parent af53111928
commit f0699436c5
59 changed files with 2846 additions and 149 deletions

149
backend/alembic.ini Normal file
View File

@@ -0,0 +1,149 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the tzdata library which can be installed by adding
# `alembic[tz]` to the pip requirements.
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url =
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
backend/alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

67
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,67 @@
"""Alembic environment configuration for smart-support."""
from __future__ import annotations
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# No SQLAlchemy ORM models -- we use raw DDL migrations
target_metadata = None
def _get_url() -> str:
"""Read DATABASE_URL from environment, falling back to alembic.ini."""
return os.environ.get("DATABASE_URL", "") or config.get_main_option(
"sqlalchemy.url", ""
)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
Configures the context with just a URL so that an Engine
is not required.
"""
url = _get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode with a live database connection."""
configuration = config.get_section(config.config_ini_section, {})
configuration["sqlalchemy.url"] = _get_url()
connectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,92 @@
"""Initial schema -- all application tables.
Revision ID: a1b2c3d4e5f6
Revises:
Create Date: 2026-04-06
"""
from __future__ import annotations
from alembic import op
revision: str = "a1b2c3d4e5f6"
down_revision: str | None = None
branch_labels: tuple[str, ...] | None = None
depends_on: tuple[str, ...] | None = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE IF NOT EXISTS conversations (
thread_id TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
total_tokens INTEGER NOT NULL DEFAULT 0,
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
status TEXT NOT NULL DEFAULT 'active'
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS active_interrupts (
interrupt_id TEXT PRIMARY KEY,
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
action TEXT NOT NULL,
params JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
resolved_at TIMESTAMPTZ,
resolution TEXT
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS sessions (
thread_id TEXT PRIMARY KEY,
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS analytics_events (
id BIGSERIAL PRIMARY KEY,
thread_id TEXT NOT NULL,
event_type TEXT NOT NULL,
agent_name TEXT,
tool_name TEXT,
tokens_used INTEGER NOT NULL DEFAULT 0,
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
duration_ms INTEGER,
success BOOLEAN,
error_message TEXT,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"""
)
# Migration columns added in Phase 4
op.execute(
"""
ALTER TABLE conversations
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ
"""
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS analytics_events")
op.execute("DROP TABLE IF EXISTS sessions")
op.execute("DROP TABLE IF EXISTS active_interrupts")
op.execute("DROP TABLE IF EXISTS conversations")

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
router = APIRouter(
prefix="/api/analytics",
prefix="/api/v1/analytics",
tags=["analytics"],
dependencies=[Depends(require_admin_api_key)],
)

View File

@@ -2,14 +2,14 @@
from __future__ import annotations
import logging
import secrets
from typing import Annotated
import structlog
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
from fastapi.security import APIKeyHeader
logger = logging.getLogger(__name__)
logger = structlog.get_logger()
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)

View File

@@ -32,6 +32,8 @@ class Settings(BaseSettings):
template_name: str = ""
log_format: str = "console" # "console" for dev, "json" for production
admin_api_key: str = ""
anthropic_api_key: str = ""

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
@@ -88,6 +89,17 @@ async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
return checkpointer
def run_alembic_migrations(database_url: str) -> None:
"""Run Alembic migrations to head."""
from alembic.config import Config
from alembic import command
alembic_cfg = Config(str(Path(__file__).parent.parent / "alembic.ini"))
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
command.upgrade(alembic_cfg, "head")
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
"""Create application-specific tables and apply migrations."""
async with pool.connection() as conn:

View File

@@ -3,14 +3,14 @@
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Protocol
import httpx
import structlog
from pydantic import BaseModel
logger = logging.getLogger(__name__)
logger = structlog.get_logger()
class EscalationPayload(BaseModel, frozen=True):

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from langchain.agents import create_agent
@@ -18,7 +17,9 @@ if TYPE_CHECKING:
from app.intent import IntentClassifier
from app.registry import AgentRegistry
logger = logging.getLogger(__name__)
import structlog
logger = structlog.get_logger()
SUPERVISOR_PROMPT = (
"You are a customer support supervisor. "

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Protocol
from pydantic import BaseModel
@@ -12,7 +11,9 @@ if TYPE_CHECKING:
from app.registry import AgentConfig
logger = logging.getLogger(__name__)
import structlog
logger = structlog.get_logger()
CLASSIFICATION_PROMPT = (
"You are an intent classifier for a customer support system.\n"

View File

@@ -0,0 +1,57 @@
"""Structured logging configuration using structlog."""
from __future__ import annotations
import logging
import sys
import structlog
def configure_logging(log_format: str = "console") -> None:
"""Configure structlog with stdlib integration.
Args:
log_format: "console" for human-readable dev output,
"json" for machine-parseable production output.
"""
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
]
if log_format == "json":
renderer: structlog.types.Processor = structlog.processors.JSONRenderer()
else:
renderer = structlog.dev.ConsoleRenderer()
structlog.configure(
processors=[
*shared_processors,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
renderer,
],
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)

View File

@@ -2,25 +2,30 @@
from __future__ import annotations
import logging
import asyncio
import contextlib
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import Depends, FastAPI, Query, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from app.analytics.api import router as analytics_router
from app.analytics.event_recorder import PostgresAnalyticsRecorder
from app.api_utils import envelope
from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings
from app.conversation_tracker import PostgresConversationTracker
from app.db import create_checkpointer, create_pool, setup_app_tables
from app.db import create_checkpointer, create_pool, run_alembic_migrations
from app.escalation import NoOpEscalator, WebhookEscalator
from app.graph import build_graph
from app.intent import LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.llm import create_llm
from app.logging_config import configure_logging
from app.openapi.review_api import router as openapi_router
from app.registry import AgentRegistry
from app.replay.api import router as replay_router
@@ -31,19 +36,44 @@ from app.ws_handler import dispatch_message
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
logger = logging.getLogger(__name__)
import structlog
logger = structlog.get_logger()
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
async def _interrupt_cleanup_loop(
interrupt_manager: InterruptManager,
interval: int = 60,
) -> None:
"""Periodically remove expired interrupts in the background.
Runs until cancelled. Catches all exceptions to prevent the task
from dying unexpectedly.
"""
while True:
await asyncio.sleep(interval)
try:
expired = interrupt_manager.cleanup_expired()
if expired:
logger.info(
"Cleaned up %d expired interrupt(s)",
len(expired),
)
except Exception:
logger.exception("Error during interrupt cleanup")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
settings = Settings()
configure_logging(settings.log_format)
pool = await create_pool(settings)
checkpointer = await create_checkpointer(pool)
await setup_app_tables(pool)
run_alembic_migrations(settings.database_url)
# Load agents from template or default YAML
if settings.template_name:
@@ -89,8 +119,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
settings.template_name or "(default)",
)
cleanup_task = asyncio.create_task(
_interrupt_cleanup_loop(interrupt_manager),
)
yield
cleanup_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await cleanup_task
await pool.close()
@@ -103,7 +141,35 @@ app.include_router(replay_router)
app.include_router(analytics_router)
@app.get("/api/health")
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap HTTPException in standard envelope format."""
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap validation errors in standard envelope format."""
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Catch-all handler -- never leak stack traces."""
logger.exception("Unhandled exception: %s", exc)
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
@app.get("/api/v1/health")
def health_check() -> dict:
"""Health check endpoint for load balancers and monitoring."""
return {"status": "ok", "version": _VERSION}

View File

@@ -8,13 +8,14 @@ classifier and an LLM-backed classifier with heuristic fallback.
from __future__ import annotations
import json
import logging
import re
from typing import Protocol
import structlog
from app.openapi.models import ClassificationResult, EndpointInfo
logger = logging.getLogger(__name__)
logger = structlog.get_logger()
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})

View File

@@ -6,10 +6,11 @@ Each stage updates the job status and calls the on_progress callback.
from __future__ import annotations
import logging
from collections.abc import Callable
from dataclasses import replace
import structlog
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
from app.openapi.fetcher import fetch_spec
from app.openapi.models import ImportJob
@@ -17,7 +18,7 @@ from app.openapi.parser import parse_endpoints
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
from app.openapi.validator import validate_spec
logger = logging.getLogger(__name__)
logger = structlog.get_logger()
ProgressCallback = Callable[[str, ImportJob], None] | None

View File

@@ -10,11 +10,11 @@ Exposes endpoints for:
from __future__ import annotations
import asyncio
import logging
import re
import uuid
from typing import Literal
import structlog
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel, field_validator
@@ -23,10 +23,10 @@ from app.openapi.generator import generate_agent_yaml, generate_tool_code
from app.openapi.importer import ImportOrchestrator
from app.openapi.models import ClassificationResult, ImportJob
logger = logging.getLogger(__name__)
logger = structlog.get_logger()
router = APIRouter(
prefix="/api/openapi",
prefix="/api/v1/openapi",
tags=["openapi"],
dependencies=[Depends(require_admin_api_key)],
)

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
router = APIRouter(
prefix="/api",
prefix="/api/v1",
tags=["replay"],
dependencies=[Depends(require_admin_api_key)],
)

View File

@@ -2,11 +2,11 @@
from __future__ import annotations
import logging
import structlog
from app.replay.models import ReplayStep, StepType
logger = logging.getLogger(__name__)
logger = structlog.get_logger()
_EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z"

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import json
import logging
import re
import time
from collections import defaultdict
@@ -21,7 +20,9 @@ if TYPE_CHECKING:
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
logger = logging.getLogger(__name__)
import structlog
logger = structlog.get_logger()
MAX_MESSAGE_SIZE = 32_768 # 32 KB
MAX_CONTENT_LENGTH = 10_000 # characters

View File

@@ -21,6 +21,8 @@ dependencies = [
"python-dotenv>=1.0,<2.0",
"httpx>=0.28,<1.0",
"openapi-spec-validator>=0.7,<1.0",
"alembic>=1.13,<2.0",
"structlog>=24.0,<26.0",
]
[project.optional-dependencies]

View File

@@ -174,7 +174,7 @@ def create_e2e_app(
app.state.analytics_recorder = AsyncMock()
app.state.conversation_tracker = AsyncMock()
@app.get("/api/health")
@app.get("/api/v1/health")
def health_check() -> dict:
return {"status": "ok", "version": "test"}

View File

@@ -341,7 +341,7 @@ class TestChatEdgeCases:
def test_health_endpoint(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/health")
resp = client.get("/api/v1/health")
assert resp.status_code == 200
assert resp.json()["status"] == "ok"

View File

@@ -62,7 +62,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client:
# Step 1: Start import job
resp = client.post(
"/api/openapi/import",
"/api/v1/openapi/import",
json={"url": "https://api.example.com/openapi.json"},
)
assert resp.status_code == 202
@@ -71,7 +71,7 @@ class TestFlow5OpenAPIImport:
job_id = body["job_id"]
# Step 2: Check job status (still pending since background task hasn't run)
resp = client.get(f"/api/openapi/jobs/{job_id}")
resp = client.get(f"/api/v1/openapi/jobs/{job_id}")
assert resp.status_code == 200
assert resp.json()["job_id"] == job_id
@@ -99,7 +99,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client:
# Step 1: Get classifications
resp = client.get(f"/api/openapi/jobs/{job_id}/classifications")
resp = client.get(f"/api/v1/openapi/jobs/{job_id}/classifications")
assert resp.status_code == 200
classifications = resp.json()
assert len(classifications) == 2
@@ -118,7 +118,7 @@ class TestFlow5OpenAPIImport:
# Step 2: Update a classification
resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/0",
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
json={
"access_type": "write",
"needs_interrupt": True,
@@ -132,7 +132,7 @@ class TestFlow5OpenAPIImport:
assert updated["agent_group"] == "order_actions"
# Step 3: Approve the job
resp = client.post(f"/api/openapi/jobs/{job_id}/approve")
resp = client.post(f"/api/v1/openapi/jobs/{job_id}/approve")
assert resp.status_code == 200
assert resp.json()["status"] == "approved"
@@ -140,14 +140,14 @@ class TestFlow5OpenAPIImport:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/openapi/jobs/nonexistent")
resp = client.get("/api/v1/openapi/jobs/nonexistent")
assert resp.status_code == 404
def test_import_invalid_url_returns_422(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.post("/api/openapi/import", json={"url": "not-a-url"})
resp = client.post("/api/v1/openapi/import", json={"url": "not-a-url"})
assert resp.status_code == 422
def test_classification_index_out_of_range(self) -> None:
@@ -166,7 +166,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client:
resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/99",
f"/api/v1/openapi/jobs/{job_id}/classifications/99",
json={
"access_type": "read",
"needs_interrupt": False,
@@ -191,7 +191,7 @@ class TestFlow5OpenAPIImport:
with TestClient(app) as client:
resp = client.put(
f"/api/openapi/jobs/{job_id}/classifications/0",
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
json={
"access_type": "read",
"needs_interrupt": False,

View File

@@ -98,7 +98,7 @@ class TestFlow6ReplayConversation:
app = create_e2e_app(pool=pool)
with TestClient(app) as client:
resp = client.get("/api/conversations")
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
@@ -124,7 +124,7 @@ class TestFlow6ReplayConversation:
app = create_e2e_app(pool=pool)
with TestClient(app) as client:
resp = client.get("/api/conversations", params={"page": 1, "per_page": 2})
resp = client.get("/api/v1/conversations", params={"page": 1, "per_page": 2})
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
@@ -139,7 +139,7 @@ class TestFlow6ReplayConversation:
app = create_e2e_app(pool=pool)
with TestClient(app) as client:
resp = client.get("/api/replay/nonexistent-thread")
resp = client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404
def test_replay_invalid_thread_id_format(self) -> None:
@@ -147,7 +147,7 @@ class TestFlow6ReplayConversation:
with TestClient(app) as client:
# Thread ID with special chars fails regex validation
resp = client.get("/api/replay/invalid%20thread%21%40")
resp = client.get("/api/v1/replay/invalid%20thread%21%40")
assert resp.status_code == 400
@@ -158,21 +158,21 @@ class TestAnalyticsDashboard:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "invalid"})
resp = client.get("/api/v1/analytics", params={"range": "invalid"})
assert resp.status_code == 400
def test_analytics_range_too_large(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "999d"})
resp = client.get("/api/v1/analytics", params={"range": "999d"})
assert resp.status_code == 400
def test_analytics_range_zero_rejected(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/analytics", params={"range": "0d"})
resp = client.get("/api/v1/analytics", params={"range": "0d"})
assert resp.status_code == 400
@@ -216,7 +216,7 @@ class TestFullUserJourney:
assert any(m["type"] == "message_complete" for m in messages)
# Step 2: Check conversations endpoint
resp = client.get("/api/conversations")
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
@@ -226,5 +226,5 @@ class TestFullUserJourney:
)
# Step 3: Health check still works
resp = client.get("/api/health")
resp = client.get("/api/v1/health")
assert resp.status_code == 200

View File

@@ -0,0 +1,183 @@
"""Integration tests for the /api/v1/analytics endpoint.
Tests the full API layer (routing, parameter validation, serialization,
error handling) with a mocked database pool.
"""
from __future__ import annotations
from dataclasses import asdict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from app.analytics.models import AnalyticsResult, InterruptStats
pytestmark = pytest.mark.integration
_SAMPLE_RESULT = AnalyticsResult(
range="7d",
total_conversations=42,
resolution_rate=0.85,
escalation_rate=0.05,
avg_turns_per_conversation=3.2,
avg_cost_per_conversation_usd=0.012,
agent_usage=(),
interrupt_stats=InterruptStats(total=10, approved=7, rejected=2, expired=1),
)
def _build_app():
"""Build a minimal FastAPI app with the analytics router and mocked deps."""
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.analytics.api import router as analytics_router
from app.api_utils import envelope
test_app = FastAPI()
test_app.include_router(analytics_router)
@test_app.exception_handler(Exception)
async def _catch_all(request, exc):
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
from fastapi import HTTPException
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
# No admin_api_key set -> auth is skipped
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = MagicMock()
return test_app
class TestAnalyticsValidRange:
"""Test analytics endpoint with valid range parameters."""
async def test_valid_range_7d_returns_envelope(self) -> None:
"""GET /api/v1/analytics?range=7d returns success envelope with data."""
test_app = _build_app()
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=_SAMPLE_RESULT,
):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "7d"})
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["error"] is None
assert body["data"]["total_conversations"] == 42
assert body["data"]["resolution_rate"] == 0.85
async def test_default_range_returns_success(self) -> None:
"""GET /api/v1/analytics with no range param defaults to 7d."""
test_app = _build_app()
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=_SAMPLE_RESULT,
) as mock_get:
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics")
assert resp.status_code == 200
# Verify default range of 7 days was passed
mock_get.assert_called_once()
call_args = mock_get.call_args
assert call_args[1].get("range_days", call_args[0][1] if len(call_args[0]) > 1 else None) in (7, None) or call_args[0][1] == 7
async def test_large_range_365d_works(self) -> None:
"""GET /api/v1/analytics?range=365d is accepted (max boundary)."""
test_app = _build_app()
result = AnalyticsResult(
range="365d",
total_conversations=1000,
resolution_rate=0.9,
escalation_rate=0.02,
avg_turns_per_conversation=4.0,
avg_cost_per_conversation_usd=0.01,
agent_usage=(),
interrupt_stats=InterruptStats(),
)
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=result,
):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "365d"})
assert resp.status_code == 200
assert resp.json()["success"] is True
class TestAnalyticsInvalidRange:
"""Test analytics endpoint with invalid range parameters."""
async def test_invalid_range_format_returns_400(self) -> None:
"""GET /api/v1/analytics?range=abc returns 400 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "abc"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert "Invalid range format" in body["error"]
async def test_zero_day_range_returns_400(self) -> None:
"""GET /api/v1/analytics?range=0d returns 400 because 0 is below minimum."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "0d"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "between 1 and 365" in body["error"]
async def test_range_exceeding_max_returns_400(self) -> None:
"""GET /api/v1/analytics?range=999d returns 400 because it exceeds 365."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "999d"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "between 1 and 365" in body["error"]

View File

@@ -0,0 +1,128 @@
"""Integration tests for global error handling and envelope format consistency.
Tests that all error responses from the FastAPI app conform to the
standard envelope: {"success": false, "data": null, "error": "..."}.
"""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _build_app():
"""Build the actual FastAPI app with exception handlers but mocked state."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.analytics.api import router as analytics_router
from app.api_utils import envelope
from app.replay.api import router as replay_router
test_app = FastAPI()
test_app.include_router(analytics_router)
test_app.include_router(replay_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@test_app.exception_handler(Exception)
async def _catch_all(request, exc):
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
@test_app.get("/api/v1/health")
def health_check():
return {"status": "ok", "version": "0.6.0"}
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = MagicMock()
return test_app
class TestEnvelopeFormat:
"""Tests that error responses consistently follow envelope format."""
async def test_http_400_produces_envelope(self) -> None:
"""A 400 error returns standard envelope with success=false."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "invalid"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
assert len(body["error"]) > 0
async def test_validation_error_produces_422_envelope(self) -> None:
"""Invalid query param type returns 422 with envelope format."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
# page must be >= 1; passing 0 triggers validation error
resp = await client.get("/api/v1/conversations", params={"page": 0})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
async def test_all_error_fields_present(self) -> None:
"""Error envelope contains exactly success, data, and error keys."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "bad"})
body = resp.json()
assert set(body.keys()) == {"success", "data", "error"}
async def test_health_endpoint_returns_200(self) -> None:
"""Health check returns 200 with status ok."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/health")
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "ok"
assert "version" in body
async def test_unknown_endpoint_returns_404(self) -> None:
"""Requesting a non-existent path returns 404."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/nonexistent-path")
# FastAPI returns 404 for unknown routes; may or may not be wrapped
assert resp.status_code == 404

View File

@@ -0,0 +1,164 @@
"""Integration tests for /api/v1/openapi/ endpoints.
Tests the full API layer for the OpenAPI import review workflow,
including job creation, status retrieval, classification updates,
and approval triggering.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _build_app():
"""Build a minimal FastAPI app with the openapi router and mocked deps."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.api_utils import envelope
from app.openapi.review_api import router as openapi_router
test_app = FastAPI()
test_app.include_router(openapi_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
test_app.state.settings = MagicMock(admin_api_key="")
return test_app
@pytest.fixture(autouse=True)
def _clear_job_store():
"""Clear the in-memory job store between tests."""
from app.openapi.review_api import _job_store
_job_store.clear()
yield
_job_store.clear()
class TestImportEndpoint:
"""Tests for POST /api/v1/openapi/import."""
async def test_import_returns_202_with_job_id(self) -> None:
"""Starting an import returns 202 with a job_id."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/openapi/import",
json={"url": "https://example.com/api/spec.json"},
)
assert resp.status_code == 202
body = resp.json()
assert "job_id" in body
assert body["status"] == "pending"
assert body["spec_url"] == "https://example.com/api/spec.json"
async def test_import_invalid_url_returns_422(self) -> None:
"""POST with invalid URL (no http/https) returns 422."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/openapi/import",
json={"url": "ftp://example.com/spec.json"},
)
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
class TestJobStatusEndpoint:
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
async def test_get_existing_job_returns_status(self) -> None:
"""Retrieving an existing job returns its status."""
from app.openapi.review_api import _job_store
_job_store["test-job-1"] = {
"job_id": "test-job-1",
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 5,
"classified_count": 5,
"error_message": None,
"classifications": [],
}
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/openapi/jobs/test-job-1")
assert resp.status_code == 200
body = resp.json()
assert body["job_id"] == "test-job-1"
assert body["status"] == "done"
assert body["total_endpoints"] == 5
async def test_get_unknown_job_returns_404(self) -> None:
"""Retrieving a non-existent job returns 404 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/openapi/jobs/unknown-id-999")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert "not found" in body["error"].lower()
class TestApproveEndpoint:
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
async def test_approve_with_no_classifications_returns_400(self) -> None:
"""Approving a job with no classifications returns 400."""
from app.openapi.review_api import _job_store
_job_store["empty-job"] = {
"job_id": "empty-job",
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 0,
"classified_count": 0,
"error_message": None,
"classifications": [],
}
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post("/api/v1/openapi/jobs/empty-job/approve")
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "no classifications" in body["error"].lower()

View File

@@ -0,0 +1,213 @@
"""Integration tests for /api/v1/conversations and /api/v1/replay/{thread_id}.
Tests the full API layer with a mocked database pool, verifying routing,
serialization, pagination, and error handling in envelope format.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _make_fake_cursor(rows, *, fetchone_value=None):
"""Build a fake async cursor returning the given rows on fetchall."""
cursor = AsyncMock()
cursor.fetchall = AsyncMock(return_value=rows)
if fetchone_value is not None:
cursor.fetchone = AsyncMock(return_value=fetchone_value)
return cursor
class _FakeConnection:
"""Fake async connection that returns pre-configured cursors in order."""
def __init__(self, cursors: list) -> None:
self._cursors = list(cursors)
self._idx = 0
async def execute(self, sql, params=None):
cursor = self._cursors[self._idx]
self._idx += 1
return cursor
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class _FakePool:
"""Fake connection pool that yields a fake connection."""
def __init__(self, conn: _FakeConnection) -> None:
self._conn = conn
def connection(self):
return self._conn
def _build_app(pool=None):
"""Build a minimal FastAPI app with the replay router and mocked deps."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.api_utils import envelope
from app.replay.api import router as replay_router
test_app = FastAPI()
test_app.include_router(replay_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = pool or MagicMock()
return test_app
class TestListConversations:
"""Tests for GET /api/v1/conversations endpoint."""
async def test_returns_paginated_envelope(self) -> None:
"""Conversations list returns envelope with pagination metadata."""
count_cursor = _make_fake_cursor([], fetchone_value=(3,))
rows = [
{"thread_id": "t1", "created_at": "2026-01-01", "last_activity": "2026-01-01",
"status": "active", "total_tokens": 100, "total_cost_usd": 0.01},
{"thread_id": "t2", "created_at": "2026-01-02", "last_activity": "2026-01-02",
"status": "resolved", "total_tokens": 200, "total_cost_usd": 0.02},
]
list_cursor = _make_fake_cursor(rows)
conn = _FakeConnection([count_cursor, list_cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["total"] == 3
assert len(body["data"]["conversations"]) == 2
assert body["data"]["page"] == 1
assert body["data"]["per_page"] == 20
async def test_custom_page_and_per_page(self) -> None:
"""Custom page/per_page params are reflected in the response."""
count_cursor = _make_fake_cursor([], fetchone_value=(50,))
list_cursor = _make_fake_cursor([])
conn = _FakeConnection([count_cursor, list_cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations", params={"page": 3, "per_page": 10})
assert resp.status_code == 200
body = resp.json()
assert body["data"]["page"] == 3
assert body["data"]["per_page"] == 10
async def test_invalid_page_returns_422(self) -> None:
"""page=0 violates ge=1 constraint and returns 422 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations", params={"page": 0})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
class TestReplayEndpoint:
"""Tests for GET /api/v1/replay/{thread_id} endpoint."""
async def test_valid_thread_returns_timeline(self) -> None:
"""Replay with valid thread_id returns steps in envelope format."""
checkpoint_rows = [
{
"thread_id": "abc123",
"checkpoint_id": "cp1",
"checkpoint": {
"channel_values": {
"messages": [
{"type": "human", "content": "Hello", "created_at": "2026-01-01T00:00:00Z"},
{"type": "ai", "content": "Hi there!", "created_at": "2026-01-01T00:00:01Z"},
]
}
},
"metadata": {},
}
]
cursor = _make_fake_cursor(checkpoint_rows)
conn = _FakeConnection([cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/abc123")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["thread_id"] == "abc123"
assert body["data"]["total_steps"] == 2
assert len(body["data"]["steps"]) == 2
assert body["data"]["steps"][0]["type"] == "user_message"
assert body["data"]["steps"][1]["type"] == "agent_response"
async def test_invalid_thread_id_format_returns_400(self) -> None:
"""Thread IDs with path traversal characters are rejected with 400."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/../../etc/passwd")
# FastAPI may return 400 from our handler or 404 from routing
assert resp.status_code in (400, 404, 422)
async def test_nonexistent_thread_returns_404(self) -> None:
"""Replay with a thread_id that has no checkpoints returns 404."""
cursor = _make_fake_cursor([])
conn = _FakeConnection([cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert "not found" in body["error"].lower()

View File

@@ -0,0 +1,159 @@
"""Integration tests for SessionManager + InterruptManager lifecycle.
These tests exercise the in-memory managers together, verifying the full
lifecycle of sessions and interrupts: creation, TTL sliding, interrupt
registration/resolution, and expired-interrupt cleanup.
No database required -- both managers are in-memory.
"""
from __future__ import annotations
import time
from unittest.mock import patch
import pytest
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
pytestmark = pytest.mark.integration
class TestSessionInterruptLifecycle:
"""Tests for the combined session + interrupt lifecycle."""
def test_create_session_register_interrupt_check_status(self) -> None:
"""Full lifecycle: create session, register interrupt, verify both states."""
sm = SessionManager(session_ttl_seconds=3600)
im = InterruptManager(ttl_seconds=300)
# Create a session
state = sm.touch("thread-1")
assert state.thread_id == "thread-1"
assert not state.has_pending_interrupt
assert not sm.is_expired("thread-1")
# Register an interrupt
record = im.register("thread-1", "cancel_order", {"order_id": "1042"})
sm.extend_for_interrupt("thread-1")
assert im.has_pending("thread-1")
session_state = sm.get_state("thread-1")
assert session_state is not None
assert session_state.has_pending_interrupt
# Session should not expire while interrupt is pending
assert not sm.is_expired("thread-1")
def test_interrupt_expiry_after_ttl(self) -> None:
"""Interrupt expires when TTL elapses, even if session is alive."""
im = InterruptManager(ttl_seconds=5)
record = im.register("thread-2", "refund", {"amount": 50})
assert im.has_pending("thread-2")
# Simulate time passing beyond TTL
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 10
assert not im.has_pending("thread-2")
status = im.check_status("thread-2")
assert status is not None
assert status.is_expired
assert status.remaining_seconds == 0.0
def test_interrupt_resolve_flow(self) -> None:
"""Resolving an interrupt removes it from pending and resets session."""
sm = SessionManager(session_ttl_seconds=3600)
im = InterruptManager(ttl_seconds=300)
sm.touch("thread-3")
im.register("thread-3", "delete_account", {"user_id": "u1"})
sm.extend_for_interrupt("thread-3")
# Verify pending state
assert im.has_pending("thread-3")
assert sm.get_state("thread-3").has_pending_interrupt
# Resolve
im.resolve("thread-3")
sm.resolve_interrupt("thread-3")
assert not im.has_pending("thread-3")
session_state = sm.get_state("thread-3")
assert session_state is not None
assert not session_state.has_pending_interrupt
def test_cleanup_expired_removes_old_interrupts(self) -> None:
"""cleanup_expired removes only expired interrupts, keeping active ones."""
im = InterruptManager(ttl_seconds=10)
# Register two interrupts at different times
old_record = im.register("thread-old", "action_old", {})
new_record = im.register("thread-new", "action_new", {})
# Simulate time where only old one expired
with patch("app.interrupt_manager.time") as mock_time:
# Move old record's creation to the past
im._interrupts["thread-old"] = old_record.__class__(
interrupt_id=old_record.interrupt_id,
thread_id=old_record.thread_id,
action=old_record.action,
params=old_record.params,
created_at=time.time() - 20,
ttl_seconds=old_record.ttl_seconds,
)
mock_time.time.return_value = time.time()
expired = im.cleanup_expired()
assert len(expired) == 1
assert expired[0].thread_id == "thread-old"
# New one should still be pending
assert im.has_pending("thread-new")
assert not im.has_pending("thread-old")
def test_session_ttl_sliding_window(self) -> None:
"""Touching a session resets the sliding window TTL."""
sm = SessionManager(session_ttl_seconds=3600)
state1 = sm.touch("thread-5")
first_activity = state1.last_activity
time.sleep(0.01)
state2 = sm.touch("thread-5")
second_activity = state2.last_activity
assert second_activity > first_activity
assert not sm.is_expired("thread-5")
def test_session_expires_after_ttl_without_activity(self) -> None:
"""Session expires when TTL passes without a touch or interrupt."""
sm = SessionManager(session_ttl_seconds=0)
sm.touch("thread-6")
# TTL is 0 so session is immediately expired
assert sm.is_expired("thread-6")
def test_pending_interrupt_prevents_session_expiry(self) -> None:
"""A session with pending interrupt does not expire even with TTL=0."""
sm = SessionManager(session_ttl_seconds=0)
sm.touch("thread-7")
sm.extend_for_interrupt("thread-7")
# Even with TTL=0, session should not expire because of pending interrupt
assert not sm.is_expired("thread-7")
def test_retry_prompt_for_expired_interrupt(self) -> None:
"""InterruptManager generates a retry prompt for expired interrupts."""
im = InterruptManager(ttl_seconds=300)
record = im.register("thread-8", "cancel_order", {"order_id": "1042"})
prompt = im.generate_retry_prompt(record)
assert prompt["type"] == "interrupt_expired"
assert prompt["thread_id"] == "thread-8"
assert "cancel_order" in prompt["action"]
assert "cancel_order" in prompt["message"]
assert "expired" in prompt["message"].lower()

View File

@@ -44,7 +44,7 @@ def _make_analytics_result() -> object:
)
def _get_analytics(app: FastAPI, path: str = "/api/analytics", **patch_kwargs: object) -> object:
def _get_analytics(app: FastAPI, path: str = "/api/v1/analytics", **patch_kwargs: object) -> object:
"""Helper: patch get_analytics, make request, return (response, mock)."""
analytics_result = _make_analytics_result()
with (
@@ -84,7 +84,7 @@ class TestAnalyticsEndpoint:
def test_custom_range_7d(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
resp, mock_ga = _get_analytics(app, "/api/analytics?range=7d")
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=7d")
assert resp.status_code == 200
mock_ga.assert_called_once()
@@ -94,7 +94,7 @@ class TestAnalyticsEndpoint:
def test_custom_range_30d(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
resp, mock_ga = _get_analytics(app, "/api/analytics?range=30d")
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=30d")
assert resp.status_code == 200
call_kwargs = mock_ga.call_args
@@ -107,7 +107,7 @@ class TestAnalyticsEndpoint:
app.state.pool = _make_mock_pool()
with TestClient(app) as client:
resp = client.get("/api/analytics?range=invalid")
resp = client.get("/api/v1/analytics?range=invalid")
assert resp.status_code == 400
@@ -116,7 +116,7 @@ class TestAnalyticsEndpoint:
app.state.pool = _make_mock_pool()
with TestClient(app) as client:
resp = client.get("/api/analytics?range=7")
resp = client.get("/api/v1/analytics?range=7")
assert resp.status_code == 400

View File

@@ -28,7 +28,7 @@ def client():
@pytest.fixture
def job_id(client):
"""Create a job and return its ID."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
assert response.status_code == 202
return response.json()["job_id"]
@@ -61,11 +61,11 @@ def job_with_classifications(client, job_id):
class TestImportEndpoint:
"""Tests for POST /api/openapi/import."""
"""Tests for POST /api/v1/openapi/import."""
def test_post_import_returns_job_id(self, client) -> None:
"""POST /import returns 202 with a job_id."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
assert response.status_code == 202
data = response.json()
assert "job_id" in data
@@ -73,38 +73,38 @@ class TestImportEndpoint:
def test_post_import_empty_url_returns_422(self, client) -> None:
"""POST /import with empty URL returns 422 validation error."""
response = client.post("/api/openapi/import", json={"url": ""})
response = client.post("/api/v1/openapi/import", json={"url": ""})
assert response.status_code == 422
def test_post_import_missing_url_returns_422(self, client) -> None:
"""POST /import with missing URL field returns 422."""
response = client.post("/api/openapi/import", json={})
response = client.post("/api/v1/openapi/import", json={})
assert response.status_code == 422
def test_post_import_invalid_scheme_returns_422(self, client) -> None:
"""POST /import with non-http URL returns 422."""
response = client.post("/api/openapi/import", json={"url": "ftp://evil.com/spec"})
response = client.post("/api/v1/openapi/import", json={"url": "ftp://evil.com/spec"})
assert response.status_code == 422
def test_post_import_returns_pending_status(self, client) -> None:
"""Newly created job has pending status."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
data = response.json()
assert data["status"] == "pending"
def test_post_import_returns_spec_url(self, client) -> None:
"""Response includes the original spec URL."""
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
data = response.json()
assert data["spec_url"] == _SAMPLE_URL
class TestGetJobEndpoint:
"""Tests for GET /api/openapi/jobs/{job_id}."""
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
def test_get_job_returns_status(self, client, job_id) -> None:
"""GET /jobs/{id} returns job status."""
response = client.get(f"/api/openapi/jobs/{job_id}")
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
assert response.status_code == 200
data = response.json()
assert "status" in data
@@ -112,23 +112,23 @@ class TestGetJobEndpoint:
def test_get_unknown_job_returns_404(self, client) -> None:
"""GET /jobs/nonexistent returns 404."""
response = client.get("/api/openapi/jobs/nonexistent-id")
response = client.get("/api/v1/openapi/jobs/nonexistent-id")
assert response.status_code == 404
def test_get_job_includes_spec_url(self, client, job_id) -> None:
"""Job response includes the spec URL."""
response = client.get(f"/api/openapi/jobs/{job_id}")
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
data = response.json()
assert data["spec_url"] == _SAMPLE_URL
class TestGetClassificationsEndpoint:
"""Tests for GET /api/openapi/jobs/{job_id}/classifications."""
"""Tests for GET /api/v1/openapi/jobs/{job_id}/classifications."""
def test_get_classifications_returns_list(self, client, job_with_classifications) -> None:
"""GET /classifications returns a list."""
response = client.get(
f"/api/openapi/jobs/{job_with_classifications}/classifications"
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
)
assert response.status_code == 200
data = response.json()
@@ -137,13 +137,13 @@ class TestGetClassificationsEndpoint:
def test_get_classifications_unknown_job_returns_404(self, client) -> None:
"""GET /classifications for unknown job returns 404."""
response = client.get("/api/openapi/jobs/unknown/classifications")
response = client.get("/api/v1/openapi/jobs/unknown/classifications")
assert response.status_code == 404
def test_classification_has_expected_fields(self, client, job_with_classifications) -> None:
"""Each classification item has access_type and endpoint fields."""
response = client.get(
f"/api/openapi/jobs/{job_with_classifications}/classifications"
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
)
item = response.json()[0]
assert "access_type" in item
@@ -152,12 +152,12 @@ class TestGetClassificationsEndpoint:
class TestUpdateClassificationEndpoint:
"""Tests for PUT /api/openapi/jobs/{job_id}/classifications/{idx}."""
"""Tests for PUT /api/v1/openapi/jobs/{job_id}/classifications/{idx}."""
def test_update_classification_succeeds(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 updates the classification."""
response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
)
assert response.status_code == 200
@@ -165,7 +165,7 @@ class TestUpdateClassificationEndpoint:
def test_update_unknown_job_returns_404(self, client) -> None:
"""PUT /classifications/0 for unknown job returns 404."""
response = client.put(
"/api/openapi/jobs/unknown/classifications/0",
"/api/v1/openapi/jobs/unknown/classifications/0",
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
)
assert response.status_code == 404
@@ -173,7 +173,7 @@ class TestUpdateClassificationEndpoint:
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid access_type returns 422."""
response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
)
assert response.status_code == 422
@@ -181,7 +181,7 @@ class TestUpdateClassificationEndpoint:
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid agent_group returns 422."""
response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
)
assert response.status_code == 422
@@ -189,31 +189,31 @@ class TestUpdateClassificationEndpoint:
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
"""PUT /classifications/999 returns 404 for out-of-range index."""
response = client.put(
f"/api/openapi/jobs/{job_with_classifications}/classifications/999",
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/999",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"},
)
assert response.status_code == 404
class TestApproveEndpoint:
"""Tests for POST /api/openapi/jobs/{job_id}/approve."""
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
def test_approve_job_succeeds(self, client, job_with_classifications) -> None:
"""POST /approve transitions job to approved status."""
response = client.post(
f"/api/openapi/jobs/{job_with_classifications}/approve"
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
)
assert response.status_code == 200
def test_approve_unknown_job_returns_404(self, client) -> None:
"""POST /approve for unknown job returns 404."""
response = client.post("/api/openapi/jobs/unknown/approve")
response = client.post("/api/v1/openapi/jobs/unknown/approve")
assert response.status_code == 404
def test_approve_returns_job_status(self, client, job_with_classifications) -> None:
"""POST /approve returns updated job status."""
response = client.post(
f"/api/openapi/jobs/{job_with_classifications}/approve"
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
)
data = response.json()
assert "status" in data

View File

@@ -5,9 +5,12 @@ from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from app.api_utils import envelope
pytestmark = pytest.mark.unit
@@ -16,6 +19,14 @@ def _build_app() -> FastAPI:
app = FastAPI()
app.include_router(router)
@app.exception_handler(HTTPException)
async def _http_exc(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
return app
@@ -64,7 +75,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/conversations")
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
@@ -89,7 +100,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool(mock_rows, count=1)
with TestClient(app) as client:
resp = client.get("/api/conversations")
resp = client.get("/api/v1/conversations")
body = resp.json()
assert resp.status_code == 200
data = body["data"]
@@ -102,7 +113,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/conversations")
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
def test_pagination_custom_params(self) -> None:
@@ -110,7 +121,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/conversations?page=2&per_page=10")
resp = client.get("/api/v1/conversations?page=2&per_page=10")
assert resp.status_code == 200
def test_per_page_max_capped_at_100(self) -> None:
@@ -118,7 +129,7 @@ class TestListConversations:
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/conversations?per_page=200")
resp = client.get("/api/v1/conversations?per_page=200")
# FastAPI Query(le=100) rejects values > 100
assert resp.status_code == 422
@@ -129,7 +140,7 @@ class TestGetReplay:
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/replay/nonexistent-thread")
resp = client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404
def test_returns_replay_page_for_existing_thread(self) -> None:
@@ -149,7 +160,7 @@ class TestGetReplay:
app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client:
resp = client.get("/api/replay/thread-123")
resp = client.get("/api/v1/replay/thread-123")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
@@ -174,7 +185,7 @@ class TestGetReplay:
app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client:
resp = client.get("/api/replay/t1?page=1&per_page=5")
resp = client.get("/api/v1/replay/t1?page=1&per_page=5")
assert resp.status_code == 200
def test_error_response_has_envelope(self) -> None:
@@ -182,16 +193,19 @@ class TestGetReplay:
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/replay/missing")
resp = client.get("/api/v1/replay/missing")
assert resp.status_code == 404
assert "detail" in resp.json()
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] is not None
def test_invalid_thread_id_returns_400(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/replay/id%20with%20spaces")
resp = client.get("/api/v1/replay/id%20with%20spaces")
assert resp.status_code == 400
def test_thread_id_special_chars_returns_400(self) -> None:
@@ -199,5 +213,5 @@ class TestGetReplay:
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/replay/id;DROP TABLE")
resp = client.get("/api/v1/replay/id;DROP TABLE")
assert resp.status_code == 400

View File

@@ -0,0 +1,142 @@
"""Tests for standardized error response envelope format."""
from __future__ import annotations
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field
from app.api_utils import envelope
pytestmark = pytest.mark.unit
def _build_test_app() -> FastAPI:
"""Build a minimal FastAPI app with the standard exception handlers."""
app = FastAPI()
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
class ItemRequest(BaseModel):
name: str = Field(..., min_length=1)
count: int = Field(..., gt=0)
@app.get("/items/{item_id}")
def get_item(item_id: int) -> dict:
if item_id == 0:
raise HTTPException(status_code=400, detail="Invalid item ID")
if item_id == 999:
raise HTTPException(status_code=404, detail="Item not found")
if item_id == 401:
raise HTTPException(status_code=401, detail="Not authenticated")
return envelope({"id": item_id, "name": "test"})
@app.post("/items")
def create_item(item: ItemRequest) -> dict:
return envelope({"id": 1, "name": item.name})
@app.get("/crash")
def crash() -> dict:
msg = "unexpected failure"
raise RuntimeError(msg)
return app
class TestHttpExceptionEnvelope:
"""HTTPException responses use the standard envelope format."""
def test_400_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/0")
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Invalid item ID"
def test_404_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/999")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Item not found"
def test_401_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/401")
assert resp.status_code == 401
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Not authenticated"
class TestValidationErrorEnvelope:
"""Validation errors return 422 with envelope format."""
def test_validation_error_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post("/items", json={"name": "", "count": -1})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
assert len(body["error"]) > 0
class TestGeneralExceptionEnvelope:
"""Unhandled exceptions return 500 with safe envelope."""
def test_unhandled_exception_returns_500_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/crash")
assert resp.status_code == 500
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Internal server error"
class TestSuccessResponseUnchanged:
"""Success responses still work normally."""
def test_success_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app) as client:
resp = client.get("/items/42")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["id"] == 42
assert body["error"] is None

View File

@@ -0,0 +1,86 @@
"""Tests for the interrupt cleanup background loop in main.py."""
from __future__ import annotations
import asyncio
import logging
from unittest.mock import MagicMock, patch
import pytest
from app.main import _interrupt_cleanup_loop
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cleanup_loop_calls_cleanup_expired() -> None:
"""The loop should call cleanup_expired after each sleep interval."""
manager = MagicMock()
manager.cleanup_expired.return_value = ()
call_count = 0
original_sleep = asyncio.sleep
async def _fake_sleep(seconds: float) -> None:
nonlocal call_count
call_count += 1
if call_count >= 2:
raise asyncio.CancelledError
await original_sleep(0)
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
with pytest.raises(asyncio.CancelledError):
await _interrupt_cleanup_loop(manager, interval=60)
assert manager.cleanup_expired.call_count >= 1
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cleanup_loop_survives_exceptions() -> None:
"""The loop should not die when cleanup_expired raises an exception."""
manager = MagicMock()
manager.cleanup_expired.side_effect = [RuntimeError("db gone"), ()]
call_count = 0
original_sleep = asyncio.sleep
async def _fake_sleep(seconds: float) -> None:
nonlocal call_count
call_count += 1
if call_count >= 3:
raise asyncio.CancelledError
await original_sleep(0)
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
with pytest.raises(asyncio.CancelledError):
await _interrupt_cleanup_loop(manager, interval=60)
# Should have been called twice: once raising, once returning ()
assert manager.cleanup_expired.call_count == 2
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cleanup_loop_logs_expired_count(capsys: pytest.CaptureFixture[str]) -> None:
"""The loop should log when expired interrupts are found."""
fake_record = MagicMock()
manager = MagicMock()
manager.cleanup_expired.return_value = (fake_record, fake_record)
call_count = 0
original_sleep = asyncio.sleep
async def _fake_sleep(seconds: float) -> None:
nonlocal call_count
call_count += 1
if call_count >= 2:
raise asyncio.CancelledError
await original_sleep(0)
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
with pytest.raises(asyncio.CancelledError):
await _interrupt_cleanup_loop(manager, interval=60)
captured = capsys.readouterr()
assert "2 expired interrupt" in captured.out

View File

@@ -0,0 +1,20 @@
"""Tests for structured logging configuration."""
from __future__ import annotations
import pytest
from app.logging_config import configure_logging
pytestmark = pytest.mark.unit
def test_configure_logging_console_mode() -> None:
"""Console mode configures without error."""
configure_logging("console")
def test_configure_logging_json_mode() -> None:
"""JSON mode configures without error."""
configure_logging("json")

View File

@@ -36,7 +36,7 @@ class TestMainModule:
def test_health_route_registered(self) -> None:
routes = [r.path for r in app.routes if hasattr(r, "path")]
assert "/api/health" in routes
assert "/api/v1/health" in routes
def test_app_version_is_0_5_0(self) -> None:
assert app.version == "0.6.0"