Compare commits
9 Commits
d2b4610df9
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0699436c5 | ||
|
|
af53111928 | ||
|
|
b8654aa31f | ||
|
|
be5c84bcff | ||
|
|
19fc9f3289 | ||
|
|
036e12349d | ||
|
|
e0931daece | ||
|
|
e55ec42ae5 | ||
|
|
189a0fad34 |
12
.env.example
12
.env.example
@@ -4,7 +4,7 @@
|
|||||||
# PostgreSQL password (used by both postgres and backend services)
|
# PostgreSQL password (used by both postgres and backend services)
|
||||||
POSTGRES_PASSWORD=dev_password
|
POSTGRES_PASSWORD=dev_password
|
||||||
|
|
||||||
# LLM provider: anthropic | openai | google
|
# LLM provider: anthropic | openai | azure_openai | google
|
||||||
LLM_PROVIDER=anthropic
|
LLM_PROVIDER=anthropic
|
||||||
LLM_MODEL=claude-sonnet-4-6
|
LLM_MODEL=claude-sonnet-4-6
|
||||||
|
|
||||||
@@ -13,6 +13,12 @@ ANTHROPIC_API_KEY=
|
|||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
|
|
||||||
|
# Azure OpenAI (required when LLM_PROVIDER=azure_openai)
|
||||||
|
AZURE_OPENAI_API_KEY=
|
||||||
|
AZURE_OPENAI_ENDPOINT=
|
||||||
|
AZURE_OPENAI_DEPLOYMENT=
|
||||||
|
AZURE_OPENAI_API_VERSION=2024-12-01-preview
|
||||||
|
|
||||||
# Optional: webhook URL for escalation notifications
|
# Optional: webhook URL for escalation notifications
|
||||||
WEBHOOK_URL=
|
WEBHOOK_URL=
|
||||||
|
|
||||||
@@ -20,6 +26,10 @@ WEBHOOK_URL=
|
|||||||
SESSION_TTL_MINUTES=30
|
SESSION_TTL_MINUTES=30
|
||||||
INTERRUPT_TTL_MINUTES=30
|
INTERRUPT_TTL_MINUTES=30
|
||||||
|
|
||||||
|
# Optional: API key for admin endpoints (analytics, replay, openapi, websocket)
|
||||||
|
# Leave empty to disable authentication (dev mode)
|
||||||
|
ADMIN_API_KEY=
|
||||||
|
|
||||||
# Optional: load a named agent template instead of agents.yaml
|
# Optional: load a named agent template instead of agents.yaml
|
||||||
# Available templates: ecommerce, saas, generic
|
# Available templates: ecommerce, saas, generic
|
||||||
TEMPLATE_NAME=
|
TEMPLATE_NAME=
|
||||||
|
|||||||
57
CLAUDE.md
57
CLAUDE.md
@@ -30,7 +30,7 @@ pytest --cov=app --cov-report=term-missing
|
|||||||
# - If any test fails, fix it before starting the new phase
|
# - If any test fails, fix it before starting the new phase
|
||||||
|
|
||||||
# 3. Create checkpoint to snapshot the starting state
|
# 3. Create checkpoint to snapshot the starting state
|
||||||
/everything-claude-code:checkpoint create "phase-name"
|
/ecc:checkpoint create "phase-name"
|
||||||
|
|
||||||
# 4. Create the phase branch
|
# 4. Create the phase branch
|
||||||
git checkout main
|
git checkout main
|
||||||
@@ -50,25 +50,32 @@ git checkout -b phase-{N}/{short-description}
|
|||||||
3. Identify all tasks, acceptance criteria, and dependencies for this phase
|
3. Identify all tasks, acceptance criteria, and dependencies for this phase
|
||||||
4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5)
|
4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5)
|
||||||
|
|
||||||
### Step 2: Develop Using Orchestrate Skill
|
### Step 2: Develop Using ECC Skills
|
||||||
|
|
||||||
Route to the correct orchestration mode based on work type:
|
Route to the correct skill based on work type:
|
||||||
|
|
||||||
| Work Type | Skill Command |
|
| Work Type | Skill Command | What It Does |
|
||||||
|-----------|---------------|
|
|-----------|---------------|--------------|
|
||||||
| New feature | `/everything-claude-code:orchestrate feature` |
|
| New feature | `/ecc:feature-dev <desc>` | Discovery -> Exploration -> Architecture -> TDD -> Review -> Summary |
|
||||||
| Bug fix | `/everything-claude-code:orchestrate bugfix` |
|
| Bug fix | `/ecc:tdd` then `/ecc:code-review` | RED -> GREEN -> REFACTOR cycle, then review |
|
||||||
| Refactor | `/everything-claude-code:orchestrate refactor` |
|
| Refactor | `/ecc:plan` then `/ecc:tdd` then `/ecc:code-review` | Plan refactor scope, TDD, review |
|
||||||
|
| Security-sensitive | Add `/ecc:security-review` after code-review | Auth, payments, user input, external APIs |
|
||||||
|
| Final verification | `/ecc:verify` | Build + tests + lint + coverage + security scan |
|
||||||
|
|
||||||
ALWAYS use the appropriate orchestrate skill. Never develop without it.
|
A single phase may contain mixed work types. Call the appropriate skill **per sub-task**:
|
||||||
|
|
||||||
A single phase may contain mixed work types (e.g., Phase 5 has feature + bugfix + refactor). Call the orchestrate skill **per sub-task** with the matching mode. Example:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
# Within Phase 5:
|
# Within a phase:
|
||||||
/everything-claude-code:orchestrate feature # for demo script
|
/ecc:feature-dev "demo script" # for new features
|
||||||
/everything-claude-code:orchestrate bugfix # for error handling fixes
|
/ecc:tdd # for bug fixes (write failing test, then fix)
|
||||||
/everything-claude-code:orchestrate refactor # for code cleanup
|
/ecc:plan "consolidate error handling" # for refactors (plan first, then TDD)
|
||||||
|
```
|
||||||
|
|
||||||
|
For full multi-phase autonomous execution, use GSD:
|
||||||
|
|
||||||
|
```
|
||||||
|
/gsd:autonomous # execute all remaining phases
|
||||||
|
/gsd:execute-phase 6 # execute a specific phase
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 3: Module Independence (CRITICAL)
|
### Step 3: Module Independence (CRITICAL)
|
||||||
@@ -171,10 +178,10 @@ After all development and testing, run verification in this exact order:
|
|||||||
|
|
||||||
```
|
```
|
||||||
# 1. Run the verification skill -- must pass
|
# 1. Run the verification skill -- must pass
|
||||||
/everything-claude-code:verify
|
/ecc:verify
|
||||||
|
|
||||||
# 2. Verify the checkpoint -- validates all phase deliverables
|
# 2. Verify the checkpoint -- validates all phase deliverables
|
||||||
/everything-claude-code:checkpoint verify "phase-name"
|
/ecc:checkpoint verify "phase-name"
|
||||||
```
|
```
|
||||||
|
|
||||||
The checkpoint verify validates:
|
The checkpoint verify validates:
|
||||||
@@ -222,11 +229,11 @@ git push origin main --tags
|
|||||||
All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy.
|
All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy.
|
||||||
|
|
||||||
A checkpoint includes:
|
A checkpoint includes:
|
||||||
- `/everything-claude-code:checkpoint create` at phase start
|
- `/ecc:checkpoint create` at phase start
|
||||||
- `/everything-claude-code:checkpoint verify` at phase end
|
- `/ecc:checkpoint verify` at phase end
|
||||||
- All tests passing (80%+ coverage)
|
- All tests passing (80%+ coverage)
|
||||||
- Phase dev log written and linked
|
- Phase dev log written and linked
|
||||||
- `/everything-claude-code:verify` passed
|
- `/ecc:verify` passed
|
||||||
- Git tag `checkpoint/phase-{N}` created
|
- Git tag `checkpoint/phase-{N}` created
|
||||||
- Phase marked COMPLETED in four locations
|
- Phase marked COMPLETED in four locations
|
||||||
- Branch merged to main
|
- Branch merged to main
|
||||||
@@ -264,7 +271,7 @@ This project inherits from `~/.claude/rules/`. CLAUDE.md only contains project-s
|
|||||||
|
|
||||||
### Hooks (ECC Plugin -- No Custom Hooks)
|
### Hooks (ECC Plugin -- No Custom Hooks)
|
||||||
|
|
||||||
All hooks come from the ECC plugin (`everything-claude-code`). No project-level hooks in `.claude/settings.local.json`.
|
All hooks come from the ECC plugin (`ecc`). No project-level hooks in `.claude/settings.local.json`.
|
||||||
|
|
||||||
| ECC Hook | Type | What It Does |
|
| ECC Hook | Type | What It Does |
|
||||||
|----------|------|-------------|
|
|----------|------|-------------|
|
||||||
@@ -290,7 +297,7 @@ Controlled by `ECC_HOOK_PROFILE` env var in `~/.claude/settings.json` (currently
|
|||||||
- Architecture doc: `docs/ARCHITECTURE.md`
|
- Architecture doc: `docs/ARCHITECTURE.md`
|
||||||
- Phase dev logs: `docs/phases/phase-{N}-dev-log.md`
|
- Phase dev logs: `docs/phases/phase-{N}-dev-log.md`
|
||||||
- Test command: `pytest --cov=app --cov-report=term-missing`
|
- Test command: `pytest --cov=app --cov-report=term-missing`
|
||||||
- **Phase start:** `/everything-claude-code:checkpoint create "phase-name"`
|
- **Phase start:** `/ecc:checkpoint create "phase-name"`
|
||||||
- **Phase end:** `/everything-claude-code:checkpoint verify "phase-name"`
|
- **Phase end:** `/ecc:checkpoint verify "phase-name"`
|
||||||
- Verify command: `/everything-claude-code:verify`
|
- Verify command: `/ecc:verify`
|
||||||
- Orchestrate: `/everything-claude-code:orchestrate {feature|bugfix|refactor}`
|
- Orchestrate: `/ecc:orchestrate {feature|bugfix|refactor}`
|
||||||
|
|||||||
123
README.md
123
README.md
@@ -45,10 +45,11 @@ User message -> Chat UI -> FastAPI WebSocket -> LangGraph Supervisor -> Speciali
|
|||||||
| Component | Technology |
|
| Component | Technology |
|
||||||
|-----------|-----------|
|
|-----------|-----------|
|
||||||
| Backend | Python 3.11+, FastAPI |
|
| Backend | Python 3.11+, FastAPI |
|
||||||
| Agent orchestration | LangGraph v1.1 |
|
| Agent orchestration | LangGraph 1.x, langgraph-supervisor |
|
||||||
| Session state | PostgreSQL + langgraph-checkpoint-postgres |
|
| Session state | PostgreSQL 16 + langgraph-checkpoint-postgres |
|
||||||
| LLM | Claude Sonnet 4.6 (configurable: OpenAI, Google) |
|
| LLM | Claude Sonnet 4.6 (configurable: OpenAI, Azure OpenAI, Google) |
|
||||||
| Frontend | React 19, TypeScript, Vite |
|
| Frontend | React 19, TypeScript, Vite |
|
||||||
|
| Testing | pytest (backend), vitest + happy-dom (frontend) |
|
||||||
| Deployment | Docker Compose |
|
| Deployment | Docker Compose |
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
@@ -59,7 +60,11 @@ cd smart-support
|
|||||||
|
|
||||||
# Configure your LLM API key
|
# Configure your LLM API key
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
# Edit .env: set ANTHROPIC_API_KEY (or OPENAI_API_KEY)
|
# Edit .env: set LLM_PROVIDER and the corresponding API key
|
||||||
|
# anthropic -> ANTHROPIC_API_KEY
|
||||||
|
# openai -> OPENAI_API_KEY
|
||||||
|
# azure_openai -> AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT + AZURE_OPENAI_DEPLOYMENT
|
||||||
|
# google -> GOOGLE_API_KEY
|
||||||
|
|
||||||
# Start all services
|
# Start all services
|
||||||
docker compose up -d
|
docker compose up -d
|
||||||
@@ -68,6 +73,25 @@ docker compose up -d
|
|||||||
open http://localhost
|
open http://localhost
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Local Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start only PostgreSQL via Docker (exposed on port 5433)
|
||||||
|
docker compose up postgres -d
|
||||||
|
|
||||||
|
# Backend (in one terminal)
|
||||||
|
cd backend
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
uvicorn app.main:app --host 0.0.0.0 --port 8001 --reload
|
||||||
|
|
||||||
|
# Frontend (in another terminal)
|
||||||
|
cd frontend
|
||||||
|
npm install
|
||||||
|
npm run dev # http://localhost:5173 (proxies /api and /ws to :8001)
|
||||||
|
```
|
||||||
|
|
||||||
|
See [Deployment Guide](docs/deployment.md) for production setup, HTTPS, and scaling.
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -75,17 +99,20 @@ smart-support/
|
|||||||
├── backend/
|
├── backend/
|
||||||
│ ├── app/
|
│ ├── app/
|
||||||
│ │ ├── main.py # FastAPI + WebSocket entry point
|
│ │ ├── main.py # FastAPI + WebSocket entry point
|
||||||
│ │ ├── graph.py # LangGraph Supervisor
|
│ │ ├── graph.py # LangGraph Supervisor construction
|
||||||
|
│ │ ├── graph_context.py # Typed wrapper for graph + classifier + registry
|
||||||
│ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting
|
│ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting
|
||||||
│ │ ├── conversation_tracker.py # Conversation lifecycle tracking
|
│ │ ├── ws_context.py # WebSocket dependency bundle
|
||||||
|
│ │ ├── auth.py # API key authentication middleware
|
||||||
|
│ │ ├── api_utils.py # Shared API response helpers
|
||||||
|
│ │ ├── safety.py # Confirmation rules + MCP error taxonomy
|
||||||
│ │ ├── agents/ # Agent definitions and tools
|
│ │ ├── agents/ # Agent definitions and tools
|
||||||
│ │ ├── registry.py # YAML agent registry loader
|
│ │ ├── registry.py # YAML agent registry loader
|
||||||
│ │ ├── openapi/ # OpenAPI parser and review API
|
│ │ ├── openapi/ # OpenAPI parser, classifier, and review API
|
||||||
│ │ ├── replay/ # Conversation replay API
|
│ │ ├── replay/ # Conversation replay API
|
||||||
│ │ ├── analytics/ # Analytics queries and API
|
│ │ └── analytics/ # Analytics queries and API
|
||||||
│ │ └── tools/ # Error handling and retry utilities
|
|
||||||
│ ├── agents.yaml # Agent registry configuration
|
│ ├── agents.yaml # Agent registry configuration
|
||||||
│ ├── fixtures/ # Demo data and sample OpenAPI spec
|
│ ├── templates/ # Vertical industry templates
|
||||||
│ └── tests/ # Unit, integration, and E2E tests
|
│ └── tests/ # Unit, integration, and E2E tests
|
||||||
├── frontend/
|
├── frontend/
|
||||||
│ ├── src/
|
│ ├── src/
|
||||||
@@ -99,67 +126,49 @@ smart-support/
|
|||||||
└── .env.example # Environment variable template
|
└── .env.example # Environment variable template
|
||||||
```
|
```
|
||||||
|
|
||||||
## Agent Configuration
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# agents.yaml
|
|
||||||
agents:
|
|
||||||
- name: order_agent
|
|
||||||
description: "Handles order status, tracking, and cancellations."
|
|
||||||
permission: write
|
|
||||||
tools:
|
|
||||||
- get_order_status
|
|
||||||
- cancel_order
|
|
||||||
personality:
|
|
||||||
tone: friendly
|
|
||||||
greeting: "I can help with your order. What is the order number?"
|
|
||||||
escalation_message: "I'm escalating this to a human agent."
|
|
||||||
|
|
||||||
- name: general_agent
|
|
||||||
description: "Answers general questions and FAQs."
|
|
||||||
permission: read
|
|
||||||
tools:
|
|
||||||
- search_faq
|
|
||||||
```
|
|
||||||
|
|
||||||
## API Endpoints
|
## API Endpoints
|
||||||
|
|
||||||
| Method | Path | Description |
|
| Method | Path | Auth | Description |
|
||||||
|--------|------|-------------|
|
|--------|------|------|-------------|
|
||||||
| WS | `/ws` | Main WebSocket chat endpoint |
|
| WS | `/ws` | Token | Main WebSocket chat endpoint (`?token=<key>`) |
|
||||||
| GET | `/api/health` | Health check |
|
| GET | `/api/health` | No | Health check |
|
||||||
| GET | `/api/conversations` | List conversations |
|
| GET | `/api/conversations` | API Key | List conversations (paginated) |
|
||||||
| GET | `/api/replay/{thread_id}` | Replay conversation |
|
| GET | `/api/replay/{thread_id}` | API Key | Replay conversation steps (paginated) |
|
||||||
| GET | `/api/analytics` | Analytics summary |
|
| GET | `/api/analytics` | API Key | Analytics summary (`?range=7d`) |
|
||||||
| POST | `/api/openapi/import` | Import OpenAPI spec |
|
| POST | `/api/openapi/import` | API Key | Start OpenAPI import job |
|
||||||
| GET | `/api/openapi/jobs/{id}` | Check import job status |
|
| GET | `/api/openapi/jobs/{id}` | API Key | Check import job status |
|
||||||
|
| GET | `/api/openapi/jobs/{id}/classifications` | API Key | Get endpoint classifications |
|
||||||
|
| PUT | `/api/openapi/jobs/{id}/classifications/{idx}` | API Key | Update a classification |
|
||||||
|
| POST | `/api/openapi/jobs/{id}/approve` | API Key | Approve and generate tools |
|
||||||
|
|
||||||
## Security
|
Authentication is controlled by the `ADMIN_API_KEY` environment variable.
|
||||||
|
API Key endpoints require the `X-API-Key` header. When `ADMIN_API_KEY` is unset, auth is disabled.
|
||||||
- **SSRF protection** -- OpenAPI import blocks private IPs and metadata service URLs
|
|
||||||
- **Input validation** -- messages validated for size (32 KB), content length (10 KB), thread ID format
|
|
||||||
- **Rate limiting** -- 10 messages per 10 seconds per session
|
|
||||||
- **Audit trail** -- every tool call logged with agent, params, result, timestamp
|
|
||||||
- **Permission isolation** -- each agent only accesses its configured tools
|
|
||||||
- **Interrupt TTL** -- unanswered approval prompts expire after 30 minutes
|
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Backend (516 tests, 94% coverage)
|
||||||
cd backend
|
cd backend
|
||||||
pytest --cov=app --cov-report=term-missing
|
pytest --cov=app --cov-report=term-missing
|
||||||
|
|
||||||
|
# Frontend (23 tests, vitest + happy-dom)
|
||||||
|
cd frontend
|
||||||
|
npm test
|
||||||
```
|
```
|
||||||
|
|
||||||
Coverage is enforced at 80%+.
|
Backend coverage is enforced at 80%+.
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
- [Architecture](docs/ARCHITECTURE.md) -- System design and component diagram
|
| Document | Description |
|
||||||
- [Development Plan](docs/DEVELOPMENT-PLAN.md) -- Phase breakdown and status
|
|----------|-------------|
|
||||||
- [Agent Config Guide](docs/agent-config-guide.md) -- How to configure agents
|
| [Architecture](docs/ARCHITECTURE.md) | System design, component diagram, data flow, ADRs |
|
||||||
- [OpenAPI Import Guide](docs/openapi-import-guide.md) -- Auto-discovery workflow
|
| [Development Plan](docs/DEVELOPMENT-PLAN.md) | Phase breakdown, task checklists, and status |
|
||||||
- [Deployment Guide](docs/deployment.md) -- Docker and production deployment
|
| [Agent Config Guide](docs/agent-config-guide.md) | agents.yaml format, fields, templates, routing logic |
|
||||||
- [Demo Script](docs/demo-script.md) -- Step-by-step live demo walkthrough
|
| [OpenAPI Import Guide](docs/openapi-import-guide.md) | Auto-discovery workflow, REST API, SSRF protection |
|
||||||
|
| [Deployment Guide](docs/deployment.md) | Docker, local dev, production, HTTPS, backups, scaling |
|
||||||
|
| [Demo Script](docs/demo-script.md) | Step-by-step live demo walkthrough (5 scenes) |
|
||||||
|
| [UX Design System](docs/ux_design_system.md) | Color palette, typography, component patterns, CSS tokens |
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
149
backend/alembic.ini
Normal file
149
backend/alembic.ini
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts.
|
||||||
|
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||||
|
# format, relative to the token %(here)s which refers to the location of this
|
||||||
|
# ini file
|
||||||
|
script_location = %(here)s/alembic
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||||
|
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory. for multiple paths, the path separator
|
||||||
|
# is defined by "path_separator" below.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the tzdata library which can be installed by adding
|
||||||
|
# `alembic[tz]` to the pip requirements.
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to <script_location>/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "path_separator"
|
||||||
|
# below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||||
|
|
||||||
|
# path_separator; This indicates what character is used to split lists of file
|
||||||
|
# paths, including version_locations and prepend_sys_path within configparser
|
||||||
|
# files such as alembic.ini.
|
||||||
|
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||||
|
# to provide os-dependent path splitting.
|
||||||
|
#
|
||||||
|
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||||
|
# take place if path_separator is not present in alembic.ini. If this
|
||||||
|
# option is omitted entirely, fallback logic is as follows:
|
||||||
|
#
|
||||||
|
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||||
|
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||||
|
# behavior of splitting on spaces and/or commas.
|
||||||
|
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||||
|
# behavior of splitting on spaces, commas, or colons.
|
||||||
|
#
|
||||||
|
# Valid values for path_separator are:
|
||||||
|
#
|
||||||
|
# path_separator = :
|
||||||
|
# path_separator = ;
|
||||||
|
# path_separator = space
|
||||||
|
# path_separator = newline
|
||||||
|
#
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
path_separator = os
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
# database URL. This is consumed by the user-maintained env.py script only.
|
||||||
|
# other means of configuring database URLs may be customized within the env.py
|
||||||
|
# file.
|
||||||
|
sqlalchemy.url =
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = module
|
||||||
|
# ruff.module = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration. This is also consumed by the user-maintained
|
||||||
|
# env.py script only.
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
1
backend/alembic/README
Normal file
1
backend/alembic/README
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Generic single-database configuration.
|
||||||
67
backend/alembic/env.py
Normal file
67
backend/alembic/env.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Alembic environment configuration for smart-support."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# No SQLAlchemy ORM models -- we use raw DDL migrations
|
||||||
|
target_metadata = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_url() -> str:
|
||||||
|
"""Read DATABASE_URL from environment, falling back to alembic.ini."""
|
||||||
|
return os.environ.get("DATABASE_URL", "") or config.get_main_option(
|
||||||
|
"sqlalchemy.url", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
Configures the context with just a URL so that an Engine
|
||||||
|
is not required.
|
||||||
|
"""
|
||||||
|
url = _get_url()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode with a live database connection."""
|
||||||
|
configuration = config.get_section(config.config_ini_section, {})
|
||||||
|
configuration["sqlalchemy.url"] = _get_url()
|
||||||
|
|
||||||
|
connectable = engine_from_config(
|
||||||
|
configuration,
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
backend/alembic/script.py.mako
Normal file
28
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
92
backend/alembic/versions/001_initial_schema.py
Normal file
92
backend/alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Initial schema -- all application tables.
|
||||||
|
|
||||||
|
Revision ID: a1b2c3d4e5f6
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-04-06
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "a1b2c3d4e5f6"
|
||||||
|
down_revision: str | None = None
|
||||||
|
branch_labels: tuple[str, ...] | None = None
|
||||||
|
depends_on: tuple[str, ...] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
total_tokens INTEGER NOT NULL DEFAULT 0,
|
||||||
|
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||||
|
status TEXT NOT NULL DEFAULT 'active'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS active_interrupts (
|
||||||
|
interrupt_id TEXT PRIMARY KEY,
|
||||||
|
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
params JSONB NOT NULL DEFAULT '{}',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
resolved_at TIMESTAMPTZ,
|
||||||
|
resolution TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS analytics_events (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
thread_id TEXT NOT NULL,
|
||||||
|
event_type TEXT NOT NULL,
|
||||||
|
agent_name TEXT,
|
||||||
|
tool_name TEXT,
|
||||||
|
tokens_used INTEGER NOT NULL DEFAULT 0,
|
||||||
|
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||||
|
duration_ms INTEGER,
|
||||||
|
success BOOLEAN,
|
||||||
|
error_message TEXT,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Migration columns added in Phase 4
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE conversations
|
||||||
|
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
||||||
|
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
|
||||||
|
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP TABLE IF EXISTS analytics_events")
|
||||||
|
op.execute("DROP TABLE IF EXISTS sessions")
|
||||||
|
op.execute("DROP TABLE IF EXISTS active_interrupts")
|
||||||
|
op.execute("DROP TABLE IF EXISTS conversations")
|
||||||
@@ -4,16 +4,22 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
|
|
||||||
from app.analytics.queries import get_analytics
|
from app.analytics.queries import get_analytics
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.auth import require_admin_api_key
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/analytics", tags=["analytics"])
|
router = APIRouter(
|
||||||
|
prefix="/api/v1/analytics",
|
||||||
|
tags=["analytics"],
|
||||||
|
dependencies=[Depends(require_admin_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
||||||
_DEFAULT_RANGE = "7d"
|
_DEFAULT_RANGE = "7d"
|
||||||
@@ -25,10 +31,6 @@ async def _get_pool(request: Request) -> AsyncConnectionPool:
|
|||||||
return request.app.state.pool
|
return request.app.state.pool
|
||||||
|
|
||||||
|
|
||||||
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
|
||||||
return {"success": success, "data": data, "error": error}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_range(range_str: str) -> int:
|
def _parse_range(range_str: str) -> int:
|
||||||
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
|
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
|
||||||
match = _RANGE_PATTERN.match(range_str)
|
match = _RANGE_PATTERN.match(range_str)
|
||||||
@@ -55,4 +57,4 @@ async def analytics(
|
|||||||
range_days = _parse_range(range)
|
range_days = _parse_range(range)
|
||||||
pool = await _get_pool(request)
|
pool = await _get_pool(request)
|
||||||
result = await get_analytics(pool, range_days=range_days)
|
result = await get_analytics(pool, range_days=range_days)
|
||||||
return _envelope(asdict(result))
|
return envelope(asdict(result))
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from psycopg.types.json import Json
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
@@ -89,7 +91,7 @@ class PostgresAnalyticsRecorder:
|
|||||||
"duration_ms": duration_ms,
|
"duration_ms": duration_ms,
|
||||||
"success": success,
|
"success": success,
|
||||||
"error_message": error_message,
|
"error_message": error_message,
|
||||||
"metadata": metadata or {},
|
"metadata": Json(metadata or {}),
|
||||||
}
|
}
|
||||||
async with self._pool.connection() as conn:
|
async with self._pool.connection() as conn:
|
||||||
await conn.execute(_INSERT_SQL, params)
|
await conn.execute(_INSERT_SQL, params)
|
||||||
|
|||||||
10
backend/app/api_utils.py
Normal file
10
backend/app/api_utils.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""Shared API response helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
||||||
|
"""Wrap API response data in a standard envelope format."""
|
||||||
|
return {"success": success, "data": data, "error": error}
|
||||||
72
backend/app/auth.py
Normal file
72
backend/app/auth.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""API key authentication for admin endpoints and WebSocket connections."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_admin_api_key(request: Request) -> str:
|
||||||
|
"""Retrieve the configured admin API key from app settings.
|
||||||
|
|
||||||
|
Returns empty string if settings are not configured (test/dev mode).
|
||||||
|
"""
|
||||||
|
settings = getattr(request.app.state, "settings", None)
|
||||||
|
if settings is None:
|
||||||
|
return ""
|
||||||
|
key = getattr(settings, "admin_api_key", "")
|
||||||
|
return key if isinstance(key, str) else ""
|
||||||
|
|
||||||
|
|
||||||
|
async def require_admin_api_key(
|
||||||
|
request: Request,
|
||||||
|
api_key: Annotated[str | None, Depends(_API_KEY_HEADER)] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Dependency that enforces API key authentication on admin endpoints.
|
||||||
|
|
||||||
|
Skips validation when no admin_api_key is configured (dev mode).
|
||||||
|
"""
|
||||||
|
expected = _get_admin_api_key(request)
|
||||||
|
if not expected:
|
||||||
|
return
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing X-API-Key header",
|
||||||
|
)
|
||||||
|
if not secrets.compare_digest(api_key, expected):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_ws_token(
|
||||||
|
ws: WebSocket,
|
||||||
|
token: str | None = Query(default=None),
|
||||||
|
) -> None:
|
||||||
|
"""Verify WebSocket connection token from query parameter.
|
||||||
|
|
||||||
|
Skips validation when no admin_api_key is configured (dev mode).
|
||||||
|
Usage: ws://host/ws?token=<api_key>
|
||||||
|
"""
|
||||||
|
settings = ws.app.state.settings
|
||||||
|
expected = settings.admin_api_key
|
||||||
|
if not expected:
|
||||||
|
return
|
||||||
|
|
||||||
|
if token is None or not secrets.compare_digest(token, expected):
|
||||||
|
await ws.close(code=4001, reason="Unauthorized")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid or missing WebSocket token",
|
||||||
|
)
|
||||||
@@ -17,7 +17,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
database_url: str
|
database_url: str
|
||||||
|
|
||||||
llm_provider: Literal["anthropic", "openai", "google"] = "anthropic"
|
llm_provider: Literal["anthropic", "openai", "azure_openai", "google"] = "anthropic"
|
||||||
llm_model: str = "claude-sonnet-4-6"
|
llm_model: str = "claude-sonnet-4-6"
|
||||||
|
|
||||||
session_ttl_minutes: int = 30
|
session_ttl_minutes: int = 30
|
||||||
@@ -32,8 +32,16 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
template_name: str = ""
|
template_name: str = ""
|
||||||
|
|
||||||
|
log_format: str = "console" # "console" for dev, "json" for production
|
||||||
|
|
||||||
|
admin_api_key: str = ""
|
||||||
|
|
||||||
anthropic_api_key: str = ""
|
anthropic_api_key: str = ""
|
||||||
openai_api_key: str = ""
|
openai_api_key: str = ""
|
||||||
|
azure_openai_api_key: str = ""
|
||||||
|
azure_openai_endpoint: str = ""
|
||||||
|
azure_openai_api_version: str = "2024-12-01-preview"
|
||||||
|
azure_openai_deployment: str = ""
|
||||||
google_api_key: str = ""
|
google_api_key: str = ""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -41,6 +49,7 @@ class Settings(BaseSettings):
|
|||||||
key_map = {
|
key_map = {
|
||||||
"anthropic": self.anthropic_api_key,
|
"anthropic": self.anthropic_api_key,
|
||||||
"openai": self.openai_api_key,
|
"openai": self.openai_api_key,
|
||||||
|
"azure_openai": self.azure_openai_api_key,
|
||||||
"google": self.google_api_key,
|
"google": self.google_api_key,
|
||||||
}
|
}
|
||||||
key = key_map.get(self.llm_provider, "")
|
key = key_map.get(self.llm_provider, "")
|
||||||
@@ -49,4 +58,13 @@ class Settings(BaseSettings):
|
|||||||
f"API key for provider '{self.llm_provider}' is required. "
|
f"API key for provider '{self.llm_provider}' is required. "
|
||||||
f"Set the corresponding environment variable."
|
f"Set the corresponding environment variable."
|
||||||
)
|
)
|
||||||
|
if self.llm_provider == "azure_openai":
|
||||||
|
if not self.azure_openai_endpoint:
|
||||||
|
raise ValueError(
|
||||||
|
"AZURE_OPENAI_ENDPOINT is required for azure_openai provider."
|
||||||
|
)
|
||||||
|
if not self.azure_openai_deployment:
|
||||||
|
raise ValueError(
|
||||||
|
"AZURE_OPENAI_DEPLOYMENT is required for azure_openai provider."
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
_ENSURE_SQL = """
|
_ENSURE_SQL = """
|
||||||
INSERT INTO conversations
|
INSERT INTO conversations
|
||||||
(thread_id, started_at, last_activity)
|
(thread_id, created_at, last_activity)
|
||||||
VALUES
|
VALUES
|
||||||
(%(thread_id)s, NOW(), NOW())
|
(%(thread_id)s, NOW(), NOW())
|
||||||
ON CONFLICT (thread_id) DO NOTHING
|
ON CONFLICT (thread_id) DO NOTHING
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
@@ -51,6 +52,15 @@ CREATE TABLE IF NOT EXISTS analytics_events (
|
|||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_SESSIONS_DDL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
_CONVERSATIONS_MIGRATION_DDL = """
|
_CONVERSATIONS_MIGRATION_DDL = """
|
||||||
ALTER TABLE conversations
|
ALTER TABLE conversations
|
||||||
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
||||||
@@ -79,10 +89,22 @@ async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
|
|||||||
return checkpointer
|
return checkpointer
|
||||||
|
|
||||||
|
|
||||||
|
def run_alembic_migrations(database_url: str) -> None:
|
||||||
|
"""Run Alembic migrations to head."""
|
||||||
|
from alembic.config import Config
|
||||||
|
|
||||||
|
from alembic import command
|
||||||
|
|
||||||
|
alembic_cfg = Config(str(Path(__file__).parent.parent / "alembic.ini"))
|
||||||
|
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
||||||
|
command.upgrade(alembic_cfg, "head")
|
||||||
|
|
||||||
|
|
||||||
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
||||||
"""Create application-specific tables and apply migrations."""
|
"""Create application-specific tables and apply migrations."""
|
||||||
async with pool.connection() as conn:
|
async with pool.connection() as conn:
|
||||||
await conn.execute(_CONVERSATIONS_DDL)
|
await conn.execute(_CONVERSATIONS_DDL)
|
||||||
await conn.execute(_INTERRUPTS_DDL)
|
await conn.execute(_INTERRUPTS_DDL)
|
||||||
|
await conn.execute(_SESSIONS_DDL)
|
||||||
await conn.execute(_ANALYTICS_EVENTS_DDL)
|
await conn.execute(_ANALYTICS_EVENTS_DDL)
|
||||||
await conn.execute(_CONVERSATIONS_MIGRATION_DDL)
|
await conn.execute(_CONVERSATIONS_MIGRATION_DDL)
|
||||||
|
|||||||
@@ -3,14 +3,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import structlog
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
class EscalationPayload(BaseModel, frozen=True):
|
class EscalationPayload(BaseModel, frozen=True):
|
||||||
|
|||||||
@@ -2,23 +2,24 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langchain.agents import create_agent
|
||||||
from langgraph_supervisor import create_supervisor
|
from langgraph_supervisor import create_supervisor
|
||||||
|
|
||||||
from app.agents import get_tools_by_names
|
from app.agents import get_tools_by_names
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
|
||||||
|
|
||||||
from app.intent import ClassificationResult, IntentClassifier
|
from app.intent import IntentClassifier
|
||||||
from app.registry import AgentRegistry
|
from app.registry import AgentRegistry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import structlog
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
SUPERVISOR_PROMPT = (
|
SUPERVISOR_PROMPT = (
|
||||||
"You are a customer support supervisor. "
|
"You are a customer support supervisor. "
|
||||||
@@ -59,11 +60,11 @@ def build_agent_nodes(
|
|||||||
f"Permission level: {agent_config.permission}."
|
f"Permission level: {agent_config.permission}."
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_node = create_react_agent(
|
agent_node = create_agent(
|
||||||
model=llm,
|
model=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
name=agent_config.name,
|
name=agent_config.name,
|
||||||
prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
agent_nodes.append(agent_node)
|
agent_nodes.append(agent_node)
|
||||||
|
|
||||||
@@ -75,12 +76,11 @@ def build_graph(
|
|||||||
llm: BaseChatModel,
|
llm: BaseChatModel,
|
||||||
checkpointer: AsyncPostgresSaver,
|
checkpointer: AsyncPostgresSaver,
|
||||||
intent_classifier: IntentClassifier | None = None,
|
intent_classifier: IntentClassifier | None = None,
|
||||||
) -> CompiledStateGraph:
|
) -> GraphContext:
|
||||||
"""Build and compile the LangGraph supervisor graph.
|
"""Build and compile the LangGraph supervisor graph.
|
||||||
|
|
||||||
If an intent_classifier is provided, the supervisor prompt is enhanced
|
Returns a GraphContext that bundles the compiled graph with its
|
||||||
with agent descriptions for better routing. The classifier is stored
|
associated registry and intent classifier.
|
||||||
for use by the routing layer (ws_handler).
|
|
||||||
"""
|
"""
|
||||||
agent_nodes = build_agent_nodes(registry, llm)
|
agent_nodes = build_agent_nodes(registry, llm)
|
||||||
agent_descriptions = _format_agent_descriptions(registry)
|
agent_descriptions = _format_agent_descriptions(registry)
|
||||||
@@ -88,34 +88,16 @@ def build_graph(
|
|||||||
prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions)
|
prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions)
|
||||||
|
|
||||||
workflow = create_supervisor(
|
workflow = create_supervisor(
|
||||||
agent_nodes,
|
agents=agent_nodes,
|
||||||
model=llm,
|
model=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
output_mode="full_history",
|
output_mode="full_history",
|
||||||
)
|
)
|
||||||
|
|
||||||
graph = workflow.compile(checkpointer=checkpointer)
|
compiled = workflow.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
# Attach classifier and registry to graph for use by ws_handler
|
return GraphContext(
|
||||||
graph.intent_classifier = intent_classifier # type: ignore[attr-defined]
|
graph=compiled,
|
||||||
graph.agent_registry = registry # type: ignore[attr-defined]
|
registry=registry,
|
||||||
|
intent_classifier=intent_classifier,
|
||||||
return graph
|
)
|
||||||
|
|
||||||
|
|
||||||
async def classify_intent(
|
|
||||||
graph: CompiledStateGraph,
|
|
||||||
message: str,
|
|
||||||
) -> ClassificationResult | None:
|
|
||||||
"""Classify user intent using the graph's attached classifier.
|
|
||||||
|
|
||||||
Returns None if no classifier is configured.
|
|
||||||
"""
|
|
||||||
classifier = getattr(graph, "intent_classifier", None)
|
|
||||||
registry = getattr(graph, "agent_registry", None)
|
|
||||||
|
|
||||||
if classifier is None or registry is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
agents = registry.list_agents()
|
|
||||||
return await classifier.classify(message, agents)
|
|
||||||
|
|||||||
36
backend/app/graph_context.py
Normal file
36
backend/app/graph_context.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""GraphContext -- typed wrapper around the compiled graph and its dependencies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
|
from app.intent import ClassificationResult, IntentClassifier
|
||||||
|
from app.registry import AgentRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GraphContext:
|
||||||
|
"""Bundles the compiled LangGraph graph with its associated services.
|
||||||
|
|
||||||
|
Replaces the previous pattern of monkey-patching attributes onto the
|
||||||
|
third-party CompiledStateGraph instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph: CompiledStateGraph
|
||||||
|
registry: AgentRegistry
|
||||||
|
intent_classifier: IntentClassifier | None = None
|
||||||
|
|
||||||
|
async def classify_intent(self, message: str) -> ClassificationResult | None:
|
||||||
|
"""Classify user intent using the attached classifier.
|
||||||
|
|
||||||
|
Returns None if no classifier is configured.
|
||||||
|
"""
|
||||||
|
if self.intent_classifier is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agents = self.registry.list_agents()
|
||||||
|
return await self.intent_classifier.classify(message, agents)
|
||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Protocol
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -12,7 +11,9 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from app.registry import AgentConfig
|
from app.registry import AgentConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import structlog
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
CLASSIFICATION_PROMPT = (
|
CLASSIFICATION_PROMPT = (
|
||||||
"You are an intent classifier for a customer support system.\n"
|
"You are an intent classifier for a customer support system.\n"
|
||||||
|
|||||||
@@ -1,10 +1,18 @@
|
|||||||
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration."""
|
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.
|
||||||
|
|
||||||
|
Provides both in-memory (InterruptManager) and PostgreSQL-backed
|
||||||
|
(PgInterruptManager) implementations behind a common Protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -28,8 +36,32 @@ class InterruptStatus:
|
|||||||
record: InterruptRecord
|
record: InterruptRecord
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptManagerProtocol(Protocol):
|
||||||
|
"""Protocol for interrupt TTL management."""
|
||||||
|
|
||||||
|
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
|
||||||
|
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
|
||||||
|
def resolve(self, thread_id: str) -> None: ...
|
||||||
|
def has_pending(self, thread_id: str) -> bool: ...
|
||||||
|
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
|
||||||
|
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||||
|
return {
|
||||||
|
"type": "interrupt_expired",
|
||||||
|
"thread_id": expired_record.thread_id,
|
||||||
|
"action": expired_record.action,
|
||||||
|
"message": (
|
||||||
|
f"The approval request for '{expired_record.action}' has expired "
|
||||||
|
f"after {expired_record.ttl_seconds // 60} minutes. "
|
||||||
|
f"Would you like to try again?"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class InterruptManager:
|
class InterruptManager:
|
||||||
"""Manages interrupt TTL with auto-expiration.
|
"""In-memory interrupt manager for single-worker development.
|
||||||
|
|
||||||
Complements SessionManager -- this tracks interrupt-specific TTL
|
Complements SessionManager -- this tracks interrupt-specific TTL
|
||||||
while SessionManager handles session-level TTL.
|
while SessionManager handles session-level TTL.
|
||||||
@@ -62,11 +94,9 @@ class InterruptManager:
|
|||||||
record = self._interrupts.get(thread_id)
|
record = self._interrupts.get(thread_id)
|
||||||
if record is None:
|
if record is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elapsed = time.time() - record.created_at
|
elapsed = time.time() - record.created_at
|
||||||
remaining = max(0.0, record.ttl_seconds - elapsed)
|
remaining = max(0.0, record.ttl_seconds - elapsed)
|
||||||
is_expired = elapsed > record.ttl_seconds
|
is_expired = elapsed > record.ttl_seconds
|
||||||
|
|
||||||
return InterruptStatus(
|
return InterruptStatus(
|
||||||
is_expired=is_expired,
|
is_expired=is_expired,
|
||||||
remaining_seconds=remaining,
|
remaining_seconds=remaining,
|
||||||
@@ -84,28 +114,17 @@ class InterruptManager:
|
|||||||
now = time.time()
|
now = time.time()
|
||||||
expired: list[InterruptRecord] = []
|
expired: list[InterruptRecord] = []
|
||||||
active: dict[str, InterruptRecord] = {}
|
active: dict[str, InterruptRecord] = {}
|
||||||
|
|
||||||
for thread_id, record in self._interrupts.items():
|
for thread_id, record in self._interrupts.items():
|
||||||
if now - record.created_at > record.ttl_seconds:
|
if now - record.created_at > record.ttl_seconds:
|
||||||
expired.append(record)
|
expired.append(record)
|
||||||
else:
|
else:
|
||||||
active[thread_id] = record
|
active[thread_id] = record
|
||||||
|
|
||||||
self._interrupts = active
|
self._interrupts = active
|
||||||
return tuple(expired)
|
return tuple(expired)
|
||||||
|
|
||||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||||
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||||
return {
|
return _build_retry_prompt(expired_record)
|
||||||
"type": "interrupt_expired",
|
|
||||||
"thread_id": expired_record.thread_id,
|
|
||||||
"action": expired_record.action,
|
|
||||||
"message": (
|
|
||||||
f"The approval request for '{expired_record.action}' has expired "
|
|
||||||
f"after {expired_record.ttl_seconds // 60} minutes. "
|
|
||||||
f"Would you like to try again?"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
def has_pending(self, thread_id: str) -> bool:
|
def has_pending(self, thread_id: str) -> bool:
|
||||||
"""Check if a thread has a pending (non-expired) interrupt."""
|
"""Check if a thread has a pending (non-expired) interrupt."""
|
||||||
@@ -113,3 +132,137 @@ class InterruptManager:
|
|||||||
if status is None:
|
if status is None:
|
||||||
return False
|
return False
|
||||||
return not status.is_expired
|
return not status.is_expired
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for explicit naming
|
||||||
|
InMemoryInterruptManager = InterruptManager
|
||||||
|
|
||||||
|
|
||||||
|
class PgInterruptManager:
|
||||||
|
"""PostgreSQL-backed interrupt manager for multi-worker production.
|
||||||
|
|
||||||
|
Uses the existing active_interrupts table defined in db.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
ttl_seconds: int = 1800,
|
||||||
|
) -> None:
|
||||||
|
self._pool = pool
|
||||||
|
self._ttl_seconds = ttl_seconds
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
action: str,
|
||||||
|
params: dict,
|
||||||
|
) -> InterruptRecord:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._register(thread_id, action, params)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _register(
|
||||||
|
self, thread_id: str, action: str, params: dict
|
||||||
|
) -> InterruptRecord:
|
||||||
|
import json
|
||||||
|
|
||||||
|
record = InterruptRecord(
|
||||||
|
interrupt_id=uuid.uuid4().hex,
|
||||||
|
thread_id=thread_id,
|
||||||
|
action=action,
|
||||||
|
params=dict(params),
|
||||||
|
created_at=time.time(),
|
||||||
|
ttl_seconds=self._ttl_seconds,
|
||||||
|
)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
|
||||||
|
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
|
||||||
|
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
|
||||||
|
DO UPDATE SET
|
||||||
|
interrupt_id = %(iid)s,
|
||||||
|
action = %(action)s,
|
||||||
|
params = %(params)s,
|
||||||
|
created_at = NOW(),
|
||||||
|
resolved_at = NULL
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"iid": record.interrupt_id,
|
||||||
|
"tid": thread_id,
|
||||||
|
"action": action,
|
||||||
|
"params": json.dumps(params),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return record
|
||||||
|
|
||||||
|
def check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._check_status(thread_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT interrupt_id, action, params, created_at
|
||||||
|
FROM active_interrupts
|
||||||
|
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||||
|
ORDER BY created_at DESC LIMIT 1
|
||||||
|
""",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
created_at = row["created_at"].timestamp()
|
||||||
|
elapsed = time.time() - created_at
|
||||||
|
remaining = max(0.0, self._ttl_seconds - elapsed)
|
||||||
|
is_expired = elapsed > self._ttl_seconds
|
||||||
|
|
||||||
|
record = InterruptRecord(
|
||||||
|
interrupt_id=row["interrupt_id"],
|
||||||
|
thread_id=thread_id,
|
||||||
|
action=row["action"],
|
||||||
|
params=row["params"] if isinstance(row["params"], dict) else {},
|
||||||
|
created_at=created_at,
|
||||||
|
ttl_seconds=self._ttl_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return InterruptStatus(
|
||||||
|
is_expired=is_expired,
|
||||||
|
remaining_seconds=remaining,
|
||||||
|
record=record,
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve(self, thread_id: str) -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
|
||||||
|
|
||||||
|
async def _resolve(self, thread_id: str) -> None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE active_interrupts
|
||||||
|
SET resolved_at = NOW(), resolution = 'resolved'
|
||||||
|
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||||
|
""",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||||
|
return _build_retry_prompt(expired_record)
|
||||||
|
|
||||||
|
def has_pending(self, thread_id: str) -> bool:
|
||||||
|
status = self.check_status(thread_id)
|
||||||
|
if status is None:
|
||||||
|
return False
|
||||||
|
return not status.is_expired
|
||||||
|
|||||||
@@ -31,6 +31,16 @@ def create_llm(settings: Settings) -> BaseChatModel:
|
|||||||
api_key=settings.openai_api_key,
|
api_key=settings.openai_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if provider == "azure_openai":
|
||||||
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
return AzureChatOpenAI(
|
||||||
|
azure_deployment=settings.azure_openai_deployment,
|
||||||
|
azure_endpoint=settings.azure_openai_endpoint,
|
||||||
|
api_key=settings.azure_openai_api_key,
|
||||||
|
api_version=settings.azure_openai_api_version,
|
||||||
|
)
|
||||||
|
|
||||||
if provider == "google":
|
if provider == "google":
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
@@ -39,4 +49,7 @@ def create_llm(settings: Settings) -> BaseChatModel:
|
|||||||
google_api_key=settings.google_api_key,
|
google_api_key=settings.google_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.")
|
raise ValueError(
|
||||||
|
f"Unknown LLM provider: '{provider}'. "
|
||||||
|
"Use 'anthropic', 'openai', 'azure_openai', or 'google'."
|
||||||
|
)
|
||||||
|
|||||||
57
backend/app/logging_config.py
Normal file
57
backend/app/logging_config.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Structured logging configuration using structlog."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(log_format: str = "console") -> None:
|
||||||
|
"""Configure structlog with stdlib integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_format: "console" for human-readable dev output,
|
||||||
|
"json" for machine-parseable production output.
|
||||||
|
"""
|
||||||
|
shared_processors: list[structlog.types.Processor] = [
|
||||||
|
structlog.contextvars.merge_contextvars,
|
||||||
|
structlog.stdlib.filter_by_level,
|
||||||
|
structlog.stdlib.add_logger_name,
|
||||||
|
structlog.stdlib.add_log_level,
|
||||||
|
structlog.processors.TimeStamper(fmt="iso"),
|
||||||
|
structlog.processors.StackInfoRenderer(),
|
||||||
|
structlog.processors.format_exc_info,
|
||||||
|
structlog.processors.UnicodeDecoder(),
|
||||||
|
]
|
||||||
|
|
||||||
|
if log_format == "json":
|
||||||
|
renderer: structlog.types.Processor = structlog.processors.JSONRenderer()
|
||||||
|
else:
|
||||||
|
renderer = structlog.dev.ConsoleRenderer()
|
||||||
|
|
||||||
|
structlog.configure(
|
||||||
|
processors=[
|
||||||
|
*shared_processors,
|
||||||
|
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||||
|
],
|
||||||
|
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||||
|
wrapper_class=structlog.stdlib.BoundLogger,
|
||||||
|
cache_logger_on_first_use=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
formatter = structlog.stdlib.ProcessorFormatter(
|
||||||
|
processors=[
|
||||||
|
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||||
|
renderer,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.handlers.clear()
|
||||||
|
root_logger.addHandler(handler)
|
||||||
|
root_logger.setLevel(logging.INFO)
|
||||||
@@ -2,47 +2,78 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import asyncio
|
||||||
|
import contextlib
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from app.analytics.api import router as analytics_router
|
from app.analytics.api import router as analytics_router
|
||||||
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||||
|
from app.api_utils import envelope
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
from app.conversation_tracker import PostgresConversationTracker
|
from app.conversation_tracker import PostgresConversationTracker
|
||||||
from app.db import create_checkpointer, create_pool, setup_app_tables
|
from app.db import create_checkpointer, create_pool, run_alembic_migrations
|
||||||
from app.escalation import NoOpEscalator, WebhookEscalator
|
from app.escalation import NoOpEscalator, WebhookEscalator
|
||||||
from app.graph import build_graph
|
from app.graph import build_graph
|
||||||
from app.intent import LLMIntentClassifier
|
from app.intent import LLMIntentClassifier
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.llm import create_llm
|
from app.llm import create_llm
|
||||||
|
from app.logging_config import configure_logging
|
||||||
from app.openapi.review_api import router as openapi_router
|
from app.openapi.review_api import router as openapi_router
|
||||||
from app.registry import AgentRegistry
|
from app.registry import AgentRegistry
|
||||||
from app.replay.api import router as replay_router
|
from app.replay.api import router as replay_router
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import structlog
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
|
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
|
||||||
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
|
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
|
||||||
|
|
||||||
|
|
||||||
|
async def _interrupt_cleanup_loop(
|
||||||
|
interrupt_manager: InterruptManager,
|
||||||
|
interval: int = 60,
|
||||||
|
) -> None:
|
||||||
|
"""Periodically remove expired interrupts in the background.
|
||||||
|
|
||||||
|
Runs until cancelled. Catches all exceptions to prevent the task
|
||||||
|
from dying unexpectedly.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
try:
|
||||||
|
expired = interrupt_manager.cleanup_expired()
|
||||||
|
if expired:
|
||||||
|
logger.info(
|
||||||
|
"Cleaned up %d expired interrupt(s)",
|
||||||
|
len(expired),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during interrupt cleanup")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
configure_logging(settings.log_format)
|
||||||
|
|
||||||
pool = await create_pool(settings)
|
pool = await create_pool(settings)
|
||||||
checkpointer = await create_checkpointer(pool)
|
checkpointer = await create_checkpointer(pool)
|
||||||
await setup_app_tables(pool)
|
run_alembic_migrations(settings.database_url)
|
||||||
|
|
||||||
# Load agents from template or default YAML
|
# Load agents from template or default YAML
|
||||||
if settings.template_name:
|
if settings.template_name:
|
||||||
@@ -52,7 +83,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
|
|
||||||
llm = create_llm(settings)
|
llm = create_llm(settings)
|
||||||
intent_classifier = LLMIntentClassifier(llm)
|
intent_classifier = LLMIntentClassifier(llm)
|
||||||
graph = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
||||||
|
|
||||||
session_manager = SessionManager(
|
session_manager = SessionManager(
|
||||||
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
||||||
@@ -71,7 +102,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
else:
|
else:
|
||||||
escalator = NoOpEscalator()
|
escalator = NoOpEscalator()
|
||||||
|
|
||||||
app.state.graph = graph
|
app.state.graph_ctx = graph_ctx
|
||||||
app.state.session_manager = session_manager
|
app.state.session_manager = session_manager
|
||||||
app.state.interrupt_manager = interrupt_manager
|
app.state.interrupt_manager = interrupt_manager
|
||||||
app.state.escalator = escalator
|
app.state.escalator = escalator
|
||||||
@@ -88,12 +119,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
settings.template_name or "(default)",
|
settings.template_name or "(default)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cleanup_task = asyncio.create_task(
|
||||||
|
_interrupt_cleanup_loop(interrupt_manager),
|
||||||
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
cleanup_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await cleanup_task
|
||||||
|
|
||||||
await pool.close()
|
await pool.close()
|
||||||
|
|
||||||
|
|
||||||
_VERSION = "0.5.0"
|
_VERSION = "0.6.0"
|
||||||
|
|
||||||
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
|
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
|
||||||
|
|
||||||
@@ -102,35 +141,72 @@ app.include_router(replay_router)
|
|||||||
app.include_router(analytics_router)
|
app.include_router(analytics_router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/health")
|
@app.exception_handler(HTTPException)
|
||||||
|
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
"""Wrap HTTPException in standard envelope format."""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
"""Wrap validation errors in standard envelope format."""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
"""Catch-all handler -- never leak stack traces."""
|
||||||
|
logger.exception("Unhandled exception: %s", exc)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=envelope(None, success=False, error="Internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/v1/health")
|
||||||
def health_check() -> dict:
|
def health_check() -> dict:
|
||||||
"""Health check endpoint for load balancers and monitoring."""
|
"""Health check endpoint for load balancers and monitoring."""
|
||||||
return {"status": "ok", "version": _VERSION}
|
return {"status": "ok", "version": _VERSION}
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/ws")
|
||||||
async def websocket_endpoint(ws: WebSocket) -> None:
|
async def websocket_endpoint(
|
||||||
await ws.accept()
|
ws: WebSocket,
|
||||||
graph = app.state.graph
|
token: str | None = Query(default=None),
|
||||||
session_manager = app.state.session_manager
|
) -> None:
|
||||||
interrupt_manager = app.state.interrupt_manager
|
|
||||||
settings = app.state.settings
|
settings = app.state.settings
|
||||||
|
|
||||||
|
# Verify WebSocket token when admin_api_key is configured
|
||||||
|
if settings.admin_api_key:
|
||||||
|
import secrets as _secrets
|
||||||
|
|
||||||
|
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
|
||||||
|
await ws.close(code=4001, reason="Unauthorized")
|
||||||
|
return
|
||||||
|
|
||||||
|
await ws.accept()
|
||||||
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
||||||
|
|
||||||
analytics_recorder = app.state.analytics_recorder
|
ws_ctx = WebSocketContext(
|
||||||
conversation_tracker = app.state.conversation_tracker
|
graph_ctx=app.state.graph_ctx,
|
||||||
pool = app.state.pool
|
session_manager=app.state.session_manager,
|
||||||
|
callback_handler=callback_handler,
|
||||||
|
interrupt_manager=app.state.interrupt_manager,
|
||||||
|
analytics_recorder=app.state.analytics_recorder,
|
||||||
|
conversation_tracker=app.state.conversation_tracker,
|
||||||
|
pool=app.state.pool,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
raw_data = await ws.receive_text()
|
raw_data = await ws.receive_text()
|
||||||
await dispatch_message(
|
await dispatch_message(ws, ws_ctx, raw_data)
|
||||||
ws, graph, session_manager, callback_handler, raw_data,
|
|
||||||
interrupt_manager=interrupt_manager,
|
|
||||||
analytics_recorder=analytics_recorder,
|
|
||||||
conversation_tracker=conversation_tracker,
|
|
||||||
pool=pool,
|
|
||||||
)
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.info("WebSocket client disconnected")
|
logger.info("WebSocket client disconnected")
|
||||||
|
|
||||||
|
|||||||
@@ -8,13 +8,14 @@ classifier and an LLM-backed classifier with heuristic fallback.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
from app.openapi.models import ClassificationResult, EndpointInfo
|
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
||||||
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ Each stage updates the job status and calls the on_progress callback.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
|
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
|
||||||
from app.openapi.fetcher import fetch_spec
|
from app.openapi.fetcher import fetch_spec
|
||||||
from app.openapi.models import ImportJob
|
from app.openapi.models import ImportJob
|
||||||
@@ -17,7 +18,7 @@ from app.openapi.parser import parse_endpoints
|
|||||||
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
|
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
|
||||||
from app.openapi.validator import validate_spec
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
ProgressCallback = Callable[[str, ImportJob], None] | None
|
ProgressCallback = Callable[[str, ImportJob], None] | None
|
||||||
|
|
||||||
|
|||||||
@@ -10,20 +10,26 @@ Exposes endpoints for:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
import structlog
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
from app.auth import require_admin_api_key
|
||||||
|
from app.openapi.generator import generate_agent_yaml, generate_tool_code
|
||||||
from app.openapi.importer import ImportOrchestrator
|
from app.openapi.importer import ImportOrchestrator
|
||||||
from app.openapi.models import ClassificationResult, ImportJob
|
from app.openapi.models import ClassificationResult, ImportJob
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/openapi", tags=["openapi"])
|
router = APIRouter(
|
||||||
|
prefix="/api/v1/openapi",
|
||||||
|
tags=["openapi"],
|
||||||
|
dependencies=[Depends(require_admin_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
# In-memory store: job_id -> job dict, guarded by async lock
|
# In-memory store: job_id -> job dict, guarded by async lock
|
||||||
_job_store: dict[str, dict] = {}
|
_job_store: dict[str, dict] = {}
|
||||||
@@ -235,11 +241,42 @@ async def update_classification(
|
|||||||
|
|
||||||
@router.post("/jobs/{job_id}/approve")
|
@router.post("/jobs/{job_id}/approve")
|
||||||
async def approve_job(job_id: str) -> dict:
|
async def approve_job(job_id: str) -> dict:
|
||||||
"""Approve a job's classifications and trigger tool generation."""
|
"""Approve a job's classifications and trigger tool generation.
|
||||||
|
|
||||||
|
Generates Python tool code for each classified endpoint and
|
||||||
|
produces an agent YAML configuration snippet.
|
||||||
|
"""
|
||||||
job = _job_store.get(job_id)
|
job = _job_store.get(job_id)
|
||||||
if job is None:
|
if job is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||||
|
|
||||||
updated_job = {**job, "status": "approved"}
|
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||||
|
if not classifications:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No classifications to approve. Import must complete first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
base_url = job["spec_url"].rsplit("/", 1)[0]
|
||||||
|
generated_tools = []
|
||||||
|
for clf in classifications:
|
||||||
|
tool = generate_tool_code(clf, base_url)
|
||||||
|
generated_tools.append({
|
||||||
|
"function_name": tool.function_name,
|
||||||
|
"agent_group": clf.agent_group,
|
||||||
|
"code": tool.code,
|
||||||
|
})
|
||||||
|
|
||||||
|
agent_yaml = generate_agent_yaml(tuple(classifications), base_url)
|
||||||
|
|
||||||
|
updated_job = {
|
||||||
|
**job,
|
||||||
|
"status": "approved",
|
||||||
|
"generated_tools": generated_tools,
|
||||||
|
"agent_yaml": agent_yaml,
|
||||||
|
}
|
||||||
_job_store[job_id] = updated_job
|
_job_store[job_id] = updated_job
|
||||||
return _job_to_response(updated_job)
|
|
||||||
|
response = _job_to_response(updated_job)
|
||||||
|
response["generated_tools_count"] = len(generated_tools)
|
||||||
|
return response
|
||||||
|
|||||||
@@ -3,16 +3,27 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Annotated, Any
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.auth import require_admin_api_key
|
||||||
|
|
||||||
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["replay"])
|
router = APIRouter(
|
||||||
|
prefix="/api/v1",
|
||||||
|
tags=["replay"],
|
||||||
|
dependencies=[Depends(require_admin_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
|
_COUNT_CONVERSATIONS_SQL = """
|
||||||
|
SELECT COUNT(*) FROM conversations
|
||||||
|
"""
|
||||||
|
|
||||||
_LIST_CONVERSATIONS_SQL = """
|
_LIST_CONVERSATIONS_SQL = """
|
||||||
SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd
|
SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd
|
||||||
@@ -34,10 +45,6 @@ async def get_pool(request: Request) -> AsyncConnectionPool:
|
|||||||
return request.app.state.pool
|
return request.app.state.pool
|
||||||
|
|
||||||
|
|
||||||
def _envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
|
||||||
return {"success": success, "data": data, "error": error}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/conversations")
|
@router.get("/conversations")
|
||||||
async def list_conversations(
|
async def list_conversations(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -48,13 +55,22 @@ async def list_conversations(
|
|||||||
pool = await get_pool(request)
|
pool = await get_pool(request)
|
||||||
offset = (page - 1) * per_page
|
offset = (page - 1) * per_page
|
||||||
async with pool.connection() as conn:
|
async with pool.connection() as conn:
|
||||||
|
count_cursor = await conn.execute(_COUNT_CONVERSATIONS_SQL)
|
||||||
|
count_row = await count_cursor.fetchone()
|
||||||
|
total = count_row[0] if count_row else 0
|
||||||
|
|
||||||
cursor = await conn.execute(
|
cursor = await conn.execute(
|
||||||
_LIST_CONVERSATIONS_SQL,
|
_LIST_CONVERSATIONS_SQL,
|
||||||
{"limit": per_page, "offset": offset},
|
{"limit": per_page, "offset": offset},
|
||||||
)
|
)
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
return _envelope([dict(row) for row in rows])
|
return envelope({
|
||||||
|
"conversations": [dict(row) for row in rows],
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@router.get("/replay/{thread_id}")
|
@router.get("/replay/{thread_id}")
|
||||||
@@ -106,4 +122,4 @@ async def get_replay(
|
|||||||
for s in page_steps
|
for s in page_steps
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
return _envelope(data)
|
return envelope(data)
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import structlog
|
||||||
|
|
||||||
from app.replay.models import ReplayStep, StepType
|
from app.replay.models import ReplayStep, StepType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
_EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z"
|
_EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z"
|
||||||
|
|
||||||
|
|||||||
131
backend/app/safety.py
Normal file
131
backend/app/safety.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Safety policy for destructive-action confirmation rules.
|
||||||
|
|
||||||
|
This module makes the confirmation rules explicit and auditable. Every tool
|
||||||
|
call passes through ``requires_confirmation`` before execution to decide
|
||||||
|
whether human-in-the-loop approval is needed.
|
||||||
|
|
||||||
|
Policy summary
|
||||||
|
--------------
|
||||||
|
- ``read`` actions: execute immediately, no confirmation required.
|
||||||
|
- ``write`` actions: require human approval via interrupt gate.
|
||||||
|
- OpenAPI-imported endpoints: use ``needs_interrupt`` from classification.
|
||||||
|
- If both the agent permission AND the endpoint classification agree
|
||||||
|
the action is read-only, it executes without confirmation.
|
||||||
|
|
||||||
|
Multi-intent semantics
|
||||||
|
----------------------
|
||||||
|
When a user message contains multiple intents (e.g. "cancel my order and
|
||||||
|
apply a refund"), the supervisor routes them sequentially. Each action is
|
||||||
|
evaluated independently:
|
||||||
|
- If a write action is blocked by an interrupt, subsequent actions in the
|
||||||
|
same message are paused until the interrupt is resolved.
|
||||||
|
- Read actions that follow a blocked write are also paused (sequential,
|
||||||
|
not best-effort) to preserve causal ordering.
|
||||||
|
- If an interrupt is rejected, the remaining actions are skipped and the
|
||||||
|
agent informs the user.
|
||||||
|
|
||||||
|
MCP error taxonomy
|
||||||
|
------------------
|
||||||
|
Tool execution errors are classified into categories for retry decisions:
|
||||||
|
|
||||||
|
- ``transient``: network timeouts, rate limits, 5xx -- retryable up to 3 times.
|
||||||
|
- ``validation``: bad parameters, 4xx -- not retryable, report to user.
|
||||||
|
- ``auth``: 401/403 -- not retryable, escalate.
|
||||||
|
- ``unknown``: unclassified -- not retryable, log and escalate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ConfirmationPolicy:
|
||||||
|
"""Result of evaluating whether an action needs confirmation."""
|
||||||
|
|
||||||
|
requires_confirmation: bool
|
||||||
|
reason: str
|
||||||
|
|
||||||
|
|
||||||
|
def requires_confirmation(
|
||||||
|
*,
|
||||||
|
agent_permission: Literal["read", "write"],
|
||||||
|
needs_interrupt: bool | None = None,
|
||||||
|
) -> ConfirmationPolicy:
|
||||||
|
"""Determine whether an action requires human confirmation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
agent_permission:
|
||||||
|
The permission level of the agent executing the action.
|
||||||
|
needs_interrupt:
|
||||||
|
Override from OpenAPI classification. When ``None``, the decision
|
||||||
|
is based solely on ``agent_permission``.
|
||||||
|
"""
|
||||||
|
if needs_interrupt is not None:
|
||||||
|
if needs_interrupt:
|
||||||
|
return ConfirmationPolicy(
|
||||||
|
requires_confirmation=True,
|
||||||
|
reason="Endpoint classified as requiring human approval",
|
||||||
|
)
|
||||||
|
return ConfirmationPolicy(
|
||||||
|
requires_confirmation=False,
|
||||||
|
reason="Endpoint classified as safe (no interrupt needed)",
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_permission == "write":
|
||||||
|
return ConfirmationPolicy(
|
||||||
|
requires_confirmation=True,
|
||||||
|
reason="Write-permission agent actions require confirmation",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ConfirmationPolicy(
|
||||||
|
requires_confirmation=False,
|
||||||
|
reason="Read-only agent actions execute immediately",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- MCP Error Taxonomy ---
|
||||||
|
|
||||||
|
|
||||||
|
MCP_ERROR_CATEGORY = Literal["transient", "validation", "auth", "unknown"]
|
||||||
|
|
||||||
|
_TRANSIENT_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504})
|
||||||
|
_AUTH_STATUS_CODES = frozenset({401, 403})
|
||||||
|
_MAX_RETRIES = 3
|
||||||
|
|
||||||
|
|
||||||
|
def classify_mcp_error(
|
||||||
|
*,
|
||||||
|
status_code: int | None = None,
|
||||||
|
error_message: str = "",
|
||||||
|
) -> MCP_ERROR_CATEGORY:
|
||||||
|
"""Classify an MCP tool error for retry decisions."""
|
||||||
|
if status_code is not None:
|
||||||
|
if status_code in _TRANSIENT_STATUS_CODES:
|
||||||
|
return "transient"
|
||||||
|
if status_code in _AUTH_STATUS_CODES:
|
||||||
|
return "auth"
|
||||||
|
if 400 <= status_code < 500:
|
||||||
|
return "validation"
|
||||||
|
|
||||||
|
lower_msg = error_message.lower()
|
||||||
|
if any(kw in lower_msg for kw in ("timeout", "timed out", "rate limit")):
|
||||||
|
return "transient"
|
||||||
|
if any(kw in lower_msg for kw in ("unauthorized", "forbidden")):
|
||||||
|
return "auth"
|
||||||
|
if any(kw in lower_msg for kw in ("invalid", "missing", "bad request")):
|
||||||
|
return "validation"
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def is_retryable(category: MCP_ERROR_CATEGORY) -> bool:
|
||||||
|
"""Return whether a given error category is retryable."""
|
||||||
|
return category == "transient"
|
||||||
|
|
||||||
|
|
||||||
|
def max_retries() -> int:
|
||||||
|
"""Maximum retry attempts for transient errors."""
|
||||||
|
return _MAX_RETRIES
|
||||||
@@ -1,9 +1,18 @@
|
|||||||
"""Session TTL management with sliding window and interrupt extension."""
|
"""Session TTL management with sliding window and interrupt extension.
|
||||||
|
|
||||||
|
Provides both in-memory (SessionManager) and PostgreSQL-backed
|
||||||
|
(PgSessionManager) implementations behind a common Protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -13,8 +22,19 @@ class SessionState:
|
|||||||
has_pending_interrupt: bool
|
has_pending_interrupt: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManagerProtocol(Protocol):
|
||||||
|
"""Protocol for session TTL management."""
|
||||||
|
|
||||||
|
def touch(self, thread_id: str) -> SessionState: ...
|
||||||
|
def is_expired(self, thread_id: str) -> bool: ...
|
||||||
|
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
|
||||||
|
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
|
||||||
|
def get_state(self, thread_id: str) -> SessionState | None: ...
|
||||||
|
def remove(self, thread_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
"""Manages session TTL with sliding window and interrupt extensions.
|
"""In-memory session manager for single-worker development.
|
||||||
|
|
||||||
- Each message resets the TTL (sliding window).
|
- Each message resets the TTL (sliding window).
|
||||||
- A pending interrupt suspends expiration until resolved.
|
- A pending interrupt suspends expiration until resolved.
|
||||||
@@ -40,10 +60,8 @@ class SessionManager:
|
|||||||
state = self._sessions.get(thread_id)
|
state = self._sessions.get(thread_id)
|
||||||
if state is None:
|
if state is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if state.has_pending_interrupt:
|
if state.has_pending_interrupt:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
elapsed = time.time() - state.last_activity
|
elapsed = time.time() - state.last_activity
|
||||||
return elapsed > self._session_ttl
|
return elapsed > self._session_ttl
|
||||||
|
|
||||||
@@ -52,7 +70,6 @@ class SessionManager:
|
|||||||
existing = self._sessions.get(thread_id)
|
existing = self._sessions.get(thread_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
return self.touch(thread_id)
|
return self.touch(thread_id)
|
||||||
|
|
||||||
new_state = SessionState(
|
new_state = SessionState(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
last_activity=existing.last_activity,
|
last_activity=existing.last_activity,
|
||||||
@@ -76,3 +93,120 @@ class SessionManager:
|
|||||||
|
|
||||||
def remove(self, thread_id: str) -> None:
|
def remove(self, thread_id: str) -> None:
|
||||||
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for explicit naming
|
||||||
|
InMemorySessionManager = SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
class PgSessionManager:
|
||||||
|
"""PostgreSQL-backed session manager for multi-worker production."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
session_ttl_seconds: int = 1800,
|
||||||
|
) -> None:
|
||||||
|
self._pool = pool
|
||||||
|
self._session_ttl = session_ttl_seconds
|
||||||
|
|
||||||
|
def touch(self, thread_id: str) -> SessionState:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
|
||||||
|
|
||||||
|
async def _touch(self, thread_id: str) -> SessionState:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||||
|
VALUES (%(tid)s, %(now)s, FALSE)
|
||||||
|
ON CONFLICT (thread_id) DO UPDATE
|
||||||
|
SET last_activity = %(now)s
|
||||||
|
""",
|
||||||
|
{"tid": thread_id, "now": now},
|
||||||
|
)
|
||||||
|
return SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=now.timestamp(),
|
||||||
|
has_pending_interrupt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_expired(self, thread_id: str) -> bool:
|
||||||
|
state = self.get_state(thread_id)
|
||||||
|
if state is None:
|
||||||
|
return True
|
||||||
|
if state.has_pending_interrupt:
|
||||||
|
return False
|
||||||
|
elapsed = time.time() - state.last_activity
|
||||||
|
return elapsed > self._session_ttl
|
||||||
|
|
||||||
|
def extend_for_interrupt(self, thread_id: str) -> SessionState:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._set_interrupt(thread_id, True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve_interrupt(self, thread_id: str) -> SessionState:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._set_interrupt(thread_id, False)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _set_interrupt(
|
||||||
|
self, thread_id: str, has_interrupt: bool
|
||||||
|
) -> SessionState:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||||
|
VALUES (%(tid)s, %(now)s, %(interrupt)s)
|
||||||
|
ON CONFLICT (thread_id) DO UPDATE
|
||||||
|
SET last_activity = %(now)s,
|
||||||
|
has_pending_interrupt = %(interrupt)s
|
||||||
|
""",
|
||||||
|
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
|
||||||
|
)
|
||||||
|
return SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=now.timestamp(),
|
||||||
|
has_pending_interrupt=has_interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_state(self, thread_id: str) -> SessionState | None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._get_state(thread_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_state(self, thread_id: str) -> SessionState | None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(
|
||||||
|
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=row["last_activity"].timestamp(),
|
||||||
|
has_pending_interrupt=row["has_pending_interrupt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove(self, thread_id: str) -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
|
||||||
|
|
||||||
|
async def _remove(self, thread_id: str) -> None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"DELETE FROM sessions WHERE thread_id = %(tid)s",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
|||||||
30
backend/app/ws_context.py
Normal file
30
backend/app/ws_context.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""WebSocketContext -- bundles all dependencies needed by dispatch_message."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.analytics.event_recorder import AnalyticsRecorder
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.conversation_tracker import ConversationTrackerProtocol
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WebSocketContext:
|
||||||
|
"""All dependencies required for WebSocket message processing.
|
||||||
|
|
||||||
|
Replaces the previous 9-parameter function signature in dispatch_message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_ctx: GraphContext
|
||||||
|
session_manager: SessionManager
|
||||||
|
callback_handler: TokenUsageCallbackHandler
|
||||||
|
interrupt_manager: InterruptManager | None = None
|
||||||
|
analytics_recorder: AnalyticsRecorder | None = None
|
||||||
|
conversation_tracker: ConversationTrackerProtocol | None = None
|
||||||
|
pool: Any = None
|
||||||
@@ -3,28 +3,26 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from app.graph import classify_intent
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
|
||||||
|
|
||||||
from app.analytics.event_recorder import AnalyticsRecorder
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
from app.conversation_tracker import ConversationTrackerProtocol
|
from app.graph_context import GraphContext
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import structlog
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
||||||
MAX_CONTENT_LENGTH = 10_000 # characters
|
MAX_CONTENT_LENGTH = 10_000 # characters
|
||||||
@@ -46,7 +44,7 @@ def _evict_stale_threads(cutoff: float) -> None:
|
|||||||
|
|
||||||
async def handle_user_message(
|
async def handle_user_message(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
graph: CompiledStateGraph,
|
ctx: GraphContext,
|
||||||
session_manager: SessionManager,
|
session_manager: SessionManager,
|
||||||
callback_handler: TokenUsageCallbackHandler,
|
callback_handler: TokenUsageCallbackHandler,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -54,15 +52,15 @@ async def handle_user_message(
|
|||||||
interrupt_manager: InterruptManager | None = None,
|
interrupt_manager: InterruptManager | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process a user message through the graph and stream results back."""
|
"""Process a user message through the graph and stream results back."""
|
||||||
if session_manager.is_expired(thread_id):
|
existing = session_manager.get_state(thread_id)
|
||||||
|
if existing is not None and session_manager.is_expired(thread_id):
|
||||||
msg = "Session expired. Please start a new conversation."
|
msg = "Session expired. Please start a new conversation."
|
||||||
await _send_json(ws, {"type": "error", "message": msg})
|
await _send_json(ws, {"type": "error", "message": msg})
|
||||||
return
|
return
|
||||||
|
|
||||||
session_manager.touch(thread_id)
|
session_manager.touch(thread_id)
|
||||||
|
|
||||||
# Run intent classification if available (for logging/future multi-intent)
|
classification = await ctx.classify_intent(content)
|
||||||
classification = await classify_intent(graph, content)
|
|
||||||
if classification is not None:
|
if classification is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
||||||
@@ -71,7 +69,6 @@ async def handle_user_message(
|
|||||||
[i.agent_name for i in classification.intents],
|
[i.agent_name for i in classification.intents],
|
||||||
)
|
)
|
||||||
|
|
||||||
# If ambiguous, send clarification and return
|
|
||||||
if classification.is_ambiguous and classification.clarification_question:
|
if classification.is_ambiguous and classification.clarification_question:
|
||||||
await _send_json(
|
await _send_json(
|
||||||
ws,
|
ws,
|
||||||
@@ -86,7 +83,6 @@ async def handle_user_message(
|
|||||||
|
|
||||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||||
|
|
||||||
# If multi-intent detected, add routing hint to the message
|
|
||||||
if classification and len(classification.intents) > 1:
|
if classification and len(classification.intents) > 1:
|
||||||
agent_names = [i.agent_name for i in classification.intents]
|
agent_names = [i.agent_name for i in classification.intents]
|
||||||
hint = (
|
hint = (
|
||||||
@@ -98,7 +94,7 @@ async def handle_user_message(
|
|||||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||||
msg_chunk, metadata = chunk
|
msg_chunk, metadata = chunk
|
||||||
node = metadata.get("langgraph_node", "")
|
node = metadata.get("langgraph_node", "")
|
||||||
|
|
||||||
@@ -123,12 +119,11 @@ async def handle_user_message(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
state = await graph.aget_state(config)
|
state = await ctx.graph.aget_state(config)
|
||||||
if _has_interrupt(state):
|
if _has_interrupt(state):
|
||||||
interrupt_data = _extract_interrupt(state)
|
interrupt_data = _extract_interrupt(state)
|
||||||
session_manager.extend_for_interrupt(thread_id)
|
session_manager.extend_for_interrupt(thread_id)
|
||||||
|
|
||||||
# Register interrupt with TTL tracking
|
|
||||||
if interrupt_manager is not None:
|
if interrupt_manager is not None:
|
||||||
interrupt_manager.register(
|
interrupt_manager.register(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -155,7 +150,7 @@ async def handle_user_message(
|
|||||||
|
|
||||||
async def handle_interrupt_response(
|
async def handle_interrupt_response(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
graph: CompiledStateGraph,
|
ctx: GraphContext,
|
||||||
session_manager: SessionManager,
|
session_manager: SessionManager,
|
||||||
callback_handler: TokenUsageCallbackHandler,
|
callback_handler: TokenUsageCallbackHandler,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -163,7 +158,6 @@ async def handle_interrupt_response(
|
|||||||
interrupt_manager: InterruptManager | None = None,
|
interrupt_manager: InterruptManager | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Resume graph execution after interrupt approval/rejection."""
|
"""Resume graph execution after interrupt approval/rejection."""
|
||||||
# Check interrupt TTL before resuming
|
|
||||||
if interrupt_manager is not None:
|
if interrupt_manager is not None:
|
||||||
status = interrupt_manager.check_status(thread_id)
|
status = interrupt_manager.check_status(thread_id)
|
||||||
if status is not None and status.is_expired:
|
if status is not None and status.is_expired:
|
||||||
@@ -181,7 +175,7 @@ async def handle_interrupt_response(
|
|||||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in graph.astream(
|
async for chunk in ctx.graph.astream(
|
||||||
Command(resume=approved),
|
Command(resume=approved),
|
||||||
config=config,
|
config=config,
|
||||||
stream_mode="messages",
|
stream_mode="messages",
|
||||||
@@ -209,14 +203,8 @@ async def handle_interrupt_response(
|
|||||||
|
|
||||||
async def dispatch_message(
|
async def dispatch_message(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
graph: CompiledStateGraph,
|
ctx: WebSocketContext,
|
||||||
session_manager: SessionManager,
|
|
||||||
callback_handler: TokenUsageCallbackHandler,
|
|
||||||
raw_data: str,
|
raw_data: str,
|
||||||
interrupt_manager: InterruptManager | None = None,
|
|
||||||
analytics_recorder: AnalyticsRecorder | None = None,
|
|
||||||
conversation_tracker: ConversationTrackerProtocol | None = None,
|
|
||||||
pool: Any = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Parse and route an incoming WebSocket message."""
|
"""Parse and route an incoming WebSocket message."""
|
||||||
if len(raw_data) > MAX_MESSAGE_SIZE:
|
if len(raw_data) > MAX_MESSAGE_SIZE:
|
||||||
@@ -265,14 +253,15 @@ async def dispatch_message(
|
|||||||
_thread_timestamps[thread_id] = [*recent, now]
|
_thread_timestamps[thread_id] = [*recent, now]
|
||||||
|
|
||||||
await handle_user_message(
|
await handle_user_message(
|
||||||
ws, graph, session_manager, callback_handler, thread_id, content,
|
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||||
interrupt_manager=interrupt_manager,
|
thread_id, content,
|
||||||
|
interrupt_manager=ctx.interrupt_manager,
|
||||||
)
|
)
|
||||||
await _fire_and_forget_tracking(
|
await _fire_and_forget_tracking(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
pool=pool,
|
pool=ctx.pool,
|
||||||
analytics_recorder=analytics_recorder,
|
analytics_recorder=ctx.analytics_recorder,
|
||||||
conversation_tracker=conversation_tracker,
|
conversation_tracker=ctx.conversation_tracker,
|
||||||
agent_name=None,
|
agent_name=None,
|
||||||
tokens=0,
|
tokens=0,
|
||||||
cost=0.0,
|
cost=0.0,
|
||||||
@@ -281,8 +270,9 @@ async def dispatch_message(
|
|||||||
elif msg_type == "interrupt_response":
|
elif msg_type == "interrupt_response":
|
||||||
approved = data.get("approved", False)
|
approved = data.get("approved", False)
|
||||||
await handle_interrupt_response(
|
await handle_interrupt_response(
|
||||||
ws, graph, session_manager, callback_handler, thread_id, approved,
|
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||||
interrupt_manager=interrupt_manager,
|
thread_id, approved,
|
||||||
|
interrupt_manager=ctx.interrupt_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -291,9 +281,9 @@ async def dispatch_message(
|
|||||||
|
|
||||||
async def _fire_and_forget_tracking(
|
async def _fire_and_forget_tracking(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
pool: Any,
|
pool: object,
|
||||||
analytics_recorder: Any | None,
|
analytics_recorder: object | None,
|
||||||
conversation_tracker: Any | None,
|
conversation_tracker: object | None,
|
||||||
agent_name: str | None,
|
agent_name: str | None,
|
||||||
tokens: int,
|
tokens: int,
|
||||||
cost: float,
|
cost: float,
|
||||||
|
|||||||
@@ -6,12 +6,13 @@ requires-python = ">=3.11"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi>=0.115,<1.0",
|
"fastapi>=0.115,<1.0",
|
||||||
"uvicorn[standard]>=0.34,<1.0",
|
"uvicorn[standard]>=0.34,<1.0",
|
||||||
"langgraph>=0.4,<1.0",
|
"langgraph>=1.0,<2.0",
|
||||||
"langgraph-supervisor>=0.0.12,<1.0",
|
"langgraph-supervisor>=0.0.30,<1.0",
|
||||||
"langgraph-checkpoint-postgres>=3.0,<4.0",
|
"langgraph-checkpoint-postgres>=3.0,<4.0",
|
||||||
"langchain-core>=0.3,<1.0",
|
"langchain>=1.0,<2.0",
|
||||||
"langchain-anthropic>=0.3,<2.0",
|
"langchain-core>=1.0,<2.0",
|
||||||
"langchain-openai>=0.3,<1.0",
|
"langchain-anthropic>=1.0,<2.0",
|
||||||
|
"langchain-openai>=1.0,<2.0",
|
||||||
"langchain-google-genai>=2.1,<3.0",
|
"langchain-google-genai>=2.1,<3.0",
|
||||||
"psycopg[binary,pool]>=3.2,<4.0",
|
"psycopg[binary,pool]>=3.2,<4.0",
|
||||||
"pydantic>=2.10,<3.0",
|
"pydantic>=2.10,<3.0",
|
||||||
@@ -20,6 +21,8 @@ dependencies = [
|
|||||||
"python-dotenv>=1.0,<2.0",
|
"python-dotenv>=1.0,<2.0",
|
||||||
"httpx>=0.28,<1.0",
|
"httpx>=0.28,<1.0",
|
||||||
"openapi-spec-validator>=0.7,<1.0",
|
"openapi-spec-validator>=0.7,<1.0",
|
||||||
|
"alembic>=1.13,<2.0",
|
||||||
|
"structlog>=24.0,<26.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
230
backend/tests/e2e/conftest.py
Normal file
230
backend/tests/e2e/conftest.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""E2E test fixtures -- full FastAPI app with mocked LLM and database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.analytics.api import router as analytics_router
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.openapi.review_api import _job_store, router as openapi_router
|
||||||
|
from app.replay.api import router as replay_router
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Graph helpers -- simulate LangGraph streaming behaviour
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncIterHelper:
|
||||||
|
"""Make a list behave as an async iterator."""
|
||||||
|
|
||||||
|
def __init__(self, items: list) -> None:
|
||||||
|
self._items = list(items)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if not self._items:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return self._items.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def make_chunk(content: str, node: str = "order_lookup") -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = content
|
||||||
|
c.tool_calls = []
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def make_tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = ""
|
||||||
|
c.tool_calls = [{"name": name, "args": args}]
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def make_state(*, interrupt: bool = False, data: dict | None = None) -> Any:
|
||||||
|
s = MagicMock()
|
||||||
|
if interrupt:
|
||||||
|
obj = MagicMock()
|
||||||
|
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
||||||
|
t = MagicMock()
|
||||||
|
t.interrupts = (obj,)
|
||||||
|
s.tasks = (t,)
|
||||||
|
else:
|
||||||
|
s.tasks = ()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def make_graph(
|
||||||
|
chunks: list | None = None,
|
||||||
|
state: Any = None,
|
||||||
|
resume_chunks: list | None = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Build a mock LangGraph CompiledStateGraph."""
|
||||||
|
g = MagicMock()
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
state = make_state()
|
||||||
|
|
||||||
|
streams = [chunks or [], resume_chunks or []]
|
||||||
|
idx = {"n": 0}
|
||||||
|
|
||||||
|
def astream_side_effect(*a, **kw):
|
||||||
|
i = min(idx["n"], len(streams) - 1)
|
||||||
|
idx["n"] += 1
|
||||||
|
return AsyncIterHelper(list(streams[i]))
|
||||||
|
|
||||||
|
g.astream = MagicMock(side_effect=astream_side_effect)
|
||||||
|
g.aget_state = AsyncMock(return_value=state)
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||||
|
"""Build a GraphContext wrapping a mock graph."""
|
||||||
|
g = graph or make_graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fake database pool
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCursor:
|
||||||
|
"""Minimal async cursor returning pre-configured rows."""
|
||||||
|
|
||||||
|
def __init__(self, rows: list[dict]) -> None:
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
async def fetchall(self) -> list[dict]:
|
||||||
|
return self._rows
|
||||||
|
|
||||||
|
async def fetchone(self) -> tuple | dict | None:
|
||||||
|
return self._rows[0] if self._rows else None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeConnection:
|
||||||
|
"""Fake async connection that returns a FakeCursor."""
|
||||||
|
|
||||||
|
def __init__(self, rows: list[dict]) -> None:
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
async def execute(self, query: str, params: dict | None = None) -> FakeCursor:
|
||||||
|
return FakeCursor(self._rows)
|
||||||
|
|
||||||
|
|
||||||
|
class FakePool:
|
||||||
|
"""Minimal pool that yields a fake connection."""
|
||||||
|
|
||||||
|
def __init__(self, rows: list[dict] | None = None) -> None:
|
||||||
|
self._rows = rows or []
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connection(self):
|
||||||
|
yield FakeConnection(self._rows)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# App factory
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def create_e2e_app(
|
||||||
|
graph: MagicMock | None = None,
|
||||||
|
pool: FakePool | None = None,
|
||||||
|
session_ttl: int = 3600,
|
||||||
|
interrupt_ttl: int = 1800,
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Create a FastAPI app wired with mocked dependencies for E2E testing."""
|
||||||
|
g = graph or make_graph()
|
||||||
|
graph_ctx = make_graph_ctx(g)
|
||||||
|
p = pool or FakePool()
|
||||||
|
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||||
|
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||||
|
|
||||||
|
app = FastAPI(title="Smart Support E2E Test")
|
||||||
|
app.include_router(openapi_router)
|
||||||
|
app.include_router(replay_router)
|
||||||
|
app.include_router(analytics_router)
|
||||||
|
|
||||||
|
app.state.graph_ctx = graph_ctx
|
||||||
|
app.state.session_manager = sm
|
||||||
|
app.state.interrupt_manager = im
|
||||||
|
app.state.pool = p
|
||||||
|
app.state.settings = MagicMock(llm_model="test-model")
|
||||||
|
app.state.analytics_recorder = AsyncMock()
|
||||||
|
app.state.conversation_tracker = AsyncMock()
|
||||||
|
|
||||||
|
@app.get("/api/v1/health")
|
||||||
|
def health_check() -> dict:
|
||||||
|
return {"status": "ok", "version": "test"}
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(ws: WebSocket) -> None:
|
||||||
|
await ws.accept()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw_data = await ws.receive_text()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=app.state.graph_ctx,
|
||||||
|
session_manager=app.state.session_manager,
|
||||||
|
callback_handler=TokenUsageCallbackHandler(model_name="test-model"),
|
||||||
|
interrupt_manager=app.state.interrupt_manager,
|
||||||
|
analytics_recorder=app.state.analytics_recorder,
|
||||||
|
conversation_tracker=app.state.conversation_tracker,
|
||||||
|
pool=app.state.pool,
|
||||||
|
)
|
||||||
|
await dispatch_message(ws, ws_ctx, raw_data)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def e2e_graph():
|
||||||
|
"""Default graph fixture -- returns tokens and message_complete."""
|
||||||
|
return make_graph(
|
||||||
|
chunks=[make_chunk("Order 1042 is "), make_chunk("shipped.")]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def e2e_app(e2e_graph):
|
||||||
|
"""Default E2E app fixture."""
|
||||||
|
return create_e2e_app(graph=e2e_graph)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def e2e_client(e2e_app):
|
||||||
|
"""Async HTTP client for E2E tests."""
|
||||||
|
transport = ASGITransport(app=e2e_app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_openapi_job_store():
|
||||||
|
"""Clear the in-memory job store between tests."""
|
||||||
|
_job_store.clear()
|
||||||
|
yield
|
||||||
|
_job_store.clear()
|
||||||
384
backend/tests/e2e/test_chat_flows.py
Normal file
384
backend/tests/e2e/test_chat_flows.py
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
"""E2E tests for critical chat user flows (flows 1-4).
|
||||||
|
|
||||||
|
Flow 1: Happy path -- query order, get answer
|
||||||
|
Flow 2: Approval flow -- write operation, interrupt, approve, execute
|
||||||
|
Flow 3: Rejection flow -- write operation, interrupt, reject, no execution
|
||||||
|
Flow 4: Multi-turn context -- sequential messages in same session
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from tests.e2e.conftest import (
|
||||||
|
create_e2e_app,
|
||||||
|
make_chunk,
|
||||||
|
make_graph,
|
||||||
|
make_state,
|
||||||
|
make_tool_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.e2e
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow1HappyPath:
|
||||||
|
"""Flow 1: query order -> get answer with streaming tokens."""
|
||||||
|
|
||||||
|
def test_websocket_happy_path_order_query(self) -> None:
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[
|
||||||
|
make_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||||
|
make_chunk("Order 1042 has been shipped and is on its way."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-happy-1",
|
||||||
|
"content": "What is the status of order 1042?",
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
while True:
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] in ("message_complete", "error"):
|
||||||
|
break
|
||||||
|
|
||||||
|
tool_calls = [m for m in messages if m["type"] == "tool_call"]
|
||||||
|
assert len(tool_calls) == 1
|
||||||
|
assert tool_calls[0]["tool"] == "get_order_status"
|
||||||
|
assert tool_calls[0]["args"] == {"order_id": "1042"}
|
||||||
|
|
||||||
|
tokens = [m for m in messages if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 1
|
||||||
|
assert "shipped" in tokens[0]["content"]
|
||||||
|
|
||||||
|
completes = [m for m in messages if m["type"] == "message_complete"]
|
||||||
|
assert len(completes) == 1
|
||||||
|
assert completes[0]["thread_id"] == "e2e-happy-1"
|
||||||
|
|
||||||
|
def test_websocket_multiple_token_stream(self) -> None:
|
||||||
|
"""Verify streaming returns multiple token chunks."""
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[
|
||||||
|
make_chunk("Your order "),
|
||||||
|
make_chunk("1042 "),
|
||||||
|
make_chunk("was delivered "),
|
||||||
|
make_chunk("yesterday."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-stream-1",
|
||||||
|
"content": "Where is my order?",
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = _collect_until_complete(ws)
|
||||||
|
|
||||||
|
tokens = [m for m in messages if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 4
|
||||||
|
full_text = "".join(t["content"] for t in tokens)
|
||||||
|
assert "1042" in full_text
|
||||||
|
assert "delivered" in full_text
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow2ApprovalFlow:
|
||||||
|
"""Flow 2: write operation -> interrupt -> approve -> execute."""
|
||||||
|
|
||||||
|
def test_interrupt_approve_executes_action(self) -> None:
|
||||||
|
interrupt_state = make_state(
|
||||||
|
interrupt=True,
|
||||||
|
data={"action": "cancel_order", "order_id": "1042"},
|
||||||
|
)
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[],
|
||||||
|
state=interrupt_state,
|
||||||
|
resume_chunks=[
|
||||||
|
make_chunk("Order 1042 has been cancelled successfully.", "order_actions"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# Step 1: Send cancel request
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-approve-1",
|
||||||
|
"content": "Cancel order 1042",
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = _collect_until_type(ws, "interrupt")
|
||||||
|
|
||||||
|
interrupts = [m for m in messages if m["type"] == "interrupt"]
|
||||||
|
assert len(interrupts) == 1
|
||||||
|
assert interrupts[0]["action"] == "cancel_order"
|
||||||
|
assert interrupts[0]["thread_id"] == "e2e-approve-1"
|
||||||
|
|
||||||
|
# Step 2: Approve the interrupt
|
||||||
|
ws.send_json({
|
||||||
|
"type": "interrupt_response",
|
||||||
|
"thread_id": "e2e-approve-1",
|
||||||
|
"approved": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
resume_messages = _collect_until_complete(ws)
|
||||||
|
|
||||||
|
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 1
|
||||||
|
assert "cancelled" in tokens[0]["content"]
|
||||||
|
assert tokens[0]["agent"] == "order_actions"
|
||||||
|
|
||||||
|
completes = [m for m in resume_messages if m["type"] == "message_complete"]
|
||||||
|
assert len(completes) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow3RejectionFlow:
|
||||||
|
"""Flow 3: write operation -> interrupt -> reject -> no execution."""
|
||||||
|
|
||||||
|
def test_interrupt_reject_does_not_execute(self) -> None:
|
||||||
|
interrupt_state = make_state(
|
||||||
|
interrupt=True,
|
||||||
|
data={"action": "cancel_order", "order_id": "1042"},
|
||||||
|
)
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[],
|
||||||
|
state=interrupt_state,
|
||||||
|
resume_chunks=[
|
||||||
|
make_chunk("Understood. Order 1042 will remain active.", "order_actions"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# Step 1: Trigger interrupt
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-reject-1",
|
||||||
|
"content": "Cancel order 1042",
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = _collect_until_type(ws, "interrupt")
|
||||||
|
assert any(m["type"] == "interrupt" for m in messages)
|
||||||
|
|
||||||
|
# Step 2: Reject
|
||||||
|
ws.send_json({
|
||||||
|
"type": "interrupt_response",
|
||||||
|
"thread_id": "e2e-reject-1",
|
||||||
|
"approved": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
resume_messages = _collect_until_complete(ws)
|
||||||
|
|
||||||
|
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 1
|
||||||
|
assert "remain active" in tokens[0]["content"]
|
||||||
|
|
||||||
|
# Verify graph.astream was called with resume=False
|
||||||
|
resume_call = graph.astream.call_args_list[-1]
|
||||||
|
command = resume_call[0][0]
|
||||||
|
assert command.resume is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow4MultiTurnContext:
|
||||||
|
"""Flow 4: multi-turn conversation in the same session."""
|
||||||
|
|
||||||
|
def test_multi_turn_messages_share_session(self) -> None:
|
||||||
|
"""Multiple messages in the same thread_id maintain session context."""
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[make_chunk("Order 1042 status: shipped.")],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# Turn 1: Query order
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-multi-1",
|
||||||
|
"content": "What is the status of order 1042?",
|
||||||
|
})
|
||||||
|
turn1 = _collect_until_complete(ws)
|
||||||
|
assert any(m["type"] == "message_complete" for m in turn1)
|
||||||
|
|
||||||
|
# Turn 2: Follow-up in same thread
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-multi-1",
|
||||||
|
"content": "When will it arrive?",
|
||||||
|
})
|
||||||
|
turn2 = _collect_until_complete(ws)
|
||||||
|
assert any(m["type"] == "message_complete" for m in turn2)
|
||||||
|
|
||||||
|
# Turn 3: Another follow-up
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-multi-1",
|
||||||
|
"content": "Can you track it?",
|
||||||
|
})
|
||||||
|
turn3 = _collect_until_complete(ws)
|
||||||
|
assert any(m["type"] == "message_complete" for m in turn3)
|
||||||
|
|
||||||
|
# Verify all turns used the same thread_id in graph calls
|
||||||
|
for call in graph.astream.call_args_list:
|
||||||
|
config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {})
|
||||||
|
assert config["configurable"]["thread_id"] == "e2e-multi-1"
|
||||||
|
|
||||||
|
def test_separate_threads_are_independent(self) -> None:
|
||||||
|
"""Different thread_ids have independent sessions."""
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[make_chunk("Response.")],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# Thread A
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-thread-a",
|
||||||
|
"content": "Hello from thread A",
|
||||||
|
})
|
||||||
|
_collect_until_complete(ws)
|
||||||
|
|
||||||
|
# Thread B
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-thread-b",
|
||||||
|
"content": "Hello from thread B",
|
||||||
|
})
|
||||||
|
_collect_until_complete(ws)
|
||||||
|
|
||||||
|
# Both threads should exist as separate sessions
|
||||||
|
sm = app.state.session_manager
|
||||||
|
assert sm.get_state("e2e-thread-a") is not None
|
||||||
|
assert sm.get_state("e2e-thread-b") is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatEdgeCases:
|
||||||
|
"""Edge cases and error handling for the chat WebSocket."""
|
||||||
|
|
||||||
|
def test_invalid_json_returns_error(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_text("not valid json")
|
||||||
|
msg = ws.receive_json()
|
||||||
|
assert msg["type"] == "error"
|
||||||
|
assert "Invalid JSON" in msg["message"]
|
||||||
|
|
||||||
|
def test_missing_thread_id_returns_error(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({"type": "message", "content": "hello"})
|
||||||
|
msg = ws.receive_json()
|
||||||
|
assert msg["type"] == "error"
|
||||||
|
assert "thread_id" in msg["message"]
|
||||||
|
|
||||||
|
def test_empty_content_returns_error(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-err-1",
|
||||||
|
"content": "",
|
||||||
|
})
|
||||||
|
msg = ws.receive_json()
|
||||||
|
assert msg["type"] == "error"
|
||||||
|
|
||||||
|
def test_expired_session_returns_error(self) -> None:
|
||||||
|
graph = make_graph(chunks=[make_chunk("Response.")])
|
||||||
|
app = create_e2e_app(graph=graph, session_ttl=0)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# First message creates the session (TTL=0)
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-expired-1",
|
||||||
|
"content": "hello",
|
||||||
|
})
|
||||||
|
_collect_until_complete_or_error(ws)
|
||||||
|
|
||||||
|
# Second message finds the session expired (TTL=0)
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-expired-1",
|
||||||
|
"content": "hello again",
|
||||||
|
})
|
||||||
|
messages = _collect_until_complete_or_error(ws)
|
||||||
|
errors = [m for m in messages if m["type"] == "error"]
|
||||||
|
assert len(errors) >= 1
|
||||||
|
assert "expired" in errors[0]["message"].lower()
|
||||||
|
|
||||||
|
def test_oversized_message_returns_error(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_text("x" * 40_000)
|
||||||
|
msg = ws.receive_json()
|
||||||
|
assert msg["type"] == "error"
|
||||||
|
assert "too large" in msg["message"].lower()
|
||||||
|
|
||||||
|
def test_health_endpoint(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]:
|
||||||
|
"""Receive WebSocket messages until message_complete or error."""
|
||||||
|
messages = []
|
||||||
|
for _ in range(max_messages):
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] in ("message_complete", "error"):
|
||||||
|
break
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]:
|
||||||
|
"""Receive until a specific message type is received."""
|
||||||
|
messages = []
|
||||||
|
for _ in range(max_messages):
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] == msg_type:
|
||||||
|
break
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]:
|
||||||
|
"""Receive until message_complete or error."""
|
||||||
|
messages = []
|
||||||
|
for _ in range(max_messages):
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] in ("message_complete", "error"):
|
||||||
|
break
|
||||||
|
return messages
|
||||||
201
backend/tests/e2e/test_openapi_import.py
Normal file
201
backend/tests/e2e/test_openapi_import.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""E2E tests for OpenAPI import flow (flow 5).
|
||||||
|
|
||||||
|
Flow 5: paste OpenAPI spec URL -> import job -> classify endpoints ->
|
||||||
|
review classifications -> approve -> tool generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||||
|
from app.openapi.review_api import _job_store
|
||||||
|
from tests.e2e.conftest import create_e2e_app
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.e2e
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_endpoint(
|
||||||
|
path: str = "/orders/{id}",
|
||||||
|
method: str = "GET",
|
||||||
|
operation_id: str = "getOrder",
|
||||||
|
summary: str = "Get order details",
|
||||||
|
) -> EndpointInfo:
|
||||||
|
return EndpointInfo(
|
||||||
|
path=path,
|
||||||
|
method=method,
|
||||||
|
operation_id=operation_id,
|
||||||
|
summary=summary,
|
||||||
|
description="",
|
||||||
|
parameters=(),
|
||||||
|
request_body_schema=None,
|
||||||
|
response_schema=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_classification(
|
||||||
|
endpoint: EndpointInfo | None = None,
|
||||||
|
access_type: str = "read",
|
||||||
|
needs_interrupt: bool = False,
|
||||||
|
agent_group: str = "order_lookup",
|
||||||
|
) -> ClassificationResult:
|
||||||
|
return ClassificationResult(
|
||||||
|
endpoint=endpoint or _fake_endpoint(),
|
||||||
|
access_type=access_type,
|
||||||
|
customer_params=["order_id"],
|
||||||
|
agent_group=agent_group,
|
||||||
|
confidence=0.95,
|
||||||
|
needs_interrupt=needs_interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow5OpenAPIImport:
|
||||||
|
"""Flow 5: full OpenAPI import lifecycle."""
|
||||||
|
|
||||||
|
def test_import_job_lifecycle(self) -> None:
|
||||||
|
"""Start import -> check status -> review classifications -> approve."""
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Step 1: Start import job
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/openapi/import",
|
||||||
|
json={"url": "https://api.example.com/openapi.json"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 202
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "pending"
|
||||||
|
job_id = body["job_id"]
|
||||||
|
|
||||||
|
# Step 2: Check job status (still pending since background task hasn't run)
|
||||||
|
resp = client.get(f"/api/v1/openapi/jobs/{job_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["job_id"] == job_id
|
||||||
|
|
||||||
|
def test_import_job_with_classifications(self) -> None:
|
||||||
|
"""Simulate completed import and review classified endpoints."""
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
# Seed a completed job directly
|
||||||
|
ep_read = _fake_endpoint("/orders/{id}", "GET", "getOrder", "Get order")
|
||||||
|
ep_write = _fake_endpoint("/orders/{id}/cancel", "POST", "cancelOrder", "Cancel order")
|
||||||
|
|
||||||
|
clf_read = _fake_classification(ep_read, "read", False, "order_lookup")
|
||||||
|
clf_write = _fake_classification(ep_write, "write", True, "order_actions")
|
||||||
|
|
||||||
|
job_id = "test-job-001"
|
||||||
|
_job_store[job_id] = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://api.example.com/openapi.json",
|
||||||
|
"total_endpoints": 2,
|
||||||
|
"classified_count": 2,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [clf_read, clf_write],
|
||||||
|
}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Step 1: Get classifications
|
||||||
|
resp = client.get(f"/api/v1/openapi/jobs/{job_id}/classifications")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
classifications = resp.json()
|
||||||
|
assert len(classifications) == 2
|
||||||
|
|
||||||
|
# Verify read endpoint
|
||||||
|
read_clf = classifications[0]
|
||||||
|
assert read_clf["access_type"] == "read"
|
||||||
|
assert read_clf["needs_interrupt"] is False
|
||||||
|
assert read_clf["endpoint"]["path"] == "/orders/{id}"
|
||||||
|
|
||||||
|
# Verify write endpoint
|
||||||
|
write_clf = classifications[1]
|
||||||
|
assert write_clf["access_type"] == "write"
|
||||||
|
assert write_clf["needs_interrupt"] is True
|
||||||
|
assert write_clf["endpoint"]["path"] == "/orders/{id}/cancel"
|
||||||
|
|
||||||
|
# Step 2: Update a classification
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
|
||||||
|
json={
|
||||||
|
"access_type": "write",
|
||||||
|
"needs_interrupt": True,
|
||||||
|
"agent_group": "order_actions",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
updated = resp.json()
|
||||||
|
assert updated["access_type"] == "write"
|
||||||
|
assert updated["needs_interrupt"] is True
|
||||||
|
assert updated["agent_group"] == "order_actions"
|
||||||
|
|
||||||
|
# Step 3: Approve the job
|
||||||
|
resp = client.post(f"/api/v1/openapi/jobs/{job_id}/approve")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "approved"
|
||||||
|
|
||||||
|
def test_import_nonexistent_job_returns_404(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/openapi/jobs/nonexistent")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_import_invalid_url_returns_422(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.post("/api/v1/openapi/import", json={"url": "not-a-url"})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_classification_index_out_of_range(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
job_id = "test-job-range"
|
||||||
|
_job_store[job_id] = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://example.com/spec.json",
|
||||||
|
"total_endpoints": 1,
|
||||||
|
"classified_count": 1,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [_fake_classification()],
|
||||||
|
}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_id}/classifications/99",
|
||||||
|
json={
|
||||||
|
"access_type": "read",
|
||||||
|
"needs_interrupt": False,
|
||||||
|
"agent_group": "order_lookup",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_update_classification_invalid_agent_group(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
job_id = "test-job-invalid"
|
||||||
|
_job_store[job_id] = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://example.com/spec.json",
|
||||||
|
"total_endpoints": 1,
|
||||||
|
"classified_count": 1,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [_fake_classification()],
|
||||||
|
}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
|
||||||
|
json={
|
||||||
|
"access_type": "read",
|
||||||
|
"needs_interrupt": False,
|
||||||
|
"agent_group": "invalid group!", # spaces and special chars
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
230
backend/tests/e2e/test_replay_analytics.py
Normal file
230
backend/tests/e2e/test_replay_analytics.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""E2E tests for replay and analytics flows (flow 6).
|
||||||
|
|
||||||
|
Flow 6: list conversations -> select one -> step-by-step replay.
|
||||||
|
Also tests the analytics dashboard endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from tests.e2e.conftest import FakePool, create_e2e_app
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.e2e
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Custom pool that returns specific data per query
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayPool(FakePool):
|
||||||
|
"""Pool that returns different data depending on the SQL query."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
conversations: list[dict] | None = None,
|
||||||
|
checkpoints: list[dict] | None = None,
|
||||||
|
analytics_rows: list[dict] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._conversations = conversations or []
|
||||||
|
self._checkpoints = checkpoints or []
|
||||||
|
self._analytics = analytics_rows or []
|
||||||
|
|
||||||
|
class _Conn:
|
||||||
|
def __init__(self, convos, checkpoints, analytics):
|
||||||
|
self._convos = convos
|
||||||
|
self._checkpoints = checkpoints
|
||||||
|
self._analytics = analytics
|
||||||
|
|
||||||
|
async def execute(self, query: str, params=None):
|
||||||
|
from tests.e2e.conftest import FakeCursor
|
||||||
|
|
||||||
|
if "COUNT" in query and "conversations" in query:
|
||||||
|
return FakeCursor([(len(self._convos),)])
|
||||||
|
if "conversations" in query and "SELECT" in query:
|
||||||
|
# Respect LIMIT/OFFSET from params if provided
|
||||||
|
rows = self._convos
|
||||||
|
if params:
|
||||||
|
offset = params.get("offset", 0)
|
||||||
|
limit = params.get("limit", len(rows))
|
||||||
|
rows = rows[offset : offset + limit]
|
||||||
|
return FakeCursor(rows)
|
||||||
|
if "checkpoints" in query:
|
||||||
|
return FakeCursor(self._checkpoints)
|
||||||
|
# Analytics queries
|
||||||
|
return FakeCursor(self._analytics)
|
||||||
|
|
||||||
|
def connection(self):
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
conn = self._Conn(self._conversations, self._checkpoints, self._analytics)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _ctx():
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
return _ctx()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow6ReplayConversation:
|
||||||
|
"""Flow 6: list conversations -> select one -> step replay."""
|
||||||
|
|
||||||
|
def test_list_conversations(self) -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc).isoformat()
|
||||||
|
conversations = [
|
||||||
|
{
|
||||||
|
"thread_id": "conv-001",
|
||||||
|
"created_at": now,
|
||||||
|
"last_activity": now,
|
||||||
|
"status": "active",
|
||||||
|
"total_tokens": 150,
|
||||||
|
"total_cost_usd": 0.003,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"thread_id": "conv-002",
|
||||||
|
"created_at": now,
|
||||||
|
"last_activity": now,
|
||||||
|
"status": "completed",
|
||||||
|
"total_tokens": 300,
|
||||||
|
"total_cost_usd": 0.006,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
pool = ReplayPool(conversations=conversations)
|
||||||
|
app = create_e2e_app(pool=pool)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/conversations")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
data = body["data"]
|
||||||
|
assert len(data["conversations"]) == 2
|
||||||
|
assert data["conversations"][0]["thread_id"] == "conv-001"
|
||||||
|
assert data["conversations"][1]["thread_id"] == "conv-002"
|
||||||
|
assert data["total"] == 2
|
||||||
|
|
||||||
|
def test_list_conversations_pagination(self) -> None:
|
||||||
|
conversations = [
|
||||||
|
{
|
||||||
|
"thread_id": f"conv-{i:03d}",
|
||||||
|
"created_at": "2026-04-01T00:00:00Z",
|
||||||
|
"last_activity": "2026-04-01T00:00:00Z",
|
||||||
|
"status": "active",
|
||||||
|
"total_tokens": 100,
|
||||||
|
"total_cost_usd": 0.001,
|
||||||
|
}
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
pool = ReplayPool(conversations=conversations)
|
||||||
|
app = create_e2e_app(pool=pool)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/conversations", params={"page": 1, "per_page": 2})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
data = body["data"]
|
||||||
|
assert data["total"] == 5
|
||||||
|
assert data["page"] == 1
|
||||||
|
assert data["per_page"] == 2
|
||||||
|
assert len(data["conversations"]) == 2
|
||||||
|
|
||||||
|
def test_replay_thread_not_found(self) -> None:
|
||||||
|
pool = ReplayPool(checkpoints=[])
|
||||||
|
app = create_e2e_app(pool=pool)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/replay/nonexistent-thread")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_replay_invalid_thread_id_format(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Thread ID with special chars fails regex validation
|
||||||
|
resp = client.get("/api/v1/replay/invalid%20thread%21%40")
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsDashboard:
|
||||||
|
"""Analytics endpoint tests."""
|
||||||
|
|
||||||
|
def test_analytics_invalid_range_format(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/analytics", params={"range": "invalid"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_analytics_range_too_large(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/analytics", params={"range": "999d"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_analytics_range_zero_rejected(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/analytics", params={"range": "0d"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullUserJourney:
|
||||||
|
"""End-to-end journey: chat -> then check replay list shows the conversation."""
|
||||||
|
|
||||||
|
def test_chat_then_check_conversations_endpoint(self) -> None:
|
||||||
|
"""After chatting via WebSocket, the conversations endpoint is reachable."""
|
||||||
|
from tests.e2e.conftest import make_chunk, make_graph
|
||||||
|
|
||||||
|
graph = make_graph(chunks=[make_chunk("Your order is shipped.")])
|
||||||
|
now = datetime.now(tz=timezone.utc).isoformat()
|
||||||
|
pool = ReplayPool(
|
||||||
|
conversations=[
|
||||||
|
{
|
||||||
|
"thread_id": "e2e-journey-1",
|
||||||
|
"created_at": now,
|
||||||
|
"last_activity": now,
|
||||||
|
"status": "active",
|
||||||
|
"total_tokens": 50,
|
||||||
|
"total_cost_usd": 0.001,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph, pool=pool)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Step 1: Chat via WebSocket
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-journey-1",
|
||||||
|
"content": "Where is my order?",
|
||||||
|
})
|
||||||
|
messages = []
|
||||||
|
for _ in range(20):
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] in ("message_complete", "error"):
|
||||||
|
break
|
||||||
|
assert any(m["type"] == "message_complete" for m in messages)
|
||||||
|
|
||||||
|
# Step 2: Check conversations endpoint
|
||||||
|
resp = client.get("/api/v1/conversations")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert any(
|
||||||
|
c["thread_id"] == "e2e-journey-1"
|
||||||
|
for c in body["data"]["conversations"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Health check still works
|
||||||
|
resp = client.get("/api/v1/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
183
backend/tests/integration/test_analytics_api.py
Normal file
183
backend/tests/integration/test_analytics_api.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Integration tests for the /api/v1/analytics endpoint.
|
||||||
|
|
||||||
|
Tests the full API layer (routing, parameter validation, serialization,
|
||||||
|
error handling) with a mocked database pool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import asdict
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
_SAMPLE_RESULT = AnalyticsResult(
|
||||||
|
range="7d",
|
||||||
|
total_conversations=42,
|
||||||
|
resolution_rate=0.85,
|
||||||
|
escalation_rate=0.05,
|
||||||
|
avg_turns_per_conversation=3.2,
|
||||||
|
avg_cost_per_conversation_usd=0.012,
|
||||||
|
agent_usage=(),
|
||||||
|
interrupt_stats=InterruptStats(total=10, approved=7, rejected=2, expired=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app():
|
||||||
|
"""Build a minimal FastAPI app with the analytics router and mocked deps."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.analytics.api import router as analytics_router
|
||||||
|
from app.api_utils import envelope
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(analytics_router)
|
||||||
|
|
||||||
|
@test_app.exception_handler(Exception)
|
||||||
|
async def _catch_all(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=envelope(None, success=False, error="Internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
@test_app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(RequestValidationError)
|
||||||
|
async def _validation_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# No admin_api_key set -> auth is skipped
|
||||||
|
test_app.state.settings = MagicMock(admin_api_key="")
|
||||||
|
test_app.state.pool = MagicMock()
|
||||||
|
|
||||||
|
return test_app
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsValidRange:
|
||||||
|
"""Test analytics endpoint with valid range parameters."""
|
||||||
|
|
||||||
|
async def test_valid_range_7d_returns_envelope(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=7d returns success envelope with data."""
|
||||||
|
test_app = _build_app()
|
||||||
|
with patch(
|
||||||
|
"app.analytics.api.get_analytics",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=_SAMPLE_RESULT,
|
||||||
|
):
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "7d"})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["error"] is None
|
||||||
|
assert body["data"]["total_conversations"] == 42
|
||||||
|
assert body["data"]["resolution_rate"] == 0.85
|
||||||
|
|
||||||
|
async def test_default_range_returns_success(self) -> None:
|
||||||
|
"""GET /api/v1/analytics with no range param defaults to 7d."""
|
||||||
|
test_app = _build_app()
|
||||||
|
with patch(
|
||||||
|
"app.analytics.api.get_analytics",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=_SAMPLE_RESULT,
|
||||||
|
) as mock_get:
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
# Verify default range of 7 days was passed
|
||||||
|
mock_get.assert_called_once()
|
||||||
|
call_args = mock_get.call_args
|
||||||
|
assert call_args[1].get("range_days", call_args[0][1] if len(call_args[0]) > 1 else None) in (7, None) or call_args[0][1] == 7
|
||||||
|
|
||||||
|
async def test_large_range_365d_works(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=365d is accepted (max boundary)."""
|
||||||
|
test_app = _build_app()
|
||||||
|
result = AnalyticsResult(
|
||||||
|
range="365d",
|
||||||
|
total_conversations=1000,
|
||||||
|
resolution_rate=0.9,
|
||||||
|
escalation_rate=0.02,
|
||||||
|
avg_turns_per_conversation=4.0,
|
||||||
|
avg_cost_per_conversation_usd=0.01,
|
||||||
|
agent_usage=(),
|
||||||
|
interrupt_stats=InterruptStats(),
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"app.analytics.api.get_analytics",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=result,
|
||||||
|
):
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "365d"})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["success"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsInvalidRange:
|
||||||
|
"""Test analytics endpoint with invalid range parameters."""
|
||||||
|
|
||||||
|
async def test_invalid_range_format_returns_400(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=abc returns 400 error envelope."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "abc"})
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert "Invalid range format" in body["error"]
|
||||||
|
|
||||||
|
async def test_zero_day_range_returns_400(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=0d returns 400 because 0 is below minimum."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "0d"})
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "between 1 and 365" in body["error"]
|
||||||
|
|
||||||
|
async def test_range_exceeding_max_returns_400(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=999d returns 400 because it exceeds 365."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "999d"})
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "between 1 and 365" in body["error"]
|
||||||
128
backend/tests/integration/test_error_responses.py
Normal file
128
backend/tests/integration/test_error_responses.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Integration tests for global error handling and envelope format consistency.
|
||||||
|
|
||||||
|
Tests that all error responses from the FastAPI app conform to the
|
||||||
|
standard envelope: {"success": false, "data": null, "error": "..."}.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app():
|
||||||
|
"""Build the actual FastAPI app with exception handlers but mocked state."""
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.analytics.api import router as analytics_router
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.replay.api import router as replay_router
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(analytics_router)
|
||||||
|
test_app.include_router(replay_router)
|
||||||
|
|
||||||
|
@test_app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(RequestValidationError)
|
||||||
|
async def _validation_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(Exception)
|
||||||
|
async def _catch_all(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=envelope(None, success=False, error="Internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.get("/api/v1/health")
|
||||||
|
def health_check():
|
||||||
|
return {"status": "ok", "version": "0.6.0"}
|
||||||
|
|
||||||
|
test_app.state.settings = MagicMock(admin_api_key="")
|
||||||
|
test_app.state.pool = MagicMock()
|
||||||
|
|
||||||
|
return test_app
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvelopeFormat:
|
||||||
|
"""Tests that error responses consistently follow envelope format."""
|
||||||
|
|
||||||
|
async def test_http_400_produces_envelope(self) -> None:
|
||||||
|
"""A 400 error returns standard envelope with success=false."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "invalid"})
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert isinstance(body["error"], str)
|
||||||
|
assert len(body["error"]) > 0
|
||||||
|
|
||||||
|
async def test_validation_error_produces_422_envelope(self) -> None:
|
||||||
|
"""Invalid query param type returns 422 with envelope format."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
# page must be >= 1; passing 0 triggers validation error
|
||||||
|
resp = await client.get("/api/v1/conversations", params={"page": 0})
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert isinstance(body["error"], str)
|
||||||
|
|
||||||
|
async def test_all_error_fields_present(self) -> None:
|
||||||
|
"""Error envelope contains exactly success, data, and error keys."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "bad"})
|
||||||
|
|
||||||
|
body = resp.json()
|
||||||
|
assert set(body.keys()) == {"success", "data", "error"}
|
||||||
|
|
||||||
|
async def test_health_endpoint_returns_200(self) -> None:
|
||||||
|
"""Health check returns 200 with status ok."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/health")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "ok"
|
||||||
|
assert "version" in body
|
||||||
|
|
||||||
|
async def test_unknown_endpoint_returns_404(self) -> None:
|
||||||
|
"""Requesting a non-existent path returns 404."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/nonexistent-path")
|
||||||
|
|
||||||
|
# FastAPI returns 404 for unknown routes; may or may not be wrapped
|
||||||
|
assert resp.status_code == 404
|
||||||
164
backend/tests/integration/test_openapi_api.py
Normal file
164
backend/tests/integration/test_openapi_api.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Integration tests for /api/v1/openapi/ endpoints.
|
||||||
|
|
||||||
|
Tests the full API layer for the OpenAPI import review workflow,
|
||||||
|
including job creation, status retrieval, classification updates,
|
||||||
|
and approval triggering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app():
|
||||||
|
"""Build a minimal FastAPI app with the openapi router and mocked deps."""
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.openapi.review_api import router as openapi_router
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(openapi_router)
|
||||||
|
|
||||||
|
@test_app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(RequestValidationError)
|
||||||
|
async def _validation_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
test_app.state.settings = MagicMock(admin_api_key="")
|
||||||
|
|
||||||
|
return test_app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_job_store():
|
||||||
|
"""Clear the in-memory job store between tests."""
|
||||||
|
from app.openapi.review_api import _job_store
|
||||||
|
|
||||||
|
_job_store.clear()
|
||||||
|
yield
|
||||||
|
_job_store.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportEndpoint:
|
||||||
|
"""Tests for POST /api/v1/openapi/import."""
|
||||||
|
|
||||||
|
async def test_import_returns_202_with_job_id(self) -> None:
|
||||||
|
"""Starting an import returns 202 with a job_id."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/openapi/import",
|
||||||
|
json={"url": "https://example.com/api/spec.json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 202
|
||||||
|
body = resp.json()
|
||||||
|
assert "job_id" in body
|
||||||
|
assert body["status"] == "pending"
|
||||||
|
assert body["spec_url"] == "https://example.com/api/spec.json"
|
||||||
|
|
||||||
|
async def test_import_invalid_url_returns_422(self) -> None:
|
||||||
|
"""POST with invalid URL (no http/https) returns 422."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/openapi/import",
|
||||||
|
json={"url": "ftp://example.com/spec.json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestJobStatusEndpoint:
|
||||||
|
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
|
||||||
|
|
||||||
|
async def test_get_existing_job_returns_status(self) -> None:
|
||||||
|
"""Retrieving an existing job returns its status."""
|
||||||
|
from app.openapi.review_api import _job_store
|
||||||
|
|
||||||
|
_job_store["test-job-1"] = {
|
||||||
|
"job_id": "test-job-1",
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://example.com/spec.json",
|
||||||
|
"total_endpoints": 5,
|
||||||
|
"classified_count": 5,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/openapi/jobs/test-job-1")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["job_id"] == "test-job-1"
|
||||||
|
assert body["status"] == "done"
|
||||||
|
assert body["total_endpoints"] == 5
|
||||||
|
|
||||||
|
async def test_get_unknown_job_returns_404(self) -> None:
|
||||||
|
"""Retrieving a non-existent job returns 404 error envelope."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/openapi/jobs/unknown-id-999")
|
||||||
|
|
||||||
|
assert resp.status_code == 404
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "not found" in body["error"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestApproveEndpoint:
|
||||||
|
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
|
||||||
|
|
||||||
|
async def test_approve_with_no_classifications_returns_400(self) -> None:
|
||||||
|
"""Approving a job with no classifications returns 400."""
|
||||||
|
from app.openapi.review_api import _job_store
|
||||||
|
|
||||||
|
_job_store["empty-job"] = {
|
||||||
|
"job_id": "empty-job",
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://example.com/spec.json",
|
||||||
|
"total_endpoints": 0,
|
||||||
|
"classified_count": 0,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.post("/api/v1/openapi/jobs/empty-job/approve")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "no classifications" in body["error"].lower()
|
||||||
@@ -20,10 +20,12 @@ import pytest
|
|||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
|
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.registry import AgentConfig, AgentRegistry
|
from app.registry import AgentConfig, AgentRegistry
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||||
@@ -128,10 +130,8 @@ class TestCheckpoint1OrderQueryRouting:
|
|||||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||||
))
|
))
|
||||||
graph.intent_classifier = mock_classifier
|
|
||||||
mock_registry = MagicMock()
|
mock_registry = MagicMock()
|
||||||
mock_registry.list_agents = MagicMock(return_value=())
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = mock_registry
|
|
||||||
|
|
||||||
# Graph streams order_lookup response
|
# Graph streams order_lookup response
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([
|
graph.astream = MagicMock(return_value=AsyncIterHelper([
|
||||||
@@ -140,14 +140,21 @@ class TestCheckpoint1OrderQueryRouting:
|
|||||||
]))
|
]))
|
||||||
graph.aget_state = AsyncMock(return_value=_state())
|
graph.aget_state = AsyncMock(return_value=_state())
|
||||||
|
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
|
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||||
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
|
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
|
||||||
@@ -201,25 +208,30 @@ class TestCheckpoint2MultiIntentSequential:
|
|||||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
graph.intent_classifier = mock_classifier
|
|
||||||
mock_registry = MagicMock()
|
mock_registry = MagicMock()
|
||||||
mock_registry.list_agents = MagicMock(return_value=())
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = mock_registry
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
graph.aget_state = AsyncMock(return_value=_state())
|
graph.aget_state = AsyncMock(return_value=_state())
|
||||||
|
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({
|
raw = json.dumps({
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"thread_id": "t1",
|
"thread_id": "t1",
|
||||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||||
})
|
})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
# Verify the graph was called with the routing hint in the message
|
# Verify the graph was called with the routing hint in the message
|
||||||
call_args = graph.astream.call_args
|
call_args = graph.astream.call_args
|
||||||
@@ -267,21 +279,26 @@ class TestCheckpoint3AmbiguousClarification:
|
|||||||
"Could you please provide more details about what you need help with?"
|
"Could you please provide more details about what you need help with?"
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
graph.intent_classifier = mock_classifier
|
|
||||||
mock_registry = MagicMock()
|
mock_registry = MagicMock()
|
||||||
mock_registry.list_agents = MagicMock(return_value=())
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = mock_registry
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
graph.aget_state = AsyncMock(return_value=_state())
|
graph.aget_state = AsyncMock(return_value=_state())
|
||||||
|
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
|
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
|
||||||
assert len(clarifications) == 1
|
assert len(clarifications) == 1
|
||||||
@@ -303,20 +320,26 @@ class TestCheckpoint4InterruptTTLAutoCancel:
|
|||||||
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
|
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
|
||||||
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||||
graph = MagicMock()
|
graph = MagicMock()
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
graph.aget_state = AsyncMock(return_value=st)
|
graph.aget_state = AsyncMock(return_value=st)
|
||||||
|
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
|
||||||
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager(ttl_seconds=1800) # 30 minutes
|
im = InterruptManager(ttl_seconds=1800) # 30 minutes
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
# Trigger interrupt
|
# Trigger interrupt
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||||
assert len(interrupts) == 1
|
assert len(interrupts) == 1
|
||||||
@@ -333,7 +356,7 @@ class TestCheckpoint4InterruptTTLAutoCancel:
|
|||||||
"thread_id": "t1",
|
"thread_id": "t1",
|
||||||
"approved": True,
|
"approved": True,
|
||||||
})
|
})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
# Should get retry prompt, NOT resume the graph
|
# Should get retry prompt, NOT resume the graph
|
||||||
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
|
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
|
||||||
|
|||||||
213
backend/tests/integration/test_replay_api.py
Normal file
213
backend/tests/integration/test_replay_api.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""Integration tests for /api/v1/conversations and /api/v1/replay/{thread_id}.
|
||||||
|
|
||||||
|
Tests the full API layer with a mocked database pool, verifying routing,
|
||||||
|
serialization, pagination, and error handling in envelope format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fake_cursor(rows, *, fetchone_value=None):
|
||||||
|
"""Build a fake async cursor returning the given rows on fetchall."""
|
||||||
|
cursor = AsyncMock()
|
||||||
|
cursor.fetchall = AsyncMock(return_value=rows)
|
||||||
|
if fetchone_value is not None:
|
||||||
|
cursor.fetchone = AsyncMock(return_value=fetchone_value)
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeConnection:
|
||||||
|
"""Fake async connection that returns pre-configured cursors in order."""
|
||||||
|
|
||||||
|
def __init__(self, cursors: list) -> None:
|
||||||
|
self._cursors = list(cursors)
|
||||||
|
self._idx = 0
|
||||||
|
|
||||||
|
async def execute(self, sql, params=None):
|
||||||
|
cursor = self._cursors[self._idx]
|
||||||
|
self._idx += 1
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _FakePool:
|
||||||
|
"""Fake connection pool that yields a fake connection."""
|
||||||
|
|
||||||
|
def __init__(self, conn: _FakeConnection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def connection(self):
|
||||||
|
return self._conn
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app(pool=None):
|
||||||
|
"""Build a minimal FastAPI app with the replay router and mocked deps."""
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.replay.api import router as replay_router
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(replay_router)
|
||||||
|
|
||||||
|
@test_app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(RequestValidationError)
|
||||||
|
async def _validation_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
test_app.state.settings = MagicMock(admin_api_key="")
|
||||||
|
test_app.state.pool = pool or MagicMock()
|
||||||
|
|
||||||
|
return test_app
|
||||||
|
|
||||||
|
|
||||||
|
class TestListConversations:
|
||||||
|
"""Tests for GET /api/v1/conversations endpoint."""
|
||||||
|
|
||||||
|
async def test_returns_paginated_envelope(self) -> None:
|
||||||
|
"""Conversations list returns envelope with pagination metadata."""
|
||||||
|
count_cursor = _make_fake_cursor([], fetchone_value=(3,))
|
||||||
|
rows = [
|
||||||
|
{"thread_id": "t1", "created_at": "2026-01-01", "last_activity": "2026-01-01",
|
||||||
|
"status": "active", "total_tokens": 100, "total_cost_usd": 0.01},
|
||||||
|
{"thread_id": "t2", "created_at": "2026-01-02", "last_activity": "2026-01-02",
|
||||||
|
"status": "resolved", "total_tokens": 200, "total_cost_usd": 0.02},
|
||||||
|
]
|
||||||
|
list_cursor = _make_fake_cursor(rows)
|
||||||
|
conn = _FakeConnection([count_cursor, list_cursor])
|
||||||
|
pool = _FakePool(conn)
|
||||||
|
test_app = _build_app(pool)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/conversations")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["data"]["total"] == 3
|
||||||
|
assert len(body["data"]["conversations"]) == 2
|
||||||
|
assert body["data"]["page"] == 1
|
||||||
|
assert body["data"]["per_page"] == 20
|
||||||
|
|
||||||
|
async def test_custom_page_and_per_page(self) -> None:
|
||||||
|
"""Custom page/per_page params are reflected in the response."""
|
||||||
|
count_cursor = _make_fake_cursor([], fetchone_value=(50,))
|
||||||
|
list_cursor = _make_fake_cursor([])
|
||||||
|
conn = _FakeConnection([count_cursor, list_cursor])
|
||||||
|
pool = _FakePool(conn)
|
||||||
|
test_app = _build_app(pool)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/conversations", params={"page": 3, "per_page": 10})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["data"]["page"] == 3
|
||||||
|
assert body["data"]["per_page"] == 10
|
||||||
|
|
||||||
|
async def test_invalid_page_returns_422(self) -> None:
|
||||||
|
"""page=0 violates ge=1 constraint and returns 422 error envelope."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/conversations", params={"page": 0})
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplayEndpoint:
|
||||||
|
"""Tests for GET /api/v1/replay/{thread_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_valid_thread_returns_timeline(self) -> None:
|
||||||
|
"""Replay with valid thread_id returns steps in envelope format."""
|
||||||
|
checkpoint_rows = [
|
||||||
|
{
|
||||||
|
"thread_id": "abc123",
|
||||||
|
"checkpoint_id": "cp1",
|
||||||
|
"checkpoint": {
|
||||||
|
"channel_values": {
|
||||||
|
"messages": [
|
||||||
|
{"type": "human", "content": "Hello", "created_at": "2026-01-01T00:00:00Z"},
|
||||||
|
{"type": "ai", "content": "Hi there!", "created_at": "2026-01-01T00:00:01Z"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
cursor = _make_fake_cursor(checkpoint_rows)
|
||||||
|
conn = _FakeConnection([cursor])
|
||||||
|
pool = _FakePool(conn)
|
||||||
|
test_app = _build_app(pool)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/replay/abc123")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["data"]["thread_id"] == "abc123"
|
||||||
|
assert body["data"]["total_steps"] == 2
|
||||||
|
assert len(body["data"]["steps"]) == 2
|
||||||
|
assert body["data"]["steps"][0]["type"] == "user_message"
|
||||||
|
assert body["data"]["steps"][1]["type"] == "agent_response"
|
||||||
|
|
||||||
|
async def test_invalid_thread_id_format_returns_400(self) -> None:
|
||||||
|
"""Thread IDs with path traversal characters are rejected with 400."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/replay/../../etc/passwd")
|
||||||
|
|
||||||
|
# FastAPI may return 400 from our handler or 404 from routing
|
||||||
|
assert resp.status_code in (400, 404, 422)
|
||||||
|
|
||||||
|
async def test_nonexistent_thread_returns_404(self) -> None:
|
||||||
|
"""Replay with a thread_id that has no checkpoints returns 404."""
|
||||||
|
cursor = _make_fake_cursor([])
|
||||||
|
conn = _FakeConnection([cursor])
|
||||||
|
pool = _FakePool(conn)
|
||||||
|
test_app = _build_app(pool)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/replay/nonexistent-thread")
|
||||||
|
|
||||||
|
assert resp.status_code == 404
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "not found" in body["error"].lower()
|
||||||
@@ -18,10 +18,12 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.registry import AgentConfig
|
from app.registry import AgentConfig
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -103,36 +105,45 @@ def _make_classifier(result: ClassificationResult) -> AsyncMock:
|
|||||||
return classifier
|
return classifier
|
||||||
|
|
||||||
|
|
||||||
def _make_graph(
|
def _make_graph_and_ctx(
|
||||||
classifier_result: ClassificationResult | None,
|
classifier_result: ClassificationResult | None,
|
||||||
chunks: list,
|
chunks: list,
|
||||||
state=None,
|
state=None,
|
||||||
) -> MagicMock:
|
) -> tuple[MagicMock, GraphContext]:
|
||||||
"""Build a graph mock with optional intent classifier."""
|
"""Build a graph mock and GraphContext with optional intent classifier."""
|
||||||
graph = MagicMock()
|
graph = MagicMock()
|
||||||
|
|
||||||
if classifier_result is not None:
|
|
||||||
graph.intent_classifier = _make_classifier(classifier_result)
|
|
||||||
mock_registry = MagicMock()
|
|
||||||
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
|
||||||
graph.agent_registry = mock_registry
|
|
||||||
else:
|
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
|
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
|
||||||
graph.aget_state = AsyncMock(return_value=state or _state())
|
graph.aget_state = AsyncMock(return_value=state or _state())
|
||||||
return graph
|
|
||||||
|
if classifier_result is not None:
|
||||||
|
classifier = _make_classifier(classifier_result)
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=classifier,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return graph, graph_ctx
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch(graph, content: str, thread_id: str = "t1") -> list[dict]:
|
async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch(thread_id)
|
sm.touch(thread_id)
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
|
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
return ws.sent
|
return ws.sent
|
||||||
|
|
||||||
|
|
||||||
@@ -151,12 +162,12 @@ class TestSingleIntentRouting:
|
|||||||
agent_name="order_lookup", confidence=0.95, reasoning="status query",
|
agent_name="order_lookup", confidence=0.95, reasoning="status query",
|
||||||
),),
|
),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
||||||
_chunk("Order 1042 is shipped.", "order_lookup"),
|
_chunk("Order 1042 is shipped.", "order_lookup"),
|
||||||
])
|
])
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "What is the status of order 1042?")
|
msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
|
||||||
|
|
||||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||||
assert len(tools) == 1
|
assert len(tools) == 1
|
||||||
@@ -171,13 +182,13 @@ class TestSingleIntentRouting:
|
|||||||
result = ClassificationResult(
|
result = ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
|
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(
|
graph, graph_ctx = _make_graph_and_ctx(
|
||||||
result,
|
result,
|
||||||
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
|
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
|
||||||
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
|
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
|
||||||
)
|
)
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "Cancel order 1042")
|
msgs = await _dispatch(graph_ctx, "Cancel order 1042")
|
||||||
|
|
||||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||||
assert tools[0]["tool"] == "cancel_order"
|
assert tools[0]["tool"] == "cancel_order"
|
||||||
@@ -191,12 +202,12 @@ class TestSingleIntentRouting:
|
|||||||
result = ClassificationResult(
|
result = ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
|
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
|
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
|
||||||
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
|
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
|
||||||
])
|
])
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "Give me a 15% coupon")
|
msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
|
||||||
|
|
||||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||||
assert tools[0]["tool"] == "generate_coupon"
|
assert tools[0]["tool"] == "generate_coupon"
|
||||||
@@ -207,11 +218,11 @@ class TestSingleIntentRouting:
|
|||||||
result = ClassificationResult(
|
result = ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
|
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
_chunk("I can help with order inquiries.", "fallback"),
|
_chunk("I can help with order inquiries.", "fallback"),
|
||||||
])
|
])
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "What can you do?")
|
msgs = await _dispatch(graph_ctx, "What can you do?")
|
||||||
|
|
||||||
tokens = [m for m in msgs if m["type"] == "token"]
|
tokens = [m for m in msgs if m["type"] == "token"]
|
||||||
assert tokens[0]["agent"] == "fallback"
|
assert tokens[0]["agent"] == "fallback"
|
||||||
@@ -233,7 +244,7 @@ class TestMultiIntentRouting:
|
|||||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
|
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
|
||||||
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
|
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
|
||||||
])
|
])
|
||||||
@@ -243,13 +254,17 @@ class TestMultiIntentRouting:
|
|||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({
|
raw = json.dumps({
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"thread_id": "t1",
|
"thread_id": "t1",
|
||||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||||
})
|
})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
# Verify routing hint was injected
|
# Verify routing hint was injected
|
||||||
call_args = graph.astream.call_args[0][0]
|
call_args = graph.astream.call_args[0][0]
|
||||||
@@ -269,16 +284,20 @@ class TestMultiIntentRouting:
|
|||||||
result = ClassificationResult(
|
result = ClassificationResult(
|
||||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [_chunk("Order shipped.", "order_lookup")])
|
graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
|
||||||
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||||
await dispatch_message(ws, graph, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
msg_content = graph.astream.call_args[0][0]["messages"][0].content
|
msg_content = graph.astream.call_args[0][0]["messages"][0].content
|
||||||
assert "[System:" not in msg_content
|
assert "[System:" not in msg_content
|
||||||
@@ -299,9 +318,9 @@ class TestAmbiguityRouting:
|
|||||||
is_ambiguous=True,
|
is_ambiguous=True,
|
||||||
clarification_question="Could you please clarify what you need?",
|
clarification_question="Could you please clarify what you need?",
|
||||||
)
|
)
|
||||||
graph = _make_graph(result, [])
|
graph, graph_ctx = _make_graph_and_ctx(result, [])
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "嗯...")
|
msgs = await _dispatch(graph_ctx, "嗯...")
|
||||||
|
|
||||||
clarifications = [m for m in msgs if m["type"] == "clarification"]
|
clarifications = [m for m in msgs if m["type"] == "clarification"]
|
||||||
assert len(clarifications) == 1
|
assert len(clarifications) == 1
|
||||||
@@ -339,12 +358,12 @@ class TestNoClassifierFallback:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_classifier_routes_via_supervisor(self) -> None:
|
async def test_no_classifier_routes_via_supervisor(self) -> None:
|
||||||
graph = _make_graph(
|
graph, graph_ctx = _make_graph_and_ctx(
|
||||||
classifier_result=None,
|
classifier_result=None,
|
||||||
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
|
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
|
||||||
)
|
)
|
||||||
|
|
||||||
msgs = await _dispatch(graph, "What is order 1042 status?")
|
msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
|
||||||
|
|
||||||
tokens = [m for m in msgs if m["type"] == "token"]
|
tokens = [m for m in msgs if m["type"] == "token"]
|
||||||
assert len(tokens) == 1
|
assert len(tokens) == 1
|
||||||
|
|||||||
159
backend/tests/integration/test_session_interrupt_lifecycle.py
Normal file
159
backend/tests/integration/test_session_interrupt_lifecycle.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""Integration tests for SessionManager + InterruptManager lifecycle.
|
||||||
|
|
||||||
|
These tests exercise the in-memory managers together, verifying the full
|
||||||
|
lifecycle of sessions and interrupts: creation, TTL sliding, interrupt
|
||||||
|
registration/resolution, and expired-interrupt cleanup.
|
||||||
|
|
||||||
|
No database required -- both managers are in-memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionInterruptLifecycle:
|
||||||
|
"""Tests for the combined session + interrupt lifecycle."""
|
||||||
|
|
||||||
|
def test_create_session_register_interrupt_check_status(self) -> None:
|
||||||
|
"""Full lifecycle: create session, register interrupt, verify both states."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=3600)
|
||||||
|
im = InterruptManager(ttl_seconds=300)
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
state = sm.touch("thread-1")
|
||||||
|
assert state.thread_id == "thread-1"
|
||||||
|
assert not state.has_pending_interrupt
|
||||||
|
assert not sm.is_expired("thread-1")
|
||||||
|
|
||||||
|
# Register an interrupt
|
||||||
|
record = im.register("thread-1", "cancel_order", {"order_id": "1042"})
|
||||||
|
sm.extend_for_interrupt("thread-1")
|
||||||
|
|
||||||
|
assert im.has_pending("thread-1")
|
||||||
|
session_state = sm.get_state("thread-1")
|
||||||
|
assert session_state is not None
|
||||||
|
assert session_state.has_pending_interrupt
|
||||||
|
|
||||||
|
# Session should not expire while interrupt is pending
|
||||||
|
assert not sm.is_expired("thread-1")
|
||||||
|
|
||||||
|
def test_interrupt_expiry_after_ttl(self) -> None:
|
||||||
|
"""Interrupt expires when TTL elapses, even if session is alive."""
|
||||||
|
im = InterruptManager(ttl_seconds=5)
|
||||||
|
|
||||||
|
record = im.register("thread-2", "refund", {"amount": 50})
|
||||||
|
assert im.has_pending("thread-2")
|
||||||
|
|
||||||
|
# Simulate time passing beyond TTL
|
||||||
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
|
mock_time.time.return_value = record.created_at + 10
|
||||||
|
assert not im.has_pending("thread-2")
|
||||||
|
|
||||||
|
status = im.check_status("thread-2")
|
||||||
|
assert status is not None
|
||||||
|
assert status.is_expired
|
||||||
|
assert status.remaining_seconds == 0.0
|
||||||
|
|
||||||
|
def test_interrupt_resolve_flow(self) -> None:
|
||||||
|
"""Resolving an interrupt removes it from pending and resets session."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=3600)
|
||||||
|
im = InterruptManager(ttl_seconds=300)
|
||||||
|
|
||||||
|
sm.touch("thread-3")
|
||||||
|
im.register("thread-3", "delete_account", {"user_id": "u1"})
|
||||||
|
sm.extend_for_interrupt("thread-3")
|
||||||
|
|
||||||
|
# Verify pending state
|
||||||
|
assert im.has_pending("thread-3")
|
||||||
|
assert sm.get_state("thread-3").has_pending_interrupt
|
||||||
|
|
||||||
|
# Resolve
|
||||||
|
im.resolve("thread-3")
|
||||||
|
sm.resolve_interrupt("thread-3")
|
||||||
|
|
||||||
|
assert not im.has_pending("thread-3")
|
||||||
|
session_state = sm.get_state("thread-3")
|
||||||
|
assert session_state is not None
|
||||||
|
assert not session_state.has_pending_interrupt
|
||||||
|
|
||||||
|
def test_cleanup_expired_removes_old_interrupts(self) -> None:
|
||||||
|
"""cleanup_expired removes only expired interrupts, keeping active ones."""
|
||||||
|
im = InterruptManager(ttl_seconds=10)
|
||||||
|
|
||||||
|
# Register two interrupts at different times
|
||||||
|
old_record = im.register("thread-old", "action_old", {})
|
||||||
|
new_record = im.register("thread-new", "action_new", {})
|
||||||
|
|
||||||
|
# Simulate time where only old one expired
|
||||||
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
|
# Move old record's creation to the past
|
||||||
|
im._interrupts["thread-old"] = old_record.__class__(
|
||||||
|
interrupt_id=old_record.interrupt_id,
|
||||||
|
thread_id=old_record.thread_id,
|
||||||
|
action=old_record.action,
|
||||||
|
params=old_record.params,
|
||||||
|
created_at=time.time() - 20,
|
||||||
|
ttl_seconds=old_record.ttl_seconds,
|
||||||
|
)
|
||||||
|
mock_time.time.return_value = time.time()
|
||||||
|
|
||||||
|
expired = im.cleanup_expired()
|
||||||
|
assert len(expired) == 1
|
||||||
|
assert expired[0].thread_id == "thread-old"
|
||||||
|
|
||||||
|
# New one should still be pending
|
||||||
|
assert im.has_pending("thread-new")
|
||||||
|
assert not im.has_pending("thread-old")
|
||||||
|
|
||||||
|
def test_session_ttl_sliding_window(self) -> None:
|
||||||
|
"""Touching a session resets the sliding window TTL."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=3600)
|
||||||
|
|
||||||
|
state1 = sm.touch("thread-5")
|
||||||
|
first_activity = state1.last_activity
|
||||||
|
|
||||||
|
time.sleep(0.01)
|
||||||
|
state2 = sm.touch("thread-5")
|
||||||
|
second_activity = state2.last_activity
|
||||||
|
|
||||||
|
assert second_activity > first_activity
|
||||||
|
assert not sm.is_expired("thread-5")
|
||||||
|
|
||||||
|
def test_session_expires_after_ttl_without_activity(self) -> None:
|
||||||
|
"""Session expires when TTL passes without a touch or interrupt."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=0)
|
||||||
|
sm.touch("thread-6")
|
||||||
|
|
||||||
|
# TTL is 0 so session is immediately expired
|
||||||
|
assert sm.is_expired("thread-6")
|
||||||
|
|
||||||
|
def test_pending_interrupt_prevents_session_expiry(self) -> None:
|
||||||
|
"""A session with pending interrupt does not expire even with TTL=0."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=0)
|
||||||
|
sm.touch("thread-7")
|
||||||
|
sm.extend_for_interrupt("thread-7")
|
||||||
|
|
||||||
|
# Even with TTL=0, session should not expire because of pending interrupt
|
||||||
|
assert not sm.is_expired("thread-7")
|
||||||
|
|
||||||
|
def test_retry_prompt_for_expired_interrupt(self) -> None:
|
||||||
|
"""InterruptManager generates a retry prompt for expired interrupts."""
|
||||||
|
im = InterruptManager(ttl_seconds=300)
|
||||||
|
record = im.register("thread-8", "cancel_order", {"order_id": "1042"})
|
||||||
|
|
||||||
|
prompt = im.generate_retry_prompt(record)
|
||||||
|
|
||||||
|
assert prompt["type"] == "interrupt_expired"
|
||||||
|
assert prompt["thread_id"] == "thread-8"
|
||||||
|
assert "cancel_order" in prompt["action"]
|
||||||
|
assert "cancel_order" in prompt["message"]
|
||||||
|
assert "expired" in prompt["message"].lower()
|
||||||
@@ -15,8 +15,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -81,8 +83,6 @@ def _graph(
|
|||||||
resume_chunks: list | None = None,
|
resume_chunks: list | None = None,
|
||||||
) -> MagicMock:
|
) -> MagicMock:
|
||||||
g = MagicMock()
|
g = MagicMock()
|
||||||
g.intent_classifier = None
|
|
||||||
g.agent_registry = None
|
|
||||||
|
|
||||||
if st is None:
|
if st is None:
|
||||||
st = _state()
|
st = _state()
|
||||||
@@ -100,6 +100,13 @@ def _graph(
|
|||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||||
|
g = graph or _graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||||
|
|
||||||
|
|
||||||
def _setup(
|
def _setup(
|
||||||
graph=None,
|
graph=None,
|
||||||
session_ttl: int = 1800,
|
session_ttl: int = 1800,
|
||||||
@@ -109,23 +116,28 @@ def _setup(
|
|||||||
):
|
):
|
||||||
"""Create test dependencies. Pre-touches session by default."""
|
"""Create test dependencies. Pre-touches session by default."""
|
||||||
g = graph or _graph()
|
g = graph or _graph()
|
||||||
|
graph_ctx = _make_graph_ctx(g)
|
||||||
sm = SessionManager(session_ttl_seconds=session_ttl)
|
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||||
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
ws = FakeWS()
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
if touch:
|
if touch:
|
||||||
sm.touch(thread_id)
|
sm.touch(thread_id)
|
||||||
return g, sm, im, cb, ws
|
return g, sm, im, cb, ws, ws_ctx
|
||||||
|
|
||||||
|
|
||||||
async def _send(ws, g, sm, im, cb, *, thread_id="t1", content="hello", msg_type="message"):
|
async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
|
||||||
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
|
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
|
||||||
async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
|
async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
|
||||||
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
|
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -136,10 +148,10 @@ async def _respond(ws, g, sm, im, cb, *, thread_id="t1", approved=True):
|
|||||||
class TestWebSocketHappyPath:
|
class TestWebSocketHappyPath:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_message_receives_tokens_and_complete(self) -> None:
|
async def test_send_message_receives_tokens_and_complete(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(
|
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||||
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
|
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
|
||||||
)
|
)
|
||||||
await _send(ws, g, sm, im, cb, content="What is the status of order 1042?")
|
await _send(ws, ws_ctx, content="What is the status of order 1042?")
|
||||||
|
|
||||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||||
assert len(tokens) == 2
|
assert len(tokens) == 2
|
||||||
@@ -153,13 +165,13 @@ class TestWebSocketHappyPath:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_call_streamed(self) -> None:
|
async def test_tool_call_streamed(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(
|
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||||
graph=_graph(chunks=[
|
graph=_graph(chunks=[
|
||||||
_tool_chunk("get_order_status", {"order_id": "1042"}),
|
_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||||
_chunk("Order shipped."),
|
_chunk("Order shipped."),
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
await _send(ws, g, sm, im, cb, content="Check order 1042")
|
await _send(ws, ws_ctx, content="Check order 1042")
|
||||||
|
|
||||||
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||||
assert len(tools) == 1
|
assert len(tools) == 1
|
||||||
@@ -168,9 +180,9 @@ class TestWebSocketHappyPath:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_multiple_messages_same_session(self) -> None:
|
async def test_multiple_messages_same_session(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
await _send(ws, g, sm, im, cb, content=f"msg {i}")
|
await _send(ws, ws_ctx, content=f"msg {i}")
|
||||||
|
|
||||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||||
assert len(completes) == 3
|
assert len(completes) == 3
|
||||||
@@ -183,10 +195,10 @@ class TestWebSocketInterruptApproval:
|
|||||||
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||||
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
|
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
|
||||||
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||||
g_, sm, im, cb, ws = _setup(graph=g)
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||||
|
|
||||||
# Send message -> triggers interrupt
|
# Send message -> triggers interrupt
|
||||||
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||||
|
|
||||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||||
assert len(interrupts) == 1
|
assert len(interrupts) == 1
|
||||||
@@ -196,7 +208,7 @@ class TestWebSocketInterruptApproval:
|
|||||||
|
|
||||||
# Approve
|
# Approve
|
||||||
ws.sent.clear()
|
ws.sent.clear()
|
||||||
await _respond(ws, g_, sm, im, cb, approved=True)
|
await _respond(ws, ws_ctx, approved=True)
|
||||||
|
|
||||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||||
assert len(tokens) == 1
|
assert len(tokens) == 1
|
||||||
@@ -211,12 +223,12 @@ class TestWebSocketInterruptApproval:
|
|||||||
st_int = _state(interrupt=True)
|
st_int = _state(interrupt=True)
|
||||||
resume = [_chunk("Order remains active.", "order_actions")]
|
resume = [_chunk("Order remains active.", "order_actions")]
|
||||||
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||||
g_, sm, im, cb, ws = _setup(graph=g)
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||||
|
|
||||||
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||||
ws.sent.clear()
|
ws.sent.clear()
|
||||||
|
|
||||||
await _respond(ws, g_, sm, im, cb, approved=False)
|
await _respond(ws, ws_ctx, approved=False)
|
||||||
|
|
||||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||||
assert "remains active" in tokens[0]["content"]
|
assert "remains active" in tokens[0]["content"]
|
||||||
@@ -226,28 +238,28 @@ class TestWebSocketInterruptApproval:
|
|||||||
class TestWebSocketSessionTTL:
|
class TestWebSocketSessionTTL:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_expired_session_returns_error(self) -> None:
|
async def test_expired_session_returns_error(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(session_ttl=0)
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
|
||||||
# Session was touched in _setup, but TTL is 0 so it's already expired
|
# Session was touched in _setup, but TTL is 0 so it's already expired
|
||||||
await _send(ws, g, sm, im, cb, content="hello")
|
await _send(ws, ws_ctx, content="hello")
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "expired" in ws.sent[0]["message"].lower()
|
assert "expired" in ws.sent[0]["message"].lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_session_not_expired(self) -> None:
|
async def test_new_session_not_expired(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(session_ttl=3600)
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||||
await _send(ws, g, sm, im, cb, content="hello")
|
await _send(ws, ws_ctx, content="hello")
|
||||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||||
assert len(completes) == 1
|
assert len(completes) == 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sliding_window_resets_on_message(self) -> None:
|
async def test_sliding_window_resets_on_message(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup(session_ttl=3600)
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||||
|
|
||||||
await _send(ws, g, sm, im, cb, content="hello")
|
await _send(ws, ws_ctx, content="hello")
|
||||||
first_activity = sm.get_state("t1").last_activity
|
first_activity = sm.get_state("t1").last_activity
|
||||||
|
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
await _send(ws, g, sm, im, cb, content="hello again")
|
await _send(ws, ws_ctx, content="hello again")
|
||||||
second_activity = sm.get_state("t1").last_activity
|
second_activity = sm.get_state("t1").last_activity
|
||||||
|
|
||||||
assert second_activity > first_activity
|
assert second_activity > first_activity
|
||||||
@@ -256,9 +268,9 @@ class TestWebSocketSessionTTL:
|
|||||||
async def test_interrupt_extends_session_ttl(self) -> None:
|
async def test_interrupt_extends_session_ttl(self) -> None:
|
||||||
st_int = _state(interrupt=True)
|
st_int = _state(interrupt=True)
|
||||||
g = _graph(chunks=[], st=st_int)
|
g = _graph(chunks=[], st=st_int)
|
||||||
g_, sm, im, cb, ws = _setup(graph=g, session_ttl=3600)
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
|
||||||
|
|
||||||
await _send(ws, g_, sm, im, cb, content="cancel order")
|
await _send(ws, ws_ctx, content="cancel order")
|
||||||
|
|
||||||
state = sm.get_state("t1")
|
state = sm.get_state("t1")
|
||||||
assert state is not None
|
assert state is not None
|
||||||
@@ -270,53 +282,53 @@ class TestWebSocketSessionTTL:
|
|||||||
class TestWebSocketValidation:
|
class TestWebSocketValidation:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_json(self) -> None:
|
async def test_invalid_json(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
await dispatch_message(ws, g, sm, cb, "not json", interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, "not json")
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "Invalid JSON" in ws.sent[0]["message"]
|
assert "Invalid JSON" in ws.sent[0]["message"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_thread_id(self) -> None:
|
async def test_missing_thread_id(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "message", "content": "hi"})
|
raw = json.dumps({"type": "message", "content": "hi"})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "thread_id" in ws.sent[0]["message"]
|
assert "thread_id" in ws.sent[0]["message"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_thread_id_format(self) -> None:
|
async def test_invalid_thread_id_format(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
|
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_content(self) -> None:
|
async def test_missing_content(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1"})
|
raw = json.dumps({"type": "message", "thread_id": "t1"})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_message_type(self) -> None:
|
async def test_unknown_message_type(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
|
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "Unknown" in ws.sent[0]["message"]
|
assert "Unknown" in ws.sent[0]["message"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_too_large(self) -> None:
|
async def test_message_too_large(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
await dispatch_message(ws, g, sm, cb, "x" * 40_000, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, "x" * 40_000)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "too large" in ws.sent[0]["message"].lower()
|
assert "too large" in ws.sent[0]["message"].lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_content_too_long(self) -> None:
|
async def test_content_too_long(self) -> None:
|
||||||
g, sm, im, cb, ws = _setup()
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||||
await dispatch_message(ws, g, sm, cb, raw, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
assert ws.sent[0]["type"] == "error"
|
assert ws.sent[0]["type"] == "error"
|
||||||
assert "too long" in ws.sent[0]["message"].lower()
|
assert "too long" in ws.sent[0]["message"].lower()
|
||||||
|
|
||||||
@@ -327,10 +339,10 @@ class TestWebSocketInterruptTTL:
|
|||||||
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
||||||
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||||
g = _graph(chunks=[], st=st_int)
|
g = _graph(chunks=[], st=st_int)
|
||||||
g_, sm, im, cb, ws = _setup(graph=g, interrupt_ttl=5)
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
|
||||||
|
|
||||||
# Trigger interrupt
|
# Trigger interrupt
|
||||||
await _send(ws, g_, sm, im, cb, content="Cancel order 1042")
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||||
|
|
||||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||||
assert len(interrupts) == 1
|
assert len(interrupts) == 1
|
||||||
@@ -341,7 +353,7 @@ class TestWebSocketInterruptTTL:
|
|||||||
|
|
||||||
with patch("app.interrupt_manager.time") as mock_time:
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
mock_time.time.return_value = record.created_at + 10
|
mock_time.time.return_value = record.created_at + 10
|
||||||
await _respond(ws, g_, sm, im, cb, approved=True)
|
await _respond(ws, ws_ctx, approved=True)
|
||||||
|
|
||||||
assert ws.sent[0]["type"] == "interrupt_expired"
|
assert ws.sent[0]["type"] == "interrupt_expired"
|
||||||
assert "cancel_order" in ws.sent[0]["message"]
|
assert "cancel_order" in ws.sent[0]["message"]
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ def _make_analytics_result() -> object:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_analytics(app: FastAPI, path: str = "/api/analytics", **patch_kwargs: object) -> object:
|
def _get_analytics(app: FastAPI, path: str = "/api/v1/analytics", **patch_kwargs: object) -> object:
|
||||||
"""Helper: patch get_analytics, make request, return (response, mock)."""
|
"""Helper: patch get_analytics, make request, return (response, mock)."""
|
||||||
analytics_result = _make_analytics_result()
|
analytics_result = _make_analytics_result()
|
||||||
with (
|
with (
|
||||||
@@ -84,7 +84,7 @@ class TestAnalyticsEndpoint:
|
|||||||
def test_custom_range_7d(self) -> None:
|
def test_custom_range_7d(self) -> None:
|
||||||
app = _build_app()
|
app = _build_app()
|
||||||
app.state.pool = _make_mock_pool()
|
app.state.pool = _make_mock_pool()
|
||||||
resp, mock_ga = _get_analytics(app, "/api/analytics?range=7d")
|
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=7d")
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
mock_ga.assert_called_once()
|
mock_ga.assert_called_once()
|
||||||
@@ -94,7 +94,7 @@ class TestAnalyticsEndpoint:
|
|||||||
def test_custom_range_30d(self) -> None:
|
def test_custom_range_30d(self) -> None:
|
||||||
app = _build_app()
|
app = _build_app()
|
||||||
app.state.pool = _make_mock_pool()
|
app.state.pool = _make_mock_pool()
|
||||||
resp, mock_ga = _get_analytics(app, "/api/analytics?range=30d")
|
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=30d")
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
call_kwargs = mock_ga.call_args
|
call_kwargs = mock_ga.call_args
|
||||||
@@ -107,7 +107,7 @@ class TestAnalyticsEndpoint:
|
|||||||
app.state.pool = _make_mock_pool()
|
app.state.pool = _make_mock_pool()
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/analytics?range=invalid")
|
resp = client.get("/api/v1/analytics?range=invalid")
|
||||||
|
|
||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ class TestAnalyticsEndpoint:
|
|||||||
app.state.pool = _make_mock_pool()
|
app.state.pool = _make_mock_pool()
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/analytics?range=7")
|
resp = client.get("/api/v1/analytics?range=7")
|
||||||
|
|
||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|||||||
@@ -145,4 +145,11 @@ class TestPostgresAnalyticsRecorder:
|
|||||||
)
|
)
|
||||||
call_args = mock_conn.execute.call_args
|
call_args = mock_conn.execute.call_args
|
||||||
params = call_args[0][1]
|
params = call_args[0][1]
|
||||||
assert params["metadata"] == {"key": "val"}
|
# PostgresAnalyticsRecorder wraps metadata with psycopg Json() adapter.
|
||||||
|
# Unwrap to compare the inner dict.
|
||||||
|
from psycopg.types.json import Json
|
||||||
|
|
||||||
|
meta = params["metadata"]
|
||||||
|
if isinstance(meta, Json):
|
||||||
|
meta = meta.obj
|
||||||
|
assert meta == {"key": "val"}
|
||||||
|
|||||||
@@ -158,6 +158,42 @@ class TestInterruptStatsQuery:
|
|||||||
assert result.expired == 0
|
assert result.expired == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTotalConversations:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_count(self) -> None:
|
||||||
|
from app.analytics.queries import _total_conversations
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone({"total": 42})
|
||||||
|
result = await _total_conversations(pool, range_days=7)
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_state_returns_zero(self) -> None:
|
||||||
|
from app.analytics.queries import _total_conversations
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone(None)
|
||||||
|
result = await _total_conversations(pool, range_days=7)
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAvgTurns:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_float(self) -> None:
|
||||||
|
from app.analytics.queries import _avg_turns
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone({"avg_turns": 3.5})
|
||||||
|
result = await _avg_turns(pool, range_days=7)
|
||||||
|
assert result == 3.5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_state_returns_zero(self) -> None:
|
||||||
|
from app.analytics.queries import _avg_turns
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone(None)
|
||||||
|
result = await _avg_turns(pool, range_days=7)
|
||||||
|
assert result == 0.0
|
||||||
|
|
||||||
|
|
||||||
class TestGetAnalytics:
|
class TestGetAnalytics:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_analytics_result(self) -> None:
|
async def test_returns_analytics_result(self) -> None:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def client():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def job_id(client):
|
def job_id(client):
|
||||||
"""Create a job and return its ID."""
|
"""Create a job and return its ID."""
|
||||||
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||||
assert response.status_code == 202
|
assert response.status_code == 202
|
||||||
return response.json()["job_id"]
|
return response.json()["job_id"]
|
||||||
|
|
||||||
@@ -61,11 +61,11 @@ def job_with_classifications(client, job_id):
|
|||||||
|
|
||||||
|
|
||||||
class TestImportEndpoint:
|
class TestImportEndpoint:
|
||||||
"""Tests for POST /api/openapi/import."""
|
"""Tests for POST /api/v1/openapi/import."""
|
||||||
|
|
||||||
def test_post_import_returns_job_id(self, client) -> None:
|
def test_post_import_returns_job_id(self, client) -> None:
|
||||||
"""POST /import returns 202 with a job_id."""
|
"""POST /import returns 202 with a job_id."""
|
||||||
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||||
assert response.status_code == 202
|
assert response.status_code == 202
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "job_id" in data
|
assert "job_id" in data
|
||||||
@@ -73,38 +73,38 @@ class TestImportEndpoint:
|
|||||||
|
|
||||||
def test_post_import_empty_url_returns_422(self, client) -> None:
|
def test_post_import_empty_url_returns_422(self, client) -> None:
|
||||||
"""POST /import with empty URL returns 422 validation error."""
|
"""POST /import with empty URL returns 422 validation error."""
|
||||||
response = client.post("/api/openapi/import", json={"url": ""})
|
response = client.post("/api/v1/openapi/import", json={"url": ""})
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
def test_post_import_missing_url_returns_422(self, client) -> None:
|
def test_post_import_missing_url_returns_422(self, client) -> None:
|
||||||
"""POST /import with missing URL field returns 422."""
|
"""POST /import with missing URL field returns 422."""
|
||||||
response = client.post("/api/openapi/import", json={})
|
response = client.post("/api/v1/openapi/import", json={})
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
def test_post_import_invalid_scheme_returns_422(self, client) -> None:
|
def test_post_import_invalid_scheme_returns_422(self, client) -> None:
|
||||||
"""POST /import with non-http URL returns 422."""
|
"""POST /import with non-http URL returns 422."""
|
||||||
response = client.post("/api/openapi/import", json={"url": "ftp://evil.com/spec"})
|
response = client.post("/api/v1/openapi/import", json={"url": "ftp://evil.com/spec"})
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
def test_post_import_returns_pending_status(self, client) -> None:
|
def test_post_import_returns_pending_status(self, client) -> None:
|
||||||
"""Newly created job has pending status."""
|
"""Newly created job has pending status."""
|
||||||
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["status"] == "pending"
|
assert data["status"] == "pending"
|
||||||
|
|
||||||
def test_post_import_returns_spec_url(self, client) -> None:
|
def test_post_import_returns_spec_url(self, client) -> None:
|
||||||
"""Response includes the original spec URL."""
|
"""Response includes the original spec URL."""
|
||||||
response = client.post("/api/openapi/import", json={"url": _SAMPLE_URL})
|
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["spec_url"] == _SAMPLE_URL
|
assert data["spec_url"] == _SAMPLE_URL
|
||||||
|
|
||||||
|
|
||||||
class TestGetJobEndpoint:
|
class TestGetJobEndpoint:
|
||||||
"""Tests for GET /api/openapi/jobs/{job_id}."""
|
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
|
||||||
|
|
||||||
def test_get_job_returns_status(self, client, job_id) -> None:
|
def test_get_job_returns_status(self, client, job_id) -> None:
|
||||||
"""GET /jobs/{id} returns job status."""
|
"""GET /jobs/{id} returns job status."""
|
||||||
response = client.get(f"/api/openapi/jobs/{job_id}")
|
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "status" in data
|
assert "status" in data
|
||||||
@@ -112,23 +112,23 @@ class TestGetJobEndpoint:
|
|||||||
|
|
||||||
def test_get_unknown_job_returns_404(self, client) -> None:
|
def test_get_unknown_job_returns_404(self, client) -> None:
|
||||||
"""GET /jobs/nonexistent returns 404."""
|
"""GET /jobs/nonexistent returns 404."""
|
||||||
response = client.get("/api/openapi/jobs/nonexistent-id")
|
response = client.get("/api/v1/openapi/jobs/nonexistent-id")
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
def test_get_job_includes_spec_url(self, client, job_id) -> None:
|
def test_get_job_includes_spec_url(self, client, job_id) -> None:
|
||||||
"""Job response includes the spec URL."""
|
"""Job response includes the spec URL."""
|
||||||
response = client.get(f"/api/openapi/jobs/{job_id}")
|
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["spec_url"] == _SAMPLE_URL
|
assert data["spec_url"] == _SAMPLE_URL
|
||||||
|
|
||||||
|
|
||||||
class TestGetClassificationsEndpoint:
|
class TestGetClassificationsEndpoint:
|
||||||
"""Tests for GET /api/openapi/jobs/{job_id}/classifications."""
|
"""Tests for GET /api/v1/openapi/jobs/{job_id}/classifications."""
|
||||||
|
|
||||||
def test_get_classifications_returns_list(self, client, job_with_classifications) -> None:
|
def test_get_classifications_returns_list(self, client, job_with_classifications) -> None:
|
||||||
"""GET /classifications returns a list."""
|
"""GET /classifications returns a list."""
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"/api/openapi/jobs/{job_with_classifications}/classifications"
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -137,13 +137,13 @@ class TestGetClassificationsEndpoint:
|
|||||||
|
|
||||||
def test_get_classifications_unknown_job_returns_404(self, client) -> None:
|
def test_get_classifications_unknown_job_returns_404(self, client) -> None:
|
||||||
"""GET /classifications for unknown job returns 404."""
|
"""GET /classifications for unknown job returns 404."""
|
||||||
response = client.get("/api/openapi/jobs/unknown/classifications")
|
response = client.get("/api/v1/openapi/jobs/unknown/classifications")
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
def test_classification_has_expected_fields(self, client, job_with_classifications) -> None:
|
def test_classification_has_expected_fields(self, client, job_with_classifications) -> None:
|
||||||
"""Each classification item has access_type and endpoint fields."""
|
"""Each classification item has access_type and endpoint fields."""
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"/api/openapi/jobs/{job_with_classifications}/classifications"
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
|
||||||
)
|
)
|
||||||
item = response.json()[0]
|
item = response.json()[0]
|
||||||
assert "access_type" in item
|
assert "access_type" in item
|
||||||
@@ -152,12 +152,12 @@ class TestGetClassificationsEndpoint:
|
|||||||
|
|
||||||
|
|
||||||
class TestUpdateClassificationEndpoint:
|
class TestUpdateClassificationEndpoint:
|
||||||
"""Tests for PUT /api/openapi/jobs/{job_id}/classifications/{idx}."""
|
"""Tests for PUT /api/v1/openapi/jobs/{job_id}/classifications/{idx}."""
|
||||||
|
|
||||||
def test_update_classification_succeeds(self, client, job_with_classifications) -> None:
|
def test_update_classification_succeeds(self, client, job_with_classifications) -> None:
|
||||||
"""PUT /classifications/0 updates the classification."""
|
"""PUT /classifications/0 updates the classification."""
|
||||||
response = client.put(
|
response = client.put(
|
||||||
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||||
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -165,7 +165,7 @@ class TestUpdateClassificationEndpoint:
|
|||||||
def test_update_unknown_job_returns_404(self, client) -> None:
|
def test_update_unknown_job_returns_404(self, client) -> None:
|
||||||
"""PUT /classifications/0 for unknown job returns 404."""
|
"""PUT /classifications/0 for unknown job returns 404."""
|
||||||
response = client.put(
|
response = client.put(
|
||||||
"/api/openapi/jobs/unknown/classifications/0",
|
"/api/v1/openapi/jobs/unknown/classifications/0",
|
||||||
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
@@ -173,7 +173,7 @@ class TestUpdateClassificationEndpoint:
|
|||||||
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
|
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
|
||||||
"""PUT /classifications/0 with invalid access_type returns 422."""
|
"""PUT /classifications/0 with invalid access_type returns 422."""
|
||||||
response = client.put(
|
response = client.put(
|
||||||
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||||
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
|
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
@@ -181,7 +181,7 @@ class TestUpdateClassificationEndpoint:
|
|||||||
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
|
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
|
||||||
"""PUT /classifications/0 with invalid agent_group returns 422."""
|
"""PUT /classifications/0 with invalid agent_group returns 422."""
|
||||||
response = client.put(
|
response = client.put(
|
||||||
f"/api/openapi/jobs/{job_with_classifications}/classifications/0",
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||||
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
|
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
@@ -189,31 +189,31 @@ class TestUpdateClassificationEndpoint:
|
|||||||
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
|
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
|
||||||
"""PUT /classifications/999 returns 404 for out-of-range index."""
|
"""PUT /classifications/999 returns 404 for out-of-range index."""
|
||||||
response = client.put(
|
response = client.put(
|
||||||
f"/api/openapi/jobs/{job_with_classifications}/classifications/999",
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/999",
|
||||||
json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"},
|
json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
class TestApproveEndpoint:
|
class TestApproveEndpoint:
|
||||||
"""Tests for POST /api/openapi/jobs/{job_id}/approve."""
|
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
|
||||||
|
|
||||||
def test_approve_job_succeeds(self, client, job_with_classifications) -> None:
|
def test_approve_job_succeeds(self, client, job_with_classifications) -> None:
|
||||||
"""POST /approve transitions job to approved status."""
|
"""POST /approve transitions job to approved status."""
|
||||||
response = client.post(
|
response = client.post(
|
||||||
f"/api/openapi/jobs/{job_with_classifications}/approve"
|
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
def test_approve_unknown_job_returns_404(self, client) -> None:
|
def test_approve_unknown_job_returns_404(self, client) -> None:
|
||||||
"""POST /approve for unknown job returns 404."""
|
"""POST /approve for unknown job returns 404."""
|
||||||
response = client.post("/api/openapi/jobs/unknown/approve")
|
response = client.post("/api/v1/openapi/jobs/unknown/approve")
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
def test_approve_returns_job_status(self, client, job_with_classifications) -> None:
|
def test_approve_returns_job_status(self, client, job_with_classifications) -> None:
|
||||||
"""POST /approve returns updated job status."""
|
"""POST /approve returns updated job status."""
|
||||||
response = client.post(
|
response = client.post(
|
||||||
f"/api/openapi/jobs/{job_with_classifications}/approve"
|
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
|
||||||
)
|
)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "status" in data
|
assert "status" in data
|
||||||
|
|||||||
@@ -5,9 +5,12 @@ from __future__ import annotations
|
|||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
@@ -16,13 +19,43 @@ def _build_app() -> FastAPI:
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def _make_mock_pool(fetchall_result: list[dict]) -> MagicMock:
|
def _make_mock_pool(
|
||||||
"""Build a mock pool that returns the given rows from fetchall."""
|
fetchall_result: list[dict],
|
||||||
|
*,
|
||||||
|
count: int | None = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Build a mock pool that returns the given rows from fetchall.
|
||||||
|
|
||||||
|
When *count* is provided, the first execute() call returns a cursor
|
||||||
|
whose fetchone() yields ``(count,)`` (for the COUNT query) and the
|
||||||
|
second call returns the rows via fetchall(). When *count* is None
|
||||||
|
(the default), a single cursor backed by *fetchall_result* is used
|
||||||
|
for all calls.
|
||||||
|
"""
|
||||||
|
if count is not None:
|
||||||
|
count_cursor = AsyncMock()
|
||||||
|
count_cursor.fetchone = AsyncMock(return_value=(count,))
|
||||||
|
|
||||||
|
rows_cursor = AsyncMock()
|
||||||
|
rows_cursor.fetchall = AsyncMock(return_value=fetchall_result)
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.execute = AsyncMock(side_effect=[count_cursor, rows_cursor])
|
||||||
|
else:
|
||||||
mock_cursor = AsyncMock()
|
mock_cursor = AsyncMock()
|
||||||
mock_cursor.fetchall = AsyncMock(return_value=fetchall_result)
|
mock_cursor.fetchall = AsyncMock(return_value=fetchall_result)
|
||||||
|
mock_cursor.fetchone = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||||
@@ -39,14 +72,17 @@ def _make_mock_pool(fetchall_result: list[dict]) -> MagicMock:
|
|||||||
class TestListConversations:
|
class TestListConversations:
|
||||||
def test_returns_200_with_empty_list(self) -> None:
|
def test_returns_200_with_empty_list(self) -> None:
|
||||||
app = _build_app()
|
app = _build_app()
|
||||||
app.state.pool = _make_mock_pool([])
|
app.state.pool = _make_mock_pool([], count=0)
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/conversations")
|
resp = client.get("/api/v1/conversations")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
assert body["success"] is True
|
assert body["success"] is True
|
||||||
assert isinstance(body["data"], list)
|
data = body["data"]
|
||||||
|
assert isinstance(data["conversations"], list)
|
||||||
|
assert data["total"] == 0
|
||||||
|
assert data["page"] == 1
|
||||||
assert body["error"] is None
|
assert body["error"] is None
|
||||||
|
|
||||||
def test_returns_conversations_list(self) -> None:
|
def test_returns_conversations_list(self) -> None:
|
||||||
@@ -61,39 +97,41 @@ class TestListConversations:
|
|||||||
"total_cost_usd": 0.01,
|
"total_cost_usd": 0.01,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
app.state.pool = _make_mock_pool(mock_rows)
|
app.state.pool = _make_mock_pool(mock_rows, count=1)
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/conversations")
|
resp = client.get("/api/v1/conversations")
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert len(body["data"]) == 1
|
data = body["data"]
|
||||||
assert body["data"][0]["thread_id"] == "t1"
|
assert len(data["conversations"]) == 1
|
||||||
|
assert data["conversations"][0]["thread_id"] == "t1"
|
||||||
|
assert data["total"] == 1
|
||||||
|
|
||||||
def test_pagination_defaults(self) -> None:
|
def test_pagination_defaults(self) -> None:
|
||||||
app = _build_app()
|
app = _build_app()
|
||||||
app.state.pool = _make_mock_pool([])
|
app.state.pool = _make_mock_pool([], count=0)
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/conversations")
|
resp = client.get("/api/v1/conversations")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
def test_pagination_custom_params(self) -> None:
|
def test_pagination_custom_params(self) -> None:
|
||||||
app = _build_app()
|
app = _build_app()
|
||||||
app.state.pool = _make_mock_pool([])
|
app.state.pool = _make_mock_pool([], count=0)
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/conversations?page=2&per_page=10")
|
resp = client.get("/api/v1/conversations?page=2&per_page=10")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
def test_per_page_max_capped_at_100(self) -> None:
|
def test_per_page_max_capped_at_100(self) -> None:
|
||||||
app = _build_app()
|
app = _build_app()
|
||||||
app.state.pool = _make_mock_pool([])
|
app.state.pool = _make_mock_pool([], count=0)
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/conversations?per_page=200")
|
resp = client.get("/api/v1/conversations?per_page=200")
|
||||||
# FastAPI validation rejects values > 100
|
# FastAPI Query(le=100) rejects values > 100
|
||||||
assert resp.status_code in (200, 422)
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
class TestGetReplay:
|
class TestGetReplay:
|
||||||
@@ -102,7 +140,7 @@ class TestGetReplay:
|
|||||||
app.state.pool = _make_mock_pool([])
|
app.state.pool = _make_mock_pool([])
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/replay/nonexistent-thread")
|
resp = client.get("/api/v1/replay/nonexistent-thread")
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
def test_returns_replay_page_for_existing_thread(self) -> None:
|
def test_returns_replay_page_for_existing_thread(self) -> None:
|
||||||
@@ -122,7 +160,7 @@ class TestGetReplay:
|
|||||||
app.state.pool = _make_mock_pool(mock_rows)
|
app.state.pool = _make_mock_pool(mock_rows)
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/replay/thread-123")
|
resp = client.get("/api/v1/replay/thread-123")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
assert body["success"] is True
|
assert body["success"] is True
|
||||||
@@ -147,7 +185,7 @@ class TestGetReplay:
|
|||||||
app.state.pool = _make_mock_pool(mock_rows)
|
app.state.pool = _make_mock_pool(mock_rows)
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/replay/t1?page=1&per_page=5")
|
resp = client.get("/api/v1/replay/t1?page=1&per_page=5")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
def test_error_response_has_envelope(self) -> None:
|
def test_error_response_has_envelope(self) -> None:
|
||||||
@@ -155,16 +193,19 @@ class TestGetReplay:
|
|||||||
app.state.pool = _make_mock_pool([])
|
app.state.pool = _make_mock_pool([])
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/replay/missing")
|
resp = client.get("/api/v1/replay/missing")
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
assert "detail" in resp.json()
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] is not None
|
||||||
|
|
||||||
def test_invalid_thread_id_returns_400(self) -> None:
|
def test_invalid_thread_id_returns_400(self) -> None:
|
||||||
app = _build_app()
|
app = _build_app()
|
||||||
app.state.pool = _make_mock_pool([])
|
app.state.pool = _make_mock_pool([])
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/replay/id%20with%20spaces")
|
resp = client.get("/api/v1/replay/id%20with%20spaces")
|
||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
|
|
||||||
def test_thread_id_special_chars_returns_400(self) -> None:
|
def test_thread_id_special_chars_returns_400(self) -> None:
|
||||||
@@ -172,5 +213,5 @@ class TestGetReplay:
|
|||||||
app.state.pool = _make_mock_pool([])
|
app.state.pool = _make_mock_pool([])
|
||||||
|
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/replay/id;DROP TABLE")
|
resp = client.get("/api/v1/replay/id;DROP TABLE")
|
||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
|
|||||||
@@ -153,3 +153,105 @@ class TestTransformCheckpoints:
|
|||||||
rows = [_make_row([{"type": "human", "content": "Hi"}])]
|
rows = [_make_row([{"type": "human", "content": "Hi"}])]
|
||||||
steps = transform_checkpoints(rows)
|
steps = transform_checkpoints(rows)
|
||||||
assert isinstance(steps[0].timestamp, str)
|
assert isinstance(steps[0].timestamp, str)
|
||||||
|
|
||||||
|
def test_list_content_joined_to_string(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "human",
|
||||||
|
"content": [
|
||||||
|
{"text": "Hello"},
|
||||||
|
{"text": " world"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].content == "Hello world"
|
||||||
|
|
||||||
|
def test_checkpoint_as_string_skipped(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"thread_id": "t1",
|
||||||
|
"checkpoint_id": "cp1",
|
||||||
|
"checkpoint": "not-a-dict",
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert steps == []
|
||||||
|
|
||||||
|
def test_channel_values_not_dict_skipped(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"thread_id": "t1",
|
||||||
|
"checkpoint_id": "cp1",
|
||||||
|
"checkpoint": {"channel_values": "bad"},
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert steps == []
|
||||||
|
|
||||||
|
def test_tool_result_valid_json_parsed(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"content": '{"order_id": "123", "status": "shipped"}',
|
||||||
|
"name": "get_order_status",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].result == {"order_id": "123", "status": "shipped"}
|
||||||
|
|
||||||
|
def test_tool_result_invalid_json_wrapped(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"content": "not valid json",
|
||||||
|
"name": "some_tool",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].result == {"raw": "not valid json"}
|
||||||
|
|
||||||
|
def test_malformed_message_skipped_gracefully(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{"type": "human", "content": "Good message"},
|
||||||
|
42, # not a dict -- will raise in _step_from_message
|
||||||
|
{"type": "ai", "content": "Response", "tool_calls": []},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
# The malformed message is skipped; the other two produce steps.
|
||||||
|
assert len(steps) == 2
|
||||||
|
assert steps[0].step == 1
|
||||||
|
assert steps[1].step == 2
|
||||||
|
|||||||
@@ -7,10 +7,41 @@ import pytest
|
|||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
def _isolated_settings(**kwargs: object) -> Settings:
|
||||||
|
"""Create a Settings instance that ignores .env files and process env vars.
|
||||||
|
|
||||||
|
pydantic-settings reads from env_file and environment by default, which
|
||||||
|
causes test results to depend on the machine they run on. We override
|
||||||
|
model_config at the class level temporarily so that every test gets
|
||||||
|
deterministic results.
|
||||||
|
"""
|
||||||
|
# Build a throwaway subclass that disables env-file and env-var loading.
|
||||||
|
class _IsolatedSettings(Settings):
|
||||||
|
model_config = Settings.model_config.copy()
|
||||||
|
model_config["env_file"] = None # type: ignore[assignment]
|
||||||
|
model_config["env_ignore_empty"] = True
|
||||||
|
|
||||||
|
# _env_parse_none_str makes pydantic-settings treat missing env vars as
|
||||||
|
# absent rather than empty-string, so required fields will raise.
|
||||||
|
import os
|
||||||
|
|
||||||
|
env_backup = os.environ.copy()
|
||||||
|
# Strip all env vars that Settings knows about so they can't leak in.
|
||||||
|
settings_fields = set(Settings.model_fields)
|
||||||
|
for key in list(os.environ):
|
||||||
|
if key.lower() in settings_fields:
|
||||||
|
del os.environ[key]
|
||||||
|
try:
|
||||||
|
return _IsolatedSettings(**kwargs) # type: ignore[return-value]
|
||||||
|
finally:
|
||||||
|
os.environ.clear()
|
||||||
|
os.environ.update(env_backup)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
class TestSettings:
|
class TestSettings:
|
||||||
def test_default_values(self) -> None:
|
def test_default_values(self) -> None:
|
||||||
settings = Settings(
|
settings = _isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
anthropic_api_key="key",
|
anthropic_api_key="key",
|
||||||
)
|
)
|
||||||
@@ -20,7 +51,7 @@ class TestSettings:
|
|||||||
assert settings.interrupt_ttl_minutes == 30
|
assert settings.interrupt_ttl_minutes == 30
|
||||||
|
|
||||||
def test_custom_values(self) -> None:
|
def test_custom_values(self) -> None:
|
||||||
settings = Settings(
|
settings = _isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
llm_provider="openai",
|
llm_provider="openai",
|
||||||
llm_model="gpt-4o",
|
llm_model="gpt-4o",
|
||||||
@@ -33,18 +64,18 @@ class TestSettings:
|
|||||||
|
|
||||||
def test_invalid_provider_rejected(self) -> None:
|
def test_invalid_provider_rejected(self) -> None:
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
Settings(
|
_isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
llm_provider="invalid",
|
llm_provider="invalid",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_missing_database_url_rejected(self) -> None:
|
def test_missing_database_url_rejected(self) -> None:
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
Settings(anthropic_api_key="key")
|
_isolated_settings(anthropic_api_key="key")
|
||||||
|
|
||||||
def test_empty_api_key_for_provider_rejected(self) -> None:
|
def test_empty_api_key_for_provider_rejected(self) -> None:
|
||||||
with pytest.raises(ValueError, match="API key"):
|
with pytest.raises(ValueError, match="API key"):
|
||||||
Settings(
|
_isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
llm_provider="anthropic",
|
llm_provider="anthropic",
|
||||||
anthropic_api_key="",
|
anthropic_api_key="",
|
||||||
@@ -52,9 +83,27 @@ class TestSettings:
|
|||||||
|
|
||||||
def test_wrong_provider_key_rejected(self) -> None:
|
def test_wrong_provider_key_rejected(self) -> None:
|
||||||
with pytest.raises(ValueError, match="API key"):
|
with pytest.raises(ValueError, match="API key"):
|
||||||
Settings(
|
_isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
llm_provider="openai",
|
llm_provider="openai",
|
||||||
anthropic_api_key="key",
|
anthropic_api_key="key",
|
||||||
openai_api_key="",
|
openai_api_key="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_azure_openai_missing_endpoint_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="AZURE_OPENAI_ENDPOINT"):
|
||||||
|
_isolated_settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="azure_openai",
|
||||||
|
azure_openai_api_key="key",
|
||||||
|
azure_openai_deployment="my-deploy",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_azure_openai_missing_deployment_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="AZURE_OPENAI_DEPLOYMENT"):
|
||||||
|
_isolated_settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="azure_openai",
|
||||||
|
azure_openai_api_key="key",
|
||||||
|
azure_openai_endpoint="https://example.openai.azure.com",
|
||||||
|
)
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class TestDbModule:
|
|||||||
from app.db import setup_app_tables
|
from app.db import setup_app_tables
|
||||||
|
|
||||||
await setup_app_tables(mock_pool)
|
await setup_app_tables(mock_pool)
|
||||||
assert mock_conn.execute.await_count == 4
|
assert mock_conn.execute.await_count == 5
|
||||||
|
|
||||||
def test_ddl_statements_valid(self) -> None:
|
def test_ddl_statements_valid(self) -> None:
|
||||||
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
||||||
|
|||||||
@@ -51,5 +51,5 @@ class TestAnalyticsEventsDDL:
|
|||||||
from app.db import setup_app_tables
|
from app.db import setup_app_tables
|
||||||
|
|
||||||
await setup_app_tables(mock_pool)
|
await setup_app_tables(mock_pool)
|
||||||
# Now expects 4 statements: conversations, interrupts, analytics_events, migrations
|
# Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations
|
||||||
assert mock_conn.execute.await_count == 4
|
assert mock_conn.execute.await_count == 5
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
@@ -20,7 +22,7 @@ def _make_ws() -> AsyncMock:
|
|||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
|
||||||
def _make_graph() -> AsyncMock:
|
def _make_graph() -> MagicMock:
|
||||||
graph = AsyncMock()
|
graph = AsyncMock()
|
||||||
|
|
||||||
class AsyncIterHelper:
|
class AsyncIterHelper:
|
||||||
@@ -34,23 +36,32 @@ def _make_graph() -> AsyncMock:
|
|||||||
state = MagicMock()
|
state = MagicMock()
|
||||||
state.tasks = ()
|
state.tasks = ()
|
||||||
graph.aget_state = AsyncMock(return_value=state)
|
graph.aget_state = AsyncMock(return_value=state)
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
|
||||||
|
graph = _make_graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
|
||||||
|
return WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx,
|
||||||
|
session_manager=sm or SessionManager(),
|
||||||
|
callback_handler=TokenUsageCallbackHandler(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
class TestEmptyMessageHandling:
|
class TestEmptyMessageHandling:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_message_content_returns_error(self) -> None:
|
async def test_empty_message_content_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -60,13 +71,12 @@ class TestEmptyMessageHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_whitespace_only_message_treated_as_empty(self) -> None:
|
async def test_whitespace_only_message_treated_as_empty(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -77,14 +87,13 @@ class TestOversizedMessageHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_content_over_10000_chars_returns_error(self) -> None:
|
async def test_content_over_10000_chars_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
content = "x" * 10001
|
content = "x" * 10001
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -93,14 +102,13 @@ class TestOversizedMessageHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
|
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
content = "x" * 10000
|
content = "x" * 10000
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
# Should be processed, not an error about length
|
# Should be processed, not an error about length
|
||||||
@@ -110,12 +118,10 @@ class TestOversizedMessageHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_raw_message_over_32kb_returns_error(self) -> None:
|
async def test_raw_message_over_32kb_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
large_msg = "x" * 40_000
|
large_msg = "x" * 40_000
|
||||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
await dispatch_message(ws, ws_ctx, large_msg)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -127,11 +133,9 @@ class TestInvalidJsonHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_json_returns_error(self) -> None:
|
async def test_invalid_json_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
await dispatch_message(ws, graph, sm, cb, "not valid json {{")
|
await dispatch_message(ws, ws_ctx, "not valid json {{")
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -140,11 +144,9 @@ class TestInvalidJsonHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_string_returns_json_error(self) -> None:
|
async def test_empty_string_returns_json_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
await dispatch_message(ws, graph, sm, cb, "")
|
await dispatch_message(ws, ws_ctx, "")
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -152,11 +154,9 @@ class TestInvalidJsonHandling:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_json_array_not_object_returns_error(self) -> None:
|
async def test_json_array_not_object_returns_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
await dispatch_message(ws, graph, sm, cb, '["not", "an", "object"]')
|
await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -167,17 +167,15 @@ class TestRateLimiting:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rapid_fire_messages_rate_limited(self) -> None:
|
async def test_rapid_fire_messages_rate_limited(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
_make_graph() # ensure graph factory works, not needed directly
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
|
|
||||||
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
|
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
|
||||||
rate_limit_triggered = False
|
rate_limit_triggered = False
|
||||||
for i in range(11):
|
for i in range(11):
|
||||||
graph2 = _make_graph() # fresh graph each time
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
await dispatch_message(ws, ws_ctx, json.dumps({
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"thread_id": "t1",
|
"thread_id": "t1",
|
||||||
"content": f"message {i}",
|
"content": f"message {i}",
|
||||||
@@ -193,19 +191,18 @@ class TestRateLimiting:
|
|||||||
async def test_different_threads_have_separate_rate_limits(self) -> None:
|
async def test_different_threads_have_separate_rate_limits(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
sm.touch("t2")
|
sm.touch("t2")
|
||||||
|
|
||||||
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
|
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
graph1 = _make_graph()
|
ws_ctx1 = _make_ws_ctx(sm=sm)
|
||||||
graph2 = _make_graph()
|
ws_ctx2 = _make_ws_ctx(sm=sm)
|
||||||
await dispatch_message(ws, graph1, sm, cb, json.dumps({
|
await dispatch_message(ws, ws_ctx1, json.dumps({
|
||||||
"type": "message", "thread_id": "t1", "content": f"msg {i}",
|
"type": "message", "thread_id": "t1", "content": f"msg {i}",
|
||||||
}))
|
}))
|
||||||
await dispatch_message(ws, graph2, sm, cb, json.dumps({
|
await dispatch_message(ws, ws_ctx2, json.dumps({
|
||||||
"type": "message", "thread_id": "t2", "content": f"msg {i}",
|
"type": "message", "thread_id": "t2", "content": f"msg {i}",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|||||||
142
backend/tests/unit/test_error_responses.py
Normal file
142
backend/tests/unit/test_error_responses.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""Tests for standardized error response envelope format."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _build_test_app() -> FastAPI:
|
||||||
|
"""Build a minimal FastAPI app with the standard exception handlers."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=envelope(None, success=False, error="Internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
class ItemRequest(BaseModel):
|
||||||
|
name: str = Field(..., min_length=1)
|
||||||
|
count: int = Field(..., gt=0)
|
||||||
|
|
||||||
|
@app.get("/items/{item_id}")
|
||||||
|
def get_item(item_id: int) -> dict:
|
||||||
|
if item_id == 0:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid item ID")
|
||||||
|
if item_id == 999:
|
||||||
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
|
if item_id == 401:
|
||||||
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
return envelope({"id": item_id, "name": "test"})
|
||||||
|
|
||||||
|
@app.post("/items")
|
||||||
|
def create_item(item: ItemRequest) -> dict:
|
||||||
|
return envelope({"id": 1, "name": item.name})
|
||||||
|
|
||||||
|
@app.get("/crash")
|
||||||
|
def crash() -> dict:
|
||||||
|
msg = "unexpected failure"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttpExceptionEnvelope:
|
||||||
|
"""HTTPException responses use the standard envelope format."""
|
||||||
|
|
||||||
|
def test_400_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.get("/items/0")
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] == "Invalid item ID"
|
||||||
|
|
||||||
|
def test_404_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.get("/items/999")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] == "Item not found"
|
||||||
|
|
||||||
|
def test_401_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.get("/items/401")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] == "Not authenticated"
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidationErrorEnvelope:
|
||||||
|
"""Validation errors return 422 with envelope format."""
|
||||||
|
|
||||||
|
def test_validation_error_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.post("/items", json={"name": "", "count": -1})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert isinstance(body["error"], str)
|
||||||
|
assert len(body["error"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeneralExceptionEnvelope:
|
||||||
|
"""Unhandled exceptions return 500 with safe envelope."""
|
||||||
|
|
||||||
|
def test_unhandled_exception_returns_500_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.get("/crash")
|
||||||
|
assert resp.status_code == 500
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] == "Internal server error"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSuccessResponseUnchanged:
|
||||||
|
"""Success responses still work normally."""
|
||||||
|
|
||||||
|
def test_success_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/items/42")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["data"]["id"] == 42
|
||||||
|
assert body["error"] is None
|
||||||
@@ -6,8 +6,10 @@ from typing import TYPE_CHECKING
|
|||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
from app.graph import build_agent_nodes, build_graph, classify_intent
|
from app.graph import build_agent_nodes, build_graph
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.intent import ClassificationResult, IntentTarget
|
from app.intent import ClassificationResult, IntentTarget
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -34,41 +36,43 @@ class TestBuildGraph:
|
|||||||
mock_llm = MagicMock()
|
mock_llm = MagicMock()
|
||||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||||
mock_checkpointer = AsyncMock()
|
checkpointer = InMemorySaver()
|
||||||
|
|
||||||
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
|
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||||
assert graph is not None
|
assert graph_ctx is not None
|
||||||
|
assert graph_ctx.graph is not None
|
||||||
|
|
||||||
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
|
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
|
||||||
mock_llm = MagicMock()
|
mock_llm = MagicMock()
|
||||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||||
mock_checkpointer = AsyncMock()
|
checkpointer = InMemorySaver()
|
||||||
mock_classifier = MagicMock()
|
mock_classifier = MagicMock()
|
||||||
|
|
||||||
graph = build_graph(
|
graph_ctx = build_graph(
|
||||||
sample_registry, mock_llm, mock_checkpointer, intent_classifier=mock_classifier
|
sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
|
||||||
)
|
)
|
||||||
assert graph.intent_classifier is mock_classifier
|
assert graph_ctx.intent_classifier is mock_classifier
|
||||||
assert graph.agent_registry is sample_registry
|
assert graph_ctx.registry is sample_registry
|
||||||
|
|
||||||
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
|
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
|
||||||
mock_llm = MagicMock()
|
mock_llm = MagicMock()
|
||||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||||
mock_checkpointer = AsyncMock()
|
checkpointer = InMemorySaver()
|
||||||
|
|
||||||
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
|
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||||
assert graph.intent_classifier is None
|
assert graph_ctx.intent_classifier is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
class TestClassifyIntent:
|
class TestClassifyIntent:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_none_without_classifier(self) -> None:
|
async def test_returns_none_without_classifier(self) -> None:
|
||||||
graph = MagicMock()
|
mock_registry = MagicMock()
|
||||||
graph.intent_classifier = None
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
result = await classify_intent(graph, "hello")
|
graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None)
|
||||||
|
result = await graph_ctx.classify_intent("hello")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -79,11 +83,12 @@ class TestClassifyIntent:
|
|||||||
mock_classifier = AsyncMock()
|
mock_classifier = AsyncMock()
|
||||||
mock_classifier.classify = AsyncMock(return_value=expected)
|
mock_classifier.classify = AsyncMock(return_value=expected)
|
||||||
|
|
||||||
graph = MagicMock()
|
mock_registry = MagicMock()
|
||||||
graph.intent_classifier = mock_classifier
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = MagicMock()
|
graph_ctx = GraphContext(
|
||||||
graph.agent_registry.list_agents = MagicMock(return_value=())
|
graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
|
|
||||||
result = await classify_intent(graph, "check order")
|
result = await graph_ctx.classify_intent("check order")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.intents[0].agent_name == "order_lookup"
|
assert result.intents[0].agent_name == "order_lookup"
|
||||||
|
|||||||
86
backend/tests/unit/test_interrupt_cleanup.py
Normal file
86
backend/tests/unit/test_interrupt_cleanup.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""Tests for the interrupt cleanup background loop in main.py."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.main import _interrupt_cleanup_loop
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_loop_calls_cleanup_expired() -> None:
|
||||||
|
"""The loop should call cleanup_expired after each sleep interval."""
|
||||||
|
manager = MagicMock()
|
||||||
|
manager.cleanup_expired.return_value = ()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
original_sleep = asyncio.sleep
|
||||||
|
|
||||||
|
async def _fake_sleep(seconds: float) -> None:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count >= 2:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
await original_sleep(0)
|
||||||
|
|
||||||
|
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await _interrupt_cleanup_loop(manager, interval=60)
|
||||||
|
|
||||||
|
assert manager.cleanup_expired.call_count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_loop_survives_exceptions() -> None:
|
||||||
|
"""The loop should not die when cleanup_expired raises an exception."""
|
||||||
|
manager = MagicMock()
|
||||||
|
manager.cleanup_expired.side_effect = [RuntimeError("db gone"), ()]
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
original_sleep = asyncio.sleep
|
||||||
|
|
||||||
|
async def _fake_sleep(seconds: float) -> None:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count >= 3:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
await original_sleep(0)
|
||||||
|
|
||||||
|
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await _interrupt_cleanup_loop(manager, interval=60)
|
||||||
|
|
||||||
|
# Should have been called twice: once raising, once returning ()
|
||||||
|
assert manager.cleanup_expired.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_loop_logs_expired_count(capsys: pytest.CaptureFixture[str]) -> None:
|
||||||
|
"""The loop should log when expired interrupts are found."""
|
||||||
|
fake_record = MagicMock()
|
||||||
|
manager = MagicMock()
|
||||||
|
manager.cleanup_expired.return_value = (fake_record, fake_record)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
original_sleep = asyncio.sleep
|
||||||
|
|
||||||
|
async def _fake_sleep(seconds: float) -> None:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count >= 2:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
await original_sleep(0)
|
||||||
|
|
||||||
|
with patch("app.main.asyncio.sleep", side_effect=_fake_sleep):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await _interrupt_cleanup_loop(manager, interval=60)
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "2 expired interrupt" in captured.out
|
||||||
20
backend/tests/unit/test_logging_config.py
Normal file
20
backend/tests/unit/test_logging_config.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""Tests for structured logging configuration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.logging_config import configure_logging
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_logging_console_mode() -> None:
|
||||||
|
"""Console mode configures without error."""
|
||||||
|
configure_logging("console")
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_logging_json_mode() -> None:
|
||||||
|
"""JSON mode configures without error."""
|
||||||
|
configure_logging("json")
|
||||||
@@ -13,7 +13,7 @@ class TestMainModule:
|
|||||||
assert app.title == "Smart Support"
|
assert app.title == "Smart Support"
|
||||||
|
|
||||||
def test_app_version(self) -> None:
|
def test_app_version(self) -> None:
|
||||||
assert app.version == "0.5.0"
|
assert app.version == "0.6.0"
|
||||||
|
|
||||||
def test_agents_yaml_path_exists(self) -> None:
|
def test_agents_yaml_path_exists(self) -> None:
|
||||||
assert AGENTS_YAML.name == "agents.yaml"
|
assert AGENTS_YAML.name == "agents.yaml"
|
||||||
@@ -36,7 +36,7 @@ class TestMainModule:
|
|||||||
|
|
||||||
def test_health_route_registered(self) -> None:
|
def test_health_route_registered(self) -> None:
|
||||||
routes = [r.path for r in app.routes if hasattr(r, "path")]
|
routes = [r.path for r in app.routes if hasattr(r, "path")]
|
||||||
assert "/api/health" in routes
|
assert "/api/v1/health" in routes
|
||||||
|
|
||||||
def test_app_version_is_0_5_0(self) -> None:
|
def test_app_version_is_0_5_0(self) -> None:
|
||||||
assert app.version == "0.5.0"
|
assert app.version == "0.6.0"
|
||||||
|
|||||||
96
backend/tests/unit/test_safety.py
Normal file
96
backend/tests/unit/test_safety.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""Tests for app.safety module -- confirmation rules and MCP error taxonomy."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.safety import (
|
||||||
|
classify_mcp_error,
|
||||||
|
is_retryable,
|
||||||
|
max_retries,
|
||||||
|
requires_confirmation,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequiresConfirmation:
|
||||||
|
def test_read_agent_no_override(self) -> None:
|
||||||
|
result = requires_confirmation(agent_permission="read")
|
||||||
|
assert result.requires_confirmation is False
|
||||||
|
|
||||||
|
def test_write_agent_no_override(self) -> None:
|
||||||
|
result = requires_confirmation(agent_permission="write")
|
||||||
|
assert result.requires_confirmation is True
|
||||||
|
|
||||||
|
def test_interrupt_override_true(self) -> None:
|
||||||
|
result = requires_confirmation(
|
||||||
|
agent_permission="read", needs_interrupt=True,
|
||||||
|
)
|
||||||
|
assert result.requires_confirmation is True
|
||||||
|
|
||||||
|
def test_interrupt_override_false(self) -> None:
|
||||||
|
result = requires_confirmation(
|
||||||
|
agent_permission="write", needs_interrupt=False,
|
||||||
|
)
|
||||||
|
assert result.requires_confirmation is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassifyMcpError:
|
||||||
|
@pytest.mark.parametrize("code", [408, 429, 500, 502, 503, 504])
|
||||||
|
def test_transient_status_codes(self, code: int) -> None:
|
||||||
|
assert classify_mcp_error(status_code=code) == "transient"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("code", [401, 403])
|
||||||
|
def test_auth_status_codes(self, code: int) -> None:
|
||||||
|
assert classify_mcp_error(status_code=code) == "auth"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("code", [400, 404, 422])
|
||||||
|
def test_validation_status_codes(self, code: int) -> None:
|
||||||
|
assert classify_mcp_error(status_code=code) == "validation"
|
||||||
|
|
||||||
|
def test_unknown_status_code(self) -> None:
|
||||||
|
assert classify_mcp_error(status_code=200) == "unknown"
|
||||||
|
|
||||||
|
def test_timeout_message(self) -> None:
|
||||||
|
assert classify_mcp_error(error_message="Connection timed out") == "transient"
|
||||||
|
|
||||||
|
def test_rate_limit_message(self) -> None:
|
||||||
|
assert classify_mcp_error(error_message="Rate limit exceeded") == "transient"
|
||||||
|
|
||||||
|
def test_unauthorized_message(self) -> None:
|
||||||
|
assert classify_mcp_error(error_message="Unauthorized access") == "auth"
|
||||||
|
|
||||||
|
def test_invalid_message(self) -> None:
|
||||||
|
assert classify_mcp_error(error_message="Invalid parameter") == "validation"
|
||||||
|
|
||||||
|
def test_unknown_message(self) -> None:
|
||||||
|
assert classify_mcp_error(error_message="Something happened") == "unknown"
|
||||||
|
|
||||||
|
def test_status_code_takes_precedence_over_message(self) -> None:
|
||||||
|
# 429 is transient by code; message would classify as validation
|
||||||
|
assert classify_mcp_error(status_code=429, error_message="invalid param") == "transient"
|
||||||
|
|
||||||
|
def test_non_classified_status_falls_through_to_message(self) -> None:
|
||||||
|
# 200 is not in any status set, so message classification takes over
|
||||||
|
assert classify_mcp_error(status_code=200, error_message="timed out") == "transient"
|
||||||
|
|
||||||
|
def test_no_args_returns_unknown(self) -> None:
|
||||||
|
assert classify_mcp_error() == "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetryPolicy:
|
||||||
|
def test_transient_is_retryable(self) -> None:
|
||||||
|
assert is_retryable("transient") is True
|
||||||
|
|
||||||
|
def test_validation_not_retryable(self) -> None:
|
||||||
|
assert is_retryable("validation") is False
|
||||||
|
|
||||||
|
def test_auth_not_retryable(self) -> None:
|
||||||
|
assert is_retryable("auth") is False
|
||||||
|
|
||||||
|
def test_unknown_not_retryable(self) -> None:
|
||||||
|
assert is_retryable("unknown") is False
|
||||||
|
|
||||||
|
def test_max_retries_value(self) -> None:
|
||||||
|
assert max_retries() == 3
|
||||||
@@ -8,8 +8,10 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
from app.interrupt_manager import InterruptManager
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import (
|
from app.ws_handler import (
|
||||||
_extract_interrupt,
|
_extract_interrupt,
|
||||||
_has_interrupt,
|
_has_interrupt,
|
||||||
@@ -25,18 +27,42 @@ def _make_ws() -> AsyncMock:
|
|||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
|
||||||
def _make_graph() -> AsyncMock:
|
def _make_graph() -> MagicMock:
|
||||||
graph = AsyncMock()
|
graph = AsyncMock()
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
state = MagicMock()
|
state = MagicMock()
|
||||||
state.tasks = ()
|
state.tasks = ()
|
||||||
graph.aget_state = AsyncMock(return_value=state)
|
graph.aget_state = AsyncMock(return_value=state)
|
||||||
# Phase 2: graph needs intent_classifier and agent_registry attrs
|
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||||
|
g = graph or _make_graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws_ctx(
|
||||||
|
graph_ctx: GraphContext | None = None,
|
||||||
|
sm: SessionManager | None = None,
|
||||||
|
cb: TokenUsageCallbackHandler | None = None,
|
||||||
|
interrupt_manager: InterruptManager | None = None,
|
||||||
|
analytics_recorder=None,
|
||||||
|
conversation_tracker=None,
|
||||||
|
pool=None,
|
||||||
|
) -> WebSocketContext:
|
||||||
|
return WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx or _make_graph_ctx(),
|
||||||
|
session_manager=sm or SessionManager(),
|
||||||
|
callback_handler=cb or TokenUsageCallbackHandler(),
|
||||||
|
interrupt_manager=interrupt_manager,
|
||||||
|
analytics_recorder=analytics_recorder,
|
||||||
|
conversation_tracker=conversation_tracker,
|
||||||
|
pool=pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncIterHelper:
|
class AsyncIterHelper:
|
||||||
"""Helper to make a list behave as an async iterator."""
|
"""Helper to make a list behave as an async iterator."""
|
||||||
|
|
||||||
@@ -57,11 +83,9 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_json(self) -> None:
|
async def test_invalid_json(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
await dispatch_message(ws, graph, sm, cb, "not json")
|
await dispatch_message(ws, ws_ctx, "not json")
|
||||||
ws.send_json.assert_awaited_once()
|
ws.send_json.assert_awaited_once()
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
@@ -70,12 +94,10 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_thread_id(self) -> None:
|
async def test_missing_thread_id(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "message", "content": "hello"})
|
msg = json.dumps({"type": "message", "content": "hello"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "thread_id" in call_data["message"]
|
assert "thread_id" in call_data["message"]
|
||||||
@@ -83,24 +105,20 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_content(self) -> None:
|
async def test_missing_content(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
msg = json.dumps({"type": "message", "thread_id": "t1"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_message_type(self) -> None:
|
async def test_unknown_message_type(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
msg = json.dumps({"type": "unknown", "thread_id": "t1"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "Unknown" in call_data["message"]
|
assert "Unknown" in call_data["message"]
|
||||||
@@ -108,12 +126,10 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_too_large(self) -> None:
|
async def test_message_too_large(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
large_msg = "x" * 40_000
|
large_msg = "x" * 40_000
|
||||||
await dispatch_message(ws, graph, sm, cb, large_msg)
|
await dispatch_message(ws, ws_ctx, large_msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "too large" in call_data["message"].lower()
|
assert "too large" in call_data["message"].lower()
|
||||||
@@ -121,12 +137,10 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_thread_id_format(self) -> None:
|
async def test_invalid_thread_id_format(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
msg = json.dumps({"type": "message", "thread_id": "../../../etc", "content": "hi"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "thread_id" in call_data["message"].lower()
|
assert "thread_id" in call_data["message"].lower()
|
||||||
@@ -134,12 +148,10 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_content_too_long(self) -> None:
|
async def test_content_too_long(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
ws_ctx = _make_ws_ctx()
|
||||||
sm = SessionManager()
|
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
|
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "too long" in call_data["message"].lower()
|
assert "too long" in call_data["message"].lower()
|
||||||
@@ -147,14 +159,13 @@ class TestDispatchMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_dispatch_with_interrupt_manager(self) -> None:
|
async def test_dispatch_with_interrupt_manager(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm, interrupt_manager=im)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
await dispatch_message(ws, graph, sm, cb, msg, interrupt_manager=im)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
@@ -164,11 +175,14 @@ class TestHandleUserMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_expired_session(self) -> None:
|
async def test_expired_session(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager(session_ttl_seconds=0)
|
sm = SessionManager(session_ttl_seconds=0)
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
# First call creates the session (TTL=0)
|
||||||
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||||
|
# Second call finds it expired
|
||||||
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello again")
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
assert "expired" in call_data["message"].lower()
|
assert "expired" in call_data["message"].lower()
|
||||||
@@ -176,12 +190,12 @@ class TestHandleUserMessage:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_successful_message(self) -> None:
|
async def test_successful_message(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
@@ -190,13 +204,12 @@ class TestHandleUserMessage:
|
|||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = AsyncMock()
|
graph = AsyncMock()
|
||||||
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
graph.astream = MagicMock(side_effect=RuntimeError("boom"))
|
||||||
graph.intent_classifier = None
|
graph_ctx = _make_graph_ctx(graph=graph)
|
||||||
graph.agent_registry = None
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
await handle_user_message(ws, graph, sm, cb, "t1", "hello")
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hello")
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
assert call_data["type"] == "error"
|
assert call_data["type"] == "error"
|
||||||
|
|
||||||
@@ -204,8 +217,6 @@ class TestHandleUserMessage:
|
|||||||
async def test_interrupt_registered_with_manager(self) -> None:
|
async def test_interrupt_registered_with_manager(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = AsyncMock()
|
graph = AsyncMock()
|
||||||
graph.intent_classifier = None
|
|
||||||
graph.agent_registry = None
|
|
||||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
|
|
||||||
# Simulate interrupt in state
|
# Simulate interrupt in state
|
||||||
@@ -217,13 +228,14 @@ class TestHandleUserMessage:
|
|||||||
state.tasks = (task,)
|
state.tasks = (task,)
|
||||||
graph.aget_state = AsyncMock(return_value=state)
|
graph.aget_state = AsyncMock(return_value=state)
|
||||||
|
|
||||||
|
graph_ctx = _make_graph_ctx(graph=graph)
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
im = InterruptManager()
|
im = InterruptManager()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
await handle_user_message(
|
await handle_user_message(
|
||||||
ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
|
ws, graph_ctx, sm, cb, "t1", "cancel order 1042", interrupt_manager=im,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Interrupt should be registered
|
# Interrupt should be registered
|
||||||
@@ -254,16 +266,17 @@ class TestHandleUserMessage:
|
|||||||
clarification_question="What do you mean?",
|
clarification_question="What do you mean?",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
graph.intent_classifier = mock_classifier
|
|
||||||
mock_registry = MagicMock()
|
mock_registry = MagicMock()
|
||||||
mock_registry.list_agents = MagicMock(return_value=())
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
graph.agent_registry = mock_registry
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
await handle_user_message(ws, graph, sm, cb, "t1", "hmm")
|
await handle_user_message(ws, graph_ctx, sm, cb, "t1", "hmm")
|
||||||
|
|
||||||
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
calls = [c[0][0] for c in ws.send_json.call_args_list]
|
||||||
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
|
clarification_msgs = [c for c in calls if c.get("type") == "clarification"]
|
||||||
@@ -276,13 +289,13 @@ class TestHandleInterruptResponse:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_approved_interrupt(self) -> None:
|
async def test_approved_interrupt(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
sm.extend_for_interrupt("t1")
|
sm.extend_for_interrupt("t1")
|
||||||
await handle_interrupt_response(ws, graph, sm, cb, "t1", True)
|
await handle_interrupt_response(ws, graph_ctx, sm, cb, "t1", True)
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
@@ -291,7 +304,7 @@ class TestHandleInterruptResponse:
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
im = InterruptManager(ttl_seconds=5)
|
im = InterruptManager(ttl_seconds=5)
|
||||||
@@ -304,7 +317,7 @@ class TestHandleInterruptResponse:
|
|||||||
with patch("app.interrupt_manager.time") as mock_time:
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
|
mock_time.time.return_value = im._interrupts["t1"].created_at + 10
|
||||||
await handle_interrupt_response(
|
await handle_interrupt_response(
|
||||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
|
||||||
)
|
)
|
||||||
|
|
||||||
call_data = ws.send_json.call_args[0][0]
|
call_data = ws.send_json.call_args[0][0]
|
||||||
@@ -314,7 +327,7 @@ class TestHandleInterruptResponse:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_interrupt_resolves(self) -> None:
|
async def test_valid_interrupt_resolves(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
graph_ctx = _make_graph_ctx()
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
cb = TokenUsageCallbackHandler()
|
||||||
im = InterruptManager(ttl_seconds=1800)
|
im = InterruptManager(ttl_seconds=1800)
|
||||||
@@ -324,7 +337,7 @@ class TestHandleInterruptResponse:
|
|||||||
im.register("t1", "cancel_order", {})
|
im.register("t1", "cancel_order", {})
|
||||||
|
|
||||||
await handle_interrupt_response(
|
await handle_interrupt_response(
|
||||||
ws, graph, sm, cb, "t1", True, interrupt_manager=im
|
ws, graph_ctx, sm, cb, "t1", True, interrupt_manager=im
|
||||||
)
|
)
|
||||||
|
|
||||||
# Interrupt should be resolved
|
# Interrupt should be resolved
|
||||||
@@ -371,19 +384,14 @@ class TestDispatchMessageWithTracking:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_conversation_tracker_called_on_message(self) -> None:
|
async def test_conversation_tracker_called_on_message(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
tracker = AsyncMock()
|
tracker = AsyncMock()
|
||||||
pool = MagicMock()
|
pool = MagicMock()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
await dispatch_message(
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
ws, graph, sm, cb, msg,
|
|
||||||
conversation_tracker=tracker,
|
|
||||||
pool=pool,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
|
tracker.ensure_conversation.assert_awaited_once_with(pool, "t1")
|
||||||
tracker.record_turn.assert_awaited_once()
|
tracker.record_turn.assert_awaited_once()
|
||||||
@@ -391,53 +399,42 @@ class TestDispatchMessageWithTracking:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_analytics_recorder_called_on_message(self) -> None:
|
async def test_analytics_recorder_called_on_message(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
recorder = AsyncMock()
|
recorder = AsyncMock()
|
||||||
pool = MagicMock()
|
pool = MagicMock()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm, analytics_recorder=recorder, pool=pool)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
await dispatch_message(
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
ws, graph, sm, cb, msg,
|
|
||||||
analytics_recorder=recorder,
|
|
||||||
pool=pool,
|
|
||||||
)
|
|
||||||
|
|
||||||
recorder.record.assert_awaited_once()
|
recorder.record.assert_awaited_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tracker_failure_does_not_break_chat(self) -> None:
|
async def test_tracker_failure_does_not_break_chat(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
|
||||||
tracker = AsyncMock()
|
tracker = AsyncMock()
|
||||||
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
|
tracker.ensure_conversation.side_effect = RuntimeError("DB down")
|
||||||
pool = MagicMock()
|
pool = MagicMock()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm, conversation_tracker=tracker, pool=pool)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
# Should not raise despite tracker failure
|
# Should not raise despite tracker failure
|
||||||
await dispatch_message(
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
ws, graph, sm, cb, msg,
|
|
||||||
conversation_tracker=tracker,
|
|
||||||
pool=pool,
|
|
||||||
)
|
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_tracker_no_error(self) -> None:
|
async def test_no_tracker_no_error(self) -> None:
|
||||||
ws = _make_ws()
|
ws = _make_ws()
|
||||||
graph = _make_graph()
|
|
||||||
sm = SessionManager()
|
sm = SessionManager()
|
||||||
cb = TokenUsageCallbackHandler()
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
sm.touch("t1")
|
sm.touch("t1")
|
||||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": "hello"})
|
||||||
# No tracker or recorder passed -- should work fine
|
# No tracker or recorder passed -- should work fine
|
||||||
await dispatch_message(ws, graph, sm, cb, msg)
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
last_call = ws.send_json.call_args[0][0]
|
last_call = ws.send_json.call_args[0][0]
|
||||||
assert last_call["type"] == "message_complete"
|
assert last_call["type"] == "message_complete"
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ services:
|
|||||||
POSTGRES_DB: smart_support
|
POSTGRES_DB: smart_support
|
||||||
POSTGRES_USER: smart_support
|
POSTGRES_USER: smart_support
|
||||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-dev_password}
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-dev_password}
|
||||||
# ports: ["5432:5432"] # Uncomment for local dev DB access only
|
ports: ["5433:5432"] # Local dev: expose on 5433 to match backend/.env
|
||||||
volumes:
|
volumes:
|
||||||
- pgdata:/var/lib/postgresql/data
|
- pgdata:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
@@ -28,6 +28,10 @@ services:
|
|||||||
LLM_MODEL: ${LLM_MODEL:-claude-sonnet-4-6}
|
LLM_MODEL: ${LLM_MODEL:-claude-sonnet-4-6}
|
||||||
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-}
|
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-}
|
||||||
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
||||||
|
AZURE_OPENAI_API_KEY: ${AZURE_OPENAI_API_KEY:-}
|
||||||
|
AZURE_OPENAI_ENDPOINT: ${AZURE_OPENAI_ENDPOINT:-}
|
||||||
|
AZURE_OPENAI_DEPLOYMENT: ${AZURE_OPENAI_DEPLOYMENT:-}
|
||||||
|
AZURE_OPENAI_API_VERSION: ${AZURE_OPENAI_API_VERSION:-2024-12-01-preview}
|
||||||
GOOGLE_API_KEY: ${GOOGLE_API_KEY:-}
|
GOOGLE_API_KEY: ${GOOGLE_API_KEY:-}
|
||||||
WEBHOOK_URL: ${WEBHOOK_URL:-}
|
WEBHOOK_URL: ${WEBHOOK_URL:-}
|
||||||
SESSION_TTL_MINUTES: ${SESSION_TTL_MINUTES:-30}
|
SESSION_TTL_MINUTES: ${SESSION_TTL_MINUTES:-30}
|
||||||
@@ -37,7 +41,7 @@ services:
|
|||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "curl -f http://localhost:8000/api/health || exit 1"]
|
test: ["CMD-SHELL", "curl -f http://localhost:8000/api/v1/health || exit 1"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ Smart Support 通过 MCP 协议连接内部系统,将自动化率提升到 60%+
|
|||||||
v
|
v
|
||||||
+--------+--------------------+
|
+--------+--------------------+
|
||||||
| LangGraph Supervisor | (Agent 编排 + 意图路由)
|
| LangGraph Supervisor | (Agent 编排 + 意图路由)
|
||||||
| langgraph-supervisor v1.1 |
|
| langgraph-supervisor 0.0.30+|
|
||||||
+--------+--------------------+
|
+--------+--------------------+
|
||||||
|
|
|
|
||||||
+----+----+----+----+
|
+----+----+----+----+
|
||||||
@@ -99,7 +99,12 @@ smart-support/
|
|||||||
├── backend/
|
├── backend/
|
||||||
│ ├── app/
|
│ ├── app/
|
||||||
│ │ ├── main.py # FastAPI + WebSocket 入口
|
│ │ ├── main.py # FastAPI + WebSocket 入口
|
||||||
│ │ ├── graph.py # LangGraph Supervisor 配置
|
│ │ ├── graph.py # LangGraph Supervisor 构建
|
||||||
|
│ │ ├── graph_context.py # GraphContext: 图 + 分类器 + 注册表的类型化封装
|
||||||
|
│ │ ├── ws_handler.py # WebSocket 消息分发 + 速率限制
|
||||||
|
│ │ ├── ws_context.py # WebSocketContext: WS 依赖包
|
||||||
|
│ │ ├── auth.py # API Key 认证中间件
|
||||||
|
│ │ ├── api_utils.py # 共享 API 响应工具 (envelope)
|
||||||
│ │ ├── agents/ # Agent 定义 + 工具绑定
|
│ │ ├── agents/ # Agent 定义 + 工具绑定
|
||||||
│ │ ├── registry.py # YAML Agent 注册表加载器
|
│ │ ├── registry.py # YAML Agent 注册表加载器
|
||||||
│ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成
|
│ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成
|
||||||
@@ -139,7 +144,11 @@ smart-support/
|
|||||||
| 模块 | 职责 |
|
| 模块 | 职责 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| main.py | 应用入口, WebSocket 端点, 静态文件服务 |
|
| main.py | 应用入口, WebSocket 端点, 静态文件服务 |
|
||||||
| WebSocket Handler | 双向通信: 接收用户消息, 流式返回 token, 处理 interrupt 响应 |
|
| auth.py | API Key 认证: 管理端点通过 `X-API-Key` header, WebSocket 通过 `?token=` query param |
|
||||||
|
| ws_handler.py | 双向通信: 接收用户消息, 流式返回 token, 处理 interrupt 响应 |
|
||||||
|
| graph_context.py | 类型化封装: 将编译后的图与分类器、注册表绑定, 替代猴子补丁 |
|
||||||
|
| ws_context.py | 依赖包: 将 WebSocket 处理所需的 9 个依赖打包为单一不可变对象 |
|
||||||
|
| api_utils.py | 共享响应格式: 统一的 `envelope()` 函数 |
|
||||||
|
|
||||||
### 2.3 Agent 编排层 (LangGraph)
|
### 2.3 Agent 编排层 (LangGraph)
|
||||||
|
|
||||||
@@ -315,7 +324,7 @@ Agent 调用写操作工具
|
|||||||
|------|------|------|
|
|------|------|------|
|
||||||
| 语言 | Python 3.11+ | LangGraph/LangChain 生态首选语言, Agent 开发最成熟 |
|
| 语言 | Python 3.11+ | LangGraph/LangChain 生态首选语言, Agent 开发最成熟 |
|
||||||
| Web 框架 | FastAPI | 原生 async, WebSocket 支持, 性能优秀 |
|
| Web 框架 | FastAPI | 原生 async, WebSocket 支持, 性能优秀 |
|
||||||
| Agent 编排 | LangGraph v1.1 + langgraph-supervisor | 内置 supervisor 模式, 中间件支持, 不重复造轮子 |
|
| Agent 编排 | LangGraph 1.x + langgraph-supervisor | 内置 supervisor 模式, 中间件支持, 不重复造轮子 |
|
||||||
| MCP 集成 | langchain-mcp-adapters | MultiServerMCPClient 管理多 MCP 连接 |
|
| MCP 集成 | langchain-mcp-adapters | MultiServerMCPClient 管理多 MCP 连接 |
|
||||||
| 本地工具 | LangChain @tool 装饰器 | 简单 Python 函数即工具, 无 MCP 开销 |
|
| 本地工具 | LangChain @tool 装饰器 | 简单 Python 函数即工具, 无 MCP 开销 |
|
||||||
| 状态持久化 | PostgresSaver (langgraph-checkpoint-postgres v3.0.5) | 从第一天起用 PostgreSQL, 支持回放/分析查询 |
|
| 状态持久化 | PostgresSaver (langgraph-checkpoint-postgres v3.0.5) | 从第一天起用 PostgreSQL, 支持回放/分析查询 |
|
||||||
@@ -427,6 +436,19 @@ CREATE INDEX idx_interrupts_ttl ON interrupts(ttl_expires_at)
|
|||||||
WHERE status = 'pending';
|
WHERE status = 'pending';
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### sessions (自定义 - 会话状态持久化)
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- 用于多 worker 部署的 PostgreSQL 会话状态管理
|
||||||
|
-- PgSessionManager 使用此表替代内存中的 dict
|
||||||
|
CREATE TABLE sessions (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
#### analytics_events (自定义 - 分析事件流)
|
#### analytics_events (自定义 - 分析事件流)
|
||||||
|
|
||||||
```sql
|
```sql
|
||||||
|
|||||||
@@ -9,33 +9,38 @@ specialist with a specific role, permission level, and set of tools it can call.
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
agents:
|
agents:
|
||||||
- name: order_agent
|
- name: order_lookup
|
||||||
description: "Handles order status, tracking, and cancellations."
|
description: "Looks up order status and tracking information."
|
||||||
permission: write
|
permission: read
|
||||||
tools:
|
tools:
|
||||||
- get_order_status
|
- get_order_status
|
||||||
- cancel_order
|
- get_tracking_info
|
||||||
personality:
|
personality:
|
||||||
tone: friendly
|
tone: "friendly and informative"
|
||||||
greeting: "I can help with your order. What is your order number?"
|
greeting: "I can help you check your order status!"
|
||||||
escalation_message: "I'm escalating this to a human agent now."
|
escalation_message: "Let me connect you with our support team."
|
||||||
|
|
||||||
- name: refund_agent
|
- name: order_actions
|
||||||
description: "Processes refund requests."
|
description: "Performs order modifications like cancellations."
|
||||||
permission: write
|
permission: write
|
||||||
tools:
|
tools:
|
||||||
- process_refund
|
- cancel_order
|
||||||
- check_refund_eligibility
|
personality:
|
||||||
personality:
|
tone: "careful and reassuring"
|
||||||
tone: empathetic
|
greeting: "I can help you with order changes."
|
||||||
greeting: "I'm the refund specialist. How can I help?"
|
escalation_message: "I'll connect you with a specialist."
|
||||||
escalation_message: "I need to escalate this refund request."
|
|
||||||
|
- name: discount
|
||||||
- name: general_agent
|
description: "Applies discounts and generates coupon codes."
|
||||||
description: "Answers general questions and FAQs."
|
permission: write
|
||||||
|
tools:
|
||||||
|
- apply_discount
|
||||||
|
- generate_coupon
|
||||||
|
|
||||||
|
- name: fallback
|
||||||
|
description: "Handles general questions and unclear requests."
|
||||||
permission: read
|
permission: read
|
||||||
tools:
|
tools:
|
||||||
- search_faq
|
|
||||||
- fallback_respond
|
- fallback_respond
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -52,31 +57,38 @@ user messages to the right agent. Be specific.
|
|||||||
Controls the interrupt threshold:
|
Controls the interrupt threshold:
|
||||||
- `read` -- no interrupt required. Agent can act immediately.
|
- `read` -- no interrupt required. Agent can act immediately.
|
||||||
- `write` -- requires human approval via interrupt before executing tools.
|
- `write` -- requires human approval via interrupt before executing tools.
|
||||||
- `admin` -- requires human approval and is logged for audit.
|
|
||||||
|
|
||||||
### `tools` (required)
|
### `tools` (required)
|
||||||
List of tool names this agent can use. Tools are registered in the agent factory.
|
List of tool names this agent can use. Tools are registered in `backend/app/agents/`.
|
||||||
Each tool name must match a registered LangChain tool.
|
Each tool name must match a registered LangChain tool.
|
||||||
|
|
||||||
|
Available built-in tools:
|
||||||
|
- `get_order_status` -- look up order details
|
||||||
|
- `get_tracking_info` -- get shipping/tracking info
|
||||||
|
- `cancel_order` -- cancel an order (write)
|
||||||
|
- `apply_discount` -- apply a discount code (write)
|
||||||
|
- `generate_coupon` -- generate a new coupon (write)
|
||||||
|
- `fallback_respond` -- generic fallback response
|
||||||
|
|
||||||
### `personality` (optional)
|
### `personality` (optional)
|
||||||
Customizes agent behavior:
|
Customizes agent behavior:
|
||||||
- `tone` -- `friendly`, `formal`, `empathetic`, `technical`
|
- `tone` -- free-text description of the agent's communication style
|
||||||
- `greeting` -- Opening message injected at session start.
|
- `greeting` -- opening message injected at session start
|
||||||
- `escalation_message` -- Message sent when the agent escalates.
|
- `escalation_message` -- message sent when the agent escalates
|
||||||
|
|
||||||
## Built-in Templates
|
## Built-in Templates
|
||||||
|
|
||||||
Use `TEMPLATE_NAME` environment variable to load a pre-built agent configuration:
|
Use `TEMPLATE_NAME` environment variable to load a pre-built agent configuration:
|
||||||
|
|
||||||
| Template | Description |
|
| Template | Filename | Description |
|
||||||
|----------|-------------|
|
|----------|----------|-------------|
|
||||||
| `ecommerce` | Orders, refunds, shipping, product questions |
|
| `e-commerce` | `templates/e-commerce.yaml` | Orders, shipping, discounts |
|
||||||
| `saas` | Account management, billing, technical support |
|
| `saas` | `templates/saas.yaml` | Account management, billing, support |
|
||||||
| `generic` | General-purpose FAQ and escalation |
|
| `fintech` | `templates/fintech.yaml` | Financial services support |
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```bash
|
```bash
|
||||||
TEMPLATE_NAME=ecommerce uvicorn app.main:app
|
TEMPLATE_NAME=e-commerce uvicorn app.main:app
|
||||||
```
|
```
|
||||||
|
|
||||||
## Adding New Agents
|
## Adding New Agents
|
||||||
@@ -98,7 +110,7 @@ matches the agent's description.
|
|||||||
|
|
||||||
## Escalation
|
## Escalation
|
||||||
|
|
||||||
Any agent can trigger escalation by calling the `escalate` tool. This:
|
Any agent can trigger escalation. This:
|
||||||
1. Sends a webhook notification (if `WEBHOOK_URL` is configured).
|
1. Sends a webhook notification (if `WEBHOOK_URL` is configured).
|
||||||
2. Marks the conversation with `resolution_type = escalated`.
|
2. Marks the conversation with `resolution_type = escalated`.
|
||||||
3. Sends the agent's `escalation_message` to the user.
|
3. Sends the agent's `escalation_message` to the user.
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ Navigate to http://localhost in your browser.
|
|||||||
|
|
||||||
1. Open the Chat tab (default).
|
1. Open the Chat tab (default).
|
||||||
2. Send: **"What is the status of order 12345?"**
|
2. Send: **"What is the status of order 12345?"**
|
||||||
- Observe the `tool_call` indicator appear in the sidebar (order_agent calling `get_order_status`).
|
- Observe the `tool_call` indicator appear in the sidebar (`order_lookup` calling `get_order_status`).
|
||||||
- The agent responds with order status.
|
- The agent responds with order status.
|
||||||
3. Send: **"Can you cancel that order?"**
|
3. Send: **"Can you cancel that order?"**
|
||||||
- The system detects a write operation and shows an **Interrupt Prompt**.
|
- The system detects a write operation and shows an **Interrupt Prompt**.
|
||||||
@@ -61,15 +61,15 @@ Key points to highlight:
|
|||||||
### Scene 2: Multi-Agent Routing (2 minutes)
|
### Scene 2: Multi-Agent Routing (2 minutes)
|
||||||
|
|
||||||
1. Start a new browser tab (new session) or clear session storage.
|
1. Start a new browser tab (new session) or clear session storage.
|
||||||
2. Send: **"I need to track my order AND request a refund for a previous order"**
|
2. Send: **"I need to check order 12345 AND cancel order 67890"**
|
||||||
- The supervisor detects two intents: `order_agent` and `refund_agent`.
|
- The supervisor detects two intents: `order_lookup` (read) and `order_actions` (write).
|
||||||
- Both agents run in sequence.
|
- Both agents run in sequence.
|
||||||
- Two interrupt prompts may appear if both operations are write-level.
|
- The cancellation triggers an interrupt prompt for human approval.
|
||||||
|
|
||||||
Key points to highlight:
|
Key points to highlight:
|
||||||
- Intent classification detecting multiple actions
|
- Intent classification detecting multiple actions
|
||||||
- Automatic routing to appropriate specialist agents
|
- Automatic routing to appropriate specialist agents
|
||||||
- Sequential execution with confirmation gates
|
- Sequential execution with confirmation gates for write operations
|
||||||
|
|
||||||
### Scene 3: Conversation Replay (2 minutes)
|
### Scene 3: Conversation Replay (2 minutes)
|
||||||
|
|
||||||
|
|||||||
@@ -54,11 +54,19 @@ Set these in production (never commit secrets):
|
|||||||
| `ANTHROPIC_API_KEY` | Yes* | LLM provider API key |
|
| `ANTHROPIC_API_KEY` | Yes* | LLM provider API key |
|
||||||
| `LLM_PROVIDER` | Yes | `anthropic`, `openai`, or `google` |
|
| `LLM_PROVIDER` | Yes | `anthropic`, `openai`, or `google` |
|
||||||
| `LLM_MODEL` | Yes | Model name for your provider |
|
| `LLM_MODEL` | Yes | Model name for your provider |
|
||||||
|
| `ADMIN_API_KEY` | Recommended | API key for admin endpoints (analytics, replay, openapi, WS). Leave empty to disable auth (dev mode only) |
|
||||||
| `WEBHOOK_URL` | No | Escalation notification endpoint |
|
| `WEBHOOK_URL` | No | Escalation notification endpoint |
|
||||||
| `SESSION_TTL_MINUTES` | No | Session timeout (default: 30) |
|
| `SESSION_TTL_MINUTES` | No | Session timeout (default: 30) |
|
||||||
|
|
||||||
*Or `OPENAI_API_KEY` / `GOOGLE_API_KEY` depending on `LLM_PROVIDER`.
|
*Or `OPENAI_API_KEY` / `GOOGLE_API_KEY` depending on `LLM_PROVIDER`.
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
|
||||||
|
When `ADMIN_API_KEY` is set, all admin REST endpoints require the `X-API-Key` header,
|
||||||
|
and WebSocket connections require a `?token=<key>` query parameter.
|
||||||
|
|
||||||
|
When unset or empty, authentication is disabled (suitable for local development only).
|
||||||
|
|
||||||
### HTTPS
|
### HTTPS
|
||||||
|
|
||||||
For production, place a reverse proxy (nginx, Caddy, or a load balancer) in
|
For production, place a reverse proxy (nginx, Caddy, or a load balancer) in
|
||||||
@@ -87,10 +95,12 @@ cat backup.sql | docker compose exec -T postgres psql -U smart_support smart_sup
|
|||||||
|
|
||||||
### Scaling
|
### Scaling
|
||||||
|
|
||||||
The backend is stateless (session state is in PostgreSQL via LangGraph's
|
The backend supports multi-worker deployments. LangGraph session state is
|
||||||
PostgresSaver). You can run multiple backend replicas behind a load balancer.
|
persisted in PostgreSQL via PostgresSaver. For full horizontal scaling, use
|
||||||
|
`PgSessionManager` and `PgInterruptManager` (instead of the default in-memory
|
||||||
|
managers) to share session and interrupt state across workers.
|
||||||
|
|
||||||
The WebSocket connections are session-specific. Use sticky sessions or a shared
|
WebSocket connections are session-specific. Use sticky sessions or a shared
|
||||||
session backend if load balancing WebSockets across multiple instances.
|
session backend if load balancing WebSockets across multiple instances.
|
||||||
|
|
||||||
## Manual / Development Setup
|
## Manual / Development Setup
|
||||||
@@ -139,7 +149,7 @@ GET /api/health
|
|||||||
|
|
||||||
Response:
|
Response:
|
||||||
```json
|
```json
|
||||||
{"status": "ok", "version": "0.5.0"}
|
{"status": "ok", "version": "0.6.0"}
|
||||||
```
|
```
|
||||||
|
|
||||||
### WebSocket health
|
### WebSocket health
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ Import a URL, review the AI-classified endpoints, approve, and your agents are l
|
|||||||
1. **Import** -- Provide a URL to an OpenAPI 3.0 spec (JSON or YAML).
|
1. **Import** -- Provide a URL to an OpenAPI 3.0 spec (JSON or YAML).
|
||||||
2. **Parse** -- The system downloads and parses the spec.
|
2. **Parse** -- The system downloads and parses the spec.
|
||||||
3. **Classify** -- An LLM classifies each endpoint's:
|
3. **Classify** -- An LLM classifies each endpoint's:
|
||||||
- `access_type`: `read`, `write`, or `admin`
|
- `access_type`: `read` or `write`
|
||||||
|
- `needs_interrupt`: whether human approval is required
|
||||||
- `agent_group`: which specialist agent should handle this endpoint
|
- `agent_group`: which specialist agent should handle this endpoint
|
||||||
4. **Review** -- You inspect and edit the classifications in the UI.
|
4. **Review** -- You inspect and edit the classifications in the UI.
|
||||||
5. **Approve** -- Approved endpoints are registered as tools on the appropriate agents.
|
5. **Approve** -- Approved endpoints are registered as tools on the appropriate agents.
|
||||||
@@ -19,12 +20,12 @@ Import a URL, review the AI-classified endpoints, approve, and your agents are l
|
|||||||
|
|
||||||
1. Navigate to the **API Review** tab.
|
1. Navigate to the **API Review** tab.
|
||||||
2. Paste your OpenAPI spec URL into the import form.
|
2. Paste your OpenAPI spec URL into the import form.
|
||||||
3. Click **Import**.
|
3. Click **Scan Tools**.
|
||||||
4. Wait for the job to complete (status: `pending` -> `processing` -> `done`).
|
4. Wait for the job to complete (status: `pending` -> `processing` -> `done`).
|
||||||
5. Review the endpoint table:
|
5. Review the endpoint cards grouped by agent:
|
||||||
- Edit `access_type` if the AI misclassified sensitivity.
|
- Edit `access_type` (Read Only / Write) if the AI misclassified sensitivity.
|
||||||
- Edit `agent_group` to reassign an endpoint to a different agent.
|
- Edit the agent assignment to reassign an endpoint to a different agent.
|
||||||
6. Click **Approve & Save** when satisfied.
|
6. Click **Save Configuration** when satisfied.
|
||||||
|
|
||||||
## Using the REST API
|
## Using the REST API
|
||||||
|
|
||||||
@@ -39,11 +40,15 @@ Content-Type: application/json
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Response:
|
Response (202):
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"success": true,
|
"job_id": "abc123",
|
||||||
"data": { "job_id": "abc123", "status": "pending" }
|
"status": "pending",
|
||||||
|
"spec_url": "https://api.example.com/openapi.yaml",
|
||||||
|
"total_endpoints": 0,
|
||||||
|
"classified_count": 0,
|
||||||
|
"error_message": null
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -53,27 +58,47 @@ Response:
|
|||||||
GET /api/openapi/jobs/{job_id}
|
GET /api/openapi/jobs/{job_id}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Get job results
|
### Get classifications
|
||||||
|
|
||||||
```http
|
```http
|
||||||
GET /api/openapi/jobs/{job_id}/result
|
GET /api/openapi/jobs/{job_id}/classifications
|
||||||
|
```
|
||||||
|
|
||||||
|
Response: array of classification objects with `index`, `access_type`,
|
||||||
|
`needs_interrupt`, `agent_group`, `confidence`, `customer_params`, and `endpoint`.
|
||||||
|
|
||||||
|
### Update a classification
|
||||||
|
|
||||||
|
```http
|
||||||
|
PUT /api/openapi/jobs/{job_id}/classifications/{index}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"access_type": "write",
|
||||||
|
"needs_interrupt": true,
|
||||||
|
"agent_group": "order_actions"
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Approve job
|
### Approve job
|
||||||
|
|
||||||
```http
|
```http
|
||||||
POST /api/openapi/jobs/{job_id}/approve
|
POST /api/openapi/jobs/{job_id}/approve
|
||||||
Content-Type: application/json
|
```
|
||||||
|
|
||||||
|
No request body. Generates tool code for each classified endpoint and produces
|
||||||
|
an agent YAML configuration. Response includes `generated_tools_count`.
|
||||||
|
|
||||||
|
Response:
|
||||||
|
```json
|
||||||
{
|
{
|
||||||
"endpoints": [
|
"job_id": "abc123",
|
||||||
{
|
"status": "approved",
|
||||||
"path": "/orders/{order_id}",
|
"spec_url": "https://api.example.com/openapi.yaml",
|
||||||
"method": "get",
|
"total_endpoints": 5,
|
||||||
"access_type": "read",
|
"classified_count": 5,
|
||||||
"agent_group": "order_agent"
|
"error_message": null,
|
||||||
}
|
"generated_tools_count": 5
|
||||||
]
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -82,18 +107,17 @@ Content-Type: application/json
|
|||||||
| Access Type | Description | Interrupt Required |
|
| Access Type | Description | Interrupt Required |
|
||||||
|-------------|-------------|-------------------|
|
|-------------|-------------|-------------------|
|
||||||
| `read` | GET operations, no side effects | No |
|
| `read` | GET operations, no side effects | No |
|
||||||
| `write` | POST/PUT/PATCH that modify data | Yes |
|
| `write` | POST/PUT/PATCH/DELETE that modify data | Yes (by default) |
|
||||||
| `admin` | DELETE, bulk operations, sensitive writes | Yes |
|
|
||||||
|
The `needs_interrupt` flag can be overridden per-endpoint during review.
|
||||||
|
|
||||||
## SSRF Protection
|
## SSRF Protection
|
||||||
|
|
||||||
All import requests are validated against an allowlist:
|
All import requests are validated:
|
||||||
- Private IP ranges are blocked (10.x, 172.16.x, 192.168.x, 127.x)
|
- Private IP ranges are blocked (10.x, 172.16.x, 192.168.x, 127.x)
|
||||||
- Localhost and metadata service URLs are blocked
|
- Localhost and cloud metadata service URLs are blocked
|
||||||
- Only `http://` and `https://` schemes are permitted
|
- Only `http://` and `https://` schemes are permitted
|
||||||
|
|
||||||
To allow internal URLs (e.g., in development), set `SSRF_ALLOWLIST_HOSTS` in your environment.
|
|
||||||
|
|
||||||
## Supported Spec Formats
|
## Supported Spec Formats
|
||||||
|
|
||||||
- OpenAPI 3.0.x (JSON or YAML)
|
- OpenAPI 3.0.x (JSON or YAML)
|
||||||
|
|||||||
76
docs/phases/eng-improvements-dev-log.md
Normal file
76
docs/phases/eng-improvements-dev-log.md
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# Engineering Improvements -- Development Log
|
||||||
|
|
||||||
|
> Status: COMPLETED
|
||||||
|
> Branch: `eng/engineering-improvements`
|
||||||
|
> Date started: 2026-04-06
|
||||||
|
> Date completed: 2026-04-06
|
||||||
|
|
||||||
|
## What Was Built
|
||||||
|
|
||||||
|
### Phase 1: Quick Wins (no new deps)
|
||||||
|
|
||||||
|
1. **Interrupt Cleanup Background Task** -- Added asyncio background task in lifespan that calls `interrupt_manager.cleanup_expired()` every 60 seconds. Prevents unbounded memory growth from expired interrupts.
|
||||||
|
|
||||||
|
2. **API Versioning** -- All REST endpoints prefixed with `/api/v1/` (was `/api/`). Updated 4 router prefixes, Docker healthcheck, all frontend fetch URLs, and all test assertions. WebSocket `/ws` endpoint unchanged.
|
||||||
|
|
||||||
|
3. **Error Response Standardization** -- Added global exception handlers for `HTTPException`, `RequestValidationError`, and `Exception`. All error responses now use the same envelope format as success responses: `{"success": false, "data": null, "error": "..."}`.
|
||||||
|
|
||||||
|
### Phase 2: Medium Items (new deps)
|
||||||
|
|
||||||
|
4. **Alembic Database Migrations** -- Replaced inline DDL in `setup_app_tables()` with versioned Alembic migrations. Initial migration `001_initial_schema.py` captures all 4 tables + ALTER TABLE migration. `setup_app_tables()` preserved for tests. Production uses `run_alembic_migrations()`.
|
||||||
|
|
||||||
|
5. **Structured Logging** -- Replaced stdlib `logging.getLogger()` with `structlog.get_logger()` across 10 files. Added `logging_config.py` with console (dev) and JSON (production) modes. Configurable via `LOG_FORMAT` env var.
|
||||||
|
|
||||||
|
### Phase 3: Test Coverage
|
||||||
|
|
||||||
|
7. **Integration Tests (+30)** -- Created 5 new test files: analytics API, replay API, OpenAPI API, error responses, session/interrupt lifecycle. Uses httpx.AsyncClient with ASGITransport for full API layer testing.
|
||||||
|
|
||||||
|
8. **Frontend Tests (+57)** -- Created 12 new test files covering all components (ChatInput, ChatMessages, InterruptPrompt, ErrorBanner, NavBar, MetricCard, ReplayTimeline, AgentAction, Layout), pages (ChatPage, ReviewPage), and hooks (useWebSocket).
|
||||||
|
|
||||||
|
## Code Structure
|
||||||
|
|
||||||
|
### New files created
|
||||||
|
- `backend/app/logging_config.py` -- structlog configuration
|
||||||
|
- `backend/alembic.ini` -- Alembic config
|
||||||
|
- `backend/alembic/env.py` -- Migration environment
|
||||||
|
- `backend/alembic/versions/001_initial_schema.py` -- Initial migration
|
||||||
|
- `backend/tests/unit/test_interrupt_cleanup.py` (3 tests)
|
||||||
|
- `backend/tests/unit/test_error_responses.py` (6 tests)
|
||||||
|
- `backend/tests/unit/test_logging_config.py` (2 tests)
|
||||||
|
- `backend/tests/integration/test_analytics_api.py` (6 tests)
|
||||||
|
- `backend/tests/integration/test_replay_api.py` (6 tests)
|
||||||
|
- `backend/tests/integration/test_openapi_api.py` (5 tests)
|
||||||
|
- `backend/tests/integration/test_error_responses.py` (5 tests)
|
||||||
|
- `backend/tests/integration/test_session_interrupt_lifecycle.py` (8 tests)
|
||||||
|
- 12 frontend test files (57 tests total)
|
||||||
|
|
||||||
|
### Modified files
|
||||||
|
- `backend/app/main.py` -- cleanup task, exception handlers, alembic, structlog
|
||||||
|
- `backend/app/db.py` -- added run_alembic_migrations()
|
||||||
|
- `backend/app/config.py` -- added log_format setting
|
||||||
|
- `backend/pyproject.toml` -- added alembic, structlog deps
|
||||||
|
- 4 router files -- `/api/v1/` prefix
|
||||||
|
- 10 files -- structlog migration
|
||||||
|
- `docker-compose.yml` -- healthcheck URL
|
||||||
|
- `frontend/src/api.ts` -- `/api/v1/` URLs
|
||||||
|
- All existing test files -- API path updates + error envelope assertions
|
||||||
|
|
||||||
|
## Test Coverage
|
||||||
|
|
||||||
|
- Backend: 557 tests (was 516), 89.75% coverage
|
||||||
|
- Unit: ~490 tests
|
||||||
|
- Integration: ~60 tests
|
||||||
|
- E2E: ~7 tests
|
||||||
|
- Frontend: 80 tests (was 23), 16 test files (was 4)
|
||||||
|
|
||||||
|
## Deviations from Plan
|
||||||
|
|
||||||
|
- Redis rate limiting deferred (single-worker sufficient for now)
|
||||||
|
- ConversationTracker verified correct by design (pool per-method), skipped
|
||||||
|
- Coverage dropped slightly from 90.26% to 89.75% due to new alembic/logging modules with partial test coverage (still well above 80% threshold)
|
||||||
|
|
||||||
|
## Known Issues / Tech Debt
|
||||||
|
|
||||||
|
- Rate limiting remains process-global (needs Redis for multi-worker)
|
||||||
|
- Alembic migrations not tested against real PostgreSQL in CI (would need running DB)
|
||||||
|
- Frontend test coverage could be deeper (e.g., WebSocket reconnect edge cases)
|
||||||
92
docs/ux_design_system.md
Normal file
92
docs/ux_design_system.md
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# Smart Support UX Design System
|
||||||
|
|
||||||
|
This document outlines the core User Experience (UX) and User Interface (UI) design standards for the Smart Support platform. Our visual identity departs from the generic "tech cold blue/white" default, leaning into a premium, trustworthy, and organic "Warm Beige" aesthetic targeted at high-end B2B SaaS buyers.
|
||||||
|
|
||||||
|
## 1. Core Philosophy
|
||||||
|
|
||||||
|
* **Trust Through Warmth:** Customer support tools need to inspire confidence. We use an organic "Rich Warm Beige" canvas paired with "Deep Slate/Walnut" typography to feel more like a premium workspace (e.g., Notion, high-end interior design) rather than a sterile terminal.
|
||||||
|
* **Action over Text:** This is an *Action Layer*, not just a chatbot. Destructive or high-risk actions (refunds, cancellations) must visually "jump out" from the conversation flow via elevated cards.
|
||||||
|
* **Expansive Workspace:** Leverage horizontal screen space. Instead of a narrow 800px ChatGPT-style centered column, our workspace flows fluidly to the edges, similar to Slack or Zendesk.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Color Palette (Design Tokens)
|
||||||
|
|
||||||
|
All colors are strictly mapped to CSS Variables in `index.css`. **Do not use hardcoded hex values in components.**
|
||||||
|
|
||||||
|
### Backgrounds & Surfaces
|
||||||
|
| Token | Hex | Usage |
|
||||||
|
| :--- | :--- | :--- |
|
||||||
|
| `App Wrapper` | `#DBD2C6` | The absolute outermost canvas (the "Dribbble presentation frame"). Visible only on large screens as a dark beige border. |
|
||||||
|
| `--bg-app` | `#F4EFE7` | The primary background color for the application shell and main content areas. |
|
||||||
|
| `--bg-surface` | `#EBE4D8` | Slightly darker beige. Used for elevated cards, the sidebar, and inputs to create depth. |
|
||||||
|
| `--bg-surface-inner` | `#F6F2EC` | A lighter inner container fill, often used as table headers or secondary nested boxes. |
|
||||||
|
| `--bg-hover` | `#E1D9CC` | Hover state backgrounds, active navigation item pills, and disabled button states. |
|
||||||
|
|
||||||
|
### Typography & Ink
|
||||||
|
| Token | Hex | Usage |
|
||||||
|
| :--- | :--- | :--- |
|
||||||
|
| `--text-primary` | `#1C1917` | Primary text (Headings, body copy). A deep brownish-slate, entirely avoiding harsh #000000 black. |
|
||||||
|
| `--text-secondary` | `#5C554D` | Secondary UI text, metadata, table column headers, and timestamps. |
|
||||||
|
|
||||||
|
### Brand & Interactive Elements
|
||||||
|
| Token | Hex | Usage |
|
||||||
|
| :--- | :--- | :--- |
|
||||||
|
| `--brand-primary` | `#3B342D` | Primary buttons, brand icons, and active UI states. |
|
||||||
|
| `--brand-hover` | `#26211C` | Hover states for primary interactive elements. |
|
||||||
|
| `--border-light` | `#D5CCC0` | Dividers, subtle borders around cards and tables. |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Typography
|
||||||
|
|
||||||
|
* **Font Family:** `'Inter', system-ui, -apple-system, sans-serif`
|
||||||
|
* **Scale:** We rely on sharp, structural typography rather than excess lines to create hierarchy.
|
||||||
|
* **Headers (h2/h3):** `700` (Bold), tight letter-spacing (`-0.01em`).
|
||||||
|
* **Nav & Buttons:** `600` (Semi-bold), `0.9375rem` (15px) or `0.875rem` (14px).
|
||||||
|
* **Micro-text (Badges/Labels):** `0.75rem` (12px), uppercase, generous letter-spacing (`0.05em`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. The "Framed Window" Layout Paradigm
|
||||||
|
|
||||||
|
Rather than a UI that bleeds indefinitely to the edges of an ultrawide monitor, the Smart Support UI employs a **Responsive Window Frame**, while maintaining a flat visual hierarchy:
|
||||||
|
|
||||||
|
* **Small Screens / Mobile (< 768px):** The `.app-layout` merges with the browser edges (`100vw/100vh`, `0px` border-radius).
|
||||||
|
* **Large Screens (>= 768px):** The App shrinks slightly, creating a `1.5rem` (24px) margin on all sides against a slightly darker background. The app window gets a luxury `20px` border-radius and a soft, diffused drop shadow.
|
||||||
|
* **Flat Visual Hierarchy:** The Sidebar background is slightly darker (`--bg-surface`) than the main work area (`--bg-app`). They sit adjacent to each other without inner dividing boxed margins. The border line is implicitly created by the tone difference.
|
||||||
|
* **Content Alignment:** The main `app-main` area does *not* center its content in a narrow channel. It uses full-width fluid layouts with standard left and right paddings (e.g., `3rem`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Component Signatures
|
||||||
|
|
||||||
|
### Micro-interactions & Loading States (New)
|
||||||
|
* **Skeleton Loading:** Never use harsh unstyled "Loading..." text strings. Utilize the `.skeleton-box` and `.skeleton-text` CSS classes which provide a smooth 1.5s pulse animation looping between `--bg-hover` and `--border-light`.
|
||||||
|
* **Graceful Rendering:** Content blocks should be replaced fully by matching structured skeletons outlining the UI during any data fetch or mock delay.
|
||||||
|
|
||||||
|
### Information Visual Hierarchy & Audit Trails (New)
|
||||||
|
* **Visual Noise Reduction:** Do not treat all logs equally. On Audit or Timeline screens (e.g. Conversation Replay), raw system logs like Tool Calls or Intent extractions must be rendered quietly as muted, italic text without background bubbles.
|
||||||
|
* **Focus Highlighting:** The highest visual weight in logs is reserved strictly for Human-to-AI interaction messages, Human-in-the-Loop Interventions, and critical overrides. Use distinctive background panels (e.g. pale red, soft lavender) only for these elevated actions.
|
||||||
|
|
||||||
|
### The Sidebar (Nav)
|
||||||
|
* **Tone-on-Tone:** Active navigation item pills should rely strictly on capsule background fills (`--bg-hover`) rather than font color switches or jarring left-bars.
|
||||||
|
|
||||||
|
### Action Cards (Human-in-the-Loop)
|
||||||
|
When an agent stops to ask for human confirmation (e.g., "This refund is >$1,000"):
|
||||||
|
1. **Isolate:** It must render as a distinct UI card (`.action-card`), jumping out from the standard Markdown text flow.
|
||||||
|
2. **Color Stripe:** It uses a high-contrast left border (e.g., Red `#DC2626` for security approvals) to signal importance.
|
||||||
|
3. **Shadows:** Elevated using `box-shadow: var(--shadow-lg)` to hover above the conversation.
|
||||||
|
|
||||||
|
### Data Tables & Analytics
|
||||||
|
* **No Vertical Borders:** Tables should only use horizontal lines (`border-bottom`) to separate rows. Vertical lines feel too rigid and clunky.
|
||||||
|
* **Hover Rows:** Wrap standard rows in a hover transition (`background-color: var(--bg-hover)`) to help the eye track long data strings.
|
||||||
|
* **Metric Boxes:** Important KPI statistics (like those on Dashboard) are housed in thick, rounded boxes (`--radius-xl`) to look like physical widgets.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. CSS Best Practices for the Project
|
||||||
|
|
||||||
|
1. **Avoid Inline Styles:** All recurring UI patterns (like `btn`, `page-header`, `metricsGrid`) should map to CSS classes in `index.css`.
|
||||||
|
2. **Use REM for Spacing/Sizing:** Prefer `rem` over `px` for paddings, margins, and font sizes to ensure accessibility scaling.
|
||||||
|
3. **Soft Shadows Only:** Shadows should have high blur radiuses and low opacity. *Bad: `rgba(0,0,0,0.5) 0px 5px`.* *Good: `rgba(0,0,0,0.06) 0px 10px 30px`*.
|
||||||
@@ -104,8 +104,8 @@ smart-support/
|
|||||||
|
|
||||||
## Tech Stack
|
## Tech Stack
|
||||||
|
|
||||||
- Python 3.11+, FastAPI, LangGraph v1.1.0
|
- Python 3.11+, FastAPI, LangGraph 1.x (currently 1.1.6)
|
||||||
- langgraph-supervisor, langchain-mcp-adapters, langgraph-checkpoint-postgres v3.0.5
|
- langgraph-supervisor 0.0.31, langchain-mcp-adapters, langgraph-checkpoint-postgres v3.0.5
|
||||||
- React (frontend), PostgreSQL 16 (via Docker Compose)
|
- React (frontend), PostgreSQL 16 (via Docker Compose)
|
||||||
- Claude Sonnet 4.6 via `ChatAnthropic` (configurable via env)
|
- Claude Sonnet 4.6 via `ChatAnthropic` (configurable via env)
|
||||||
- pytest + FastAPI TestClient for backend tests
|
- pytest + FastAPI TestClient for backend tests
|
||||||
|
|||||||
1876
frontend/package-lock.json
generated
1876
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -6,18 +6,25 @@
|
|||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "vite",
|
"dev": "vite",
|
||||||
"build": "tsc -b && vite build",
|
"build": "tsc -b && vite build",
|
||||||
"preview": "vite preview"
|
"preview": "vite preview",
|
||||||
|
"test": "vitest run",
|
||||||
|
"test:watch": "vitest"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"react": "^19.0.0",
|
"react": "^19.0.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
|
"react-markdown": "^10.1.0",
|
||||||
"react-router-dom": "^7.13.2"
|
"react-router-dom": "^7.13.2"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
"@testing-library/jest-dom": "^6.9.1",
|
||||||
|
"@testing-library/react": "^16.3.2",
|
||||||
"@types/react": "^19.0.0",
|
"@types/react": "^19.0.0",
|
||||||
"@types/react-dom": "^19.0.0",
|
"@types/react-dom": "^19.0.0",
|
||||||
"@vitejs/plugin-react": "^4.3.0",
|
"@vitejs/plugin-react": "^4.3.0",
|
||||||
|
"happy-dom": "^20.8.9",
|
||||||
"typescript": "~5.7.0",
|
"typescript": "~5.7.0",
|
||||||
"vite": "^6.2.0"
|
"vite": "^6.2.0",
|
||||||
|
"vitest": "^4.1.2"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
117
frontend/src/api.test.ts
Normal file
117
frontend/src/api.test.ts
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||||
|
import { fetchConversations, fetchReplay, fetchAnalytics } from "./api";
|
||||||
|
|
||||||
|
// Mock global fetch
|
||||||
|
const mockFetch = vi.fn();
|
||||||
|
vi.stubGlobal("fetch", mockFetch);
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockFetch.mockReset();
|
||||||
|
});
|
||||||
|
|
||||||
|
function jsonResponse(body: unknown, status = 200): Response {
|
||||||
|
return {
|
||||||
|
ok: status >= 200 && status < 300,
|
||||||
|
status,
|
||||||
|
statusText: status === 200 ? "OK" : "Error",
|
||||||
|
json: () => Promise.resolve(body),
|
||||||
|
} as Response;
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("fetchConversations", () => {
|
||||||
|
it("returns conversations page on success", async () => {
|
||||||
|
const data = {
|
||||||
|
conversations: [{ thread_id: "t1", created_at: "", last_activity: "", status: "active", total_tokens: 0, total_cost_usd: 0 }],
|
||||||
|
total: 1,
|
||||||
|
page: 1,
|
||||||
|
per_page: 20,
|
||||||
|
};
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({ success: true, data, error: null }));
|
||||||
|
|
||||||
|
const result = await fetchConversations();
|
||||||
|
expect(result.conversations).toHaveLength(1);
|
||||||
|
expect(result.total).toBe(1);
|
||||||
|
expect(mockFetch).toHaveBeenCalledWith("/api/v1/conversations?page=1&per_page=20");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("passes custom page and perPage", async () => {
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { conversations: [], total: 0, page: 2, per_page: 10 }, error: null }));
|
||||||
|
|
||||||
|
await fetchConversations(2, 10);
|
||||||
|
expect(mockFetch).toHaveBeenCalledWith("/api/v1/conversations?page=2&per_page=10");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws on HTTP error", async () => {
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({}, 500));
|
||||||
|
|
||||||
|
await expect(fetchConversations()).rejects.toThrow("API error 500");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws on success=false with error message", async () => {
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({ success: false, data: null, error: "Database unavailable" }));
|
||||||
|
|
||||||
|
await expect(fetchConversations()).rejects.toThrow("Database unavailable");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws unknown error when success=false with no message", async () => {
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({ success: false, data: null, error: null }));
|
||||||
|
|
||||||
|
await expect(fetchConversations()).rejects.toThrow("Unknown API error");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("fetchReplay", () => {
|
||||||
|
it("returns replay page on success", async () => {
|
||||||
|
const data = {
|
||||||
|
thread_id: "t1",
|
||||||
|
total_steps: 3,
|
||||||
|
page: 1,
|
||||||
|
per_page: 20,
|
||||||
|
steps: [{ step: 1, type: "message", content: "Hello", agent: null, tool: null, params: null, result: null, timestamp: "" }],
|
||||||
|
};
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({ success: true, data, error: null }));
|
||||||
|
|
||||||
|
const result = await fetchReplay("t1");
|
||||||
|
expect(result.total_steps).toBe(3);
|
||||||
|
expect(result.steps).toHaveLength(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("encodes thread_id in URL", async () => {
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { thread_id: "a/b", total_steps: 0, page: 1, per_page: 20, steps: [] }, error: null }));
|
||||||
|
|
||||||
|
await fetchReplay("a/b");
|
||||||
|
expect(mockFetch).toHaveBeenCalledWith("/api/v1/replay/a%2Fb?page=1&per_page=20");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("throws on HTTP error", async () => {
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({}, 404));
|
||||||
|
await expect(fetchReplay("missing")).rejects.toThrow("API error 404");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("fetchAnalytics", () => {
|
||||||
|
it("returns analytics data on success", async () => {
|
||||||
|
const data = {
|
||||||
|
range: "7d",
|
||||||
|
total_conversations: 100,
|
||||||
|
resolution_rate: 0.75,
|
||||||
|
escalation_rate: 0.25,
|
||||||
|
avg_turns_per_conversation: 3.5,
|
||||||
|
avg_cost_per_conversation_usd: 0.03,
|
||||||
|
agent_usage: [],
|
||||||
|
interrupt_stats: { total: 0, approved: 0, rejected: 0, expired: 0 },
|
||||||
|
};
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({ success: true, data, error: null }));
|
||||||
|
|
||||||
|
const result = await fetchAnalytics("7d");
|
||||||
|
expect(result.total_conversations).toBe(100);
|
||||||
|
expect(result.range).toBe("7d");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("uses default range", async () => {
|
||||||
|
mockFetch.mockResolvedValue(jsonResponse({ success: true, data: { range: "7d" }, error: null }));
|
||||||
|
|
||||||
|
await fetchAnalytics();
|
||||||
|
expect(mockFetch).toHaveBeenCalledWith("/api/v1/analytics?range=7d");
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -10,13 +10,11 @@ export interface ApiResponse<T> {
|
|||||||
|
|
||||||
export interface ConversationSummary {
|
export interface ConversationSummary {
|
||||||
thread_id: string;
|
thread_id: string;
|
||||||
started_at: string;
|
created_at: string;
|
||||||
last_activity: string;
|
last_activity: string;
|
||||||
turn_count: number;
|
status: string | null;
|
||||||
agents_used: string[];
|
|
||||||
total_tokens: number;
|
total_tokens: number;
|
||||||
total_cost_usd: number;
|
total_cost_usd: number;
|
||||||
resolution_type: string | null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ConversationsPage {
|
export interface ConversationsPage {
|
||||||
@@ -39,17 +37,16 @@ export interface ReplayStep {
|
|||||||
|
|
||||||
export interface ReplayPage {
|
export interface ReplayPage {
|
||||||
thread_id: string;
|
thread_id: string;
|
||||||
steps: ReplayStep[];
|
total_steps: number;
|
||||||
total: number;
|
|
||||||
page: number;
|
page: number;
|
||||||
per_page: number;
|
per_page: number;
|
||||||
|
steps: ReplayStep[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AgentUsage {
|
export interface AgentUsage {
|
||||||
agent_name: string;
|
agent: string;
|
||||||
message_count: number;
|
count: number;
|
||||||
total_tokens: number;
|
percentage: number;
|
||||||
total_cost_usd: number;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface InterruptStats {
|
export interface InterruptStats {
|
||||||
@@ -60,14 +57,12 @@ export interface InterruptStats {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface AnalyticsData {
|
export interface AnalyticsData {
|
||||||
|
range: string;
|
||||||
total_conversations: number;
|
total_conversations: number;
|
||||||
resolved_conversations: number;
|
|
||||||
escalated_conversations: number;
|
|
||||||
resolution_rate: number;
|
resolution_rate: number;
|
||||||
escalation_rate: number;
|
escalation_rate: number;
|
||||||
total_tokens: number;
|
|
||||||
total_cost_usd: number;
|
|
||||||
avg_turns_per_conversation: number;
|
avg_turns_per_conversation: number;
|
||||||
|
avg_cost_per_conversation_usd: number;
|
||||||
agent_usage: AgentUsage[];
|
agent_usage: AgentUsage[];
|
||||||
interrupt_stats: InterruptStats;
|
interrupt_stats: InterruptStats;
|
||||||
}
|
}
|
||||||
@@ -89,7 +84,7 @@ export async function fetchConversations(
|
|||||||
perPage = 20
|
perPage = 20
|
||||||
): Promise<ConversationsPage> {
|
): Promise<ConversationsPage> {
|
||||||
return apiFetch<ConversationsPage>(
|
return apiFetch<ConversationsPage>(
|
||||||
`/api/conversations?page=${page}&per_page=${perPage}`
|
`/api/v1/conversations?page=${page}&per_page=${perPage}`
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,10 +94,81 @@ export async function fetchReplay(
|
|||||||
perPage = 20
|
perPage = 20
|
||||||
): Promise<ReplayPage> {
|
): Promise<ReplayPage> {
|
||||||
return apiFetch<ReplayPage>(
|
return apiFetch<ReplayPage>(
|
||||||
`/api/replay/${encodeURIComponent(threadId)}?page=${page}&per_page=${perPage}`
|
`/api/v1/replay/${encodeURIComponent(threadId)}?page=${page}&per_page=${perPage}`
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> {
|
export async function fetchAnalytics(range = "7d"): Promise<AnalyticsData> {
|
||||||
return apiFetch<AnalyticsData>(`/api/analytics?range=${range}`);
|
return apiFetch<AnalyticsData>(`/api/v1/analytics?range=${range}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- OpenAPI import --
|
||||||
|
|
||||||
|
export interface ImportJobResponse {
|
||||||
|
job_id: string;
|
||||||
|
status: string;
|
||||||
|
spec_url: string;
|
||||||
|
total_endpoints: number;
|
||||||
|
classified_count: number;
|
||||||
|
error_message: string | null;
|
||||||
|
generated_tools_count?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EndpointClassification {
|
||||||
|
index: number;
|
||||||
|
access_type: string;
|
||||||
|
needs_interrupt: boolean;
|
||||||
|
agent_group: string;
|
||||||
|
confidence: number;
|
||||||
|
customer_params: string[];
|
||||||
|
endpoint: {
|
||||||
|
path: string;
|
||||||
|
method: string;
|
||||||
|
operation_id: string;
|
||||||
|
summary: string;
|
||||||
|
description: string;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async function apiPost<T>(path: string, body: unknown): Promise<T> {
|
||||||
|
const res = await fetch(`${API_BASE}${path}`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify(body),
|
||||||
|
});
|
||||||
|
if (!res.ok) {
|
||||||
|
throw new Error(`API error ${res.status}: ${res.statusText}`);
|
||||||
|
}
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function startImport(url: string): Promise<ImportJobResponse> {
|
||||||
|
return apiPost<ImportJobResponse>("/api/v1/openapi/import", { url });
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchImportJob(jobId: string): Promise<ImportJobResponse> {
|
||||||
|
const res = await fetch(`${API_BASE}/api/v1/openapi/jobs/${encodeURIComponent(jobId)}`);
|
||||||
|
if (!res.ok) {
|
||||||
|
throw new Error(`API error ${res.status}: ${res.statusText}`);
|
||||||
|
}
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchClassifications(
|
||||||
|
jobId: string
|
||||||
|
): Promise<EndpointClassification[]> {
|
||||||
|
const res = await fetch(
|
||||||
|
`${API_BASE}/api/v1/openapi/jobs/${encodeURIComponent(jobId)}/classifications`
|
||||||
|
);
|
||||||
|
if (!res.ok) {
|
||||||
|
throw new Error(`API error ${res.status}: ${res.statusText}`);
|
||||||
|
}
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function approveJob(jobId: string): Promise<ImportJobResponse> {
|
||||||
|
return apiPost<ImportJobResponse>(
|
||||||
|
`/api/v1/openapi/jobs/${encodeURIComponent(jobId)}/approve`,
|
||||||
|
{}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
47
frontend/src/components/AgentAction.test.tsx
Normal file
47
frontend/src/components/AgentAction.test.tsx
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import { describe, it, expect } from "vitest";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import { AgentAction } from "./AgentAction";
|
||||||
|
import type { ToolAction } from "../types";
|
||||||
|
|
||||||
|
function makeAction(overrides: Partial<ToolAction> = {}): ToolAction {
|
||||||
|
return {
|
||||||
|
id: "action-1",
|
||||||
|
agent: "OrderAgent",
|
||||||
|
tool: "get_order",
|
||||||
|
args: { order_id: "ORD-100" },
|
||||||
|
timestamp: Date.now(),
|
||||||
|
...overrides,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("AgentAction", () => {
|
||||||
|
it("renders agent name and tool name", () => {
|
||||||
|
render(<AgentAction action={makeAction()} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("OrderAgent")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("get_order")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows args and result when expanded", () => {
|
||||||
|
const action = makeAction({ result: { status: "shipped" } });
|
||||||
|
render(<AgentAction action={action} />);
|
||||||
|
|
||||||
|
// Click header to expand
|
||||||
|
fireEvent.click(screen.getByText("OrderAgent"));
|
||||||
|
|
||||||
|
expect(screen.getByText("Args:")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Result:")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(/"order_id": "ORD-100"/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(/"status": "shipped"/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("does not show result section when result is undefined", () => {
|
||||||
|
render(<AgentAction action={makeAction()} />);
|
||||||
|
|
||||||
|
// Expand
|
||||||
|
fireEvent.click(screen.getByText("OrderAgent"));
|
||||||
|
|
||||||
|
expect(screen.getByText("Args:")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("Result:")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
53
frontend/src/components/ChatInput.test.tsx
Normal file
53
frontend/src/components/ChatInput.test.tsx
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import { describe, it, expect, vi } from "vitest";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import { ChatInput } from "./ChatInput";
|
||||||
|
|
||||||
|
describe("ChatInput", () => {
|
||||||
|
it("renders input field and send button", () => {
|
||||||
|
render(<ChatInput onSend={vi.fn()} disabled={false} />);
|
||||||
|
|
||||||
|
expect(screen.getByPlaceholderText("Message Smart Support...")).toBeInTheDocument();
|
||||||
|
expect(screen.getByRole("button", { name: "Send Message" })).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calls onSend with trimmed content when form is submitted via Enter", () => {
|
||||||
|
const onSend = vi.fn();
|
||||||
|
render(<ChatInput onSend={onSend} disabled={false} />);
|
||||||
|
|
||||||
|
const input = screen.getByPlaceholderText("Message Smart Support...");
|
||||||
|
fireEvent.change(input, { target: { value: " Hello world " } });
|
||||||
|
fireEvent.keyDown(input, { key: "Enter" });
|
||||||
|
|
||||||
|
expect(onSend).toHaveBeenCalledWith("Hello world");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("clears input after successful send", () => {
|
||||||
|
const onSend = vi.fn();
|
||||||
|
render(<ChatInput onSend={onSend} disabled={false} />);
|
||||||
|
|
||||||
|
const input = screen.getByPlaceholderText("Message Smart Support...") as HTMLInputElement;
|
||||||
|
fireEvent.change(input, { target: { value: "Test message" } });
|
||||||
|
fireEvent.keyDown(input, { key: "Enter" });
|
||||||
|
|
||||||
|
expect(input.value).toBe("");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows disabled placeholder and disables input when disabled", () => {
|
||||||
|
render(<ChatInput onSend={vi.fn()} disabled={true} />);
|
||||||
|
|
||||||
|
const input = screen.getByPlaceholderText("Agent is working...") as HTMLInputElement;
|
||||||
|
expect(input.disabled).toBe(true);
|
||||||
|
expect(screen.getByRole("button", { name: "Send Message" })).toBeDisabled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("does not call onSend when input is empty or whitespace", () => {
|
||||||
|
const onSend = vi.fn();
|
||||||
|
render(<ChatInput onSend={onSend} disabled={false} />);
|
||||||
|
|
||||||
|
const input = screen.getByPlaceholderText("Message Smart Support...");
|
||||||
|
fireEvent.change(input, { target: { value: " " } });
|
||||||
|
fireEvent.keyDown(input, { key: "Enter" });
|
||||||
|
|
||||||
|
expect(onSend).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -23,46 +23,23 @@ export function ChatInput({ onSend, disabled }: Props) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div style={styles.container}>
|
<div className="chat-input-container">
|
||||||
|
<div className="chat-input-wrapper">
|
||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
value={value}
|
value={value}
|
||||||
onChange={(e) => setValue(e.target.value)}
|
onChange={(e) => setValue(e.target.value)}
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
placeholder={disabled ? "Waiting for response..." : "Type a message..."}
|
placeholder={disabled ? "Agent is working..." : "Message Smart Support..."}
|
||||||
disabled={disabled}
|
disabled={disabled}
|
||||||
style={styles.input}
|
|
||||||
/>
|
/>
|
||||||
<button onClick={handleSubmit} disabled={disabled || !value.trim()} style={styles.button}>
|
<button className="chat-send-btn" onClick={handleSubmit} disabled={disabled || !value.trim()} aria-label="Send Message">
|
||||||
Send
|
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round">
|
||||||
|
<line x1="22" y1="2" x2="11" y2="13"></line>
|
||||||
|
<polygon points="22 2 15 22 11 13 2 9 22 2"></polygon>
|
||||||
|
</svg>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const styles: Record<string, React.CSSProperties> = {
|
|
||||||
container: {
|
|
||||||
display: "flex",
|
|
||||||
gap: "8px",
|
|
||||||
padding: "12px 16px",
|
|
||||||
borderTop: "1px solid #e0e0e0",
|
|
||||||
background: "white",
|
|
||||||
},
|
|
||||||
input: {
|
|
||||||
flex: 1,
|
|
||||||
padding: "10px 14px",
|
|
||||||
border: "1px solid #ccc",
|
|
||||||
borderRadius: "8px",
|
|
||||||
fontSize: "14px",
|
|
||||||
outline: "none",
|
|
||||||
},
|
|
||||||
button: {
|
|
||||||
padding: "10px 20px",
|
|
||||||
background: "#0066cc",
|
|
||||||
color: "white",
|
|
||||||
border: "none",
|
|
||||||
borderRadius: "8px",
|
|
||||||
fontSize: "14px",
|
|
||||||
cursor: "pointer",
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|||||||
59
frontend/src/components/ChatMessages.test.tsx
Normal file
59
frontend/src/components/ChatMessages.test.tsx
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import { describe, it, expect, vi } from "vitest";
|
||||||
|
import { render, screen } from "@testing-library/react";
|
||||||
|
import { ChatMessages } from "./ChatMessages";
|
||||||
|
import type { ChatMessage } from "../types";
|
||||||
|
|
||||||
|
// Mock react-markdown to avoid complex rendering
|
||||||
|
vi.mock("react-markdown", () => ({
|
||||||
|
default: ({ children }: { children: string }) => <span>{children}</span>,
|
||||||
|
}));
|
||||||
|
|
||||||
|
describe("ChatMessages", () => {
|
||||||
|
it("renders welcome message when messages array is empty", () => {
|
||||||
|
render(<ChatMessages messages={[]} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Hello! How can I help you today?")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Smart Support")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders user messages with correct sender label", () => {
|
||||||
|
const messages: ChatMessage[] = [
|
||||||
|
{ id: "1", sender: "user", content: "I need help", timestamp: Date.now() },
|
||||||
|
];
|
||||||
|
render(<ChatMessages messages={messages} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("You")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("I need help")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Me")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders agent messages with agent name", () => {
|
||||||
|
const messages: ChatMessage[] = [
|
||||||
|
{ id: "2", sender: "agent", agent: "OrderBot", content: "Sure, let me check.", timestamp: Date.now() },
|
||||||
|
];
|
||||||
|
render(<ChatMessages messages={messages} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("OrderBot")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Sure, let me check.")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("AI")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows streaming cursor for messages being streamed", () => {
|
||||||
|
const messages: ChatMessage[] = [
|
||||||
|
{ id: "3", sender: "agent", agent: "Bot", content: "Processing", timestamp: Date.now(), isStreaming: true },
|
||||||
|
];
|
||||||
|
render(<ChatMessages messages={messages} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("|")).toBeInTheDocument();
|
||||||
|
expect(document.querySelector(".cursor-blink")).toBeTruthy();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows fallback agent label when agent field is missing", () => {
|
||||||
|
const messages: ChatMessage[] = [
|
||||||
|
{ id: "4", sender: "agent", content: "Generic response", timestamp: Date.now() },
|
||||||
|
];
|
||||||
|
render(<ChatMessages messages={messages} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import { useEffect, useRef } from "react";
|
import { useEffect, useRef } from "react";
|
||||||
|
import ReactMarkdown from "react-markdown";
|
||||||
import type { ChatMessage } from "../types";
|
import type { ChatMessage } from "../types";
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
@@ -13,70 +14,33 @@ export function ChatMessages({ messages }: Props) {
|
|||||||
}, [messages]);
|
}, [messages]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div style={styles.container}>
|
<div className="chat-messages-container">
|
||||||
{messages.map((msg) => (
|
{messages.map((msg) => (
|
||||||
<div
|
<div key={msg.id} className="chat-message-row">
|
||||||
key={msg.id}
|
<div className={`avatar ${msg.sender === "user" ? "user" : "agent"}`}>
|
||||||
style={{
|
{msg.sender === "user" ? "Me" : "AI"}
|
||||||
...styles.message,
|
</div>
|
||||||
...(msg.sender === "user" ? styles.userMessage : styles.agentMessage),
|
<div className="message-body">
|
||||||
}}
|
<div className="message-sender">
|
||||||
>
|
{msg.sender === "user" ? "You" : msg.agent || "Agent"}
|
||||||
<div style={styles.header}>
|
</div>
|
||||||
<span style={styles.sender}>
|
<div className="message-content md-prose">
|
||||||
{msg.sender === "user" ? "You" : msg.agent || "Agent"}
|
<ReactMarkdown>{msg.content}</ReactMarkdown>
|
||||||
</span>
|
{msg.isStreaming && <span className="cursor-blink">|</span>}
|
||||||
</div>
|
</div>
|
||||||
<div style={styles.content}>
|
|
||||||
{msg.content}
|
|
||||||
{msg.isStreaming && <span style={styles.cursor}>|</span>}
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
|
{messages.length === 0 && (
|
||||||
|
<div className="chat-message-row">
|
||||||
|
<div className="avatar agent">AI</div>
|
||||||
|
<div className="message-body">
|
||||||
|
<div className="message-sender">Smart Support</div>
|
||||||
|
<div className="message-content">Hello! How can I help you today?</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<div ref={bottomRef} />
|
<div ref={bottomRef} />
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const styles: Record<string, React.CSSProperties> = {
|
|
||||||
container: {
|
|
||||||
flex: 1,
|
|
||||||
overflowY: "auto",
|
|
||||||
padding: "16px",
|
|
||||||
display: "flex",
|
|
||||||
flexDirection: "column",
|
|
||||||
gap: "12px",
|
|
||||||
},
|
|
||||||
message: {
|
|
||||||
maxWidth: "80%",
|
|
||||||
padding: "10px 14px",
|
|
||||||
borderRadius: "12px",
|
|
||||||
lineHeight: 1.5,
|
|
||||||
},
|
|
||||||
userMessage: {
|
|
||||||
alignSelf: "flex-end",
|
|
||||||
background: "#0066cc",
|
|
||||||
color: "white",
|
|
||||||
},
|
|
||||||
agentMessage: {
|
|
||||||
alignSelf: "flex-start",
|
|
||||||
background: "#f0f0f0",
|
|
||||||
color: "#333",
|
|
||||||
},
|
|
||||||
header: {
|
|
||||||
marginBottom: "4px",
|
|
||||||
},
|
|
||||||
sender: {
|
|
||||||
fontSize: "12px",
|
|
||||||
fontWeight: 600,
|
|
||||||
opacity: 0.8,
|
|
||||||
},
|
|
||||||
content: {
|
|
||||||
fontSize: "14px",
|
|
||||||
whiteSpace: "pre-wrap",
|
|
||||||
},
|
|
||||||
cursor: {
|
|
||||||
animation: "blink 1s infinite",
|
|
||||||
opacity: 0.7,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|||||||
33
frontend/src/components/ErrorBanner.test.tsx
Normal file
33
frontend/src/components/ErrorBanner.test.tsx
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import { describe, it, expect, vi } from "vitest";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import { ErrorBanner } from "./ErrorBanner";
|
||||||
|
|
||||||
|
describe("ErrorBanner", () => {
|
||||||
|
it("returns null when status is connected", () => {
|
||||||
|
const { container } = render(<ErrorBanner status="connected" />);
|
||||||
|
expect(container.innerHTML).toBe("");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows disconnection message when status is disconnected", () => {
|
||||||
|
render(<ErrorBanner status="disconnected" onReconnect={vi.fn()} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Disconnected from server. Retrying...")).toBeInTheDocument();
|
||||||
|
expect(screen.getByRole("alert")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows connecting message when status is connecting", () => {
|
||||||
|
render(<ErrorBanner status="connecting" />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Connecting to server...")).toBeInTheDocument();
|
||||||
|
// No reconnect button while connecting
|
||||||
|
expect(screen.queryByText("Reconnect")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calls onReconnect when reconnect button is clicked", () => {
|
||||||
|
const onReconnect = vi.fn();
|
||||||
|
render(<ErrorBanner status="disconnected" onReconnect={onReconnect} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Reconnect"));
|
||||||
|
expect(onReconnect).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
});
|
||||||
58
frontend/src/components/InterruptPrompt.test.tsx
Normal file
58
frontend/src/components/InterruptPrompt.test.tsx
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import { describe, it, expect, vi } from "vitest";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import { InterruptPrompt } from "./InterruptPrompt";
|
||||||
|
import type { InterruptMessage } from "../types";
|
||||||
|
|
||||||
|
describe("InterruptPrompt", () => {
|
||||||
|
const baseInterrupt: InterruptMessage = {
|
||||||
|
type: "interrupt",
|
||||||
|
thread_id: "t1",
|
||||||
|
action: "cancel_order",
|
||||||
|
params: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
it("renders action name and approval title", () => {
|
||||||
|
render(<InterruptPrompt interrupt={baseInterrupt} onRespond={vi.fn()} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Action Requires Approval")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("cancel_order")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calls onRespond with true when Approve button is clicked", () => {
|
||||||
|
const onRespond = vi.fn();
|
||||||
|
render(<InterruptPrompt interrupt={baseInterrupt} onRespond={onRespond} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Approve Action"));
|
||||||
|
expect(onRespond).toHaveBeenCalledWith(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calls onRespond with false when Reject button is clicked", () => {
|
||||||
|
const onRespond = vi.fn();
|
||||||
|
render(<InterruptPrompt interrupt={baseInterrupt} onRespond={onRespond} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Reject & Escalate"));
|
||||||
|
expect(onRespond).toHaveBeenCalledWith(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("displays order_id parameter when present", () => {
|
||||||
|
const interrupt: InterruptMessage = {
|
||||||
|
...baseInterrupt,
|
||||||
|
params: { order_id: "ORD-12345" },
|
||||||
|
};
|
||||||
|
render(<InterruptPrompt interrupt={interrupt} onRespond={vi.fn()} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Target Order ID")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("ORD-12345")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("displays message parameter when present", () => {
|
||||||
|
const interrupt: InterruptMessage = {
|
||||||
|
...baseInterrupt,
|
||||||
|
params: { message: "This will refund $50" },
|
||||||
|
};
|
||||||
|
render(<InterruptPrompt interrupt={interrupt} onRespond={vi.fn()} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Detail Message")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("This will refund $50")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -7,75 +7,49 @@ interface Props {
|
|||||||
|
|
||||||
export function InterruptPrompt({ interrupt, onRespond }: Props) {
|
export function InterruptPrompt({ interrupt, onRespond }: Props) {
|
||||||
return (
|
return (
|
||||||
<div style={styles.container}>
|
<div className="action-card-container">
|
||||||
<div style={styles.header}>Action Requires Approval</div>
|
<div className="action-card">
|
||||||
<div style={styles.action}>
|
<div className="action-card-header">
|
||||||
<strong>Action:</strong> {interrupt.action}
|
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="var(--brand-accent)" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round">
|
||||||
|
<path d="M10.29 3.86L1.82 18a2 2 0 0 0 1.71 3h16.94a2 2 0 0 0 1.71-3L13.71 3.86a2 2 0 0 0-3.42 0z"></path>
|
||||||
|
<line x1="12" y1="9" x2="12" y2="13"></line>
|
||||||
|
<line x1="12" y1="17" x2="12.01" y2="17"></line>
|
||||||
|
</svg>
|
||||||
|
<h3 className="action-card-title">Action Requires Approval</h3>
|
||||||
|
<div style={{ flex: 1 }} />
|
||||||
|
<span className="action-card-badge">Pending</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div className="action-card-body">
|
||||||
|
<div className="action-detail-row">
|
||||||
|
<span className="action-detail-label">Action Name</span>
|
||||||
|
<span className="action-detail-value" style={{ fontWeight: 600 }}>{interrupt.action}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
{"message" in interrupt.params && interrupt.params.message != null && (
|
{"message" in interrupt.params && interrupt.params.message != null && (
|
||||||
<div style={styles.detail}>{String(interrupt.params.message)}</div>
|
<div className="action-detail-row">
|
||||||
)}
|
<span className="action-detail-label">Detail Message</span>
|
||||||
{"order_id" in interrupt.params && interrupt.params.order_id != null && (
|
<span className="action-detail-value">{String(interrupt.params.message)}</span>
|
||||||
<div style={styles.detail}>
|
|
||||||
<strong>Order:</strong> {String(interrupt.params.order_id)}
|
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<div style={styles.buttons}>
|
|
||||||
<button onClick={() => onRespond(true)} style={styles.approveBtn}>
|
{"order_id" in interrupt.params && interrupt.params.order_id != null && (
|
||||||
Approve
|
<div className="action-detail-row">
|
||||||
|
<span className="action-detail-label">Target Order ID</span>
|
||||||
|
<span className="action-detail-value">{String(interrupt.params.order_id)}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="action-card-footer">
|
||||||
|
<button className="btn btn-secondary" onClick={() => onRespond(false)}>
|
||||||
|
Reject & Escalate
|
||||||
</button>
|
</button>
|
||||||
<button onClick={() => onRespond(false)} style={styles.rejectBtn}>
|
<button className="btn btn-primary" onClick={() => onRespond(true)}>
|
||||||
Reject
|
Approve Action
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const styles: Record<string, React.CSSProperties> = {
|
|
||||||
container: {
|
|
||||||
margin: "12px 16px",
|
|
||||||
padding: "16px",
|
|
||||||
border: "2px solid #ff9800",
|
|
||||||
borderRadius: "12px",
|
|
||||||
background: "#fff8e1",
|
|
||||||
},
|
|
||||||
header: {
|
|
||||||
fontWeight: 700,
|
|
||||||
fontSize: "14px",
|
|
||||||
color: "#e65100",
|
|
||||||
marginBottom: "8px",
|
|
||||||
},
|
|
||||||
action: {
|
|
||||||
fontSize: "14px",
|
|
||||||
marginBottom: "4px",
|
|
||||||
},
|
|
||||||
detail: {
|
|
||||||
fontSize: "13px",
|
|
||||||
color: "#555",
|
|
||||||
marginBottom: "4px",
|
|
||||||
},
|
|
||||||
buttons: {
|
|
||||||
display: "flex",
|
|
||||||
gap: "8px",
|
|
||||||
marginTop: "12px",
|
|
||||||
},
|
|
||||||
approveBtn: {
|
|
||||||
padding: "8px 20px",
|
|
||||||
background: "#4caf50",
|
|
||||||
color: "white",
|
|
||||||
border: "none",
|
|
||||||
borderRadius: "6px",
|
|
||||||
cursor: "pointer",
|
|
||||||
fontWeight: 600,
|
|
||||||
},
|
|
||||||
rejectBtn: {
|
|
||||||
padding: "8px 20px",
|
|
||||||
background: "#f44336",
|
|
||||||
color: "white",
|
|
||||||
border: "none",
|
|
||||||
borderRadius: "6px",
|
|
||||||
cursor: "pointer",
|
|
||||||
fontWeight: 600,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|||||||
39
frontend/src/components/Layout.test.tsx
Normal file
39
frontend/src/components/Layout.test.tsx
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import { describe, it, expect, vi } from "vitest";
|
||||||
|
import { render, screen } from "@testing-library/react";
|
||||||
|
import { MemoryRouter, Routes, Route } from "react-router-dom";
|
||||||
|
import { Layout } from "./Layout";
|
||||||
|
|
||||||
|
// Mock NavBar to simplify layout tests
|
||||||
|
vi.mock("./NavBar", () => ({
|
||||||
|
NavBar: () => <nav data-testid="navbar">NavBar</nav>,
|
||||||
|
}));
|
||||||
|
|
||||||
|
function renderLayout(path = "/") {
|
||||||
|
return render(
|
||||||
|
<MemoryRouter initialEntries={[path]}>
|
||||||
|
<Routes>
|
||||||
|
<Route element={<Layout />}>
|
||||||
|
<Route path="/" element={<div>Home Content</div>} />
|
||||||
|
<Route path="/dashboard" element={<div>Dashboard Content</div>} />
|
||||||
|
</Route>
|
||||||
|
</Routes>
|
||||||
|
</MemoryRouter>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("Layout", () => {
|
||||||
|
it("renders NavBar component", () => {
|
||||||
|
renderLayout();
|
||||||
|
expect(screen.getByTestId("navbar")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders child route content via Outlet", () => {
|
||||||
|
renderLayout("/");
|
||||||
|
expect(screen.getByText("Home Content")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders correct content for different routes", () => {
|
||||||
|
renderLayout("/dashboard");
|
||||||
|
expect(screen.getByText("Dashboard Content")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -3,9 +3,9 @@ import { NavBar } from "./NavBar";
|
|||||||
|
|
||||||
export function Layout() {
|
export function Layout() {
|
||||||
return (
|
return (
|
||||||
<div style={{ display: "flex", flexDirection: "column", height: "100vh" }}>
|
<div className="app-layout">
|
||||||
<NavBar />
|
<NavBar />
|
||||||
<main style={{ flex: 1, overflow: "auto" }}>
|
<main className="app-main">
|
||||||
<Outlet />
|
<Outlet />
|
||||||
</main>
|
</main>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
28
frontend/src/components/MetricCard.test.tsx
Normal file
28
frontend/src/components/MetricCard.test.tsx
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import { describe, it, expect } from "vitest";
|
||||||
|
import { render, screen } from "@testing-library/react";
|
||||||
|
import { MetricCard } from "./MetricCard";
|
||||||
|
|
||||||
|
describe("MetricCard", () => {
|
||||||
|
it("renders label and value", () => {
|
||||||
|
render(<MetricCard label="Total Users" value={42} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Total Users")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("42")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders with unit prefix and suffix", () => {
|
||||||
|
render(<MetricCard label="Cost" value="3.50" unit="$" suffix="/mo" />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Cost")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("$")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("3.50")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("/mo")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles zero value correctly", () => {
|
||||||
|
render(<MetricCard label="Errors" value={0} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Errors")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("0")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
54
frontend/src/components/NavBar.test.tsx
Normal file
54
frontend/src/components/NavBar.test.tsx
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import { describe, it, expect } from "vitest";
|
||||||
|
import { render, screen } from "@testing-library/react";
|
||||||
|
import { MemoryRouter } from "react-router-dom";
|
||||||
|
import { NavBar } from "./NavBar";
|
||||||
|
|
||||||
|
function renderNavBar(initialPath = "/") {
|
||||||
|
return render(
|
||||||
|
<MemoryRouter initialEntries={[initialPath]}>
|
||||||
|
<NavBar />
|
||||||
|
</MemoryRouter>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("NavBar", () => {
|
||||||
|
it("renders all navigation links", () => {
|
||||||
|
renderNavBar();
|
||||||
|
|
||||||
|
expect(screen.getByText("Dashboard")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Inbox")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Conversation Replay")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Agents & Tools")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("navigation links point to correct routes", () => {
|
||||||
|
renderNavBar();
|
||||||
|
|
||||||
|
const dashboardLink = screen.getByText("Dashboard").closest("a");
|
||||||
|
expect(dashboardLink).toHaveAttribute("href", "/dashboard");
|
||||||
|
|
||||||
|
const inboxLink = screen.getByText("Inbox").closest("a");
|
||||||
|
expect(inboxLink).toHaveAttribute("href", "/");
|
||||||
|
|
||||||
|
const replayLink = screen.getByText("Conversation Replay").closest("a");
|
||||||
|
expect(replayLink).toHaveAttribute("href", "/replay");
|
||||||
|
|
||||||
|
const reviewLink = screen.getByText("Agents & Tools").closest("a");
|
||||||
|
expect(reviewLink).toHaveAttribute("href", "/review");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("active link has active class when on matching route", () => {
|
||||||
|
renderNavBar("/dashboard");
|
||||||
|
|
||||||
|
const dashboardLink = screen.getByText("Dashboard").closest("a");
|
||||||
|
expect(dashboardLink?.className).toContain("active");
|
||||||
|
|
||||||
|
const inboxLink = screen.getByText("Inbox").closest("a");
|
||||||
|
expect(inboxLink?.className).not.toContain("active");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders brand name", () => {
|
||||||
|
renderNavBar();
|
||||||
|
expect(screen.getByText("Nexus AI")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,64 +1,56 @@
|
|||||||
import { NavLink } from "react-router-dom";
|
import { NavLink } from "react-router-dom";
|
||||||
|
|
||||||
const navLinks = [
|
const navLinks = [
|
||||||
{ to: "/", label: "Chat", exact: true },
|
{ to: "/dashboard", label: "Dashboard", icon: "grid" },
|
||||||
{ to: "/replay", label: "Replay" },
|
{ to: "/", label: "Inbox", icon: "inbox" },
|
||||||
{ to: "/dashboard", label: "Dashboard" },
|
{ to: "/replay", label: "Conversation Replay", icon: "play" },
|
||||||
{ to: "/review", label: "API Review" },
|
{ to: "/review", label: "Agents & Tools", icon: "cpu" },
|
||||||
];
|
];
|
||||||
|
|
||||||
const styles: Record<string, React.CSSProperties> = {
|
function getIcon(name: string) {
|
||||||
nav: {
|
switch (name) {
|
||||||
display: "flex",
|
case "grid": return <svg width="18" height="18" fill="none" stroke="currentColor" strokeWidth="2" viewBox="0 0 24 24"><rect x="3" y="3" width="7" height="7"></rect><rect x="14" y="3" width="7" height="7"></rect><rect x="14" y="14" width="7" height="7"></rect><rect x="3" y="14" width="7" height="7"></rect></svg>;
|
||||||
alignItems: "center",
|
case "inbox": return <svg width="18" height="18" fill="none" stroke="currentColor" strokeWidth="2" viewBox="0 0 24 24"><polyline points="22 12 16 12 14 15 10 15 8 12 2 12"></polyline><path d="M5.45 5.11L2 12v6a2 2 0 0 0 2 2h16a2 2 0 0 0 2-2v-6l-3.45-6.89A2 2 0 0 0 16.76 4H7.24a2 2 0 0 0-1.79 1.11z"></path></svg>;
|
||||||
gap: "0",
|
case "play": return <svg width="18" height="18" fill="none" stroke="currentColor" strokeWidth="2" viewBox="0 0 24 24"><polygon points="5 3 19 12 5 21 5 3"></polygon></svg>;
|
||||||
padding: "0 16px",
|
case "cpu": return <svg width="18" height="18" fill="none" stroke="currentColor" strokeWidth="2" viewBox="0 0 24 24"><rect x="4" y="4" width="16" height="16" rx="2" ry="2"></rect><rect x="9" y="9" width="6" height="6"></rect><line x1="9" y1="1" x2="9" y2="4"></line><line x1="15" y1="1" x2="15" y2="4"></line><line x1="9" y1="20" x2="9" y2="23"></line><line x1="15" y1="20" x2="15" y2="23"></line><line x1="20" y1="9" x2="23" y2="9"></line><line x1="20" y1="14" x2="23" y2="14"></line><line x1="1" y1="9" x2="4" y2="9"></line><line x1="1" y1="14" x2="4" y2="14"></line></svg>;
|
||||||
borderBottom: "1px solid #e0e0e0",
|
default: return null;
|
||||||
background: "#fff",
|
}
|
||||||
height: "48px",
|
}
|
||||||
boxShadow: "0 1px 4px rgba(0,0,0,0.06)",
|
|
||||||
},
|
|
||||||
brand: {
|
|
||||||
fontWeight: 700,
|
|
||||||
fontSize: "16px",
|
|
||||||
color: "#1a1a1a",
|
|
||||||
marginRight: "24px",
|
|
||||||
textDecoration: "none",
|
|
||||||
},
|
|
||||||
link: {
|
|
||||||
padding: "0 14px",
|
|
||||||
height: "48px",
|
|
||||||
display: "flex",
|
|
||||||
alignItems: "center",
|
|
||||||
fontSize: "14px",
|
|
||||||
color: "#555",
|
|
||||||
textDecoration: "none",
|
|
||||||
borderBottom: "2px solid transparent",
|
|
||||||
transition: "color 0.15s, border-color 0.15s",
|
|
||||||
},
|
|
||||||
activeLink: {
|
|
||||||
color: "#1976d2",
|
|
||||||
borderBottom: "2px solid #1976d2",
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export function NavBar() {
|
export function NavBar() {
|
||||||
return (
|
return (
|
||||||
<nav style={styles.nav}>
|
<nav className="app-sidebar">
|
||||||
<span style={styles.brand}>Smart Support</span>
|
<div className="brand-header">
|
||||||
{navLinks.map(({ to, label }) => (
|
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" className="brand-logo-svg">
|
||||||
|
<circle cx="12" cy="12" r="10"></circle>
|
||||||
|
<path d="M8 14s1.5 2 4 2 4-2 4-2"></path>
|
||||||
|
<line x1="9" y1="9" x2="9.01" y2="9"></line>
|
||||||
|
<line x1="15" y1="9" x2="15.01" y2="9"></line>
|
||||||
|
</svg>
|
||||||
|
<span style={{ fontSize: "1.25rem", letterSpacing: "-0.03em" }}>Nexus AI</span>
|
||||||
|
</div>
|
||||||
|
<div className="nav-links" style={{ marginTop: "1rem" }}>
|
||||||
|
{navLinks.map(({ to, label, icon }) => (
|
||||||
<NavLink
|
<NavLink
|
||||||
key={to}
|
key={to}
|
||||||
to={to}
|
to={to}
|
||||||
end={to === "/"}
|
end={to === "/"}
|
||||||
style={({ isActive }) => ({
|
className={({ isActive }) => `nav-link ${isActive ? "active" : ""}`}
|
||||||
...styles.link,
|
style={{ display: "flex", gap: "12px", padding: "0.875rem 1rem", fontSize: "0.9375rem" }}
|
||||||
...(isActive ? styles.activeLink : {}),
|
|
||||||
})}
|
|
||||||
>
|
>
|
||||||
|
<span style={{ opacity: 0.7 }}>{getIcon(icon)}</span>
|
||||||
{label}
|
{label}
|
||||||
</NavLink>
|
</NavLink>
|
||||||
))}
|
))}
|
||||||
|
</div>
|
||||||
|
<div style={{ flex: 1 }} />
|
||||||
|
<div style={{ display: "flex", alignItems: "center", gap: "12px", borderTop: "1px solid var(--border-light)", paddingTop: "1rem" }}>
|
||||||
|
<div style={{ width: "36px", height: "36px", borderRadius: "50%", background: "var(--brand-primary)", color: "white", display: "flex", alignItems: "center", justifyContent: "center", fontWeight: "bold" }}>A</div>
|
||||||
|
<div>
|
||||||
|
<div style={{ fontWeight: 600, fontSize: "0.875rem" }}>Alex Thompson</div>
|
||||||
|
<div style={{ fontSize: "0.75rem", color: "var(--text-secondary)" }}>Nexus Corp</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</nav>
|
</nav>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
69
frontend/src/components/ReplayTimeline.test.tsx
Normal file
69
frontend/src/components/ReplayTimeline.test.tsx
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import { describe, it, expect } from "vitest";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import { ReplayTimeline } from "./ReplayTimeline";
|
||||||
|
import type { ReplayStep } from "../api";
|
||||||
|
|
||||||
|
function makeStep(overrides: Partial<ReplayStep> = {}): ReplayStep {
|
||||||
|
return {
|
||||||
|
step: 1,
|
||||||
|
type: "message",
|
||||||
|
content: "Hello",
|
||||||
|
agent: null,
|
||||||
|
tool: null,
|
||||||
|
params: null,
|
||||||
|
result: null,
|
||||||
|
timestamp: "2026-04-01T12:00:00Z",
|
||||||
|
...overrides,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("ReplayTimeline", () => {
|
||||||
|
it("returns null when steps array is empty", () => {
|
||||||
|
const { container } = render(<ReplayTimeline steps={[]} />);
|
||||||
|
expect(container.innerHTML).toBe("");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders a list of steps with type badges", () => {
|
||||||
|
const steps = [
|
||||||
|
makeStep({ step: 1, type: "message", content: "User said hi" }),
|
||||||
|
makeStep({ step: 2, type: "tool_call", content: "Calling API", agent: "OrderBot", tool: "get_order" }),
|
||||||
|
];
|
||||||
|
render(<ReplayTimeline steps={steps} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("message")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("tool call")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("User said hi")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("OrderBot")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("get_order()")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("expands step details when View JSON Payload button is clicked", () => {
|
||||||
|
const steps = [
|
||||||
|
makeStep({
|
||||||
|
step: 1,
|
||||||
|
type: "tool_call",
|
||||||
|
params: { order_id: "123" },
|
||||||
|
result: { status: "ok" },
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
render(<ReplayTimeline steps={steps} />);
|
||||||
|
|
||||||
|
const expandButton = screen.getByText("View JSON Payload", { exact: false });
|
||||||
|
expect(expandButton).toBeInTheDocument();
|
||||||
|
|
||||||
|
fireEvent.click(expandButton);
|
||||||
|
|
||||||
|
// After expanding, the JSON payload should be visible
|
||||||
|
expect(screen.getByText("Hide JSON Payload", { exact: false })).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(/"order_id": "123"/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("does not show expand button when step has no params or result", () => {
|
||||||
|
const steps = [
|
||||||
|
makeStep({ step: 1, type: "message", params: null, result: null }),
|
||||||
|
];
|
||||||
|
render(<ReplayTimeline steps={steps} />);
|
||||||
|
|
||||||
|
expect(screen.queryByText("View JSON Payload", { exact: false })).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -2,31 +2,31 @@ import { useState } from "react";
|
|||||||
import type { ReplayStep } from "../api";
|
import type { ReplayStep } from "../api";
|
||||||
|
|
||||||
const TYPE_COLORS: Record<string, string> = {
|
const TYPE_COLORS: Record<string, string> = {
|
||||||
message: "#1976d2",
|
message: "var(--brand-primary)",
|
||||||
token: "#388e3c",
|
token: "#9CA3AF", // Soft gray
|
||||||
tool_call: "#f57c00",
|
tool_call: "#D97706", // Amber
|
||||||
tool_result: "#7b1fa2",
|
tool_result: "#059669", // Emerald
|
||||||
interrupt: "#d32f2f",
|
interrupt: "#DC2626", // Red for wait
|
||||||
interrupt_response: "#c2185b",
|
interrupt_response: "#7C3AED", // Purple for human action
|
||||||
error: "#c62828",
|
error: "#991B1B", // Dark red
|
||||||
};
|
};
|
||||||
|
|
||||||
function TypeBadge({ type }: { type: string }) {
|
function TypeBadge({ type }: { type: string }) {
|
||||||
const color = TYPE_COLORS[type] ?? "#555";
|
const color = TYPE_COLORS[type] ?? "var(--text-secondary)";
|
||||||
return (
|
return (
|
||||||
<span
|
<span
|
||||||
style={{
|
style={{
|
||||||
background: color,
|
background: color,
|
||||||
color: "#fff",
|
color: "#fff",
|
||||||
fontSize: "11px",
|
fontSize: "0.65rem",
|
||||||
fontWeight: 600,
|
fontWeight: 700,
|
||||||
padding: "2px 7px",
|
padding: "0.2rem 0.5rem",
|
||||||
borderRadius: "10px",
|
borderRadius: "99px",
|
||||||
textTransform: "uppercase",
|
textTransform: "uppercase",
|
||||||
letterSpacing: "0.5px",
|
letterSpacing: "0.05em",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{type}
|
{type.replace("_", " ")}
|
||||||
</span>
|
</span>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -38,9 +38,9 @@ function ReplayStepItem({ step }: { step: ReplayStep }) {
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
style={{
|
style={{
|
||||||
borderLeft: "2px solid #e0e0e0",
|
borderLeft: "2px solid var(--border-light)",
|
||||||
paddingLeft: "12px",
|
paddingLeft: "1.25rem",
|
||||||
marginBottom: "12px",
|
paddingBottom: "1.5rem",
|
||||||
position: "relative",
|
position: "relative",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
@@ -48,70 +48,91 @@ function ReplayStepItem({ step }: { step: ReplayStep }) {
|
|||||||
style={{
|
style={{
|
||||||
position: "absolute",
|
position: "absolute",
|
||||||
left: "-5px",
|
left: "-5px",
|
||||||
top: "4px",
|
top: "6px",
|
||||||
width: "8px",
|
width: "8px",
|
||||||
height: "8px",
|
height: "8px",
|
||||||
borderRadius: "50%",
|
borderRadius: "50%",
|
||||||
background: TYPE_COLORS[step.type] ?? "#555",
|
background: TYPE_COLORS[step.type] ?? "var(--text-secondary)",
|
||||||
|
boxShadow: `0 0 0 4px var(--bg-surface)`
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<div style={{ display: "flex", alignItems: "center", gap: "8px", marginBottom: "4px" }}>
|
|
||||||
<span style={{ fontSize: "11px", color: "#888" }}>#{step.step}</span>
|
<div style={{ display: "flex", alignItems: "center", gap: "0.75rem", marginBottom: "0.5rem" }}>
|
||||||
<TypeBadge type={step.type} />
|
<TypeBadge type={step.type} />
|
||||||
{step.agent && (
|
{step.agent && (
|
||||||
<span style={{ fontSize: "11px", color: "#666", fontStyle: "italic" }}>
|
<span style={{ fontSize: "0.8125rem", color: "var(--text-primary)", fontWeight: 600 }}>
|
||||||
{step.agent}
|
{step.agent}
|
||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
{step.tool && (
|
{step.tool && (
|
||||||
<span style={{ fontSize: "11px", color: "#555" }}>
|
<span style={{ fontSize: "0.8125rem", color: "var(--text-secondary)", fontFamily: "monospace", backgroundColor: "var(--bg-app)", padding: "2px 6px", borderRadius: "4px" }}>
|
||||||
tool: <strong>{step.tool}</strong>
|
{step.tool}()
|
||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
<span style={{ fontSize: "11px", color: "#aaa", marginLeft: "auto" }}>
|
<span style={{ fontSize: "0.75rem", color: "var(--text-secondary)", marginLeft: "auto" }}>
|
||||||
{new Date(step.timestamp).toLocaleTimeString()}
|
{new Date(step.timestamp).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' })}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{step.content && (
|
{step.content && (
|
||||||
<div
|
<div
|
||||||
style={{
|
style={
|
||||||
fontSize: "13px",
|
["message", "interrupt", "interrupt_response"].includes(step.type)
|
||||||
color: "#333",
|
? {
|
||||||
background: "#f9f9f9",
|
fontSize: "0.9375rem",
|
||||||
padding: "6px 10px",
|
color: "var(--text-primary)",
|
||||||
borderRadius: "4px",
|
background: step.type === "interrupt" ? "#FEF2F2" : (step.type === "interrupt_response" ? "#F5F3FF" : "var(--bg-app)"),
|
||||||
maxHeight: "80px",
|
border: step.type === "interrupt" ? "1px solid #FECACA" : (step.type === "interrupt_response" ? "1px solid #DDD6FE" : "1px solid var(--border-light)"),
|
||||||
overflow: "hidden",
|
padding: "0.875rem 1rem",
|
||||||
textOverflow: "ellipsis",
|
borderRadius: "var(--radius-md)",
|
||||||
}}
|
lineHeight: 1.5,
|
||||||
|
whiteSpace: "pre-wrap"
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
fontSize: "0.8125rem",
|
||||||
|
color: "var(--text-secondary)",
|
||||||
|
padding: "0.25rem 0",
|
||||||
|
fontStyle: "italic",
|
||||||
|
lineHeight: 1.4
|
||||||
|
}
|
||||||
|
}
|
||||||
>
|
>
|
||||||
{step.content}
|
{step.content}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{hasDetails && (
|
{hasDetails && (
|
||||||
<button
|
<button
|
||||||
onClick={() => setExpanded((v) => !v)}
|
onClick={() => setExpanded((v) => !v)}
|
||||||
style={{
|
style={{
|
||||||
background: "none",
|
background: "none",
|
||||||
border: "none",
|
border: "none",
|
||||||
color: "#1976d2",
|
color: "var(--text-secondary)",
|
||||||
cursor: "pointer",
|
cursor: "pointer",
|
||||||
fontSize: "12px",
|
fontSize: "0.75rem",
|
||||||
padding: "2px 0",
|
fontWeight: 600,
|
||||||
|
padding: "0.5rem 0 0 0",
|
||||||
|
display: "flex",
|
||||||
|
alignItems: "center",
|
||||||
|
gap: "0.25rem"
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{expanded ? "Hide details" : "Show details"}
|
{expanded ? "▼ Hide JSON Payload" : "▶ View JSON Payload"}
|
||||||
</button>
|
</button>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{expanded && hasDetails && (
|
{expanded && hasDetails && (
|
||||||
<pre
|
<pre
|
||||||
style={{
|
style={{
|
||||||
fontSize: "11px",
|
fontSize: "0.75rem",
|
||||||
background: "#f3f3f3",
|
background: "var(--text-primary)",
|
||||||
padding: "8px",
|
color: "white",
|
||||||
borderRadius: "4px",
|
padding: "1rem",
|
||||||
|
borderRadius: "var(--radius-md)",
|
||||||
overflow: "auto",
|
overflow: "auto",
|
||||||
maxHeight: "200px",
|
maxHeight: "250px",
|
||||||
|
marginTop: "0.5rem",
|
||||||
|
fontFamily: "monospace"
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{JSON.stringify({ params: step.params, result: step.result }, null, 2)}
|
{JSON.stringify({ params: step.params, result: step.result }, null, 2)}
|
||||||
@@ -126,19 +147,14 @@ interface ReplayTimelineProps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function ReplayTimeline({ steps }: ReplayTimelineProps) {
|
export function ReplayTimeline({ steps }: ReplayTimelineProps) {
|
||||||
if (steps.length === 0) {
|
if (!steps || steps.length === 0) return null;
|
||||||
return (
|
|
||||||
<div style={{ color: "#888", fontSize: "14px", padding: "16px 0" }}>
|
|
||||||
No steps recorded.
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div style={{ padding: "8px 0" }}>
|
<div style={{ marginLeft: "4px" }}>
|
||||||
{steps.map((step) => (
|
{steps.map((step) => (
|
||||||
<ReplayStepItem key={step.step} step={step} />
|
<ReplayStepItem key={step.step} step={step} />
|
||||||
))}
|
))}
|
||||||
|
<div style={{ borderLeft: "2px dashed var(--border-light)", height: "20px", marginLeft: "0px", opacity: 0.5 }}></div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
221
frontend/src/hooks/useWebSocket.test.ts
Normal file
221
frontend/src/hooks/useWebSocket.test.ts
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||||
|
import { renderHook, act } from "@testing-library/react";
|
||||||
|
import { useWebSocket } from "./useWebSocket";
|
||||||
|
|
||||||
|
// Mock sessionStorage
|
||||||
|
const mockSessionStorage: Record<string, string> = {};
|
||||||
|
vi.stubGlobal("sessionStorage", {
|
||||||
|
getItem: (key: string) => mockSessionStorage[key] ?? null,
|
||||||
|
setItem: (key: string, value: string) => {
|
||||||
|
mockSessionStorage[key] = value;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock crypto.randomUUID
|
||||||
|
vi.stubGlobal("crypto", { randomUUID: () => "test-uuid-1234" });
|
||||||
|
|
||||||
|
// Mock WebSocket
|
||||||
|
class MockWebSocket {
|
||||||
|
static OPEN = 1;
|
||||||
|
static CLOSED = 3;
|
||||||
|
static instances: MockWebSocket[] = [];
|
||||||
|
|
||||||
|
url: string;
|
||||||
|
readyState = 0;
|
||||||
|
onopen: (() => void) | null = null;
|
||||||
|
onclose: (() => void) | null = null;
|
||||||
|
onmessage: ((event: { data: string }) => void) | null = null;
|
||||||
|
onerror: (() => void) | null = null;
|
||||||
|
send = vi.fn();
|
||||||
|
close = vi.fn().mockImplementation(() => {
|
||||||
|
this.readyState = MockWebSocket.CLOSED;
|
||||||
|
// Trigger onclose asynchronously like real WebSocket
|
||||||
|
setTimeout(() => this.onclose?.(), 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
constructor(url: string) {
|
||||||
|
this.url = url;
|
||||||
|
MockWebSocket.instances.push(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
simulateOpen() {
|
||||||
|
this.readyState = MockWebSocket.OPEN;
|
||||||
|
this.onopen?.();
|
||||||
|
}
|
||||||
|
|
||||||
|
simulateMessage(data: unknown) {
|
||||||
|
this.onmessage?.({ data: JSON.stringify(data) });
|
||||||
|
}
|
||||||
|
|
||||||
|
simulateClose() {
|
||||||
|
this.readyState = MockWebSocket.CLOSED;
|
||||||
|
this.onclose?.();
|
||||||
|
}
|
||||||
|
|
||||||
|
simulateError() {
|
||||||
|
this.onerror?.();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.stubGlobal("WebSocket", MockWebSocket);
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
MockWebSocket.instances = [];
|
||||||
|
delete mockSessionStorage["smart_support_thread_id"];
|
||||||
|
vi.useFakeTimers();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.useRealTimers();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("useWebSocket", () => {
|
||||||
|
it("establishes connection with correct URL on mount", () => {
|
||||||
|
const onMessage = vi.fn();
|
||||||
|
renderHook(() => useWebSocket(onMessage));
|
||||||
|
|
||||||
|
expect(MockWebSocket.instances).toHaveLength(1);
|
||||||
|
expect(MockWebSocket.instances[0].url).toContain("/ws");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("sets status to connected when WebSocket opens", () => {
|
||||||
|
const onMessage = vi.fn();
|
||||||
|
const { result } = renderHook(() => useWebSocket(onMessage));
|
||||||
|
|
||||||
|
expect(result.current.status).toBe("connecting");
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateOpen();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.current.status).toBe("connected");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("parses incoming JSON messages and dispatches to callback", () => {
|
||||||
|
const onMessage = vi.fn();
|
||||||
|
renderHook(() => useWebSocket(onMessage));
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateOpen();
|
||||||
|
});
|
||||||
|
|
||||||
|
const serverMsg = { type: "token", agent: "bot", content: "Hello" };
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateMessage(serverMsg);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(onMessage).toHaveBeenCalledWith(serverMsg);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("sends JSON through WebSocket via sendMessage", () => {
|
||||||
|
const onMessage = vi.fn();
|
||||||
|
const { result } = renderHook(() => useWebSocket(onMessage));
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateOpen();
|
||||||
|
});
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
result.current.sendMessage("Hi there");
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(MockWebSocket.instances[0].send).toHaveBeenCalledTimes(1);
|
||||||
|
const sent = JSON.parse(MockWebSocket.instances[0].send.mock.calls[0][0]);
|
||||||
|
expect(sent.type).toBe("message");
|
||||||
|
expect(sent.content).toBe("Hi there");
|
||||||
|
expect(sent.thread_id).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calls onDisconnect when WebSocket closes", () => {
|
||||||
|
const onMessage = vi.fn();
|
||||||
|
const onDisconnect = vi.fn();
|
||||||
|
renderHook(() => useWebSocket(onMessage, { onDisconnect }));
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateOpen();
|
||||||
|
});
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateClose();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(onDisconnect).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("sets status to disconnected on close and attempts reconnect", () => {
|
||||||
|
const onMessage = vi.fn();
|
||||||
|
const { result } = renderHook(() => useWebSocket(onMessage));
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateOpen();
|
||||||
|
});
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateClose();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.current.status).toBe("disconnected");
|
||||||
|
|
||||||
|
// After timeout, a new WebSocket should be created (reconnect attempt)
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(1500);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(MockWebSocket.instances.length).toBeGreaterThanOrEqual(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("closes WebSocket on error event", () => {
|
||||||
|
const onMessage = vi.fn();
|
||||||
|
renderHook(() => useWebSocket(onMessage));
|
||||||
|
|
||||||
|
const ws = MockWebSocket.instances[0];
|
||||||
|
act(() => {
|
||||||
|
ws.simulateError();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(ws.close).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("reconnect resets retries and creates a new connection", () => {
|
||||||
|
const onMessage = vi.fn();
|
||||||
|
const { result } = renderHook(() => useWebSocket(onMessage));
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateOpen();
|
||||||
|
});
|
||||||
|
|
||||||
|
const wsBeforeReconnect = MockWebSocket.instances[0];
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
result.current.reconnect();
|
||||||
|
});
|
||||||
|
|
||||||
|
// The old socket should have been closed
|
||||||
|
expect(wsBeforeReconnect.close).toHaveBeenCalled();
|
||||||
|
|
||||||
|
// Let the close callback fire and reconnect timer run
|
||||||
|
act(() => {
|
||||||
|
vi.advanceTimersByTime(100);
|
||||||
|
});
|
||||||
|
|
||||||
|
// A new WebSocket should have been created
|
||||||
|
expect(MockWebSocket.instances.length).toBeGreaterThan(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("sends interrupt response with approved flag", () => {
|
||||||
|
const onMessage = vi.fn();
|
||||||
|
const { result } = renderHook(() => useWebSocket(onMessage));
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
MockWebSocket.instances[0].simulateOpen();
|
||||||
|
});
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
result.current.sendInterruptResponse(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
const sent = JSON.parse(MockWebSocket.instances[0].send.mock.calls[0][0]);
|
||||||
|
expect(sent.type).toBe("interrupt_response");
|
||||||
|
expect(sent.approved).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
812
frontend/src/index.css
Normal file
812
frontend/src/index.css
Normal file
@@ -0,0 +1,812 @@
|
|||||||
|
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
|
||||||
|
|
||||||
|
:root {
|
||||||
|
/* Rich Warm Beige Theme based on Design Mockup */
|
||||||
|
--bg-app: #F4EFE7; /* Main app background (Sidebar & Main area) */
|
||||||
|
--bg-surface: #EBE4D8; /* Slightly darker for cards */
|
||||||
|
--bg-surface-inner: #F6F2EC; /* Lighter inner container */
|
||||||
|
--bg-hover: #E1D9CC; /* Hover state for sidebar and buttons */
|
||||||
|
|
||||||
|
--text-primary: #1C1917; /* Slate dark/brownish */
|
||||||
|
--text-secondary: #5C554D; /* Muted stone */
|
||||||
|
|
||||||
|
--border-light: #D5CCC0; /* Warm border */
|
||||||
|
--border-focus: #B6AAA0;
|
||||||
|
|
||||||
|
--brand-primary: #3B342D; /* Dark brown/grey for buttons */
|
||||||
|
--brand-hover: #26211C;
|
||||||
|
--brand-accent: #3B342D;
|
||||||
|
|
||||||
|
--shadow-sm: 0 2px 4px rgba(0, 0, 0, 0.02);
|
||||||
|
--shadow-md: 0 4px 12px rgba(0, 0, 0, 0.04);
|
||||||
|
--shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.06);
|
||||||
|
|
||||||
|
--font-sans: 'Inter', system-ui, -apple-system, sans-serif;
|
||||||
|
--radius-md: 10px;
|
||||||
|
--radius-lg: 16px;
|
||||||
|
--radius-xl: 24px;
|
||||||
|
}
|
||||||
|
|
||||||
|
* {
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
font-family: var(--font-sans);
|
||||||
|
background-color: #DBD2C6; /* Subtle deeper tone */
|
||||||
|
color: var(--text-primary);
|
||||||
|
-webkit-font-smoothing: antialiased;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Application Shell Layout */
|
||||||
|
.app-layout {
|
||||||
|
display: flex;
|
||||||
|
height: 100vh;
|
||||||
|
width: 100vw;
|
||||||
|
overflow: hidden;
|
||||||
|
background-color: var(--bg-surface);
|
||||||
|
box-shadow: none;
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (min-width: 768px) {
|
||||||
|
body {
|
||||||
|
padding: 1.5rem;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
.app-layout {
|
||||||
|
height: calc(100vh - 3rem);
|
||||||
|
width: 100%;
|
||||||
|
border-radius: 20px;
|
||||||
|
box-shadow: 0 10px 30px rgba(0,0,0,0.06), 0 0 0 1px rgba(0,0,0,0.02);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Sidebar layout */
|
||||||
|
.app-sidebar {
|
||||||
|
width: 260px;
|
||||||
|
background-color: transparent; /* Makes it blend into the main background */
|
||||||
|
border-right: none;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
padding: 1.5rem 1rem;
|
||||||
|
z-index: 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
.brand-header {
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 1.25rem;
|
||||||
|
color: var(--text-primary);
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
padding-left: 0.5rem;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.75rem;
|
||||||
|
letter-spacing: -0.01em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.nav-links {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.nav-link {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
border-radius: var(--radius-md);
|
||||||
|
color: var(--text-secondary);
|
||||||
|
text-decoration: none;
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 0.9375rem;
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.nav-link:hover {
|
||||||
|
background-color: var(--bg-hover);
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.nav-link.active {
|
||||||
|
background-color: var(--bg-hover);
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.app-main {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
position: relative;
|
||||||
|
overflow: hidden;
|
||||||
|
background-color: var(--bg-app);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* --- Chat Interface (Option B) --- */
|
||||||
|
.chat-page {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-header {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
padding: 1rem 1.5rem;
|
||||||
|
border-bottom: 1px solid var(--border-light);
|
||||||
|
background-color: transparent;
|
||||||
|
z-index: 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-header h1 {
|
||||||
|
font-size: 1.125rem;
|
||||||
|
font-weight: 600;
|
||||||
|
margin: 0;
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-messages-container {
|
||||||
|
flex: 1;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 2rem 0;
|
||||||
|
scroll-behavior: smooth;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message-row {
|
||||||
|
display: flex;
|
||||||
|
gap: 1.25rem;
|
||||||
|
padding: 1rem;
|
||||||
|
transition: background-color 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (min-width: 768px) {
|
||||||
|
.chat-message-row {
|
||||||
|
padding: 1rem 3rem;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message-row:hover {
|
||||||
|
background-color: var(--bg-hover);
|
||||||
|
}
|
||||||
|
|
||||||
|
.avatar {
|
||||||
|
flex-shrink: 0;
|
||||||
|
width: 32px;
|
||||||
|
height: 32px;
|
||||||
|
border-radius: 8px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.avatar.user {
|
||||||
|
background-color: var(--border-light);
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.avatar.agent {
|
||||||
|
background: linear-gradient(135deg, var(--brand-primary), #334155);
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body {
|
||||||
|
flex: 1;
|
||||||
|
min-width: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-sender {
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
margin-bottom: 0.375rem;
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-content {
|
||||||
|
font-size: 0.9375rem;
|
||||||
|
line-height: 1.6;
|
||||||
|
color: var(--text-primary);
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.cursor-blink {
|
||||||
|
animation: blink 1s infinite alternate;
|
||||||
|
font-weight: 700;
|
||||||
|
color: var(--brand-accent);
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes blink {
|
||||||
|
0% { opacity: 1; }
|
||||||
|
100% { opacity: 0.2; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Markdown Prose Styles */
|
||||||
|
.md-prose p {
|
||||||
|
margin: 0 0 0.75rem 0;
|
||||||
|
}
|
||||||
|
.md-prose p:last-child {
|
||||||
|
margin-bottom: 0;
|
||||||
|
}
|
||||||
|
.md-prose strong {
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
.md-prose ul, .md-prose ol {
|
||||||
|
margin: 0.25rem 0 0.75rem 0;
|
||||||
|
padding-left: 1.5rem;
|
||||||
|
}
|
||||||
|
.md-prose li {
|
||||||
|
margin-bottom: 0.25rem;
|
||||||
|
}
|
||||||
|
.md-prose pre {
|
||||||
|
background-color: var(--bg-hover);
|
||||||
|
padding: 0.75rem;
|
||||||
|
border-radius: var(--radius-md);
|
||||||
|
overflow-x: auto;
|
||||||
|
border: 1px solid var(--border-light);
|
||||||
|
font-size: 0.875rem;
|
||||||
|
}
|
||||||
|
.md-prose code {
|
||||||
|
font-family: monospace;
|
||||||
|
background-color: var(--bg-hover);
|
||||||
|
padding: 0.125rem 0.25rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
.md-prose pre code {
|
||||||
|
background-color: transparent;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Chat Input Bar */
|
||||||
|
.chat-input-container {
|
||||||
|
padding: 1.5rem;
|
||||||
|
background: linear-gradient(to top, var(--bg-app) 80%, transparent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input-wrapper {
|
||||||
|
margin: 0 1rem;
|
||||||
|
position: relative;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
box-shadow: var(--shadow-md);
|
||||||
|
border-radius: var(--radius-lg);
|
||||||
|
background-color: var(--bg-surface);
|
||||||
|
border: 1px solid var(--border-light);
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (min-width: 768px) {
|
||||||
|
.chat-input-wrapper {
|
||||||
|
margin: 0 3rem;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input-wrapper:focus-within {
|
||||||
|
border-color: var(--border-focus);
|
||||||
|
box-shadow: var(--shadow-lg);
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input-wrapper input {
|
||||||
|
flex: 1;
|
||||||
|
padding: 1rem 1.25rem;
|
||||||
|
border: none;
|
||||||
|
background: transparent;
|
||||||
|
font-size: 1rem;
|
||||||
|
font-family: inherit;
|
||||||
|
color: var(--text-primary);
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input-wrapper input::placeholder {
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-send-btn {
|
||||||
|
margin-right: 0.75rem;
|
||||||
|
width: 32px;
|
||||||
|
height: 32px;
|
||||||
|
border-radius: 6px;
|
||||||
|
background-color: var(--brand-primary);
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-send-btn:hover:not(:disabled) {
|
||||||
|
background-color: var(--brand-hover);
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-send-btn:disabled {
|
||||||
|
opacity: 0.5;
|
||||||
|
cursor: not-allowed;
|
||||||
|
background-color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* --- Human in the loop Action Card --- */
|
||||||
|
.action-card-container {
|
||||||
|
margin: 1.5rem 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (min-width: 768px) {
|
||||||
|
.action-card-container {
|
||||||
|
margin: 1.5rem 3rem;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-card {
|
||||||
|
background-color: var(--bg-surface);
|
||||||
|
border: 1px solid var(--border-light);
|
||||||
|
border-radius: var(--radius-xl);
|
||||||
|
box-shadow: var(--shadow-lg);
|
||||||
|
overflow: hidden;
|
||||||
|
position: relative;
|
||||||
|
transition: transform 0.2s, box-shadow 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-card::before {
|
||||||
|
content: '';
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
bottom: 0;
|
||||||
|
width: 4px;
|
||||||
|
background-color: var(--brand-accent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-card-header {
|
||||||
|
padding: 1.25rem 1.5rem 0.75rem 1.75rem;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.75rem;
|
||||||
|
border-bottom: 1px solid var(--bg-hover);
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-card-title {
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 1rem;
|
||||||
|
color: var(--text-primary);
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-card-badge {
|
||||||
|
background-color: #FEF2F2;
|
||||||
|
color: #B91C1C;
|
||||||
|
padding: 0.25rem 0.625rem;
|
||||||
|
border-radius: 9999px;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
font-weight: 600;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.05em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-card-body {
|
||||||
|
padding: 1.25rem 1.75rem;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.75rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-detail-row {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-detail-label {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
text-transform: uppercase;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
letter-spacing: 0.05em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-detail-value {
|
||||||
|
font-size: 0.9375rem;
|
||||||
|
color: var(--text-primary);
|
||||||
|
font-family: var(--font-sans);
|
||||||
|
background-color: var(--bg-hover);
|
||||||
|
padding: 0.5rem 0.75rem;
|
||||||
|
border-radius: var(--radius-md);
|
||||||
|
border: 1px solid var(--border-light);
|
||||||
|
word-break: break-word;
|
||||||
|
}
|
||||||
|
|
||||||
|
.action-card-footer {
|
||||||
|
padding: 1.25rem 1.75rem;
|
||||||
|
background-color: var(--bg-hover);
|
||||||
|
border-top: 1px solid var(--border-light);
|
||||||
|
display: flex;
|
||||||
|
justify-content: flex-end;
|
||||||
|
gap: 0.75rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
padding: 0.5rem 1.25rem;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 600;
|
||||||
|
border-radius: var(--radius-md);
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s;
|
||||||
|
border: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary {
|
||||||
|
background-color: var(--brand-accent);
|
||||||
|
color: white;
|
||||||
|
box-shadow: var(--shadow-sm);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary:hover {
|
||||||
|
background-color: #C2410C;
|
||||||
|
box-shadow: var(--shadow-md);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-secondary {
|
||||||
|
background-color: transparent;
|
||||||
|
color: var(--text-primary);
|
||||||
|
border: 1px solid var(--border-focus);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-secondary:hover {
|
||||||
|
background-color: var(--bg-surface);
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* --- Agent Card Grid (Option 2) --- */
|
||||||
|
.page-container {
|
||||||
|
padding: 2.5rem;
|
||||||
|
width: 100%;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.page-header {
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.page-header h2 {
|
||||||
|
font-size: 1.5rem;
|
||||||
|
font-weight: 700;
|
||||||
|
color: var(--text-primary);
|
||||||
|
margin: 0 0 0.5rem 0;
|
||||||
|
letter-spacing: -0.01em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.page-header p {
|
||||||
|
color: var(--text-secondary);
|
||||||
|
margin: 0;
|
||||||
|
font-size: 0.9375rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Form Styles */
|
||||||
|
.import-form {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.75rem;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
background-color: var(--bg-surface);
|
||||||
|
padding: 1.5rem;
|
||||||
|
border-radius: var(--radius-lg);
|
||||||
|
box-shadow: var(--shadow-sm);
|
||||||
|
border: 1px solid var(--border-light);
|
||||||
|
}
|
||||||
|
|
||||||
|
.import-input {
|
||||||
|
flex: 1;
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
border: 1px solid var(--border-light);
|
||||||
|
border-radius: var(--radius-md);
|
||||||
|
font-size: 0.9375rem;
|
||||||
|
font-family: inherit;
|
||||||
|
background-color: var(--bg-app);
|
||||||
|
color: var(--text-primary);
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.import-input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: var(--border-focus);
|
||||||
|
box-shadow: 0 0 0 3px rgba(0,0,0,0.03);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Agent Grid */
|
||||||
|
.agent-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(auto-fill, minmax(340px, 1fr));
|
||||||
|
gap: 1.5rem;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-grid-card {
|
||||||
|
background-color: var(--bg-surface);
|
||||||
|
border-radius: var(--radius-xl);
|
||||||
|
border: 1px solid var(--border-light);
|
||||||
|
box-shadow: var(--shadow-sm);
|
||||||
|
overflow: hidden;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
transition: transform 0.2s, box-shadow 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-grid-card:hover {
|
||||||
|
box-shadow: var(--shadow-md);
|
||||||
|
transform: translateY(-2px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-card-header-bg {
|
||||||
|
padding: 1.5rem 1.5rem 1rem 1.5rem;
|
||||||
|
border-bottom: 1px solid var(--bg-hover);
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-avatar-lg {
|
||||||
|
width: 48px;
|
||||||
|
height: 48px;
|
||||||
|
border-radius: 12px;
|
||||||
|
background: linear-gradient(135deg, var(--text-primary), var(--text-secondary));
|
||||||
|
color: white;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
font-size: 1.25rem;
|
||||||
|
font-weight: 700;
|
||||||
|
box-shadow: var(--shadow-sm);
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-card-meta h3 {
|
||||||
|
margin: 0 0 0.25rem 0;
|
||||||
|
font-size: 1.125rem;
|
||||||
|
font-weight: 700;
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-card-meta span {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: var(--brand-primary);
|
||||||
|
font-weight: 600;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.05em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-tools-list {
|
||||||
|
padding: 1.25rem 1.5rem;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.75rem;
|
||||||
|
flex: 1;
|
||||||
|
background-color: var(--bg-surface-inner);
|
||||||
|
border-radius: 20px;
|
||||||
|
margin: 0 1.25rem 1.25rem 1.25rem;
|
||||||
|
border: 1px solid var(--border-light);
|
||||||
|
}
|
||||||
|
|
||||||
|
.tool-pill-item {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.25rem;
|
||||||
|
padding-bottom: 0.75rem;
|
||||||
|
border-bottom: 1px solid var(--border-light);
|
||||||
|
}
|
||||||
|
|
||||||
|
.tool-pill-item:last-child {
|
||||||
|
border-bottom: none;
|
||||||
|
padding-bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tool-pill-header {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tool-method-badge {
|
||||||
|
font-size: 0.65rem;
|
||||||
|
font-weight: 700;
|
||||||
|
text-transform: uppercase;
|
||||||
|
padding: 0.2rem 0.5rem;
|
||||||
|
border-radius: 99px;
|
||||||
|
background-color: var(--text-primary);
|
||||||
|
color: white;
|
||||||
|
letter-spacing: 0.05em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tool-path-text {
|
||||||
|
font-family: var(--font-sans);
|
||||||
|
font-size: 0.8125rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--text-primary);
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tool-summary-text {
|
||||||
|
font-size: 0.8125rem;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
line-height: 1.4;
|
||||||
|
margin-top: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tool-pill-controls {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.5rem;
|
||||||
|
margin-top: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tool-select, .tool-input {
|
||||||
|
background-color: transparent;
|
||||||
|
border: 1px solid var(--border-focus);
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 0.375rem 0.625rem;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: var(--text-primary);
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tool-select:focus, .tool-input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* --- Shared Data Display Components --- */
|
||||||
|
|
||||||
|
.section-card {
|
||||||
|
background-color: var(--bg-surface);
|
||||||
|
border-radius: var(--radius-xl);
|
||||||
|
padding: 1.5rem;
|
||||||
|
border: 1px solid var(--border-light);
|
||||||
|
}
|
||||||
|
|
||||||
|
.stat-label {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
text-transform: uppercase;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
font-weight: 600;
|
||||||
|
letter-spacing: 0.05em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.stat-value {
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 0.9375rem;
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-badge {
|
||||||
|
display: inline-block;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
padding: 4px 10px;
|
||||||
|
border-radius: 6px;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-badge--resolved {
|
||||||
|
background-color: #DEF7EC;
|
||||||
|
color: #03543F;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-badge--escalated {
|
||||||
|
background-color: #FDE8E8;
|
||||||
|
color: #9B1C1C;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-badge--active {
|
||||||
|
background-color: var(--bg-hover);
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.data-table {
|
||||||
|
width: 100%;
|
||||||
|
border-collapse: collapse;
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
|
||||||
|
.data-table th {
|
||||||
|
padding: 0.75rem 1.5rem;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
text-transform: uppercase;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.data-table td {
|
||||||
|
padding: 1.25rem 1.5rem;
|
||||||
|
font-size: 0.9375rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.data-table thead tr {
|
||||||
|
border-bottom: 2px solid var(--border-light);
|
||||||
|
}
|
||||||
|
|
||||||
|
.data-table tbody tr {
|
||||||
|
border-bottom: 1px solid var(--border-light);
|
||||||
|
transition: background-color 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.data-table tbody tr:last-child {
|
||||||
|
border-bottom: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.data-table tbody tr:hover {
|
||||||
|
background-color: var(--bg-hover);
|
||||||
|
}
|
||||||
|
|
||||||
|
.empty-state {
|
||||||
|
padding: 3rem;
|
||||||
|
text-align: center;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.empty-state__title {
|
||||||
|
font-size: 1.125rem;
|
||||||
|
font-weight: 600;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.empty-state__description {
|
||||||
|
margin-top: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.error-state {
|
||||||
|
padding: 3rem;
|
||||||
|
text-align: center;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.error-state__title {
|
||||||
|
font-size: 1.125rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--brand-accent);
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.error-state__description {
|
||||||
|
margin-top: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.pagination-bar {
|
||||||
|
padding: 1.25rem 1.5rem;
|
||||||
|
border-top: 1px solid var(--border-light);
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
background-color: var(--bg-surface-inner);
|
||||||
|
}
|
||||||
|
|
||||||
|
.pagination-bar__info {
|
||||||
|
font-size: 0.875rem;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.pagination-bar__controls {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* --- Skeleton Loading Animation --- */
|
||||||
|
@keyframes pulse-skeleton {
|
||||||
|
0% { opacity: 0.5; background-color: var(--bg-hover); }
|
||||||
|
50% { opacity: 0.8; background-color: var(--border-light); }
|
||||||
|
100% { opacity: 0.5; background-color: var(--bg-hover); }
|
||||||
|
}
|
||||||
|
|
||||||
|
.skeleton-box {
|
||||||
|
animation: pulse-skeleton 1.5s infinite ease-in-out;
|
||||||
|
border-radius: var(--radius-md);
|
||||||
|
}
|
||||||
|
|
||||||
|
.skeleton-text {
|
||||||
|
height: 1rem;
|
||||||
|
width: 100%;
|
||||||
|
border-radius: 4px;
|
||||||
|
animation: pulse-skeleton 1.5s infinite ease-in-out;
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import { StrictMode } from "react";
|
import { StrictMode } from "react";
|
||||||
import { createRoot } from "react-dom/client";
|
import { createRoot } from "react-dom/client";
|
||||||
|
import "./index.css";
|
||||||
import App from "./App";
|
import App from "./App";
|
||||||
|
|
||||||
createRoot(document.getElementById("root")!).render(
|
createRoot(document.getElementById("root")!).render(
|
||||||
|
|||||||
106
frontend/src/pages/ChatPage.test.tsx
Normal file
106
frontend/src/pages/ChatPage.test.tsx
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||||
|
import { render, screen, fireEvent, waitFor, act } from "@testing-library/react";
|
||||||
|
import { ChatPage } from "./ChatPage";
|
||||||
|
|
||||||
|
// Mock react-markdown
|
||||||
|
vi.mock("react-markdown", () => ({
|
||||||
|
default: ({ children }: { children: string }) => <span>{children}</span>,
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock crypto.randomUUID for stable IDs
|
||||||
|
vi.stubGlobal("crypto", { randomUUID: () => `uuid-${Date.now()}-${Math.random()}` });
|
||||||
|
|
||||||
|
// Capture the onMessage callback from the hook
|
||||||
|
let capturedOnMessage: ((msg: unknown) => void) | null = null;
|
||||||
|
const mockSendMessage = vi.fn();
|
||||||
|
const mockSendInterruptResponse = vi.fn();
|
||||||
|
const mockReconnect = vi.fn();
|
||||||
|
let mockStatus = "connected";
|
||||||
|
|
||||||
|
vi.mock("../hooks/useWebSocket", () => ({
|
||||||
|
useWebSocket: (onMessage: (msg: unknown) => void) => {
|
||||||
|
capturedOnMessage = onMessage;
|
||||||
|
return {
|
||||||
|
status: mockStatus,
|
||||||
|
threadId: "test-thread",
|
||||||
|
sendMessage: mockSendMessage,
|
||||||
|
sendInterruptResponse: mockSendInterruptResponse,
|
||||||
|
reconnect: mockReconnect,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
capturedOnMessage = null;
|
||||||
|
mockSendMessage.mockReset();
|
||||||
|
mockSendInterruptResponse.mockReset();
|
||||||
|
mockReconnect.mockReset();
|
||||||
|
mockStatus = "connected";
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("ChatPage", () => {
|
||||||
|
it("renders chat interface with input field and header", () => {
|
||||||
|
render(<ChatPage />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Inbox")).toBeInTheDocument();
|
||||||
|
expect(screen.getByPlaceholderText("Message Smart Support...")).toBeInTheDocument();
|
||||||
|
expect(screen.getByRole("button", { name: "Send Message" })).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("user can type and submit a message", () => {
|
||||||
|
render(<ChatPage />);
|
||||||
|
|
||||||
|
const input = screen.getByPlaceholderText("Message Smart Support...");
|
||||||
|
fireEvent.change(input, { target: { value: "Hello bot" } });
|
||||||
|
fireEvent.keyDown(input, { key: "Enter" });
|
||||||
|
|
||||||
|
expect(mockSendMessage).toHaveBeenCalledWith("Hello bot");
|
||||||
|
expect(screen.getByText("Hello bot")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("You")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("displays streaming tokens as they arrive", () => {
|
||||||
|
render(<ChatPage />);
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
capturedOnMessage?.({ type: "token", agent: "Bot", content: "Hello " });
|
||||||
|
});
|
||||||
|
act(() => {
|
||||||
|
capturedOnMessage?.({ type: "token", agent: "Bot", content: "world" });
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("Hello world")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows interrupt prompt when interrupt message received", () => {
|
||||||
|
render(<ChatPage />);
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
capturedOnMessage?.({
|
||||||
|
type: "interrupt",
|
||||||
|
thread_id: "t1",
|
||||||
|
action: "cancel_order",
|
||||||
|
params: { order_id: "ORD-999" },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("Action Requires Approval")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("cancel_order")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows error message when server sends error", () => {
|
||||||
|
render(<ChatPage />);
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
capturedOnMessage?.({ type: "error", message: "Something went wrong" });
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("Error: Something went wrong")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders welcome message in empty state", () => {
|
||||||
|
render(<ChatPage />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Hello! How can I help you today?")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -13,10 +13,8 @@ import type {
|
|||||||
ToolAction,
|
ToolAction,
|
||||||
} from "../types";
|
} from "../types";
|
||||||
|
|
||||||
let msgCounter = 0;
|
|
||||||
function nextId(): string {
|
function nextId(): string {
|
||||||
msgCounter += 1;
|
return crypto.randomUUID();
|
||||||
return `msg-${msgCounter}`;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatPage() {
|
export function ChatPage() {
|
||||||
@@ -68,6 +66,48 @@ export function ChatPage() {
|
|||||||
setIsWaiting(false);
|
setIsWaiting(false);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case "clarification": {
|
||||||
|
setMessages((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: nextId(),
|
||||||
|
sender: "agent",
|
||||||
|
agent: "System",
|
||||||
|
content: msg.message,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
setIsWaiting(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case "interrupt_expired": {
|
||||||
|
setCurrentInterrupt(null);
|
||||||
|
setMessages((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: nextId(),
|
||||||
|
sender: "agent",
|
||||||
|
agent: "System",
|
||||||
|
content: msg.message,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
setIsWaiting(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case "tool_result": {
|
||||||
|
setToolActions((prev) => {
|
||||||
|
const last = prev[prev.length - 1];
|
||||||
|
if (last && last.tool === msg.tool && last.agent === msg.agent) {
|
||||||
|
return [
|
||||||
|
...prev.slice(0, -1),
|
||||||
|
{ ...last, result: msg.result },
|
||||||
|
];
|
||||||
|
}
|
||||||
|
return prev;
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
case "message_complete": {
|
case "message_complete": {
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
const last = prev[prev.length - 1];
|
const last = prev[prev.length - 1];
|
||||||
@@ -126,15 +166,15 @@ export function ChatPage() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div style={styles.page}>
|
<div className="chat-page">
|
||||||
<div style={styles.header}>
|
<div className="chat-header">
|
||||||
<h1 style={styles.title}>Smart Support</h1>
|
<h1>Inbox</h1>
|
||||||
<StatusIndicator status={status} />
|
<StatusIndicator status={status} />
|
||||||
</div>
|
</div>
|
||||||
<ErrorBanner status={status} onReconnect={reconnect} />
|
<ErrorBanner status={status} onReconnect={reconnect} />
|
||||||
<ChatMessages messages={messages} />
|
<ChatMessages messages={messages} />
|
||||||
{toolActions.length > 0 && (
|
{toolActions.length > 0 && (
|
||||||
<div style={styles.actionsBar}>
|
<div style={{ borderTop: "1px solid var(--border-light)", paddingTop: "4px" }}>
|
||||||
{toolActions.slice(-3).map((action) => (
|
{toolActions.slice(-3).map((action) => (
|
||||||
<AgentAction key={action.id} action={action} />
|
<AgentAction key={action.id} action={action} />
|
||||||
))}
|
))}
|
||||||
@@ -153,9 +193,9 @@ export function ChatPage() {
|
|||||||
|
|
||||||
function StatusIndicator({ status }: { status: ConnectionStatus }) {
|
function StatusIndicator({ status }: { status: ConnectionStatus }) {
|
||||||
const colors: Record<ConnectionStatus, string> = {
|
const colors: Record<ConnectionStatus, string> = {
|
||||||
connected: "#4caf50",
|
connected: "#10b981", // Emerald
|
||||||
connecting: "#ff9800",
|
connecting: "#f59e0b", // Amber
|
||||||
disconnected: "#f44336",
|
disconnected: "#ef4444", // Red
|
||||||
};
|
};
|
||||||
return (
|
return (
|
||||||
<div style={{ display: "flex", alignItems: "center", gap: "6px" }}>
|
<div style={{ display: "flex", alignItems: "center", gap: "6px" }}>
|
||||||
@@ -165,38 +205,10 @@ function StatusIndicator({ status }: { status: ConnectionStatus }) {
|
|||||||
height: "8px",
|
height: "8px",
|
||||||
borderRadius: "50%",
|
borderRadius: "50%",
|
||||||
background: colors[status],
|
background: colors[status],
|
||||||
|
boxShadow: `0 0 8px ${colors[status]}`,
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<span style={{ fontSize: "12px", color: "#666" }}>{status}</span>
|
<span style={{ fontSize: "12px", color: "var(--text-secondary)", fontWeight: 500, textTransform: "capitalize" }}>{status}</span>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const styles: Record<string, React.CSSProperties> = {
|
|
||||||
page: {
|
|
||||||
height: "100vh",
|
|
||||||
display: "flex",
|
|
||||||
flexDirection: "column",
|
|
||||||
background: "white",
|
|
||||||
maxWidth: "800px",
|
|
||||||
margin: "0 auto",
|
|
||||||
boxShadow: "0 0 20px rgba(0,0,0,0.1)",
|
|
||||||
},
|
|
||||||
header: {
|
|
||||||
display: "flex",
|
|
||||||
justifyContent: "space-between",
|
|
||||||
alignItems: "center",
|
|
||||||
padding: "12px 16px",
|
|
||||||
borderBottom: "1px solid #e0e0e0",
|
|
||||||
},
|
|
||||||
title: {
|
|
||||||
fontSize: "18px",
|
|
||||||
fontWeight: 700,
|
|
||||||
margin: 0,
|
|
||||||
color: "#333",
|
|
||||||
},
|
|
||||||
actionsBar: {
|
|
||||||
borderTop: "1px solid #eee",
|
|
||||||
paddingTop: "4px",
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|||||||
70
frontend/src/pages/DashboardPage.test.tsx
Normal file
70
frontend/src/pages/DashboardPage.test.tsx
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||||
|
import { render, screen, waitFor } from "@testing-library/react";
|
||||||
|
import { DashboardPage } from "./DashboardPage";
|
||||||
|
|
||||||
|
vi.mock("../api", () => ({
|
||||||
|
fetchAnalytics: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
import { fetchAnalytics } from "../api";
|
||||||
|
const mockFetchAnalytics = vi.mocked(fetchAnalytics);
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockFetchAnalytics.mockReset();
|
||||||
|
});
|
||||||
|
|
||||||
|
const MOCK_DATA = {
|
||||||
|
range: "30d",
|
||||||
|
total_conversations: 100,
|
||||||
|
resolution_rate: 0.75,
|
||||||
|
escalation_rate: 0.25,
|
||||||
|
avg_turns_per_conversation: 3.5,
|
||||||
|
avg_cost_per_conversation_usd: 0.03,
|
||||||
|
agent_usage: [{ agent: "order_agent", count: 50, percentage: 0.5 }],
|
||||||
|
interrupt_stats: { total: 10, approved: 8, rejected: 2, expired: 0 },
|
||||||
|
};
|
||||||
|
|
||||||
|
describe("DashboardPage", () => {
|
||||||
|
it("renders loading state initially", () => {
|
||||||
|
mockFetchAnalytics.mockReturnValue(new Promise(() => {})); // never resolves
|
||||||
|
render(<DashboardPage />);
|
||||||
|
expect(document.querySelector(".skeleton-box")).toBeTruthy();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders data after successful fetch", async () => {
|
||||||
|
mockFetchAnalytics.mockResolvedValue(MOCK_DATA);
|
||||||
|
render(<DashboardPage />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("100")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
expect(screen.getByText("75.0%")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("$0.03")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders error state on fetch failure", async () => {
|
||||||
|
mockFetchAnalytics.mockRejectedValue(new Error("Network error"));
|
||||||
|
render(<DashboardPage />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Failed to load analytics")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
expect(screen.getByText("Network error")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders empty state when data has zero conversations", async () => {
|
||||||
|
mockFetchAnalytics.mockResolvedValue({
|
||||||
|
...MOCK_DATA,
|
||||||
|
total_conversations: 0,
|
||||||
|
agent_usage: [],
|
||||||
|
interrupt_stats: { total: 0, approved: 0, rejected: 0, expired: 0 },
|
||||||
|
});
|
||||||
|
render(<DashboardPage />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("0")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
expect(screen.getByText("No agent activity recorded yet.")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("No interrupt events recorded yet.")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
import { useEffect, useState } from "react";
|
import { useState, useEffect } from "react";
|
||||||
import { fetchAnalytics } from "../api";
|
import { fetchAnalytics, AnalyticsData } from "../api";
|
||||||
import type { AnalyticsData } from "../api";
|
|
||||||
import { MetricCard } from "../components/MetricCard";
|
|
||||||
|
|
||||||
const RANGE_OPTIONS = [
|
const RANGE_OPTIONS = [
|
||||||
{ value: "7d", label: "7 days" },
|
{ value: "7d", label: "7 days" },
|
||||||
@@ -9,41 +7,52 @@ const RANGE_OPTIONS = [
|
|||||||
{ value: "30d", label: "30 days" },
|
{ value: "30d", label: "30 days" },
|
||||||
];
|
];
|
||||||
|
|
||||||
function pct(value: number): string {
|
|
||||||
return `${(value * 100).toFixed(1)}%`;
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatCost(usd: number): string {
|
|
||||||
return usd < 0.01 ? "<$0.01" : `$${usd.toFixed(3)}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function DashboardPage() {
|
export function DashboardPage() {
|
||||||
const [range, setRange] = useState("7d");
|
const [range, setRange] = useState("30d");
|
||||||
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
const [data, setData] = useState<AnalyticsData | null>(null);
|
const [data, setData] = useState<AnalyticsData | null>(null);
|
||||||
const [loading, setLoading] = useState(true);
|
|
||||||
const [error, setError] = useState<string | null>(null);
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setLoading(true);
|
setIsLoading(true);
|
||||||
setError(null);
|
setError(null);
|
||||||
fetchAnalytics(range)
|
fetchAnalytics(range)
|
||||||
.then(setData)
|
.then((result) => setData(result))
|
||||||
.catch((err: Error) => setError(err.message))
|
.catch((err: Error) => setError(err.message))
|
||||||
.finally(() => setLoading(false));
|
.finally(() => setIsLoading(false));
|
||||||
}, [range]);
|
}, [range]);
|
||||||
|
|
||||||
|
function pct(value: number): string {
|
||||||
|
return `${(value * 100).toFixed(1)}%`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatCost(usd: number): string {
|
||||||
|
return `$${usd.toFixed(2)}`;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div style={styles.container}>
|
<div className="page-container">
|
||||||
<div style={styles.header}>
|
<div className="page-header" style={{ display: "flex", justifyContent: "space-between", alignItems: "flex-end", marginBottom: "2rem" }}>
|
||||||
<h2 style={styles.heading}>Dashboard</h2>
|
<div>
|
||||||
<div style={styles.rangeSelector}>
|
<h2>Analytics Dashboard</h2>
|
||||||
|
<p>Monitor AI action performance, automation ROI, and agent efficiency.</p>
|
||||||
|
</div>
|
||||||
|
<div style={{ display: "flex", gap: "0.25rem", background: "var(--bg-hover)", padding: "0.25rem", borderRadius: "12px" }}>
|
||||||
{RANGE_OPTIONS.map((opt) => (
|
{RANGE_OPTIONS.map((opt) => (
|
||||||
<button
|
<button
|
||||||
key={opt.value}
|
key={opt.value}
|
||||||
onClick={() => setRange(opt.value)}
|
onClick={() => setRange(opt.value)}
|
||||||
|
disabled={isLoading}
|
||||||
style={{
|
style={{
|
||||||
...styles.rangeBtn,
|
padding: "0.5rem 1rem",
|
||||||
...(range === opt.value ? styles.rangeBtnActive : {}),
|
border: "none",
|
||||||
|
borderRadius: "8px",
|
||||||
|
cursor: isLoading ? "not-allowed" : "pointer",
|
||||||
|
fontSize: "0.875rem",
|
||||||
|
fontWeight: 600,
|
||||||
|
color: range === opt.value ? "white" : "var(--text-secondary)",
|
||||||
|
backgroundColor: range === opt.value ? "var(--brand-primary)" : "transparent",
|
||||||
|
transition: "all 0.2s"
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{opt.label}
|
{opt.label}
|
||||||
@@ -52,133 +61,116 @@ export function DashboardPage() {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{loading && <div style={styles.center}>Loading analytics...</div>}
|
{isLoading ? (
|
||||||
{error && <div style={styles.error}>Error: {error}</div>}
|
|
||||||
|
|
||||||
{!loading && !error && data && (
|
|
||||||
<>
|
<>
|
||||||
{data.total_conversations === 0 ? (
|
<div style={{ display: "grid", gridTemplateColumns: "repeat(auto-fit, minmax(200px, 1fr))", gap: "1.5rem", marginBottom: "2.5rem" }}>
|
||||||
<div style={styles.empty}>
|
{[1, 2, 3, 4].map(i => (
|
||||||
No conversations yet. Start a chat to see analytics here.
|
<div key={i} className="skeleton-box section-card" style={{ height: "120px" }}>
|
||||||
|
<div className="skeleton-text" style={{ width: "60%", height: "12px", marginBottom: "1.5rem" }}></div>
|
||||||
|
<div className="skeleton-text" style={{ width: "40%", height: "30px", marginBottom: "1rem" }}></div>
|
||||||
|
<div className="skeleton-text" style={{ width: "80%", height: "12px" }}></div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<div style={{ display: "grid", gridTemplateColumns: "2fr 1fr", gap: "1.5rem" }}>
|
||||||
|
<div className="skeleton-box" style={{ height: "300px", borderRadius: "var(--radius-xl)", background: "var(--bg-surface)" }}></div>
|
||||||
|
<div className="skeleton-box" style={{ height: "300px", borderRadius: "var(--radius-xl)", background: "var(--bg-surface)" }}></div>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
) : error ? (
|
||||||
|
<div className="error-state">
|
||||||
|
<p className="error-state__title">Failed to load analytics</p>
|
||||||
|
<p className="error-state__description">{error}</p>
|
||||||
|
<button onClick={() => setRange(range)} className="btn btn-secondary" style={{ marginTop: "1rem" }}>Retry</button>
|
||||||
|
</div>
|
||||||
|
) : !data ? (
|
||||||
|
<div className="empty-state">
|
||||||
|
<p className="empty-state__title">No analytics data available</p>
|
||||||
|
<p className="empty-state__description">Start some conversations to see metrics here.</p>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<div style={styles.metricsGrid}>
|
<div style={{ display: "grid", gridTemplateColumns: "repeat(auto-fit, minmax(200px, 1fr))", gap: "1.5rem", marginBottom: "2.5rem" }}>
|
||||||
<MetricCard
|
<MetricBox label="Tickets Processed" value={data.total_conversations.toLocaleString()} trend={`Range: ${data.range}`} />
|
||||||
label="Total Conversations"
|
<MetricBox label="Auto-Resolution Rate" value={pct(data.resolution_rate)} trend="Target: 70%" positive={data.resolution_rate >= 0.7} />
|
||||||
value={data.total_conversations}
|
<MetricBox label="Human Escalations" value={pct(data.escalation_rate)} trend="Lower is better" />
|
||||||
/>
|
<MetricBox label="Avg Cost / Conversation" value={formatCost(data.avg_cost_per_conversation_usd)} trend={`${data.avg_turns_per_conversation.toFixed(1)} avg turns`} />
|
||||||
<MetricCard
|
|
||||||
label="Resolution Rate"
|
|
||||||
value={pct(data.resolution_rate)}
|
|
||||||
/>
|
|
||||||
<MetricCard
|
|
||||||
label="Escalation Rate"
|
|
||||||
value={pct(data.escalation_rate)}
|
|
||||||
/>
|
|
||||||
<MetricCard
|
|
||||||
label="Avg Turns"
|
|
||||||
value={data.avg_turns_per_conversation.toFixed(1)}
|
|
||||||
/>
|
|
||||||
<MetricCard
|
|
||||||
label="Total Tokens"
|
|
||||||
value={data.total_tokens.toLocaleString()}
|
|
||||||
/>
|
|
||||||
<MetricCard
|
|
||||||
label="Total Cost"
|
|
||||||
value={formatCost(data.total_cost_usd)}
|
|
||||||
/>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<h3 style={styles.sectionHeading}>Agent Usage</h3>
|
<div style={{ display: "grid", gridTemplateColumns: "2fr 1fr", gap: "1.5rem" }}>
|
||||||
|
{/* Agent Workload Table */}
|
||||||
|
<div className="section-card">
|
||||||
|
<h3 style={{ fontSize: "1.125rem", color: "var(--text-primary)", fontWeight: 700, margin: "0 0 1rem 0" }}>Agent Workload Distribution</h3>
|
||||||
{data.agent_usage.length === 0 ? (
|
{data.agent_usage.length === 0 ? (
|
||||||
<div style={styles.empty}>No agent data.</div>
|
<p style={{ color: "var(--text-secondary)", fontSize: "0.875rem" }}>No agent activity recorded yet.</p>
|
||||||
) : (
|
) : (
|
||||||
<table style={styles.table}>
|
<table className="data-table">
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<th style={styles.th}>Agent</th>
|
<th style={{ paddingLeft: 0 }}>Agent Name</th>
|
||||||
<th style={styles.th}>Messages</th>
|
<th>Message Count</th>
|
||||||
<th style={styles.th}>Tokens</th>
|
<th>Share</th>
|
||||||
<th style={styles.th}>Cost</th>
|
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
{data.agent_usage.map((a) => (
|
{data.agent_usage.map((a) => (
|
||||||
<tr key={a.agent_name}>
|
<tr key={a.agent}>
|
||||||
<td style={styles.td}>{a.agent_name}</td>
|
<td style={{ paddingLeft: 0, fontWeight: 600 }}>{a.agent}</td>
|
||||||
<td style={styles.td}>{a.message_count}</td>
|
<td>{a.count.toLocaleString()}</td>
|
||||||
<td style={styles.td}>{a.total_tokens.toLocaleString()}</td>
|
<td>{pct(a.percentage)}</td>
|
||||||
<td style={styles.td}>{formatCost(a.total_cost_usd)}</td>
|
|
||||||
</tr>
|
</tr>
|
||||||
))}
|
))}
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<h3 style={styles.sectionHeading}>Interrupt Stats</h3>
|
|
||||||
<div style={styles.metricsGrid}>
|
|
||||||
<MetricCard label="Total Interrupts" value={data.interrupt_stats.total} />
|
|
||||||
<MetricCard label="Approved" value={data.interrupt_stats.approved} />
|
|
||||||
<MetricCard label="Rejected" value={data.interrupt_stats.rejected} />
|
|
||||||
<MetricCard label="Expired" value={data.interrupt_stats.expired} />
|
|
||||||
</div>
|
</div>
|
||||||
</>
|
|
||||||
|
{/* Human in the loop card */}
|
||||||
|
<div className="section-card">
|
||||||
|
<div style={{ display: "flex", alignItems: "center", gap: "0.5rem", marginBottom: "1rem" }}>
|
||||||
|
<h3 style={{ fontSize: "1.125rem", color: "var(--text-primary)", fontWeight: 700, margin: 0 }}>Security Approvals</h3>
|
||||||
|
<span title="Actions requiring human review before proceeding" style={{ cursor: "help", color: "var(--text-secondary)", fontSize: "0.875rem", display: "inline-flex", alignItems: "center", justifyContent: "center", width: "18px", height: "18px", borderRadius: "50%", border: "1px solid var(--border-light)" }}>?</span>
|
||||||
|
</div>
|
||||||
|
<p style={{ fontSize: "0.875rem", color: "var(--text-secondary)", marginBottom: "1.5rem", lineHeight: 1.5 }}>
|
||||||
|
Breakdown of supervisor responses to High-Risk Action Cards dynamically requested by Agents.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
{data.interrupt_stats.total === 0 ? (
|
||||||
|
<p style={{ color: "var(--text-secondary)", fontSize: "0.875rem" }}>No interrupt events recorded yet.</p>
|
||||||
|
) : (
|
||||||
|
<div style={{ display: "flex", flexDirection: "column", gap: "1rem" }}>
|
||||||
|
<div style={{ display: "flex", justifyContent: "space-between", alignItems: "center" }}>
|
||||||
|
<span style={{ fontWeight: 600, fontSize: "0.875rem" }}>Action Approved</span>
|
||||||
|
<span style={{ color: "#059669", fontWeight: 700 }}>{data.interrupt_stats.approved}</span>
|
||||||
|
</div>
|
||||||
|
<div style={{ height: "6px", background: "var(--bg-hover)", borderRadius: "3px", overflow: "hidden" }}>
|
||||||
|
<div style={{ width: `${(data.interrupt_stats.approved / data.interrupt_stats.total) * 100}%`, height: "100%", background: "#059669" }} />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div style={{ display: "flex", justifyContent: "space-between", alignItems: "center", marginTop: "0.5rem" }}>
|
||||||
|
<span style={{ fontWeight: 600, fontSize: "0.875rem" }}>Action Rejected (Escalated)</span>
|
||||||
|
<span style={{ color: "#DC2626", fontWeight: 700 }}>{data.interrupt_stats.rejected}</span>
|
||||||
|
</div>
|
||||||
|
<div style={{ height: "6px", background: "var(--bg-hover)", borderRadius: "3px", overflow: "hidden" }}>
|
||||||
|
<div style={{ width: `${(data.interrupt_stats.rejected / data.interrupt_stats.total) * 100}%`, height: "100%", background: "#DC2626" }} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
)}
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const styles: Record<string, React.CSSProperties> = {
|
|
||||||
container: { padding: "24px", maxWidth: "1000px", margin: "0 auto" },
|
function MetricBox({ label, value, trend, positive }: { label: string, value: string | number, trend: string, positive?: boolean }) {
|
||||||
header: {
|
return (
|
||||||
display: "flex",
|
<div className="section-card" style={{ display: "flex", flexDirection: "column", gap: "0.5rem" }}>
|
||||||
justifyContent: "space-between",
|
<div className="stat-label">{label}</div>
|
||||||
alignItems: "center",
|
<div style={{ fontSize: "2rem", fontWeight: 700, color: "var(--text-primary)" }}>{value}</div>
|
||||||
marginBottom: "20px",
|
<div style={{ fontSize: "0.8125rem", color: positive ? "#059669" : "var(--text-secondary)", fontWeight: positive ? 600 : 400 }}>{trend}</div>
|
||||||
},
|
</div>
|
||||||
heading: { fontSize: "20px", fontWeight: 700, margin: 0 },
|
);
|
||||||
rangeSelector: { display: "flex", gap: "4px" },
|
}
|
||||||
rangeBtn: {
|
|
||||||
padding: "5px 14px",
|
|
||||||
border: "1px solid #e0e0e0",
|
|
||||||
borderRadius: "4px",
|
|
||||||
background: "#fff",
|
|
||||||
cursor: "pointer",
|
|
||||||
fontSize: "13px",
|
|
||||||
color: "#555",
|
|
||||||
},
|
|
||||||
rangeBtnActive: {
|
|
||||||
background: "#1976d2",
|
|
||||||
color: "#fff",
|
|
||||||
borderColor: "#1976d2",
|
|
||||||
},
|
|
||||||
metricsGrid: {
|
|
||||||
display: "flex",
|
|
||||||
flexWrap: "wrap" as const,
|
|
||||||
gap: "12px",
|
|
||||||
marginBottom: "24px",
|
|
||||||
},
|
|
||||||
sectionHeading: {
|
|
||||||
fontSize: "15px",
|
|
||||||
fontWeight: 600,
|
|
||||||
marginBottom: "12px",
|
|
||||||
color: "#333",
|
|
||||||
},
|
|
||||||
table: { width: "100%", borderCollapse: "collapse", fontSize: "13px", marginBottom: "24px" },
|
|
||||||
th: {
|
|
||||||
textAlign: "left",
|
|
||||||
padding: "8px 12px",
|
|
||||||
borderBottom: "2px solid #e0e0e0",
|
|
||||||
color: "#555",
|
|
||||||
fontWeight: 600,
|
|
||||||
textTransform: "uppercase",
|
|
||||||
fontSize: "11px",
|
|
||||||
},
|
|
||||||
td: { padding: "10px 12px", borderBottom: "1px solid #f0f0f0" },
|
|
||||||
center: { padding: "48px", textAlign: "center", color: "#888" },
|
|
||||||
error: { padding: "24px", color: "#c62828" },
|
|
||||||
empty: { color: "#888", fontSize: "14px", padding: "16px 0" },
|
|
||||||
};
|
|
||||||
|
|||||||
106
frontend/src/pages/ReplayListPage.test.tsx
Normal file
106
frontend/src/pages/ReplayListPage.test.tsx
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||||
|
import { render, screen, waitFor } from "@testing-library/react";
|
||||||
|
import { MemoryRouter } from "react-router-dom";
|
||||||
|
import { ReplayListPage } from "./ReplayListPage";
|
||||||
|
|
||||||
|
vi.mock("../api", () => ({
|
||||||
|
fetchConversations: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
import { fetchConversations } from "../api";
|
||||||
|
const mockFetchConversations = vi.mocked(fetchConversations);
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockFetchConversations.mockReset();
|
||||||
|
});
|
||||||
|
|
||||||
|
function renderWithRouter() {
|
||||||
|
return render(
|
||||||
|
<MemoryRouter>
|
||||||
|
<ReplayListPage />
|
||||||
|
</MemoryRouter>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("ReplayListPage", () => {
|
||||||
|
it("renders loading state initially", () => {
|
||||||
|
mockFetchConversations.mockReturnValue(new Promise(() => {}));
|
||||||
|
renderWithRouter();
|
||||||
|
expect(document.querySelector(".skeleton-box")).toBeTruthy();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders empty state when no conversations", async () => {
|
||||||
|
mockFetchConversations.mockResolvedValue({
|
||||||
|
conversations: [],
|
||||||
|
total: 0,
|
||||||
|
page: 1,
|
||||||
|
per_page: 20,
|
||||||
|
});
|
||||||
|
renderWithRouter();
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("No conversations yet")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders conversation list on success", async () => {
|
||||||
|
mockFetchConversations.mockResolvedValue({
|
||||||
|
conversations: [
|
||||||
|
{
|
||||||
|
thread_id: "t1",
|
||||||
|
created_at: "2026-04-01T00:00:00Z",
|
||||||
|
last_activity: "2026-04-01T00:01:00Z",
|
||||||
|
status: "resolved",
|
||||||
|
total_tokens: 100,
|
||||||
|
total_cost_usd: 0.01,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
total: 1,
|
||||||
|
page: 1,
|
||||||
|
per_page: 20,
|
||||||
|
});
|
||||||
|
renderWithRouter();
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("t1")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
expect(screen.getByText("resolved")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders error state on fetch failure", async () => {
|
||||||
|
mockFetchConversations.mockRejectedValue(new Error("Server down"));
|
||||||
|
renderWithRouter();
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Failed to load conversations")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
expect(screen.getByText("Server down")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("applies correct status badge classes", async () => {
|
||||||
|
mockFetchConversations.mockResolvedValue({
|
||||||
|
conversations: [
|
||||||
|
{ thread_id: "t1", created_at: "", last_activity: "", status: "resolved", total_tokens: 0, total_cost_usd: 0 },
|
||||||
|
{ thread_id: "t2", created_at: "", last_activity: "", status: "escalated", total_tokens: 0, total_cost_usd: 0 },
|
||||||
|
{ thread_id: "t3", created_at: "", last_activity: "", status: null, total_tokens: 0, total_cost_usd: 0 },
|
||||||
|
],
|
||||||
|
total: 3,
|
||||||
|
page: 1,
|
||||||
|
per_page: 20,
|
||||||
|
});
|
||||||
|
renderWithRouter();
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("resolved")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
const resolvedBadge = screen.getByText("resolved");
|
||||||
|
expect(resolvedBadge.className).toContain("status-badge--resolved");
|
||||||
|
|
||||||
|
const escalatedBadge = screen.getByText("escalated");
|
||||||
|
expect(escalatedBadge.className).toContain("status-badge--escalated");
|
||||||
|
|
||||||
|
const activeBadge = screen.getByText("active");
|
||||||
|
expect(activeBadge.className).toContain("status-badge--active");
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,54 +1,84 @@
|
|||||||
import { useEffect, useState } from "react";
|
import { useState, useEffect } from "react";
|
||||||
import { useNavigate } from "react-router-dom";
|
import { useNavigate } from "react-router-dom";
|
||||||
import { fetchConversations } from "../api";
|
import { fetchConversations, ConversationSummary } from "../api";
|
||||||
import type { ConversationSummary } from "../api";
|
|
||||||
|
|
||||||
export function ReplayListPage() {
|
export function ReplayListPage() {
|
||||||
const [conversations, setConversations] = useState<ConversationSummary[]>([]);
|
|
||||||
const [total, setTotal] = useState(0);
|
|
||||||
const [page, setPage] = useState(1);
|
|
||||||
const [loading, setLoading] = useState(true);
|
|
||||||
const [error, setError] = useState<string | null>(null);
|
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const perPage = 20;
|
const [page, setPage] = useState(1);
|
||||||
|
const [perPage] = useState(20);
|
||||||
|
const [total, setTotal] = useState(0);
|
||||||
|
const [conversations, setConversations] = useState<ConversationSummary[]>([]);
|
||||||
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setLoading(true);
|
setIsLoading(true);
|
||||||
setError(null);
|
setError(null);
|
||||||
fetchConversations(page, perPage)
|
fetchConversations(page, perPage)
|
||||||
.then((data) => {
|
.then((result) => {
|
||||||
setConversations(data.conversations);
|
setConversations(result.conversations);
|
||||||
setTotal(data.total);
|
setTotal(result.total);
|
||||||
})
|
})
|
||||||
.catch((err: Error) => setError(err.message))
|
.catch((err: Error) => setError(err.message))
|
||||||
.finally(() => setLoading(false));
|
.finally(() => setIsLoading(false));
|
||||||
}, [page]);
|
}, [page, perPage]);
|
||||||
|
|
||||||
if (loading) {
|
const totalPages = Math.max(1, Math.ceil(total / perPage));
|
||||||
return <div style={styles.center}>Loading conversations...</div>;
|
|
||||||
|
function formatDate(iso: string): string {
|
||||||
|
try {
|
||||||
|
return new Date(iso).toLocaleString();
|
||||||
|
} catch {
|
||||||
|
return iso;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (error) {
|
function formatCost(usd: number): string {
|
||||||
return <div style={styles.error}>Error: {error}</div>;
|
return `$${usd.toFixed(2)}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
const totalPages = Math.ceil(total / perPage);
|
function statusClass(status: string | null): string {
|
||||||
|
if (status === "resolved") return "status-badge status-badge--resolved";
|
||||||
|
if (status === "escalated") return "status-badge status-badge--escalated";
|
||||||
|
return "status-badge status-badge--active";
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div style={styles.container}>
|
<div className="page-container">
|
||||||
<h2 style={styles.heading}>Conversations</h2>
|
<div className="page-header">
|
||||||
{conversations.length === 0 ? (
|
<h2>Conversation Replay</h2>
|
||||||
<div style={styles.empty}>No conversations yet.</div>
|
<p>Review autonomous agent sessions and audit MCP action execution trails.</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{error ? (
|
||||||
|
<div className="error-state">
|
||||||
|
<p className="error-state__title">Failed to load conversations</p>
|
||||||
|
<p className="error-state__description">{error}</p>
|
||||||
|
<button onClick={() => setPage(1)} className="btn btn-secondary" style={{ marginTop: "1rem" }}>Retry</button>
|
||||||
|
</div>
|
||||||
|
) : isLoading ? (
|
||||||
|
<div className="section-card" style={{ padding: "2rem" }}>
|
||||||
|
{[1, 2, 3, 4, 5].map(i => (
|
||||||
|
<div key={i} className="skeleton-box" style={{ height: "60px", marginBottom: "1rem", borderRadius: "8px" }}>
|
||||||
|
<div className="skeleton-text" style={{ width: "30%", height: "14px", margin: "12px 16px" }}></div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
) : conversations.length === 0 ? (
|
||||||
|
<div className="empty-state">
|
||||||
|
<p className="empty-state__title">No conversations yet</p>
|
||||||
|
<p className="empty-state__description">Start a chat session to see conversations here.</p>
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<div className="section-card" style={{ padding: 0, overflow: "hidden" }}>
|
||||||
<table style={styles.table}>
|
<table className="data-table">
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr style={{ backgroundColor: "var(--bg-surface-inner)" }}>
|
||||||
<th style={styles.th}>Thread ID</th>
|
<th>Thread</th>
|
||||||
<th style={styles.th}>Started</th>
|
<th>Created</th>
|
||||||
<th style={styles.th}>Turns</th>
|
<th>Last Activity</th>
|
||||||
<th style={styles.th}>Agents</th>
|
<th>Status</th>
|
||||||
<th style={styles.th}>Resolution</th>
|
<th>Cost</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
@@ -56,78 +86,47 @@ export function ReplayListPage() {
|
|||||||
<tr
|
<tr
|
||||||
key={c.thread_id}
|
key={c.thread_id}
|
||||||
onClick={() => navigate(`/replay/${c.thread_id}`)}
|
onClick={() => navigate(`/replay/${c.thread_id}`)}
|
||||||
style={styles.row}
|
style={{ cursor: "pointer" }}
|
||||||
>
|
>
|
||||||
<td style={styles.td}>
|
<td>
|
||||||
<span style={styles.threadId}>{c.thread_id}</span>
|
<span style={{ fontWeight: 600, fontFamily: "monospace" }}>{c.thread_id}</span>
|
||||||
</td>
|
</td>
|
||||||
<td style={styles.td}>
|
<td style={{ color: "var(--text-secondary)" }}>{formatDate(c.created_at)}</td>
|
||||||
{new Date(c.started_at).toLocaleString()}
|
<td style={{ color: "var(--text-secondary)" }}>{formatDate(c.last_activity)}</td>
|
||||||
|
<td>
|
||||||
|
<span className={statusClass(c.status)}>{c.status ?? "active"}</span>
|
||||||
|
</td>
|
||||||
|
<td style={{ color: "var(--text-secondary)" }}>
|
||||||
|
{c.total_tokens.toLocaleString()} tokens / {formatCost(c.total_cost_usd)}
|
||||||
</td>
|
</td>
|
||||||
<td style={styles.td}>{c.turn_count}</td>
|
|
||||||
<td style={styles.td}>{c.agents_used.join(", ") || "—"}</td>
|
|
||||||
<td style={styles.td}>{c.resolution_type ?? "open"}</td>
|
|
||||||
</tr>
|
</tr>
|
||||||
))}
|
))}
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
<div style={styles.pagination}>
|
|
||||||
|
<div className="pagination-bar">
|
||||||
|
<span className="pagination-bar__info">
|
||||||
|
Showing {(page - 1) * perPage + 1}-{Math.min(page * perPage, total)} of {total} sessions
|
||||||
|
</span>
|
||||||
|
<div className="pagination-bar__controls">
|
||||||
<button
|
<button
|
||||||
onClick={() => setPage((p) => Math.max(1, p - 1))}
|
onClick={(e) => { e.stopPropagation(); setPage(p => Math.max(1, p - 1)) }}
|
||||||
disabled={page === 1}
|
disabled={page === 1}
|
||||||
style={styles.pageBtn}
|
className="btn btn-secondary"
|
||||||
>
|
>
|
||||||
Previous
|
Previous
|
||||||
</button>
|
</button>
|
||||||
<span style={{ fontSize: "13px", color: "#555" }}>
|
|
||||||
Page {page} of {totalPages}
|
|
||||||
</span>
|
|
||||||
<button
|
<button
|
||||||
onClick={() => setPage((p) => Math.min(totalPages, p + 1))}
|
onClick={(e) => { e.stopPropagation(); setPage(p => Math.min(totalPages, p + 1)) }}
|
||||||
disabled={page >= totalPages}
|
disabled={page >= totalPages}
|
||||||
style={styles.pageBtn}
|
className="btn btn-secondary"
|
||||||
>
|
>
|
||||||
Next
|
Next
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</div>
|
||||||
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const styles: Record<string, React.CSSProperties> = {
|
|
||||||
container: { padding: "24px", maxWidth: "1000px", margin: "0 auto" },
|
|
||||||
heading: { fontSize: "20px", fontWeight: 700, marginBottom: "16px" },
|
|
||||||
center: { padding: "48px", textAlign: "center", color: "#888" },
|
|
||||||
error: { padding: "24px", color: "#c62828" },
|
|
||||||
empty: { color: "#888", fontSize: "14px" },
|
|
||||||
table: { width: "100%", borderCollapse: "collapse", fontSize: "13px" },
|
|
||||||
th: {
|
|
||||||
textAlign: "left",
|
|
||||||
padding: "8px 12px",
|
|
||||||
borderBottom: "2px solid #e0e0e0",
|
|
||||||
color: "#555",
|
|
||||||
fontWeight: 600,
|
|
||||||
textTransform: "uppercase",
|
|
||||||
fontSize: "11px",
|
|
||||||
letterSpacing: "0.5px",
|
|
||||||
},
|
|
||||||
td: { padding: "10px 12px", borderBottom: "1px solid #f0f0f0" },
|
|
||||||
row: { cursor: "pointer", transition: "background 0.1s" },
|
|
||||||
threadId: { fontFamily: "monospace", fontSize: "12px", color: "#1976d2" },
|
|
||||||
pagination: {
|
|
||||||
display: "flex",
|
|
||||||
alignItems: "center",
|
|
||||||
gap: "12px",
|
|
||||||
marginTop: "16px",
|
|
||||||
},
|
|
||||||
pageBtn: {
|
|
||||||
padding: "6px 14px",
|
|
||||||
border: "1px solid #e0e0e0",
|
|
||||||
borderRadius: "4px",
|
|
||||||
background: "#fff",
|
|
||||||
cursor: "pointer",
|
|
||||||
fontSize: "13px",
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|||||||
85
frontend/src/pages/ReplayPage.test.tsx
Normal file
85
frontend/src/pages/ReplayPage.test.tsx
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||||
|
import { render, screen, waitFor } from "@testing-library/react";
|
||||||
|
import { MemoryRouter, Route, Routes } from "react-router-dom";
|
||||||
|
import { ReplayPage } from "./ReplayPage";
|
||||||
|
|
||||||
|
vi.mock("../api", () => ({
|
||||||
|
fetchReplay: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("../components/ReplayTimeline", () => ({
|
||||||
|
ReplayTimeline: ({ steps }: { steps: unknown[] }) => (
|
||||||
|
<div data-testid="replay-timeline">{steps.length} steps</div>
|
||||||
|
),
|
||||||
|
}));
|
||||||
|
|
||||||
|
import { fetchReplay } from "../api";
|
||||||
|
const mockFetchReplay = vi.mocked(fetchReplay);
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockFetchReplay.mockReset();
|
||||||
|
});
|
||||||
|
|
||||||
|
function renderWithRoute(threadId: string) {
|
||||||
|
return render(
|
||||||
|
<MemoryRouter initialEntries={[`/replay/${threadId}`]}>
|
||||||
|
<Routes>
|
||||||
|
<Route path="/replay/:threadId" element={<ReplayPage />} />
|
||||||
|
</Routes>
|
||||||
|
</MemoryRouter>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("ReplayPage", () => {
|
||||||
|
it("renders loading state initially", () => {
|
||||||
|
mockFetchReplay.mockReturnValue(new Promise(() => {}));
|
||||||
|
renderWithRoute("t1");
|
||||||
|
expect(document.querySelector(".skeleton-box")).toBeTruthy();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders replay steps on success", async () => {
|
||||||
|
mockFetchReplay.mockResolvedValue({
|
||||||
|
thread_id: "t1",
|
||||||
|
total_steps: 2,
|
||||||
|
page: 1,
|
||||||
|
per_page: 100,
|
||||||
|
steps: [
|
||||||
|
{ step: 1, type: "message", content: "Hello", agent: null, tool: null, params: null, result: null, timestamp: "2026-04-01T00:00:00Z" },
|
||||||
|
{ step: 2, type: "response", content: "Hi!", agent: "bot", tool: null, params: null, result: null, timestamp: "2026-04-01T00:00:01Z" },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
renderWithRoute("t1");
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTestId("replay-timeline")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
expect(screen.getByText("2 steps")).toBeInTheDocument();
|
||||||
|
// Thread ID appears in multiple places (header + sidebar)
|
||||||
|
expect(screen.getAllByText("t1").length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders empty state when no steps", async () => {
|
||||||
|
mockFetchReplay.mockResolvedValue({
|
||||||
|
thread_id: "t1",
|
||||||
|
total_steps: 0,
|
||||||
|
page: 1,
|
||||||
|
per_page: 100,
|
||||||
|
steps: [],
|
||||||
|
});
|
||||||
|
renderWithRoute("t1");
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("No replay steps found")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("renders error state on fetch failure", async () => {
|
||||||
|
mockFetchReplay.mockRejectedValue(new Error("Not found"));
|
||||||
|
renderWithRoute("t1");
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Failed to load replay")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
expect(screen.getByText("Not found")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user