Compare commits
22 Commits
98331abbd5
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0699436c5 | ||
|
|
af53111928 | ||
|
|
b8654aa31f | ||
|
|
be5c84bcff | ||
|
|
19fc9f3289 | ||
|
|
036e12349d | ||
|
|
e0931daece | ||
|
|
e55ec42ae5 | ||
|
|
189a0fad34 | ||
|
|
d2b4610df9 | ||
|
|
0e78e5b06b | ||
|
|
38644594d2 | ||
|
|
ef6e5ac2be | ||
|
|
33db5aeb10 | ||
|
|
a2f750269d | ||
|
|
a54eb224e0 | ||
|
|
006b4ee5d7 | ||
|
|
b861ff055f | ||
|
|
512f988dd0 | ||
|
|
6e7b824b64 | ||
|
|
1050df780d | ||
|
|
7c3571b47d |
@@ -2,7 +2,22 @@
|
|||||||
"permissions": {
|
"permissions": {
|
||||||
"allow": [
|
"allow": [
|
||||||
"Bash(find:*)",
|
"Bash(find:*)",
|
||||||
|
"Bash(ruff:*)",
|
||||||
|
"Bash(pytest:*)",
|
||||||
|
"Bash(git status:*)",
|
||||||
|
"Bash(git diff:*)",
|
||||||
|
"Bash(git log:*)",
|
||||||
|
"Bash(git branch:*)",
|
||||||
|
"Bash(git add:*)",
|
||||||
|
"Bash(git commit:*)",
|
||||||
|
"Bash(git checkout:*)",
|
||||||
|
"Bash(git merge:*)",
|
||||||
|
"Bash(git tag:*)",
|
||||||
|
"Bash(git show:*)",
|
||||||
|
"Bash(docker:*)",
|
||||||
|
"Bash(docker-compose:*)",
|
||||||
"WebSearch"
|
"WebSearch"
|
||||||
]
|
],
|
||||||
|
"defaultMode": "bypassPermissions"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
35
.env.example
Normal file
35
.env.example
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Smart Support -- Docker Compose environment variables
|
||||||
|
# Copy to .env and fill in your values
|
||||||
|
|
||||||
|
# PostgreSQL password (used by both postgres and backend services)
|
||||||
|
POSTGRES_PASSWORD=dev_password
|
||||||
|
|
||||||
|
# LLM provider: anthropic | openai | azure_openai | google
|
||||||
|
LLM_PROVIDER=anthropic
|
||||||
|
LLM_MODEL=claude-sonnet-4-6
|
||||||
|
|
||||||
|
# API keys (provide the one matching LLM_PROVIDER)
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
GOOGLE_API_KEY=
|
||||||
|
|
||||||
|
# Azure OpenAI (required when LLM_PROVIDER=azure_openai)
|
||||||
|
AZURE_OPENAI_API_KEY=
|
||||||
|
AZURE_OPENAI_ENDPOINT=
|
||||||
|
AZURE_OPENAI_DEPLOYMENT=
|
||||||
|
AZURE_OPENAI_API_VERSION=2024-12-01-preview
|
||||||
|
|
||||||
|
# Optional: webhook URL for escalation notifications
|
||||||
|
WEBHOOK_URL=
|
||||||
|
|
||||||
|
# Session and interrupt TTL in minutes
|
||||||
|
SESSION_TTL_MINUTES=30
|
||||||
|
INTERRUPT_TTL_MINUTES=30
|
||||||
|
|
||||||
|
# Optional: API key for admin endpoints (analytics, replay, openapi, websocket)
|
||||||
|
# Leave empty to disable authentication (dev mode)
|
||||||
|
ADMIN_API_KEY=
|
||||||
|
|
||||||
|
# Optional: load a named agent template instead of agents.yaml
|
||||||
|
# Available templates: ecommerce, saas, generic
|
||||||
|
TEMPLATE_NAME=
|
||||||
65
CLAUDE.md
65
CLAUDE.md
@@ -30,7 +30,7 @@ pytest --cov=app --cov-report=term-missing
|
|||||||
# - If any test fails, fix it before starting the new phase
|
# - If any test fails, fix it before starting the new phase
|
||||||
|
|
||||||
# 3. Create checkpoint to snapshot the starting state
|
# 3. Create checkpoint to snapshot the starting state
|
||||||
/everything-claude-code:checkpoint create [phase name]
|
/ecc:checkpoint create "phase-name"
|
||||||
|
|
||||||
# 4. Create the phase branch
|
# 4. Create the phase branch
|
||||||
git checkout main
|
git checkout main
|
||||||
@@ -50,25 +50,32 @@ git checkout -b phase-{N}/{short-description}
|
|||||||
3. Identify all tasks, acceptance criteria, and dependencies for this phase
|
3. Identify all tasks, acceptance criteria, and dependencies for this phase
|
||||||
4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5)
|
4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5)
|
||||||
|
|
||||||
### Step 2: Develop Using Orchestrate Skill
|
### Step 2: Develop Using ECC Skills
|
||||||
|
|
||||||
Route to the correct orchestration mode based on work type:
|
Route to the correct skill based on work type:
|
||||||
|
|
||||||
| Work Type | Skill Command |
|
| Work Type | Skill Command | What It Does |
|
||||||
|-----------|---------------|
|
|-----------|---------------|--------------|
|
||||||
| New feature | `/everything-claude-code:orchestrate feature` |
|
| New feature | `/ecc:feature-dev <desc>` | Discovery -> Exploration -> Architecture -> TDD -> Review -> Summary |
|
||||||
| Bug fix | `/everything-claude-code:orchestrate bugfix` |
|
| Bug fix | `/ecc:tdd` then `/ecc:code-review` | RED -> GREEN -> REFACTOR cycle, then review |
|
||||||
| Refactor | `/everything-claude-code:orchestrate refactor` |
|
| Refactor | `/ecc:plan` then `/ecc:tdd` then `/ecc:code-review` | Plan refactor scope, TDD, review |
|
||||||
|
| Security-sensitive | Add `/ecc:security-review` after code-review | Auth, payments, user input, external APIs |
|
||||||
|
| Final verification | `/ecc:verify` | Build + tests + lint + coverage + security scan |
|
||||||
|
|
||||||
ALWAYS use the appropriate orchestrate skill. Never develop without it.
|
A single phase may contain mixed work types. Call the appropriate skill **per sub-task**:
|
||||||
|
|
||||||
A single phase may contain mixed work types (e.g., Phase 5 has feature + bugfix + refactor). Call the orchestrate skill **per sub-task** with the matching mode. Example:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
# Within Phase 5:
|
# Within a phase:
|
||||||
/everything-claude-code:orchestrate feature # for demo script
|
/ecc:feature-dev "demo script" # for new features
|
||||||
/everything-claude-code:orchestrate bugfix # for error handling fixes
|
/ecc:tdd # for bug fixes (write failing test, then fix)
|
||||||
/everything-claude-code:orchestrate refactor # for code cleanup
|
/ecc:plan "consolidate error handling" # for refactors (plan first, then TDD)
|
||||||
|
```
|
||||||
|
|
||||||
|
For full multi-phase autonomous execution, use GSD:
|
||||||
|
|
||||||
|
```
|
||||||
|
/gsd:autonomous # execute all remaining phases
|
||||||
|
/gsd:execute-phase 6 # execute a specific phase
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 3: Module Independence (CRITICAL)
|
### Step 3: Module Independence (CRITICAL)
|
||||||
@@ -171,10 +178,10 @@ After all development and testing, run verification in this exact order:
|
|||||||
|
|
||||||
```
|
```
|
||||||
# 1. Run the verification skill -- must pass
|
# 1. Run the verification skill -- must pass
|
||||||
/everything-claude-code:verify
|
/ecc:verify
|
||||||
|
|
||||||
# 2. Verify the checkpoint -- validates all phase deliverables
|
# 2. Verify the checkpoint -- validates all phase deliverables
|
||||||
/everything-claude-code:checkpoint verify [phase name]
|
/ecc:checkpoint verify "phase-name"
|
||||||
```
|
```
|
||||||
|
|
||||||
The checkpoint verify validates:
|
The checkpoint verify validates:
|
||||||
@@ -222,11 +229,11 @@ git push origin main --tags
|
|||||||
All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy.
|
All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy.
|
||||||
|
|
||||||
A checkpoint includes:
|
A checkpoint includes:
|
||||||
- `/everything-claude-code:checkpoint create` at phase start
|
- `/ecc:checkpoint create` at phase start
|
||||||
- `/everything-claude-code:checkpoint verify` at phase end
|
- `/ecc:checkpoint verify` at phase end
|
||||||
- All tests passing (80%+ coverage)
|
- All tests passing (80%+ coverage)
|
||||||
- Phase dev log written and linked
|
- Phase dev log written and linked
|
||||||
- `/everything-claude-code:verify` passed
|
- `/ecc:verify` passed
|
||||||
- Git tag `checkpoint/phase-{N}` created
|
- Git tag `checkpoint/phase-{N}` created
|
||||||
- Phase marked COMPLETED in four locations
|
- Phase marked COMPLETED in four locations
|
||||||
- Branch merged to main
|
- Branch merged to main
|
||||||
@@ -238,10 +245,10 @@ A checkpoint includes:
|
|||||||
| Phase | Branch | Focus | Status |
|
| Phase | Branch | Focus | Status |
|
||||||
|-------|--------|-------|--------|
|
|-------|--------|-------|--------|
|
||||||
| 1 | `phase-1/core-framework` | FastAPI + LangGraph + React chat loop + PostgresSaver | COMPLETED (2026-03-30) |
|
| 1 | `phase-1/core-framework` | FastAPI + LangGraph + React chat loop + PostgresSaver | COMPLETED (2026-03-30) |
|
||||||
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | NOT STARTED |
|
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | COMPLETED (2026-03-30) |
|
||||||
| 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | NOT STARTED |
|
| 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | COMPLETED (2026-03-30) |
|
||||||
| 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | NOT STARTED |
|
| 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | COMPLETED (2026-03-31) |
|
||||||
| 5 | `phase-5/polish-demo` | Error hardening + demo prep + Docker deploy | NOT STARTED |
|
| 5 | `phase-5/polish-demo` | Error hardening + demo prep + Docker deploy | COMPLETED (2026-03-31) |
|
||||||
|
|
||||||
Status values: `NOT STARTED` -> `IN PROGRESS` -> `COMPLETED (YYYY-MM-DD)`
|
Status values: `NOT STARTED` -> `IN PROGRESS` -> `COMPLETED (YYYY-MM-DD)`
|
||||||
|
|
||||||
@@ -264,7 +271,7 @@ This project inherits from `~/.claude/rules/`. CLAUDE.md only contains project-s
|
|||||||
|
|
||||||
### Hooks (ECC Plugin -- No Custom Hooks)
|
### Hooks (ECC Plugin -- No Custom Hooks)
|
||||||
|
|
||||||
All hooks come from the ECC plugin (`everything-claude-code`). No project-level hooks in `.claude/settings.local.json`.
|
All hooks come from the ECC plugin (`ecc`). No project-level hooks in `.claude/settings.local.json`.
|
||||||
|
|
||||||
| ECC Hook | Type | What It Does |
|
| ECC Hook | Type | What It Does |
|
||||||
|----------|------|-------------|
|
|----------|------|-------------|
|
||||||
@@ -290,7 +297,7 @@ Controlled by `ECC_HOOK_PROFILE` env var in `~/.claude/settings.json` (currently
|
|||||||
- Architecture doc: `docs/ARCHITECTURE.md`
|
- Architecture doc: `docs/ARCHITECTURE.md`
|
||||||
- Phase dev logs: `docs/phases/phase-{N}-dev-log.md`
|
- Phase dev logs: `docs/phases/phase-{N}-dev-log.md`
|
||||||
- Test command: `pytest --cov=app --cov-report=term-missing`
|
- Test command: `pytest --cov=app --cov-report=term-missing`
|
||||||
- **Phase start:** `/everything-claude-code:checkpoint create [phase name]`
|
- **Phase start:** `/ecc:checkpoint create "phase-name"`
|
||||||
- **Phase end:** `/everything-claude-code:checkpoint verify [phase name]`
|
- **Phase end:** `/ecc:checkpoint verify "phase-name"`
|
||||||
- Verify command: `/everything-claude-code:verify`
|
- Verify command: `/ecc:verify`
|
||||||
- Orchestrate: `/everything-claude-code:orchestrate {feature|bugfix|refactor}`
|
- Orchestrate: `/ecc:orchestrate {feature|bugfix|refactor}`
|
||||||
|
|||||||
267
README.md
267
README.md
@@ -1,159 +1,174 @@
|
|||||||
# Smart Support
|
# Smart Support
|
||||||
|
|
||||||
AI 客服行动层框架。粘贴你的 API,获得一个能执行真实操作的智能客服。
|
AI customer support action layer. Paste your API spec, get an AI agent that executes real actions.
|
||||||
|
|
||||||
## 问题
|
## The Problem
|
||||||
|
|
||||||
现有客服工具(Zendesk、Intercom、Ada)擅长回答 FAQ,但自动化率卡在 20-30%。剩下 70% 的工单需要人工登录内部系统,手动查订单、取消订单、发优惠券。
|
Existing support tools (Zendesk, Intercom, Ada) answer FAQs well but automation
|
||||||
|
rates stall at 20-30%. The remaining 70% of tickets require agents to manually
|
||||||
|
log into internal systems to look up orders, cancel orders, issue coupons.
|
||||||
|
|
||||||
Smart Support 是补全这个缺口的「行动层」。它不替代现有客服平台,而是让 AI 能直接调用内部系统完成操作。
|
Smart Support fills that gap as the "action layer" -- it does not replace your
|
||||||
|
existing support platform, it enables AI to directly call your internal systems.
|
||||||
|
|
||||||
## 工作原理
|
## How It Works
|
||||||
|
|
||||||
```
|
```
|
||||||
客户消息 → Chat UI → FastAPI WebSocket → LangGraph Supervisor → 专业 Agent → MCP Tools → 你的内部系统
|
User message -> Chat UI -> FastAPI WebSocket -> LangGraph Supervisor -> Specialist Agent -> MCP Tools -> Your systems
|
||||||
↑ ↑
|
| |
|
||||||
Agent 注册表 interrupt()
|
Agent Registry interrupt()
|
||||||
(YAML 配置) (人工确认)
|
(YAML config) (human approval)
|
||||||
↑
|
|
|
||||||
PostgresSaver
|
PostgresSaver
|
||||||
(会话状态持久化)
|
(session persistence)
|
||||||
```
|
```
|
||||||
|
|
||||||
1. 客户在聊天界面发送消息
|
1. User sends a message in the chat UI.
|
||||||
2. LangGraph Supervisor 分析意图,路由到对应的专业 Agent
|
2. LangGraph Supervisor classifies intent and routes to the right agent.
|
||||||
3. Agent 通过 MCP 协议调用你的内部系统(查订单、取消订单、发折扣...)
|
3. Agent calls your internal systems via MCP tools.
|
||||||
4. 涉及写操作时,自动触发人工确认流程
|
4. Write operations trigger a human-in-the-loop approval gate.
|
||||||
5. 所有操作全程记录,支持回放和分析
|
5. All operations are logged with full replay and analytics.
|
||||||
|
|
||||||
## 核心特性
|
## Key Features
|
||||||
|
|
||||||
- **多 Agent 协作** - 不同操作由不同 Agent 处理,各自拥有独立的权限边界和工具集
|
- **Multi-agent routing** -- each operation goes to a specialist agent with its own tools and permissions
|
||||||
- **即插即用** - 粘贴 OpenAPI 规范 URL,自动生成 MCP 工具和 Agent 配置
|
- **Zero-config import** -- paste an OpenAPI 3.0 URL, agents are generated automatically
|
||||||
- **人工确认** - 所有写操作(取消、退款、修改)需要人工审批,读操作直接执行
|
- **Human-in-the-loop** -- all write operations (cancel, refund, modify) require approval; reads execute immediately
|
||||||
- **会话上下文** - 支持多轮对话,Agent 能理解「取消那个订单」这样的指代
|
- **Session context** -- multi-turn conversation with persistent state across reconnects
|
||||||
- **实时流式输出** - WebSocket 双向通信,逐 token 流式返回
|
- **Real-time streaming** -- WebSocket token streaming with live tool call visibility
|
||||||
- **对话回放** - 逐步查看 Agent 决策过程、工具调用和返回结果
|
- **Conversation replay** -- step-by-step audit trail of every agent decision
|
||||||
- **数据分析** - 解决率、Agent 使用率、升级率、每次对话成本
|
- **Analytics dashboard** -- resolution rate, agent usage, escalation rate, cost per conversation
|
||||||
- **YAML 驱动配置** - Agent 定义、人设、垂直模板全部通过 YAML 配置
|
- **YAML-driven config** -- agents, personas, and vertical templates in a single file
|
||||||
|
|
||||||
## 技术栈
|
## Tech Stack
|
||||||
|
|
||||||
| 组件 | 技术选型 |
|
| Component | Technology |
|
||||||
|------|---------|
|
|-----------|-----------|
|
||||||
| 后端 | Python 3.11+, FastAPI |
|
| Backend | Python 3.11+, FastAPI |
|
||||||
| Agent 编排 | LangGraph v1.1, langgraph-supervisor |
|
| Agent orchestration | LangGraph 1.x, langgraph-supervisor |
|
||||||
| 工具集成 | langchain-mcp-adapters, @tool |
|
| Session state | PostgreSQL 16 + langgraph-checkpoint-postgres |
|
||||||
| 状态持久化 | PostgreSQL + langgraph-checkpoint-postgres |
|
| LLM | Claude Sonnet 4.6 (configurable: OpenAI, Azure OpenAI, Google) |
|
||||||
| LLM | Claude Sonnet 4.6(可切换 OpenAI、Google 等) |
|
| Frontend | React 19, TypeScript, Vite |
|
||||||
| 前端 | React |
|
| Testing | pytest (backend), vitest + happy-dom (frontend) |
|
||||||
| 部署 | Docker Compose |
|
| Deployment | Docker Compose |
|
||||||
|
|
||||||
## 项目结构
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone <repo-url>
|
||||||
|
cd smart-support
|
||||||
|
|
||||||
|
# Configure your LLM API key
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env: set LLM_PROVIDER and the corresponding API key
|
||||||
|
# anthropic -> ANTHROPIC_API_KEY
|
||||||
|
# openai -> OPENAI_API_KEY
|
||||||
|
# azure_openai -> AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT + AZURE_OPENAI_DEPLOYMENT
|
||||||
|
# google -> GOOGLE_API_KEY
|
||||||
|
|
||||||
|
# Start all services
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Open the app
|
||||||
|
open http://localhost
|
||||||
|
```
|
||||||
|
|
||||||
|
### Local Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start only PostgreSQL via Docker (exposed on port 5433)
|
||||||
|
docker compose up postgres -d
|
||||||
|
|
||||||
|
# Backend (in one terminal)
|
||||||
|
cd backend
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
uvicorn app.main:app --host 0.0.0.0 --port 8001 --reload
|
||||||
|
|
||||||
|
# Frontend (in another terminal)
|
||||||
|
cd frontend
|
||||||
|
npm install
|
||||||
|
npm run dev # http://localhost:5173 (proxies /api and /ws to :8001)
|
||||||
|
```
|
||||||
|
|
||||||
|
See [Deployment Guide](docs/deployment.md) for production setup, HTTPS, and scaling.
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
smart-support/
|
smart-support/
|
||||||
├── backend/
|
├── backend/
|
||||||
│ ├── app/
|
│ ├── app/
|
||||||
│ │ ├── main.py # FastAPI + WebSocket 入口
|
│ │ ├── main.py # FastAPI + WebSocket entry point
|
||||||
│ │ ├── graph.py # LangGraph Supervisor 配置
|
│ │ ├── graph.py # LangGraph Supervisor construction
|
||||||
│ │ ├── agents/ # Agent 定义 + 工具
|
│ │ ├── graph_context.py # Typed wrapper for graph + classifier + registry
|
||||||
│ │ ├── registry.py # YAML Agent 注册表加载器
|
│ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting
|
||||||
│ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成
|
│ │ ├── ws_context.py # WebSocket dependency bundle
|
||||||
│ │ ├── replay/ # 对话回放 API
|
│ │ ├── auth.py # API key authentication middleware
|
||||||
│ │ ├── analytics/ # 数据分析查询 + API
|
│ │ ├── api_utils.py # Shared API response helpers
|
||||||
│ │ └── callbacks.py # Token 用量统计
|
│ │ ├── safety.py # Confirmation rules + MCP error taxonomy
|
||||||
│ ├── agents.yaml # Agent 注册表配置
|
│ │ ├── agents/ # Agent definitions and tools
|
||||||
│ ├── templates/ # 垂直行业模板
|
│ │ ├── registry.py # YAML agent registry loader
|
||||||
│ └── tests/
|
│ │ ├── openapi/ # OpenAPI parser, classifier, and review API
|
||||||
├── frontend/ # React 聊天 UI + 回放 + 仪表盘
|
│ │ ├── replay/ # Conversation replay API
|
||||||
├── docker-compose.yml # PostgreSQL + 应用
|
│ │ └── analytics/ # Analytics queries and API
|
||||||
└── pyproject.toml
|
│ ├── agents.yaml # Agent registry configuration
|
||||||
|
│ ├── templates/ # Vertical industry templates
|
||||||
|
│ └── tests/ # Unit, integration, and E2E tests
|
||||||
|
├── frontend/
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── pages/ # Chat, Replay, Dashboard, Review pages
|
||||||
|
│ │ ├── components/ # NavBar, Layout, MetricCard, ReplayTimeline
|
||||||
|
│ │ ├── hooks/ # useWebSocket with reconnect support
|
||||||
|
│ │ └── api.ts # Typed API client
|
||||||
|
│ └── Dockerfile # Multi-stage nginx build
|
||||||
|
├── docs/ # Architecture, deployment, guides
|
||||||
|
├── docker-compose.yml # Full-stack compose
|
||||||
|
└── .env.example # Environment variable template
|
||||||
```
|
```
|
||||||
|
|
||||||
## 快速开始
|
## API Endpoints
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|--------|------|------|-------------|
|
||||||
|
| WS | `/ws` | Token | Main WebSocket chat endpoint (`?token=<key>`) |
|
||||||
|
| GET | `/api/health` | No | Health check |
|
||||||
|
| GET | `/api/conversations` | API Key | List conversations (paginated) |
|
||||||
|
| GET | `/api/replay/{thread_id}` | API Key | Replay conversation steps (paginated) |
|
||||||
|
| GET | `/api/analytics` | API Key | Analytics summary (`?range=7d`) |
|
||||||
|
| POST | `/api/openapi/import` | API Key | Start OpenAPI import job |
|
||||||
|
| GET | `/api/openapi/jobs/{id}` | API Key | Check import job status |
|
||||||
|
| GET | `/api/openapi/jobs/{id}/classifications` | API Key | Get endpoint classifications |
|
||||||
|
| PUT | `/api/openapi/jobs/{id}/classifications/{idx}` | API Key | Update a classification |
|
||||||
|
| POST | `/api/openapi/jobs/{id}/approve` | API Key | Approve and generate tools |
|
||||||
|
|
||||||
|
Authentication is controlled by the `ADMIN_API_KEY` environment variable.
|
||||||
|
API Key endpoints require the `X-API-Key` header. When `ADMIN_API_KEY` is unset, auth is disabled.
|
||||||
|
|
||||||
|
## Running Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 启动 PostgreSQL 和应用
|
# Backend (516 tests, 94% coverage)
|
||||||
docker compose up
|
cd backend
|
||||||
|
pytest --cov=app --cov-report=term-missing
|
||||||
|
|
||||||
# 访问聊天界面
|
# Frontend (23 tests, vitest + happy-dom)
|
||||||
open http://localhost:8000
|
cd frontend
|
||||||
|
npm test
|
||||||
```
|
```
|
||||||
|
|
||||||
## Agent 配置示例
|
Backend coverage is enforced at 80%+.
|
||||||
|
|
||||||
```yaml
|
## Documentation
|
||||||
# agents.yaml
|
|
||||||
agents:
|
|
||||||
- name: order_lookup
|
|
||||||
description: 查询订单状态、物流信息
|
|
||||||
permission: read
|
|
||||||
personality:
|
|
||||||
tone: professional
|
|
||||||
greeting: "您好,我来帮您查询订单信息。"
|
|
||||||
tools:
|
|
||||||
- get_order_status
|
|
||||||
- get_tracking_info
|
|
||||||
|
|
||||||
- name: order_actions
|
| Document | Description |
|
||||||
description: 取消订单、修改订单
|
|----------|-------------|
|
||||||
permission: write # 触发人工确认
|
| [Architecture](docs/ARCHITECTURE.md) | System design, component diagram, data flow, ADRs |
|
||||||
personality:
|
| [Development Plan](docs/DEVELOPMENT-PLAN.md) | Phase breakdown, task checklists, and status |
|
||||||
tone: careful
|
| [Agent Config Guide](docs/agent-config-guide.md) | agents.yaml format, fields, templates, routing logic |
|
||||||
greeting: "我可以帮您处理订单变更,所有操作都会先经过您的确认。"
|
| [OpenAPI Import Guide](docs/openapi-import-guide.md) | Auto-discovery workflow, REST API, SSRF protection |
|
||||||
tools:
|
| [Deployment Guide](docs/deployment.md) | Docker, local dev, production, HTTPS, backups, scaling |
|
||||||
- cancel_order
|
| [Demo Script](docs/demo-script.md) | Step-by-step live demo walkthrough (5 scenes) |
|
||||||
- modify_order
|
| [UX Design System](docs/ux_design_system.md) | Color palette, typography, component patterns, CSS tokens |
|
||||||
|
|
||||||
- name: discount
|
|
||||||
description: 发放优惠券、折扣码
|
|
||||||
permission: write
|
|
||||||
tools:
|
|
||||||
- apply_discount
|
|
||||||
- generate_coupon
|
|
||||||
```
|
|
||||||
|
|
||||||
## OpenAPI 自动接入
|
|
||||||
|
|
||||||
不需要手动写 MCP 连接器。粘贴你的 API 规范 URL:
|
|
||||||
|
|
||||||
1. 框架解析 OpenAPI 3.0 规范
|
|
||||||
2. LLM 自动分类每个端点(读/写、客户参数、Agent 分组)
|
|
||||||
3. 运维人员审核分类结果
|
|
||||||
4. 自动生成 MCP 服务器 + Agent YAML 配置
|
|
||||||
5. 新工具立即可用
|
|
||||||
|
|
||||||
## 安全设计
|
|
||||||
|
|
||||||
- **人工确认** - 所有写操作需要客户或运维人员批准
|
|
||||||
- **SSRF 防护** - OpenAPI URL 导入时屏蔽内网地址和 DNS 重绑定攻击
|
|
||||||
- **操作审计** - 每个操作记录 Agent、参数、结果、时间戳
|
|
||||||
- **权限隔离** - 每个 Agent 只能访问其配置的工具集
|
|
||||||
- **中断超时** - 30 分钟未确认的操作自动取消,防止过期审批
|
|
||||||
|
|
||||||
## 开发阶段
|
|
||||||
|
|
||||||
| 阶段 | 周期 | 内容 |
|
|
||||||
|------|------|------|
|
|
||||||
| Phase 1 | 第 1-3 周 | 核心框架:Chat UI + Supervisor + Agent 注册表 + 中断流程 |
|
|
||||||
| Phase 2 | 第 3-4 周 | 多 Agent 路由 + Webhook 升级 + 垂直模板 |
|
|
||||||
| Phase 3 | 第 4-6 周 | OpenAPI 自动发现 + MCP 服务器生成 + SSRF 防护 |
|
|
||||||
| Phase 4 | 第 6-7 周 | 对话回放 + 数据分析仪表盘 |
|
|
||||||
|
|
||||||
## 目标用户
|
|
||||||
|
|
||||||
中型电商公司(日均 500-5000 订单,5-20 名客服)的客户体验负责人。
|
|
||||||
|
|
||||||
他们的痛点:客服需要在 Zendesk 和 Shopify 后台之间反复切换,手动执行查询和操作。Smart Support 让 AI 直接完成这些操作,人工只需审批关键步骤。
|
|
||||||
|
|
||||||
## 相关文档
|
|
||||||
|
|
||||||
- [设计文档](design-doc.md) - 问题定义、约束、方案选择
|
|
||||||
- [CEO 计划](ceo-plan.md) - 产品愿景、范围决策
|
|
||||||
- [工程评审计划](eng-review-plan.md) - 架构决策、测试策略、失败模式
|
|
||||||
- [测试计划](eng-review-test-plan.md) - 测试路径、边界情况、E2E 流程
|
|
||||||
- [待办事项](TODOS.md) - 延迟到后续阶段的工作
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,34 @@
|
|||||||
# Database
|
# Smart Support Backend -- environment variables
|
||||||
|
# Copy to .env and fill in your values
|
||||||
|
|
||||||
|
# Required: PostgreSQL connection string
|
||||||
DATABASE_URL=postgresql://smart_support:dev_password@localhost:5432/smart_support
|
DATABASE_URL=postgresql://smart_support:dev_password@localhost:5432/smart_support
|
||||||
|
|
||||||
# LLM Provider: anthropic | openai | google
|
# Required: LLM provider configuration
|
||||||
|
# provider: anthropic | openai | google
|
||||||
LLM_PROVIDER=anthropic
|
LLM_PROVIDER=anthropic
|
||||||
LLM_MODEL=claude-sonnet-4-6
|
LLM_MODEL=claude-sonnet-4-6
|
||||||
|
|
||||||
# API Keys (set the one matching your LLM_PROVIDER)
|
# API keys -- provide the one matching LLM_PROVIDER
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
|
|
||||||
# Session
|
# Optional: webhook endpoint for escalation notifications
|
||||||
|
# The backend will POST a JSON payload when a conversation is escalated.
|
||||||
|
WEBHOOK_URL=
|
||||||
|
WEBHOOK_TIMEOUT_SECONDS=10
|
||||||
|
WEBHOOK_MAX_RETRIES=3
|
||||||
|
|
||||||
|
# Session management
|
||||||
SESSION_TTL_MINUTES=30
|
SESSION_TTL_MINUTES=30
|
||||||
INTERRUPT_TTL_MINUTES=30
|
INTERRUPT_TTL_MINUTES=30
|
||||||
|
|
||||||
# Server
|
# Optional: load a named agent template instead of agents.yaml
|
||||||
|
# Leave blank to use the default agents.yaml in the backend directory.
|
||||||
|
# Available templates: ecommerce, saas, generic
|
||||||
|
TEMPLATE_NAME=
|
||||||
|
|
||||||
|
# Server binding
|
||||||
WS_HOST=0.0.0.0
|
WS_HOST=0.0.0.0
|
||||||
WS_PORT=8000
|
WS_PORT=8000
|
||||||
|
|||||||
@@ -20,6 +20,17 @@ agents:
|
|||||||
tools:
|
tools:
|
||||||
- cancel_order
|
- cancel_order
|
||||||
|
|
||||||
|
- name: discount
|
||||||
|
description: "Applies discounts and generates coupon codes. Use when the customer asks for discounts, promotions, or coupon codes."
|
||||||
|
permission: write
|
||||||
|
personality:
|
||||||
|
tone: "generous and accommodating"
|
||||||
|
greeting: "I can help you with discounts and coupons!"
|
||||||
|
escalation_message: "Let me connect you with our promotions team."
|
||||||
|
tools:
|
||||||
|
- apply_discount
|
||||||
|
- generate_coupon
|
||||||
|
|
||||||
- name: fallback
|
- name: fallback
|
||||||
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
||||||
permission: read
|
permission: read
|
||||||
|
|||||||
149
backend/alembic.ini
Normal file
149
backend/alembic.ini
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts.
|
||||||
|
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||||
|
# format, relative to the token %(here)s which refers to the location of this
|
||||||
|
# ini file
|
||||||
|
script_location = %(here)s/alembic
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||||
|
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory. for multiple paths, the path separator
|
||||||
|
# is defined by "path_separator" below.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the tzdata library which can be installed by adding
|
||||||
|
# `alembic[tz]` to the pip requirements.
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to <script_location>/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "path_separator"
|
||||||
|
# below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||||
|
|
||||||
|
# path_separator; This indicates what character is used to split lists of file
|
||||||
|
# paths, including version_locations and prepend_sys_path within configparser
|
||||||
|
# files such as alembic.ini.
|
||||||
|
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||||
|
# to provide os-dependent path splitting.
|
||||||
|
#
|
||||||
|
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||||
|
# take place if path_separator is not present in alembic.ini. If this
|
||||||
|
# option is omitted entirely, fallback logic is as follows:
|
||||||
|
#
|
||||||
|
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||||
|
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||||
|
# behavior of splitting on spaces and/or commas.
|
||||||
|
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||||
|
# behavior of splitting on spaces, commas, or colons.
|
||||||
|
#
|
||||||
|
# Valid values for path_separator are:
|
||||||
|
#
|
||||||
|
# path_separator = :
|
||||||
|
# path_separator = ;
|
||||||
|
# path_separator = space
|
||||||
|
# path_separator = newline
|
||||||
|
#
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
path_separator = os
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
# database URL. This is consumed by the user-maintained env.py script only.
|
||||||
|
# other means of configuring database URLs may be customized within the env.py
|
||||||
|
# file.
|
||||||
|
sqlalchemy.url =
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = module
|
||||||
|
# ruff.module = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration. This is also consumed by the user-maintained
|
||||||
|
# env.py script only.
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
1
backend/alembic/README
Normal file
1
backend/alembic/README
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Generic single-database configuration.
|
||||||
67
backend/alembic/env.py
Normal file
67
backend/alembic/env.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Alembic environment configuration for smart-support."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# No SQLAlchemy ORM models -- we use raw DDL migrations
|
||||||
|
target_metadata = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_url() -> str:
|
||||||
|
"""Read DATABASE_URL from environment, falling back to alembic.ini."""
|
||||||
|
return os.environ.get("DATABASE_URL", "") or config.get_main_option(
|
||||||
|
"sqlalchemy.url", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
Configures the context with just a URL so that an Engine
|
||||||
|
is not required.
|
||||||
|
"""
|
||||||
|
url = _get_url()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode with a live database connection."""
|
||||||
|
configuration = config.get_section(config.config_ini_section, {})
|
||||||
|
configuration["sqlalchemy.url"] = _get_url()
|
||||||
|
|
||||||
|
connectable = engine_from_config(
|
||||||
|
configuration,
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
backend/alembic/script.py.mako
Normal file
28
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
92
backend/alembic/versions/001_initial_schema.py
Normal file
92
backend/alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Initial schema -- all application tables.
|
||||||
|
|
||||||
|
Revision ID: a1b2c3d4e5f6
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-04-06
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "a1b2c3d4e5f6"
|
||||||
|
down_revision: str | None = None
|
||||||
|
branch_labels: tuple[str, ...] | None = None
|
||||||
|
depends_on: tuple[str, ...] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
total_tokens INTEGER NOT NULL DEFAULT 0,
|
||||||
|
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||||
|
status TEXT NOT NULL DEFAULT 'active'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS active_interrupts (
|
||||||
|
interrupt_id TEXT PRIMARY KEY,
|
||||||
|
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
params JSONB NOT NULL DEFAULT '{}',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
resolved_at TIMESTAMPTZ,
|
||||||
|
resolution TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS analytics_events (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
thread_id TEXT NOT NULL,
|
||||||
|
event_type TEXT NOT NULL,
|
||||||
|
agent_name TEXT,
|
||||||
|
tool_name TEXT,
|
||||||
|
tokens_used INTEGER NOT NULL DEFAULT 0,
|
||||||
|
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||||
|
duration_ms INTEGER,
|
||||||
|
success BOOLEAN,
|
||||||
|
error_message TEXT,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Migration columns added in Phase 4
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE conversations
|
||||||
|
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
||||||
|
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
|
||||||
|
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP TABLE IF EXISTS analytics_events")
|
||||||
|
op.execute("DROP TABLE IF EXISTS sessions")
|
||||||
|
op.execute("DROP TABLE IF EXISTS active_interrupts")
|
||||||
|
op.execute("DROP TABLE IF EXISTS conversations")
|
||||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.discount import apply_discount, generate_coupon
|
||||||
from app.agents.fallback import fallback_respond
|
from app.agents.fallback import fallback_respond
|
||||||
from app.agents.order_actions import cancel_order
|
from app.agents.order_actions import cancel_order
|
||||||
from app.agents.order_lookup import get_order_status, get_tracking_info
|
from app.agents.order_lookup import get_order_status, get_tracking_info
|
||||||
@@ -16,6 +17,8 @@ _TOOL_MAP: dict[str, BaseTool] = {
|
|||||||
"get_tracking_info": get_tracking_info,
|
"get_tracking_info": get_tracking_info,
|
||||||
"cancel_order": cancel_order,
|
"cancel_order": cancel_order,
|
||||||
"fallback_respond": fallback_respond,
|
"fallback_respond": fallback_respond,
|
||||||
|
"apply_discount": apply_discount,
|
||||||
|
"generate_coupon": generate_coupon,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
79
backend/app/agents/discount.py
Normal file
79
backend/app/agents/discount.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Discount agent tools -- apply discounts and generate coupons."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langgraph.types import interrupt
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def apply_discount(order_id: str, discount_percent: int) -> dict:
|
||||||
|
"""Apply a discount to an order. Requires human approval before execution."""
|
||||||
|
if discount_percent < 1 or discount_percent > 100:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"order_id": order_id,
|
||||||
|
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = interrupt(
|
||||||
|
{
|
||||||
|
"action": "apply_discount",
|
||||||
|
"order_id": order_id,
|
||||||
|
"discount_percent": discount_percent,
|
||||||
|
"message": (
|
||||||
|
f"Please confirm: apply {discount_percent}% discount to order {order_id}?"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, bool):
|
||||||
|
approved = response
|
||||||
|
elif isinstance(response, dict):
|
||||||
|
approved = response.get("approved", False)
|
||||||
|
else:
|
||||||
|
approved = bool(response)
|
||||||
|
|
||||||
|
if approved:
|
||||||
|
return {
|
||||||
|
"status": "applied",
|
||||||
|
"order_id": order_id,
|
||||||
|
"discount_percent": discount_percent,
|
||||||
|
"message": (
|
||||||
|
f"{discount_percent}% discount applied to order {order_id}."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"status": "declined",
|
||||||
|
"order_id": order_id,
|
||||||
|
"message": f"Discount for order {order_id} was declined.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def generate_coupon(discount_percent: int, expiry_days: int = 30) -> dict:
|
||||||
|
"""Generate a coupon code with the specified discount percentage."""
|
||||||
|
if discount_percent < 1 or discount_percent > 100:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
|
||||||
|
}
|
||||||
|
if expiry_days < 1:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Invalid expiry: {expiry_days} days. Must be at least 1.",
|
||||||
|
}
|
||||||
|
|
||||||
|
coupon_code = f"SAVE{discount_percent}-{uuid.uuid4().hex[:8].upper()}"
|
||||||
|
return {
|
||||||
|
"status": "generated",
|
||||||
|
"coupon_code": coupon_code,
|
||||||
|
"discount_percent": discount_percent,
|
||||||
|
"expiry_days": expiry_days,
|
||||||
|
"message": (
|
||||||
|
f"Coupon {coupon_code} generated: {discount_percent}% off, "
|
||||||
|
f"valid for {expiry_days} days."
|
||||||
|
),
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Fallback agent tools -- handles unmatched intents."""
|
"""Fallback agent tools -- handles unmatched intents and clarification requests."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -13,6 +13,7 @@ def fallback_respond(query: str) -> str:
|
|||||||
"Here's what I can do:\n"
|
"Here's what I can do:\n"
|
||||||
"- Check order status (e.g., 'What is the status of order 1042?')\n"
|
"- Check order status (e.g., 'What is the status of order 1042?')\n"
|
||||||
"- Get tracking information (e.g., 'Track order 1042')\n"
|
"- Get tracking information (e.g., 'Track order 1042')\n"
|
||||||
"- Cancel an order (e.g., 'Cancel order 1042')\n\n"
|
"- Cancel an order (e.g., 'Cancel order 1042')\n"
|
||||||
|
"- Apply discounts or generate coupons\n\n"
|
||||||
"Could you please rephrase your request?"
|
"Could you please rephrase your request?"
|
||||||
)
|
)
|
||||||
|
|||||||
3
backend/app/analytics/__init__.py
Normal file
3
backend/app/analytics/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Analytics module -- event recording and dashboard queries."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
60
backend/app/analytics/api.py
Normal file
60
backend/app/analytics/api.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""Analytics API router -- dashboard metrics endpoint."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
|
|
||||||
|
from app.analytics.queries import get_analytics
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.auth import require_admin_api_key
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/api/v1/analytics",
|
||||||
|
tags=["analytics"],
|
||||||
|
dependencies=[Depends(require_admin_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
|
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
||||||
|
_DEFAULT_RANGE = "7d"
|
||||||
|
_MAX_RANGE_DAYS = 365
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_pool(request: Request) -> AsyncConnectionPool:
|
||||||
|
"""Dependency: extract the shared pool from app state."""
|
||||||
|
return request.app.state.pool
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_range(range_str: str) -> int:
|
||||||
|
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
|
||||||
|
match = _RANGE_PATTERN.match(range_str)
|
||||||
|
if not match:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid range format. Expected: '<N>d' e.g. '7d', '30d'.",
|
||||||
|
)
|
||||||
|
days = int(match.group(1))
|
||||||
|
if days < 1 or days > _MAX_RANGE_DAYS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Range must be between 1 and {_MAX_RANGE_DAYS} days.",
|
||||||
|
)
|
||||||
|
return days
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def analytics(
|
||||||
|
request: Request,
|
||||||
|
range: str = Query(default=_DEFAULT_RANGE, alias="range"), # noqa: A002
|
||||||
|
) -> dict:
|
||||||
|
"""Return aggregated analytics metrics for the given time range."""
|
||||||
|
range_days = _parse_range(range)
|
||||||
|
pool = await _get_pool(request)
|
||||||
|
result = await get_analytics(pool, range_days=range_days)
|
||||||
|
return envelope(asdict(result))
|
||||||
97
backend/app/analytics/event_recorder.py
Normal file
97
backend/app/analytics/event_recorder.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Analytics event recorder -- Protocol and implementations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from psycopg.types.json import Json
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
_INSERT_SQL = """
|
||||||
|
INSERT INTO analytics_events
|
||||||
|
(thread_id, event_type, agent_name, tool_name, tokens_used, cost_usd,
|
||||||
|
duration_ms, success, error_message, metadata)
|
||||||
|
VALUES
|
||||||
|
(%(thread_id)s, %(event_type)s, %(agent_name)s, %(tool_name)s,
|
||||||
|
%(tokens_used)s, %(cost_usd)s, %(duration_ms)s, %(success)s,
|
||||||
|
%(error_message)s, %(metadata)s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class AnalyticsRecorder(Protocol):
|
||||||
|
"""Protocol for recording analytics events."""
|
||||||
|
|
||||||
|
async def record(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
event_type: str,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
tokens_used: int = 0,
|
||||||
|
cost_usd: float = 0.0,
|
||||||
|
duration_ms: int | None = None,
|
||||||
|
success: bool | None = None,
|
||||||
|
error_message: str | None = None,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpAnalyticsRecorder:
|
||||||
|
"""No-op implementation for testing or when the DB is unavailable."""
|
||||||
|
|
||||||
|
async def record(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
event_type: str,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
tokens_used: int = 0,
|
||||||
|
cost_usd: float = 0.0,
|
||||||
|
duration_ms: int | None = None,
|
||||||
|
success: bool | None = None,
|
||||||
|
error_message: str | None = None,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresAnalyticsRecorder:
|
||||||
|
"""Postgres-backed analytics recorder -- INSERTs into analytics_events."""
|
||||||
|
|
||||||
|
def __init__(self, pool: AsyncConnectionPool) -> None:
|
||||||
|
self._pool = pool
|
||||||
|
|
||||||
|
async def record(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
event_type: str,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
tokens_used: int = 0,
|
||||||
|
cost_usd: float = 0.0,
|
||||||
|
duration_ms: int | None = None,
|
||||||
|
success: bool | None = None,
|
||||||
|
error_message: str | None = None,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert one analytics event row."""
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"event_type": event_type,
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"tokens_used": tokens_used,
|
||||||
|
"cost_usd": cost_usd,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"success": success,
|
||||||
|
"error_message": error_message,
|
||||||
|
"metadata": Json(metadata or {}),
|
||||||
|
}
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(_INSERT_SQL, params)
|
||||||
38
backend/app/analytics/models.py
Normal file
38
backend/app/analytics/models.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Value objects for analytics dashboard."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AgentUsage:
|
||||||
|
"""Agent usage statistics within a time range."""
|
||||||
|
|
||||||
|
agent: str
|
||||||
|
count: int
|
||||||
|
percentage: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class InterruptStats:
|
||||||
|
"""Interrupt approval/rejection statistics within a time range."""
|
||||||
|
|
||||||
|
total: int = 0
|
||||||
|
approved: int = 0
|
||||||
|
rejected: int = 0
|
||||||
|
expired: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AnalyticsResult:
|
||||||
|
"""Full analytics result for a given time range."""
|
||||||
|
|
||||||
|
range: str
|
||||||
|
total_conversations: int
|
||||||
|
resolution_rate: float
|
||||||
|
escalation_rate: float
|
||||||
|
avg_turns_per_conversation: float
|
||||||
|
avg_cost_per_conversation_usd: float
|
||||||
|
agent_usage: tuple[AgentUsage, ...]
|
||||||
|
interrupt_stats: InterruptStats
|
||||||
184
backend/app/analytics/queries.py
Normal file
184
backend/app/analytics/queries.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""Analytics query functions -- all async, take pool + range_days."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
_RESOLUTION_RATE_SQL = """
|
||||||
|
SELECT
|
||||||
|
CASE WHEN COUNT(*) = 0 THEN 0.0
|
||||||
|
ELSE COUNT(*) FILTER (WHERE resolution_type = 'resolved')::float / COUNT(*)
|
||||||
|
END AS rate
|
||||||
|
FROM conversations
|
||||||
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ESCALATION_RATE_SQL = """
|
||||||
|
SELECT
|
||||||
|
CASE WHEN COUNT(*) = 0 THEN 0.0
|
||||||
|
ELSE COUNT(*) FILTER (WHERE resolution_type = 'escalated')::float / COUNT(*)
|
||||||
|
END AS rate
|
||||||
|
FROM conversations
|
||||||
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
|
"""
|
||||||
|
|
||||||
|
_TOTAL_CONVERSATIONS_SQL = """
|
||||||
|
SELECT COUNT(*) AS total
|
||||||
|
FROM conversations
|
||||||
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
|
"""
|
||||||
|
|
||||||
|
_AVG_TURNS_SQL = """
|
||||||
|
SELECT COALESCE(AVG(turn_count), 0.0) AS avg_turns
|
||||||
|
FROM conversations
|
||||||
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
|
"""
|
||||||
|
|
||||||
|
_COST_PER_CONVERSATION_SQL = """
|
||||||
|
SELECT COALESCE(AVG(total_cost_usd), 0.0) AS avg_cost
|
||||||
|
FROM conversations
|
||||||
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
|
"""
|
||||||
|
|
||||||
|
_AGENT_USAGE_SQL = """
|
||||||
|
SELECT
|
||||||
|
agent,
|
||||||
|
COUNT(*) AS count,
|
||||||
|
ROUND(COUNT(*) * 100.0 / NULLIF(SUM(COUNT(*)) OVER (), 0), 2) AS percentage
|
||||||
|
FROM (
|
||||||
|
SELECT UNNEST(agents_used) AS agent
|
||||||
|
FROM conversations
|
||||||
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
|
AND agents_used IS NOT NULL
|
||||||
|
) sub
|
||||||
|
GROUP BY agent
|
||||||
|
ORDER BY count DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
_INTERRUPT_STATS_SQL = """
|
||||||
|
SELECT
|
||||||
|
COUNT(*) FILTER (WHERE event_type = 'interrupt') AS total,
|
||||||
|
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND success = TRUE) AS approved,
|
||||||
|
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND success = FALSE
|
||||||
|
AND error_message IS NULL) AS rejected,
|
||||||
|
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND error_message = 'expired') AS expired
|
||||||
|
FROM analytics_events
|
||||||
|
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def resolution_rate(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||||
|
"""Return the fraction of resolved conversations in the given range."""
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(_RESOLUTION_RATE_SQL, {"days": range_days})
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
return 0.0
|
||||||
|
return float(row.get("rate") or 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
async def escalation_rate(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||||
|
"""Return the fraction of escalated conversations in the given range."""
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(_ESCALATION_RATE_SQL, {"days": range_days})
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
return 0.0
|
||||||
|
return float(row.get("rate") or 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
async def _total_conversations(pool: AsyncConnectionPool, range_days: int) -> int:
|
||||||
|
"""Return the total number of conversations in the given range."""
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(_TOTAL_CONVERSATIONS_SQL, {"days": range_days})
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
return 0
|
||||||
|
return int(row.get("total") or 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def _avg_turns(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||||
|
"""Return the average turn count per conversation in the given range."""
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(_AVG_TURNS_SQL, {"days": range_days})
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
return 0.0
|
||||||
|
return float(row.get("avg_turns") or 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
async def cost_per_conversation(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||||
|
"""Return the average cost per conversation in the given range."""
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(_COST_PER_CONVERSATION_SQL, {"days": range_days})
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
return 0.0
|
||||||
|
return float(row.get("avg_cost") or 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
async def agent_usage(
|
||||||
|
pool: AsyncConnectionPool, range_days: int
|
||||||
|
) -> tuple[AgentUsage, ...]:
|
||||||
|
"""Return per-agent usage statistics for the given range."""
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(_AGENT_USAGE_SQL, {"days": range_days})
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
if not rows:
|
||||||
|
return ()
|
||||||
|
return tuple(
|
||||||
|
AgentUsage(
|
||||||
|
agent=row["agent"],
|
||||||
|
count=int(row["count"]),
|
||||||
|
percentage=float(row["percentage"]),
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def interrupt_stats(
|
||||||
|
pool: AsyncConnectionPool, range_days: int
|
||||||
|
) -> InterruptStats:
|
||||||
|
"""Return interrupt approval/rejection statistics for the given range."""
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(_INTERRUPT_STATS_SQL, {"days": range_days})
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
return InterruptStats()
|
||||||
|
return InterruptStats(
|
||||||
|
total=int(row.get("total") or 0),
|
||||||
|
approved=int(row.get("approved") or 0),
|
||||||
|
rejected=int(row.get("rejected") or 0),
|
||||||
|
expired=int(row.get("expired") or 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_analytics(
|
||||||
|
pool: AsyncConnectionPool, range_days: int
|
||||||
|
) -> AnalyticsResult:
|
||||||
|
"""Aggregate all analytics metrics into a single AnalyticsResult."""
|
||||||
|
res_rate, esc_rate, cost, usage, i_stats, total, avg_t = await asyncio.gather(
|
||||||
|
resolution_rate(pool, range_days),
|
||||||
|
escalation_rate(pool, range_days),
|
||||||
|
cost_per_conversation(pool, range_days),
|
||||||
|
agent_usage(pool, range_days),
|
||||||
|
interrupt_stats(pool, range_days),
|
||||||
|
_total_conversations(pool, range_days),
|
||||||
|
_avg_turns(pool, range_days),
|
||||||
|
)
|
||||||
|
return AnalyticsResult(
|
||||||
|
range=f"{range_days}d",
|
||||||
|
total_conversations=total,
|
||||||
|
resolution_rate=res_rate,
|
||||||
|
escalation_rate=esc_rate,
|
||||||
|
avg_turns_per_conversation=avg_t,
|
||||||
|
avg_cost_per_conversation_usd=cost,
|
||||||
|
agent_usage=usage,
|
||||||
|
interrupt_stats=i_stats,
|
||||||
|
)
|
||||||
10
backend/app/api_utils.py
Normal file
10
backend/app/api_utils.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""Shared API response helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
||||||
|
"""Wrap API response data in a standard envelope format."""
|
||||||
|
return {"success": success, "data": data, "error": error}
|
||||||
72
backend/app/auth.py
Normal file
72
backend/app/auth.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""API key authentication for admin endpoints and WebSocket connections."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_admin_api_key(request: Request) -> str:
|
||||||
|
"""Retrieve the configured admin API key from app settings.
|
||||||
|
|
||||||
|
Returns empty string if settings are not configured (test/dev mode).
|
||||||
|
"""
|
||||||
|
settings = getattr(request.app.state, "settings", None)
|
||||||
|
if settings is None:
|
||||||
|
return ""
|
||||||
|
key = getattr(settings, "admin_api_key", "")
|
||||||
|
return key if isinstance(key, str) else ""
|
||||||
|
|
||||||
|
|
||||||
|
async def require_admin_api_key(
|
||||||
|
request: Request,
|
||||||
|
api_key: Annotated[str | None, Depends(_API_KEY_HEADER)] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Dependency that enforces API key authentication on admin endpoints.
|
||||||
|
|
||||||
|
Skips validation when no admin_api_key is configured (dev mode).
|
||||||
|
"""
|
||||||
|
expected = _get_admin_api_key(request)
|
||||||
|
if not expected:
|
||||||
|
return
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing X-API-Key header",
|
||||||
|
)
|
||||||
|
if not secrets.compare_digest(api_key, expected):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_ws_token(
|
||||||
|
ws: WebSocket,
|
||||||
|
token: str | None = Query(default=None),
|
||||||
|
) -> None:
|
||||||
|
"""Verify WebSocket connection token from query parameter.
|
||||||
|
|
||||||
|
Skips validation when no admin_api_key is configured (dev mode).
|
||||||
|
Usage: ws://host/ws?token=<api_key>
|
||||||
|
"""
|
||||||
|
settings = ws.app.state.settings
|
||||||
|
expected = settings.admin_api_key
|
||||||
|
if not expected:
|
||||||
|
return
|
||||||
|
|
||||||
|
if token is None or not secrets.compare_digest(token, expected):
|
||||||
|
await ws.close(code=4001, reason="Unauthorized")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid or missing WebSocket token",
|
||||||
|
)
|
||||||
@@ -17,7 +17,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
database_url: str
|
database_url: str
|
||||||
|
|
||||||
llm_provider: Literal["anthropic", "openai", "google"] = "anthropic"
|
llm_provider: Literal["anthropic", "openai", "azure_openai", "google"] = "anthropic"
|
||||||
llm_model: str = "claude-sonnet-4-6"
|
llm_model: str = "claude-sonnet-4-6"
|
||||||
|
|
||||||
session_ttl_minutes: int = 30
|
session_ttl_minutes: int = 30
|
||||||
@@ -26,8 +26,22 @@ class Settings(BaseSettings):
|
|||||||
ws_host: str = "0.0.0.0"
|
ws_host: str = "0.0.0.0"
|
||||||
ws_port: int = 8000
|
ws_port: int = 8000
|
||||||
|
|
||||||
|
webhook_url: str = ""
|
||||||
|
webhook_timeout_seconds: int = 10
|
||||||
|
webhook_max_retries: int = 3
|
||||||
|
|
||||||
|
template_name: str = ""
|
||||||
|
|
||||||
|
log_format: str = "console" # "console" for dev, "json" for production
|
||||||
|
|
||||||
|
admin_api_key: str = ""
|
||||||
|
|
||||||
anthropic_api_key: str = ""
|
anthropic_api_key: str = ""
|
||||||
openai_api_key: str = ""
|
openai_api_key: str = ""
|
||||||
|
azure_openai_api_key: str = ""
|
||||||
|
azure_openai_endpoint: str = ""
|
||||||
|
azure_openai_api_version: str = "2024-12-01-preview"
|
||||||
|
azure_openai_deployment: str = ""
|
||||||
google_api_key: str = ""
|
google_api_key: str = ""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -35,6 +49,7 @@ class Settings(BaseSettings):
|
|||||||
key_map = {
|
key_map = {
|
||||||
"anthropic": self.anthropic_api_key,
|
"anthropic": self.anthropic_api_key,
|
||||||
"openai": self.openai_api_key,
|
"openai": self.openai_api_key,
|
||||||
|
"azure_openai": self.azure_openai_api_key,
|
||||||
"google": self.google_api_key,
|
"google": self.google_api_key,
|
||||||
}
|
}
|
||||||
key = key_map.get(self.llm_provider, "")
|
key = key_map.get(self.llm_provider, "")
|
||||||
@@ -43,4 +58,13 @@ class Settings(BaseSettings):
|
|||||||
f"API key for provider '{self.llm_provider}' is required. "
|
f"API key for provider '{self.llm_provider}' is required. "
|
||||||
f"Set the corresponding environment variable."
|
f"Set the corresponding environment variable."
|
||||||
)
|
)
|
||||||
|
if self.llm_provider == "azure_openai":
|
||||||
|
if not self.azure_openai_endpoint:
|
||||||
|
raise ValueError(
|
||||||
|
"AZURE_OPENAI_ENDPOINT is required for azure_openai provider."
|
||||||
|
)
|
||||||
|
if not self.azure_openai_deployment:
|
||||||
|
raise ValueError(
|
||||||
|
"AZURE_OPENAI_DEPLOYMENT is required for azure_openai provider."
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|||||||
135
backend/app/conversation_tracker.py
Normal file
135
backend/app/conversation_tracker.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Conversation tracker -- Protocol and implementations for tracking conversation state."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
_ENSURE_SQL = """
|
||||||
|
INSERT INTO conversations
|
||||||
|
(thread_id, created_at, last_activity)
|
||||||
|
VALUES
|
||||||
|
(%(thread_id)s, NOW(), NOW())
|
||||||
|
ON CONFLICT (thread_id) DO NOTHING
|
||||||
|
"""
|
||||||
|
|
||||||
|
_RECORD_TURN_SQL = """
|
||||||
|
UPDATE conversations
|
||||||
|
SET
|
||||||
|
turn_count = turn_count + 1,
|
||||||
|
agents_used = CASE
|
||||||
|
WHEN %(agent_name)s IS NOT NULL AND NOT (agents_used @> ARRAY[%(agent_name)s]::text[])
|
||||||
|
THEN agents_used || ARRAY[%(agent_name)s]::text[]
|
||||||
|
ELSE agents_used
|
||||||
|
END,
|
||||||
|
total_tokens = total_tokens + %(tokens)s,
|
||||||
|
total_cost_usd = total_cost_usd + %(cost)s,
|
||||||
|
last_activity = NOW()
|
||||||
|
WHERE thread_id = %(thread_id)s
|
||||||
|
"""
|
||||||
|
|
||||||
|
_RESOLVE_SQL = """
|
||||||
|
UPDATE conversations
|
||||||
|
SET
|
||||||
|
resolution_type = %(resolution_type)s,
|
||||||
|
ended_at = NOW()
|
||||||
|
WHERE thread_id = %(thread_id)s
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ConversationTrackerProtocol(Protocol):
|
||||||
|
"""Protocol for tracking conversation lifecycle and metrics."""
|
||||||
|
|
||||||
|
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
||||||
|
"""Create conversation row if it does not already exist."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def record_turn(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
thread_id: str,
|
||||||
|
agent_name: str | None,
|
||||||
|
tokens: int,
|
||||||
|
cost: float,
|
||||||
|
) -> None:
|
||||||
|
"""Increment turn count and update aggregated metrics."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def resolve(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
thread_id: str,
|
||||||
|
resolution_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""Mark conversation as resolved with a resolution type."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpConversationTracker:
|
||||||
|
"""No-op implementation -- used in tests or when DB is unavailable."""
|
||||||
|
|
||||||
|
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
|
||||||
|
async def record_turn(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
thread_id: str,
|
||||||
|
agent_name: str | None,
|
||||||
|
tokens: int,
|
||||||
|
cost: float,
|
||||||
|
) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
|
||||||
|
async def resolve(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
thread_id: str,
|
||||||
|
resolution_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresConversationTracker:
|
||||||
|
"""Postgres-backed conversation tracker."""
|
||||||
|
|
||||||
|
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
||||||
|
"""Insert conversation row; do nothing if already exists (ON CONFLICT DO NOTHING)."""
|
||||||
|
params = {"thread_id": thread_id}
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
await conn.execute(_ENSURE_SQL, params)
|
||||||
|
|
||||||
|
async def record_turn(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
thread_id: str,
|
||||||
|
agent_name: str | None,
|
||||||
|
tokens: int,
|
||||||
|
cost: float,
|
||||||
|
) -> None:
|
||||||
|
"""Increment turn count, append agent if new, update token/cost totals."""
|
||||||
|
params = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"tokens": tokens,
|
||||||
|
"cost": cost,
|
||||||
|
}
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
await conn.execute(_RECORD_TURN_SQL, params)
|
||||||
|
|
||||||
|
async def resolve(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
thread_id: str,
|
||||||
|
resolution_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""Set resolution_type and ended_at on the conversation row."""
|
||||||
|
params = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"resolution_type": resolution_type,
|
||||||
|
}
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
await conn.execute(_RESOLVE_SQL, params)
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
@@ -34,6 +35,40 @@ CREATE TABLE IF NOT EXISTS active_interrupts (
|
|||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_ANALYTICS_EVENTS_DDL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS analytics_events (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
thread_id TEXT NOT NULL,
|
||||||
|
event_type TEXT NOT NULL,
|
||||||
|
agent_name TEXT,
|
||||||
|
tool_name TEXT,
|
||||||
|
tokens_used INTEGER NOT NULL DEFAULT 0,
|
||||||
|
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||||
|
duration_ms INTEGER,
|
||||||
|
success BOOLEAN,
|
||||||
|
error_message TEXT,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
_SESSIONS_DDL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
thread_id TEXT PRIMARY KEY,
|
||||||
|
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
_CONVERSATIONS_MIGRATION_DDL = """
|
||||||
|
ALTER TABLE conversations
|
||||||
|
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
||||||
|
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
|
||||||
|
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ;
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def create_pool(settings: Settings) -> AsyncConnectionPool:
|
async def create_pool(settings: Settings) -> AsyncConnectionPool:
|
||||||
"""Create an async connection pool with the required psycopg settings."""
|
"""Create an async connection pool with the required psycopg settings."""
|
||||||
@@ -54,8 +89,22 @@ async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
|
|||||||
return checkpointer
|
return checkpointer
|
||||||
|
|
||||||
|
|
||||||
|
def run_alembic_migrations(database_url: str) -> None:
|
||||||
|
"""Run Alembic migrations to head."""
|
||||||
|
from alembic.config import Config
|
||||||
|
|
||||||
|
from alembic import command
|
||||||
|
|
||||||
|
alembic_cfg = Config(str(Path(__file__).parent.parent / "alembic.ini"))
|
||||||
|
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
||||||
|
command.upgrade(alembic_cfg, "head")
|
||||||
|
|
||||||
|
|
||||||
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
||||||
"""Create application-specific tables (conversations, active_interrupts)."""
|
"""Create application-specific tables and apply migrations."""
|
||||||
async with pool.connection() as conn:
|
async with pool.connection() as conn:
|
||||||
await conn.execute(_CONVERSATIONS_DDL)
|
await conn.execute(_CONVERSATIONS_DDL)
|
||||||
await conn.execute(_INTERRUPTS_DDL)
|
await conn.execute(_INTERRUPTS_DDL)
|
||||||
|
await conn.execute(_SESSIONS_DDL)
|
||||||
|
await conn.execute(_ANALYTICS_EVENTS_DDL)
|
||||||
|
await conn.execute(_CONVERSATIONS_MIGRATION_DDL)
|
||||||
|
|||||||
140
backend/app/escalation.py
Normal file
140
backend/app/escalation.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Webhook escalation module -- HTTP POST with exponential backoff retry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import structlog
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class EscalationPayload(BaseModel, frozen=True):
|
||||||
|
"""Immutable payload sent to the escalation webhook."""
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
reason: str
|
||||||
|
conversation_summary: str
|
||||||
|
metadata: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EscalationResult:
|
||||||
|
"""Immutable result of an escalation attempt."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
status_code: int | None
|
||||||
|
attempts: int
|
||||||
|
error: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class EscalationService(Protocol):
|
||||||
|
"""Protocol for escalation implementations."""
|
||||||
|
|
||||||
|
async def escalate(self, payload: EscalationPayload) -> EscalationResult: ...
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookEscalator:
|
||||||
|
"""Sends escalation requests via HTTP POST with exponential backoff retry."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
timeout_seconds: int = 10,
|
||||||
|
max_retries: int = 3,
|
||||||
|
) -> None:
|
||||||
|
self._url = url
|
||||||
|
self._timeout = timeout_seconds
|
||||||
|
self._max_retries = max_retries
|
||||||
|
|
||||||
|
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
|
||||||
|
"""POST the escalation payload to the configured webhook URL."""
|
||||||
|
if not self._url:
|
||||||
|
return EscalationResult(
|
||||||
|
success=False,
|
||||||
|
status_code=None,
|
||||||
|
attempts=0,
|
||||||
|
error="Webhook URL not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
last_error: str | None = None
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||||
|
for attempt in range(1, self._max_retries + 1):
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
self._url,
|
||||||
|
json=payload.model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if 200 <= response.status_code < 300:
|
||||||
|
logger.info(
|
||||||
|
"Escalation succeeded for thread %s (attempt %d)",
|
||||||
|
payload.thread_id,
|
||||||
|
attempt,
|
||||||
|
)
|
||||||
|
return EscalationResult(
|
||||||
|
success=True,
|
||||||
|
status_code=response.status_code,
|
||||||
|
attempts=attempt,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_error = f"HTTP {response.status_code}"
|
||||||
|
logger.warning(
|
||||||
|
"Escalation attempt %d/%d failed: %s",
|
||||||
|
attempt,
|
||||||
|
self._max_retries,
|
||||||
|
last_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
last_error = "Request timed out"
|
||||||
|
logger.warning(
|
||||||
|
"Escalation attempt %d/%d timed out",
|
||||||
|
attempt,
|
||||||
|
self._max_retries,
|
||||||
|
)
|
||||||
|
except httpx.RequestError as exc:
|
||||||
|
last_error = str(exc)
|
||||||
|
logger.warning(
|
||||||
|
"Escalation attempt %d/%d error: %s",
|
||||||
|
attempt,
|
||||||
|
self._max_retries,
|
||||||
|
last_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exponential backoff: skip delay after last attempt
|
||||||
|
if attempt < self._max_retries:
|
||||||
|
delay = 2**attempt
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
"Escalation failed for thread %s after %d attempts: %s",
|
||||||
|
payload.thread_id,
|
||||||
|
self._max_retries,
|
||||||
|
last_error,
|
||||||
|
)
|
||||||
|
return EscalationResult(
|
||||||
|
success=False,
|
||||||
|
status_code=None,
|
||||||
|
attempts=self._max_retries,
|
||||||
|
error=last_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpEscalator:
|
||||||
|
"""Escalator that does nothing -- used when webhook URL is not configured."""
|
||||||
|
|
||||||
|
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
|
||||||
|
logger.info("Escalation disabled (no webhook URL). Thread: %s", payload.thread_id)
|
||||||
|
return EscalationResult(
|
||||||
|
success=False,
|
||||||
|
status_code=None,
|
||||||
|
attempts=0,
|
||||||
|
error="Escalation disabled",
|
||||||
|
)
|
||||||
@@ -4,27 +4,46 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langchain.agents import create_agent
|
||||||
from langgraph_supervisor import create_supervisor
|
from langgraph_supervisor import create_supervisor
|
||||||
|
|
||||||
from app.agents import get_tools_by_names
|
from app.agents import get_tools_by_names
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
|
||||||
|
|
||||||
|
from app.intent import IntentClassifier
|
||||||
from app.registry import AgentRegistry
|
from app.registry import AgentRegistry
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
SUPERVISOR_PROMPT = (
|
SUPERVISOR_PROMPT = (
|
||||||
"You are a customer support supervisor. "
|
"You are a customer support supervisor. "
|
||||||
"Route customer requests to the appropriate agent based on their description. "
|
"Route customer requests to the appropriate agent based on their description.\n\n"
|
||||||
"For order status and tracking queries, use the order_lookup agent. "
|
"Available agents and their roles:\n"
|
||||||
"For order modifications like cancellations, use the order_actions agent. "
|
"{agent_descriptions}\n\n"
|
||||||
"For anything else, use the fallback agent."
|
"Routing rules:\n"
|
||||||
|
"- For order status and tracking queries, use the order_lookup agent.\n"
|
||||||
|
"- For order modifications like cancellations, use the order_actions agent.\n"
|
||||||
|
"- For discounts, promotions, or coupon codes, use the discount agent.\n"
|
||||||
|
"- For anything else or when uncertain, use the fallback agent.\n"
|
||||||
|
"- If the user's request involves multiple actions, execute them in order.\n"
|
||||||
|
"- If a previous intent classification is provided, follow it.\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_agent_descriptions(registry: AgentRegistry) -> str:
|
||||||
|
"""Build agent description text for the supervisor prompt."""
|
||||||
|
lines = []
|
||||||
|
for agent in registry.list_agents():
|
||||||
|
lines.append(f"- {agent.name}: {agent.description}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def build_agent_nodes(
|
def build_agent_nodes(
|
||||||
registry: AgentRegistry,
|
registry: AgentRegistry,
|
||||||
llm: BaseChatModel,
|
llm: BaseChatModel,
|
||||||
@@ -41,11 +60,11 @@ def build_agent_nodes(
|
|||||||
f"Permission level: {agent_config.permission}."
|
f"Permission level: {agent_config.permission}."
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_node = create_react_agent(
|
agent_node = create_agent(
|
||||||
model=llm,
|
model=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
name=agent_config.name,
|
name=agent_config.name,
|
||||||
prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
agent_nodes.append(agent_node)
|
agent_nodes.append(agent_node)
|
||||||
|
|
||||||
@@ -56,15 +75,29 @@ def build_graph(
|
|||||||
registry: AgentRegistry,
|
registry: AgentRegistry,
|
||||||
llm: BaseChatModel,
|
llm: BaseChatModel,
|
||||||
checkpointer: AsyncPostgresSaver,
|
checkpointer: AsyncPostgresSaver,
|
||||||
) -> CompiledStateGraph:
|
intent_classifier: IntentClassifier | None = None,
|
||||||
"""Build and compile the LangGraph supervisor graph."""
|
) -> GraphContext:
|
||||||
|
"""Build and compile the LangGraph supervisor graph.
|
||||||
|
|
||||||
|
Returns a GraphContext that bundles the compiled graph with its
|
||||||
|
associated registry and intent classifier.
|
||||||
|
"""
|
||||||
agent_nodes = build_agent_nodes(registry, llm)
|
agent_nodes = build_agent_nodes(registry, llm)
|
||||||
|
agent_descriptions = _format_agent_descriptions(registry)
|
||||||
|
|
||||||
|
prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions)
|
||||||
|
|
||||||
workflow = create_supervisor(
|
workflow = create_supervisor(
|
||||||
agent_nodes,
|
agents=agent_nodes,
|
||||||
model=llm,
|
model=llm,
|
||||||
prompt=SUPERVISOR_PROMPT,
|
prompt=prompt,
|
||||||
output_mode="full_history",
|
output_mode="full_history",
|
||||||
)
|
)
|
||||||
|
|
||||||
return workflow.compile(checkpointer=checkpointer)
|
compiled = workflow.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
|
return GraphContext(
|
||||||
|
graph=compiled,
|
||||||
|
registry=registry,
|
||||||
|
intent_classifier=intent_classifier,
|
||||||
|
)
|
||||||
|
|||||||
36
backend/app/graph_context.py
Normal file
36
backend/app/graph_context.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""GraphContext -- typed wrapper around the compiled graph and its dependencies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
|
from app.intent import ClassificationResult, IntentClassifier
|
||||||
|
from app.registry import AgentRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GraphContext:
|
||||||
|
"""Bundles the compiled LangGraph graph with its associated services.
|
||||||
|
|
||||||
|
Replaces the previous pattern of monkey-patching attributes onto the
|
||||||
|
third-party CompiledStateGraph instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph: CompiledStateGraph
|
||||||
|
registry: AgentRegistry
|
||||||
|
intent_classifier: IntentClassifier | None = None
|
||||||
|
|
||||||
|
async def classify_intent(self, message: str) -> ClassificationResult | None:
|
||||||
|
"""Classify user intent using the attached classifier.
|
||||||
|
|
||||||
|
Returns None if no classifier is configured.
|
||||||
|
"""
|
||||||
|
if self.intent_classifier is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agents = self.registry.list_agents()
|
||||||
|
return await self.intent_classifier.classify(message, agents)
|
||||||
119
backend/app/intent.py
Normal file
119
backend/app/intent.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""Intent classification using LLM structured output."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
from app.registry import AgentConfig
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
CLASSIFICATION_PROMPT = (
|
||||||
|
"You are an intent classifier for a customer support system.\n"
|
||||||
|
"Given a user message, determine which agent(s) should handle it.\n\n"
|
||||||
|
"Available agents:\n{agent_list}\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"- If the message clearly maps to one agent, return a single intent.\n"
|
||||||
|
"- If the message contains multiple distinct requests, return multiple intents "
|
||||||
|
"in execution order.\n"
|
||||||
|
"- If the message is vague or doesn't match any agent, set is_ambiguous=True "
|
||||||
|
"and provide a clarification_question.\n"
|
||||||
|
"- Never route to the fallback agent unless truly ambiguous.\n"
|
||||||
|
"- confidence should be between 0.0 and 1.0.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
AMBIGUITY_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class IntentTarget(BaseModel, frozen=True):
|
||||||
|
"""A single classified intent targeting a specific agent."""
|
||||||
|
|
||||||
|
agent_name: str
|
||||||
|
confidence: float
|
||||||
|
reasoning: str
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationResult(BaseModel, frozen=True):
|
||||||
|
"""Result of intent classification -- may contain multiple intents."""
|
||||||
|
|
||||||
|
intents: tuple[IntentTarget, ...]
|
||||||
|
is_ambiguous: bool = False
|
||||||
|
clarification_question: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IntentClassifier(Protocol):
|
||||||
|
"""Protocol for intent classification implementations."""
|
||||||
|
|
||||||
|
async def classify(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
available_agents: tuple[AgentConfig, ...],
|
||||||
|
) -> ClassificationResult: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _build_agent_list(agents: tuple[AgentConfig, ...]) -> str:
|
||||||
|
"""Format agent descriptions for the classification prompt."""
|
||||||
|
lines = []
|
||||||
|
for agent in agents:
|
||||||
|
lines.append(f"- {agent.name}: {agent.description} (permission: {agent.permission})")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMIntentClassifier:
|
||||||
|
"""Classifies user intent using LLM structured output."""
|
||||||
|
|
||||||
|
def __init__(self, llm: BaseChatModel) -> None:
|
||||||
|
self._llm = llm
|
||||||
|
|
||||||
|
async def classify(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
available_agents: tuple[AgentConfig, ...],
|
||||||
|
) -> ClassificationResult:
|
||||||
|
"""Classify user message into one or more agent intents."""
|
||||||
|
agent_list = _build_agent_list(available_agents)
|
||||||
|
system_prompt = CLASSIFICATION_PROMPT.format(agent_list=agent_list)
|
||||||
|
|
||||||
|
structured_llm = self._llm.with_structured_output(ClassificationResult)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await structured_llm.ainvoke(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": message},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Intent classification failed, returning ambiguous")
|
||||||
|
return ClassificationResult(
|
||||||
|
intents=(),
|
||||||
|
is_ambiguous=True,
|
||||||
|
clarification_question="I'm not sure I understood. Could you please rephrase?",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(result, ClassificationResult):
|
||||||
|
return ClassificationResult(
|
||||||
|
intents=(),
|
||||||
|
is_ambiguous=True,
|
||||||
|
clarification_question="I'm not sure I understood. Could you please rephrase?",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply ambiguity threshold
|
||||||
|
if result.intents and all(i.confidence < AMBIGUITY_THRESHOLD for i in result.intents):
|
||||||
|
return ClassificationResult(
|
||||||
|
intents=result.intents,
|
||||||
|
is_ambiguous=True,
|
||||||
|
clarification_question=(
|
||||||
|
result.clarification_question
|
||||||
|
or "I'm not sure I understood. Could you please rephrase?"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
268
backend/app/interrupt_manager.py
Normal file
268
backend/app/interrupt_manager.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.
|
||||||
|
|
||||||
|
Provides both in-memory (InterruptManager) and PostgreSQL-backed
|
||||||
|
(PgInterruptManager) implementations behind a common Protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class InterruptRecord:
|
||||||
|
"""Immutable record of a pending interrupt."""
|
||||||
|
|
||||||
|
interrupt_id: str
|
||||||
|
thread_id: str
|
||||||
|
action: str
|
||||||
|
params: dict
|
||||||
|
created_at: float
|
||||||
|
ttl_seconds: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class InterruptStatus:
|
||||||
|
"""Current status of a tracked interrupt."""
|
||||||
|
|
||||||
|
is_expired: bool
|
||||||
|
remaining_seconds: float
|
||||||
|
record: InterruptRecord
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptManagerProtocol(Protocol):
|
||||||
|
"""Protocol for interrupt TTL management."""
|
||||||
|
|
||||||
|
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
|
||||||
|
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
|
||||||
|
def resolve(self, thread_id: str) -> None: ...
|
||||||
|
def has_pending(self, thread_id: str) -> bool: ...
|
||||||
|
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
|
||||||
|
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||||
|
return {
|
||||||
|
"type": "interrupt_expired",
|
||||||
|
"thread_id": expired_record.thread_id,
|
||||||
|
"action": expired_record.action,
|
||||||
|
"message": (
|
||||||
|
f"The approval request for '{expired_record.action}' has expired "
|
||||||
|
f"after {expired_record.ttl_seconds // 60} minutes. "
|
||||||
|
f"Would you like to try again?"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptManager:
|
||||||
|
"""In-memory interrupt manager for single-worker development.
|
||||||
|
|
||||||
|
Complements SessionManager -- this tracks interrupt-specific TTL
|
||||||
|
while SessionManager handles session-level TTL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ttl_seconds: int = 1800) -> None:
|
||||||
|
self._ttl_seconds = ttl_seconds
|
||||||
|
self._interrupts: dict[str, InterruptRecord] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
action: str,
|
||||||
|
params: dict,
|
||||||
|
) -> InterruptRecord:
|
||||||
|
"""Register a new pending interrupt with TTL tracking."""
|
||||||
|
record = InterruptRecord(
|
||||||
|
interrupt_id=uuid.uuid4().hex,
|
||||||
|
thread_id=thread_id,
|
||||||
|
action=action,
|
||||||
|
params=dict(params),
|
||||||
|
created_at=time.time(),
|
||||||
|
ttl_seconds=self._ttl_seconds,
|
||||||
|
)
|
||||||
|
self._interrupts = {**self._interrupts, thread_id: record}
|
||||||
|
return record
|
||||||
|
|
||||||
|
def check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||||
|
"""Check the TTL status of a pending interrupt."""
|
||||||
|
record = self._interrupts.get(thread_id)
|
||||||
|
if record is None:
|
||||||
|
return None
|
||||||
|
elapsed = time.time() - record.created_at
|
||||||
|
remaining = max(0.0, record.ttl_seconds - elapsed)
|
||||||
|
is_expired = elapsed > record.ttl_seconds
|
||||||
|
return InterruptStatus(
|
||||||
|
is_expired=is_expired,
|
||||||
|
remaining_seconds=remaining,
|
||||||
|
record=record,
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve(self, thread_id: str) -> None:
|
||||||
|
"""Remove a resolved interrupt from tracking."""
|
||||||
|
self._interrupts = {
|
||||||
|
k: v for k, v in self._interrupts.items() if k != thread_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def cleanup_expired(self) -> tuple[InterruptRecord, ...]:
|
||||||
|
"""Find and remove all expired interrupts. Returns the expired records."""
|
||||||
|
now = time.time()
|
||||||
|
expired: list[InterruptRecord] = []
|
||||||
|
active: dict[str, InterruptRecord] = {}
|
||||||
|
for thread_id, record in self._interrupts.items():
|
||||||
|
if now - record.created_at > record.ttl_seconds:
|
||||||
|
expired.append(record)
|
||||||
|
else:
|
||||||
|
active[thread_id] = record
|
||||||
|
self._interrupts = active
|
||||||
|
return tuple(expired)
|
||||||
|
|
||||||
|
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||||
|
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||||
|
return _build_retry_prompt(expired_record)
|
||||||
|
|
||||||
|
def has_pending(self, thread_id: str) -> bool:
|
||||||
|
"""Check if a thread has a pending (non-expired) interrupt."""
|
||||||
|
status = self.check_status(thread_id)
|
||||||
|
if status is None:
|
||||||
|
return False
|
||||||
|
return not status.is_expired
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for explicit naming
|
||||||
|
InMemoryInterruptManager = InterruptManager
|
||||||
|
|
||||||
|
|
||||||
|
class PgInterruptManager:
|
||||||
|
"""PostgreSQL-backed interrupt manager for multi-worker production.
|
||||||
|
|
||||||
|
Uses the existing active_interrupts table defined in db.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
ttl_seconds: int = 1800,
|
||||||
|
) -> None:
|
||||||
|
self._pool = pool
|
||||||
|
self._ttl_seconds = ttl_seconds
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
action: str,
|
||||||
|
params: dict,
|
||||||
|
) -> InterruptRecord:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._register(thread_id, action, params)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _register(
|
||||||
|
self, thread_id: str, action: str, params: dict
|
||||||
|
) -> InterruptRecord:
|
||||||
|
import json
|
||||||
|
|
||||||
|
record = InterruptRecord(
|
||||||
|
interrupt_id=uuid.uuid4().hex,
|
||||||
|
thread_id=thread_id,
|
||||||
|
action=action,
|
||||||
|
params=dict(params),
|
||||||
|
created_at=time.time(),
|
||||||
|
ttl_seconds=self._ttl_seconds,
|
||||||
|
)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
|
||||||
|
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
|
||||||
|
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
|
||||||
|
DO UPDATE SET
|
||||||
|
interrupt_id = %(iid)s,
|
||||||
|
action = %(action)s,
|
||||||
|
params = %(params)s,
|
||||||
|
created_at = NOW(),
|
||||||
|
resolved_at = NULL
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"iid": record.interrupt_id,
|
||||||
|
"tid": thread_id,
|
||||||
|
"action": action,
|
||||||
|
"params": json.dumps(params),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return record
|
||||||
|
|
||||||
|
def check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._check_status(thread_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT interrupt_id, action, params, created_at
|
||||||
|
FROM active_interrupts
|
||||||
|
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||||
|
ORDER BY created_at DESC LIMIT 1
|
||||||
|
""",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
created_at = row["created_at"].timestamp()
|
||||||
|
elapsed = time.time() - created_at
|
||||||
|
remaining = max(0.0, self._ttl_seconds - elapsed)
|
||||||
|
is_expired = elapsed > self._ttl_seconds
|
||||||
|
|
||||||
|
record = InterruptRecord(
|
||||||
|
interrupt_id=row["interrupt_id"],
|
||||||
|
thread_id=thread_id,
|
||||||
|
action=row["action"],
|
||||||
|
params=row["params"] if isinstance(row["params"], dict) else {},
|
||||||
|
created_at=created_at,
|
||||||
|
ttl_seconds=self._ttl_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return InterruptStatus(
|
||||||
|
is_expired=is_expired,
|
||||||
|
remaining_seconds=remaining,
|
||||||
|
record=record,
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve(self, thread_id: str) -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
|
||||||
|
|
||||||
|
async def _resolve(self, thread_id: str) -> None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE active_interrupts
|
||||||
|
SET resolved_at = NOW(), resolution = 'resolved'
|
||||||
|
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||||
|
""",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||||
|
return _build_retry_prompt(expired_record)
|
||||||
|
|
||||||
|
def has_pending(self, thread_id: str) -> bool:
|
||||||
|
status = self.check_status(thread_id)
|
||||||
|
if status is None:
|
||||||
|
return False
|
||||||
|
return not status.is_expired
|
||||||
@@ -31,6 +31,16 @@ def create_llm(settings: Settings) -> BaseChatModel:
|
|||||||
api_key=settings.openai_api_key,
|
api_key=settings.openai_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if provider == "azure_openai":
|
||||||
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
return AzureChatOpenAI(
|
||||||
|
azure_deployment=settings.azure_openai_deployment,
|
||||||
|
azure_endpoint=settings.azure_openai_endpoint,
|
||||||
|
api_key=settings.azure_openai_api_key,
|
||||||
|
api_version=settings.azure_openai_api_version,
|
||||||
|
)
|
||||||
|
|
||||||
if provider == "google":
|
if provider == "google":
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
@@ -39,4 +49,7 @@ def create_llm(settings: Settings) -> BaseChatModel:
|
|||||||
google_api_key=settings.google_api_key,
|
google_api_key=settings.google_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.")
|
raise ValueError(
|
||||||
|
f"Unknown LLM provider: '{provider}'. "
|
||||||
|
"Use 'anthropic', 'openai', 'azure_openai', or 'google'."
|
||||||
|
)
|
||||||
|
|||||||
57
backend/app/logging_config.py
Normal file
57
backend/app/logging_config.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Structured logging configuration using structlog."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(log_format: str = "console") -> None:
|
||||||
|
"""Configure structlog with stdlib integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_format: "console" for human-readable dev output,
|
||||||
|
"json" for machine-parseable production output.
|
||||||
|
"""
|
||||||
|
shared_processors: list[structlog.types.Processor] = [
|
||||||
|
structlog.contextvars.merge_contextvars,
|
||||||
|
structlog.stdlib.filter_by_level,
|
||||||
|
structlog.stdlib.add_logger_name,
|
||||||
|
structlog.stdlib.add_log_level,
|
||||||
|
structlog.processors.TimeStamper(fmt="iso"),
|
||||||
|
structlog.processors.StackInfoRenderer(),
|
||||||
|
structlog.processors.format_exc_info,
|
||||||
|
structlog.processors.UnicodeDecoder(),
|
||||||
|
]
|
||||||
|
|
||||||
|
if log_format == "json":
|
||||||
|
renderer: structlog.types.Processor = structlog.processors.JSONRenderer()
|
||||||
|
else:
|
||||||
|
renderer = structlog.dev.ConsoleRenderer()
|
||||||
|
|
||||||
|
structlog.configure(
|
||||||
|
processors=[
|
||||||
|
*shared_processors,
|
||||||
|
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||||
|
],
|
||||||
|
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||||
|
wrapper_class=structlog.stdlib.BoundLogger,
|
||||||
|
cache_logger_on_first_use=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
formatter = structlog.stdlib.ProcessorFormatter(
|
||||||
|
processors=[
|
||||||
|
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||||
|
renderer,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.handlers.clear()
|
||||||
|
root_logger.addHandler(handler)
|
||||||
|
root_logger.setLevel(logging.INFO)
|
||||||
@@ -2,79 +2,211 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import asyncio
|
||||||
|
import contextlib
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from app.analytics.api import router as analytics_router
|
||||||
|
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||||
|
from app.api_utils import envelope
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
from app.db import create_checkpointer, create_pool, setup_app_tables
|
from app.conversation_tracker import PostgresConversationTracker
|
||||||
|
from app.db import create_checkpointer, create_pool, run_alembic_migrations
|
||||||
|
from app.escalation import NoOpEscalator, WebhookEscalator
|
||||||
from app.graph import build_graph
|
from app.graph import build_graph
|
||||||
|
from app.intent import LLMIntentClassifier
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.llm import create_llm
|
from app.llm import create_llm
|
||||||
|
from app.logging_config import configure_logging
|
||||||
|
from app.openapi.review_api import router as openapi_router
|
||||||
from app.registry import AgentRegistry
|
from app.registry import AgentRegistry
|
||||||
|
from app.replay.api import router as replay_router
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
from app.ws_handler import dispatch_message
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import structlog
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
|
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
|
||||||
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
|
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
|
||||||
|
|
||||||
|
|
||||||
|
async def _interrupt_cleanup_loop(
|
||||||
|
interrupt_manager: InterruptManager,
|
||||||
|
interval: int = 60,
|
||||||
|
) -> None:
|
||||||
|
"""Periodically remove expired interrupts in the background.
|
||||||
|
|
||||||
|
Runs until cancelled. Catches all exceptions to prevent the task
|
||||||
|
from dying unexpectedly.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
try:
|
||||||
|
expired = interrupt_manager.cleanup_expired()
|
||||||
|
if expired:
|
||||||
|
logger.info(
|
||||||
|
"Cleaned up %d expired interrupt(s)",
|
||||||
|
len(expired),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during interrupt cleanup")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
configure_logging(settings.log_format)
|
||||||
|
|
||||||
pool = await create_pool(settings)
|
pool = await create_pool(settings)
|
||||||
checkpointer = await create_checkpointer(pool)
|
checkpointer = await create_checkpointer(pool)
|
||||||
await setup_app_tables(pool)
|
run_alembic_migrations(settings.database_url)
|
||||||
|
|
||||||
|
# Load agents from template or default YAML
|
||||||
|
if settings.template_name:
|
||||||
|
registry = AgentRegistry.load_template(settings.template_name)
|
||||||
|
else:
|
||||||
registry = AgentRegistry.load(AGENTS_YAML)
|
registry = AgentRegistry.load(AGENTS_YAML)
|
||||||
|
|
||||||
llm = create_llm(settings)
|
llm = create_llm(settings)
|
||||||
graph = build_graph(registry, llm, checkpointer)
|
intent_classifier = LLMIntentClassifier(llm)
|
||||||
|
graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
||||||
|
|
||||||
session_manager = SessionManager(
|
session_manager = SessionManager(
|
||||||
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
||||||
)
|
)
|
||||||
|
interrupt_manager = InterruptManager(
|
||||||
|
ttl_seconds=settings.interrupt_ttl_minutes * 60,
|
||||||
|
)
|
||||||
|
|
||||||
app.state.graph = graph
|
# Configure escalation
|
||||||
|
if settings.webhook_url:
|
||||||
|
escalator = WebhookEscalator(
|
||||||
|
url=settings.webhook_url,
|
||||||
|
timeout_seconds=settings.webhook_timeout_seconds,
|
||||||
|
max_retries=settings.webhook_max_retries,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
escalator = NoOpEscalator()
|
||||||
|
|
||||||
|
app.state.graph_ctx = graph_ctx
|
||||||
app.state.session_manager = session_manager
|
app.state.session_manager = session_manager
|
||||||
|
app.state.interrupt_manager = interrupt_manager
|
||||||
|
app.state.escalator = escalator
|
||||||
app.state.settings = settings
|
app.state.settings = settings
|
||||||
app.state.pool = pool
|
app.state.pool = pool
|
||||||
|
app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool)
|
||||||
|
app.state.conversation_tracker = PostgresConversationTracker()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Smart Support started: %d agents loaded, LLM=%s/%s",
|
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
|
||||||
len(registry),
|
len(registry),
|
||||||
settings.llm_provider,
|
settings.llm_provider,
|
||||||
settings.llm_model,
|
settings.llm_model,
|
||||||
|
settings.template_name or "(default)",
|
||||||
|
)
|
||||||
|
|
||||||
|
cleanup_task = asyncio.create_task(
|
||||||
|
_interrupt_cleanup_loop(interrupt_manager),
|
||||||
)
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
cleanup_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await cleanup_task
|
||||||
|
|
||||||
await pool.close()
|
await pool.close()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan)
|
_VERSION = "0.6.0"
|
||||||
|
|
||||||
|
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
|
||||||
|
|
||||||
|
app.include_router(openapi_router)
|
||||||
|
app.include_router(replay_router)
|
||||||
|
app.include_router(analytics_router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
"""Wrap HTTPException in standard envelope format."""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
"""Wrap validation errors in standard envelope format."""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
"""Catch-all handler -- never leak stack traces."""
|
||||||
|
logger.exception("Unhandled exception: %s", exc)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=envelope(None, success=False, error="Internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/v1/health")
|
||||||
|
def health_check() -> dict:
|
||||||
|
"""Health check endpoint for load balancers and monitoring."""
|
||||||
|
return {"status": "ok", "version": _VERSION}
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/ws")
|
||||||
async def websocket_endpoint(ws: WebSocket) -> None:
|
async def websocket_endpoint(
|
||||||
await ws.accept()
|
ws: WebSocket,
|
||||||
graph = app.state.graph
|
token: str | None = Query(default=None),
|
||||||
session_manager = app.state.session_manager
|
) -> None:
|
||||||
settings = app.state.settings
|
settings = app.state.settings
|
||||||
|
|
||||||
|
# Verify WebSocket token when admin_api_key is configured
|
||||||
|
if settings.admin_api_key:
|
||||||
|
import secrets as _secrets
|
||||||
|
|
||||||
|
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
|
||||||
|
await ws.close(code=4001, reason="Unauthorized")
|
||||||
|
return
|
||||||
|
|
||||||
|
await ws.accept()
|
||||||
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
||||||
|
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=app.state.graph_ctx,
|
||||||
|
session_manager=app.state.session_manager,
|
||||||
|
callback_handler=callback_handler,
|
||||||
|
interrupt_manager=app.state.interrupt_manager,
|
||||||
|
analytics_recorder=app.state.analytics_recorder,
|
||||||
|
conversation_tracker=app.state.conversation_tracker,
|
||||||
|
pool=app.state.pool,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
raw_data = await ws.receive_text()
|
raw_data = await ws.receive_text()
|
||||||
await dispatch_message(ws, graph, session_manager, callback_handler, raw_data)
|
await dispatch_message(ws, ws_ctx, raw_data)
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.info("WebSocket client disconnected")
|
logger.info("WebSocket client disconnected")
|
||||||
|
|
||||||
|
|||||||
2
backend/app/openapi/__init__.py
Normal file
2
backend/app/openapi/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# OpenAPI auto-discovery module
|
||||||
|
# Parses OpenAPI specs, classifies endpoints via LLM, generates tools
|
||||||
169
backend/app/openapi/classifier.py
Normal file
169
backend/app/openapi/classifier.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""OpenAPI endpoint classifier.
|
||||||
|
|
||||||
|
Classifies endpoints into read/write access types and identifies
|
||||||
|
customer-identifying parameters. Provides a rule-based heuristic
|
||||||
|
classifier and an LLM-backed classifier with heuristic fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
||||||
|
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
||||||
|
|
||||||
|
# Parameter names that identify the customer/order context
|
||||||
|
_CUSTOMER_PARAM_PATTERNS = re.compile(
|
||||||
|
r"(customer_id|order_id|user_id|account_id|client_id|subscriber_id)",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierProtocol(Protocol):
|
||||||
|
"""Protocol for endpoint classifiers."""
|
||||||
|
|
||||||
|
async def classify(
|
||||||
|
self, endpoints: tuple[EndpointInfo, ...]
|
||||||
|
) -> tuple[ClassificationResult, ...]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class HeuristicClassifier:
|
||||||
|
"""Rule-based endpoint classifier.
|
||||||
|
|
||||||
|
GET -> read, no interrupt.
|
||||||
|
POST/PUT/PATCH/DELETE -> write, needs_interrupt=True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def classify(
|
||||||
|
self, endpoints: tuple[EndpointInfo, ...]
|
||||||
|
) -> tuple[ClassificationResult, ...]:
|
||||||
|
"""Classify endpoints using HTTP method heuristics."""
|
||||||
|
if not endpoints:
|
||||||
|
return ()
|
||||||
|
return tuple(_classify_one(ep) for ep in endpoints)
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_one(ep: EndpointInfo) -> ClassificationResult:
|
||||||
|
"""Classify a single endpoint using heuristics."""
|
||||||
|
access_type = "write" if ep.method in _WRITE_METHODS else "read"
|
||||||
|
needs_interrupt = ep.method in _INTERRUPT_METHODS
|
||||||
|
customer_params = _detect_customer_params(ep)
|
||||||
|
agent_group = "write_agent" if access_type == "write" else "read_agent"
|
||||||
|
return ClassificationResult(
|
||||||
|
endpoint=ep,
|
||||||
|
access_type=access_type,
|
||||||
|
customer_params=customer_params,
|
||||||
|
agent_group=agent_group,
|
||||||
|
confidence=0.7,
|
||||||
|
needs_interrupt=needs_interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_customer_params(ep: EndpointInfo) -> tuple[str, ...]:
|
||||||
|
"""Extract parameter names that identify the customer/order context."""
|
||||||
|
return tuple(
|
||||||
|
p.name
|
||||||
|
for p in ep.parameters
|
||||||
|
if _CUSTOMER_PARAM_PATTERNS.fullmatch(p.name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClassifier:
|
||||||
|
"""LLM-backed endpoint classifier with heuristic fallback.
|
||||||
|
|
||||||
|
Uses an LLM to classify endpoints with higher accuracy.
|
||||||
|
Falls back to HeuristicClassifier on any LLM error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm: object) -> None:
|
||||||
|
self._llm = llm
|
||||||
|
self._fallback = HeuristicClassifier()
|
||||||
|
|
||||||
|
async def classify(
|
||||||
|
self, endpoints: tuple[EndpointInfo, ...]
|
||||||
|
) -> tuple[ClassificationResult, ...]:
|
||||||
|
"""Classify endpoints using LLM with heuristic fallback."""
|
||||||
|
if not endpoints:
|
||||||
|
return ()
|
||||||
|
try:
|
||||||
|
return await self._classify_with_llm(endpoints)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"LLM classification failed, falling back to heuristic",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return await self._fallback.classify(endpoints)
|
||||||
|
|
||||||
|
async def _classify_with_llm(
|
||||||
|
self, endpoints: tuple[EndpointInfo, ...]
|
||||||
|
) -> tuple[ClassificationResult, ...]:
|
||||||
|
"""Attempt LLM-based classification."""
|
||||||
|
prompt = _build_classification_prompt(endpoints)
|
||||||
|
response = await self._llm.ainvoke(prompt)
|
||||||
|
parsed = _parse_llm_response(response.content, endpoints)
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _build_classification_prompt(endpoints: tuple[EndpointInfo, ...]) -> str:
|
||||||
|
"""Build a prompt for classifying endpoints."""
|
||||||
|
items = []
|
||||||
|
for i, ep in enumerate(endpoints):
|
||||||
|
items.append(
|
||||||
|
f"{i}: {ep.method} {ep.path} - {ep.summary or ep.operation_id}"
|
||||||
|
)
|
||||||
|
endpoint_list = "\n".join(items)
|
||||||
|
return (
|
||||||
|
"Classify each API endpoint as 'read' or 'write'. "
|
||||||
|
"For each, determine if it needs human interrupt approval, "
|
||||||
|
"identify customer-identifying parameters, and assign an agent_group.\n\n"
|
||||||
|
f"Endpoints:\n{endpoint_list}\n\n"
|
||||||
|
"Respond with a JSON array with one object per endpoint:\n"
|
||||||
|
'[{"access_type": "read|write", "agent_group": "...", '
|
||||||
|
'"needs_interrupt": true|false, "customer_params": [...], "confidence": 0.0-1.0}]'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_llm_response(
|
||||||
|
content: str, endpoints: tuple[EndpointInfo, ...]
|
||||||
|
) -> tuple[ClassificationResult, ...]:
|
||||||
|
"""Parse LLM JSON response into ClassificationResult instances.
|
||||||
|
|
||||||
|
Raises ValueError if the response cannot be parsed or is mismatched.
|
||||||
|
"""
|
||||||
|
# Extract JSON array from response
|
||||||
|
match = re.search(r"\[.*\]", content, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"No JSON array found in LLM response: {content!r}")
|
||||||
|
|
||||||
|
items = json.loads(match.group())
|
||||||
|
if not isinstance(items, list) or len(items) != len(endpoints):
|
||||||
|
raise ValueError(
|
||||||
|
f"LLM returned {len(items)} items for {len(endpoints)} endpoints"
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[ClassificationResult] = []
|
||||||
|
for ep, item in zip(endpoints, items, strict=True):
|
||||||
|
raw_access = item.get("access_type", "read")
|
||||||
|
access_type = raw_access if raw_access in {"read", "write"} else "read"
|
||||||
|
confidence = max(0.0, min(1.0, float(item.get("confidence", 0.8))))
|
||||||
|
raw_group = str(item.get("agent_group", "support"))
|
||||||
|
agent_group = raw_group if raw_group.strip() else "support"
|
||||||
|
results.append(
|
||||||
|
ClassificationResult(
|
||||||
|
endpoint=ep,
|
||||||
|
access_type=access_type,
|
||||||
|
customer_params=tuple(item.get("customer_params", [])),
|
||||||
|
agent_group=agent_group,
|
||||||
|
confidence=confidence,
|
||||||
|
needs_interrupt=bool(item.get("needs_interrupt", False)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tuple(results)
|
||||||
93
backend/app/openapi/fetcher.py
Normal file
93
backend/app/openapi/fetcher.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""OpenAPI spec fetcher with SSRF protection.
|
||||||
|
|
||||||
|
Fetches OpenAPI spec documents from remote URLs, validates them against
|
||||||
|
SSRF policy, and parses JSON or YAML format automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
|
||||||
|
|
||||||
|
_MAX_SIZE_BYTES = 10 * 1024 * 1024 # 10MB
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_spec(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> dict:
|
||||||
|
"""Fetch an OpenAPI spec from a URL and return as a dict.
|
||||||
|
|
||||||
|
Auto-detects JSON or YAML format from content-type header or URL extension.
|
||||||
|
Enforces a 10MB size limit.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SSRFError: If the URL is blocked by SSRF policy.
|
||||||
|
ValueError: If the response is too large or cannot be parsed.
|
||||||
|
"""
|
||||||
|
from app.openapi.ssrf import safe_fetch
|
||||||
|
|
||||||
|
response = await safe_fetch(url, policy=policy)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
content = response.text
|
||||||
|
if len(content.encode("utf-8")) > _MAX_SIZE_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Response too large: {len(content.encode('utf-8'))} bytes "
|
||||||
|
f"(max {_MAX_SIZE_BYTES} bytes)"
|
||||||
|
)
|
||||||
|
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
return _parse_content(content, content_type, url)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_content(content: str, content_type: str, url: str) -> dict:
|
||||||
|
"""Parse content as JSON or YAML based on content-type or URL extension."""
|
||||||
|
if _is_yaml_format(content_type, url):
|
||||||
|
return _parse_yaml(content)
|
||||||
|
if _is_json_format(content_type, url):
|
||||||
|
return _parse_json(content)
|
||||||
|
# Fall back: try JSON first, then YAML
|
||||||
|
try:
|
||||||
|
return _parse_json(content)
|
||||||
|
except ValueError:
|
||||||
|
return _parse_yaml(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_yaml_format(content_type: str, url: str) -> bool:
|
||||||
|
"""Check if the content is YAML format."""
|
||||||
|
yaml_types = {"application/x-yaml", "text/yaml", "application/yaml"}
|
||||||
|
if any(t in content_type for t in yaml_types):
|
||||||
|
return True
|
||||||
|
lower_url = url.lower().split("?")[0]
|
||||||
|
return lower_url.endswith(".yaml") or lower_url.endswith(".yml")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_json_format(content_type: str, url: str) -> bool:
|
||||||
|
"""Check if the content is JSON format."""
|
||||||
|
if "application/json" in content_type:
|
||||||
|
return True
|
||||||
|
lower_url = url.lower().split("?")[0]
|
||||||
|
return lower_url.endswith(".json")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json(content: str) -> dict:
|
||||||
|
"""Parse content as JSON, raising ValueError on failure."""
|
||||||
|
try:
|
||||||
|
result = json.loads(content)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError(f"Invalid JSON: {exc}") from exc
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
raise ValueError(f"Expected a JSON object, got {type(result).__name__}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_yaml(content: str) -> dict:
|
||||||
|
"""Parse content as YAML, raising ValueError on failure."""
|
||||||
|
try:
|
||||||
|
result = yaml.safe_load(content)
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
raise ValueError(f"Invalid YAML: {exc}") from exc
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
raise ValueError(f"Expected a YAML mapping, got {type(result).__name__}")
|
||||||
|
return result
|
||||||
164
backend/app/openapi/generator.py
Normal file
164
backend/app/openapi/generator.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Tool code generator for classified OpenAPI endpoints.
|
||||||
|
|
||||||
|
Generates Python source code for LangChain @tool functions and
|
||||||
|
YAML agent configurations from classification results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import keyword
|
||||||
|
import re
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.openapi.models import ClassificationResult, EndpointInfo, GeneratedTool, ParameterInfo
|
||||||
|
|
||||||
|
_INDENT = " "
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tool_code(classification: ClassificationResult, base_url: str) -> GeneratedTool:
|
||||||
|
"""Generate a LangChain @tool function for a classified endpoint.
|
||||||
|
|
||||||
|
Returns a GeneratedTool with the function source code as a string.
|
||||||
|
"""
|
||||||
|
ep = classification.endpoint
|
||||||
|
func_name = _to_snake_case(ep.operation_id)
|
||||||
|
params = _collect_params(ep)
|
||||||
|
sig = _build_signature(params, ep.request_body_schema)
|
||||||
|
docstring = _sanitize_docstring(ep.summary or ep.description or ep.operation_id)
|
||||||
|
interrupt_comment = _interrupt_comment(classification)
|
||||||
|
http_call = _build_http_call(ep, base_url, params)
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"@tool",
|
||||||
|
f"async def {func_name}({sig}) -> str:",
|
||||||
|
f'{_INDENT}"""{docstring}"""',
|
||||||
|
]
|
||||||
|
if interrupt_comment:
|
||||||
|
lines.append(f"{_INDENT}{interrupt_comment}")
|
||||||
|
lines += [
|
||||||
|
f"{_INDENT}async with httpx.AsyncClient() as client:",
|
||||||
|
f"{_INDENT}{_INDENT}{http_call}",
|
||||||
|
f"{_INDENT}{_INDENT}return response.text",
|
||||||
|
]
|
||||||
|
|
||||||
|
code = "\n".join(lines)
|
||||||
|
return GeneratedTool(
|
||||||
|
function_name=func_name,
|
||||||
|
endpoint=ep,
|
||||||
|
classification=classification,
|
||||||
|
code=code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_agent_yaml(
|
||||||
|
classifications: tuple[ClassificationResult, ...],
|
||||||
|
base_url: str,
|
||||||
|
) -> str:
|
||||||
|
"""Generate an agents.yaml string from a set of classification results.
|
||||||
|
|
||||||
|
Groups tools by agent_group, creating one agent entry per group.
|
||||||
|
"""
|
||||||
|
if not classifications:
|
||||||
|
return yaml.safe_dump({"agents": []})
|
||||||
|
|
||||||
|
groups: dict[str, dict] = {}
|
||||||
|
for clf in classifications:
|
||||||
|
group = clf.agent_group
|
||||||
|
func_name = _to_snake_case(clf.endpoint.operation_id)
|
||||||
|
if group not in groups:
|
||||||
|
permission = "read" if clf.access_type == "read" else "write"
|
||||||
|
groups[group] = {
|
||||||
|
"name": group,
|
||||||
|
"description": f"Agent for {group} operations",
|
||||||
|
"permission": permission,
|
||||||
|
"tools": [],
|
||||||
|
}
|
||||||
|
groups[group]["tools"].append(func_name)
|
||||||
|
|
||||||
|
return yaml.safe_dump({"agents": list(groups.values())}, sort_keys=False)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Private helpers ---
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_params(ep: EndpointInfo) -> list[ParameterInfo]:
|
||||||
|
"""Return path params first, then query params."""
|
||||||
|
path_params = [p for p in ep.parameters if p.location == "path"]
|
||||||
|
other_params = [p for p in ep.parameters if p.location != "path"]
|
||||||
|
return path_params + other_params
|
||||||
|
|
||||||
|
|
||||||
|
def _build_signature(params: list[ParameterInfo], body_schema: dict | None) -> str:
|
||||||
|
"""Build a Python function signature string from parameters."""
|
||||||
|
parts: list[str] = []
|
||||||
|
for p in params:
|
||||||
|
py_type = _schema_type_to_python(p.schema_type)
|
||||||
|
if p.required:
|
||||||
|
parts.append(f"{p.name}: {py_type}")
|
||||||
|
else:
|
||||||
|
parts.append(f"{p.name}: {py_type} | None = None")
|
||||||
|
if body_schema:
|
||||||
|
parts.append("body: dict | None = None")
|
||||||
|
return ", ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_http_call(
|
||||||
|
ep: EndpointInfo, base_url: str, params: list[ParameterInfo]
|
||||||
|
) -> str:
|
||||||
|
"""Build the httpx client call line."""
|
||||||
|
method = ep.method.lower()
|
||||||
|
path = ep.path
|
||||||
|
|
||||||
|
# Replace path parameters with f-string expressions
|
||||||
|
for p in params:
|
||||||
|
if p.location == "path":
|
||||||
|
path = path.replace(f"{{{p.name}}}", f"{{{p.name}}}")
|
||||||
|
|
||||||
|
url_expr = f'f"{base_url}{path}"'
|
||||||
|
|
||||||
|
query_params = [p for p in params if p.location == "query"]
|
||||||
|
extra_args = []
|
||||||
|
if query_params:
|
||||||
|
qp_dict = "{" + ", ".join(f'"{p.name}": {p.name}' for p in query_params) + "}"
|
||||||
|
extra_args.append(f"params={qp_dict}")
|
||||||
|
|
||||||
|
if ep.request_body_schema and method in ("post", "put", "patch"):
|
||||||
|
extra_args.append("json=body")
|
||||||
|
|
||||||
|
args_str = ", ".join([url_expr] + extra_args)
|
||||||
|
return f"response = await client.{method}({args_str})"
|
||||||
|
|
||||||
|
|
||||||
|
def _interrupt_comment(classification: ClassificationResult) -> str:
|
||||||
|
"""Return a comment line if the endpoint requires interrupt/approval."""
|
||||||
|
if classification.needs_interrupt:
|
||||||
|
return "# INTERRUPT: requires human approval before execution"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_type_to_python(schema_type: str) -> str:
|
||||||
|
"""Map OpenAPI schema type to Python type annotation."""
|
||||||
|
mapping = {
|
||||||
|
"string": "str",
|
||||||
|
"integer": "int",
|
||||||
|
"number": "float",
|
||||||
|
"boolean": "bool",
|
||||||
|
"array": "list",
|
||||||
|
"object": "dict",
|
||||||
|
}
|
||||||
|
return mapping.get(schema_type, "str")
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_docstring(text: str) -> str:
|
||||||
|
"""Escape triple-quotes and newlines to prevent docstring injection."""
|
||||||
|
return text.replace("\\", "\\\\").replace('"""', r"\"\"\"").replace("\n", " ")
|
||||||
|
|
||||||
|
|
||||||
|
def _to_snake_case(name: str) -> str:
|
||||||
|
"""Convert operationId to a valid snake_case Python identifier."""
|
||||||
|
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
|
||||||
|
result = clean.lower() or "unnamed_tool"
|
||||||
|
if keyword.iskeyword(result):
|
||||||
|
result = f"{result}_tool"
|
||||||
|
return result
|
||||||
111
backend/app/openapi/importer.py
Normal file
111
backend/app/openapi/importer.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Import orchestrator for OpenAPI auto-discovery pipeline.
|
||||||
|
|
||||||
|
Orchestrates: fetch -> validate -> parse -> classify
|
||||||
|
Each stage updates the job status and calls the on_progress callback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
|
||||||
|
from app.openapi.fetcher import fetch_spec
|
||||||
|
from app.openapi.models import ImportJob
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
|
||||||
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
ProgressCallback = Callable[[str, ImportJob], None] | None
|
||||||
|
|
||||||
|
|
||||||
|
class ImportOrchestrator:
|
||||||
|
"""Orchestrates the full OpenAPI import pipeline.
|
||||||
|
|
||||||
|
Stages:
|
||||||
|
1. fetching -- download and parse spec from URL
|
||||||
|
2. validating -- check spec structure
|
||||||
|
3. parsing -- extract endpoint definitions
|
||||||
|
4. classifying -- classify endpoints for agent routing
|
||||||
|
5. done / failed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: ClassifierProtocol | None = None,
|
||||||
|
policy: SSRFPolicy = DEFAULT_POLICY,
|
||||||
|
) -> None:
|
||||||
|
self._classifier = classifier or HeuristicClassifier()
|
||||||
|
self._policy = policy
|
||||||
|
|
||||||
|
async def start_import(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
job_id: str,
|
||||||
|
on_progress: ProgressCallback,
|
||||||
|
) -> ImportJob:
|
||||||
|
"""Run the full import pipeline for a spec URL.
|
||||||
|
|
||||||
|
Returns an ImportJob reflecting final status (done or failed).
|
||||||
|
on_progress is called with (stage_name, current_job) at each stage.
|
||||||
|
Passing None for on_progress is safe.
|
||||||
|
"""
|
||||||
|
job = ImportJob(
|
||||||
|
job_id=job_id,
|
||||||
|
status="pending",
|
||||||
|
spec_url=url,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Stage 1: fetch
|
||||||
|
job = _update(job, status="fetching")
|
||||||
|
_notify(on_progress, "fetching", job)
|
||||||
|
spec_dict = await fetch_spec(url, self._policy)
|
||||||
|
|
||||||
|
# Stage 2: validate
|
||||||
|
job = _update(job, status="validating")
|
||||||
|
_notify(on_progress, "validating", job)
|
||||||
|
errors = validate_spec(spec_dict)
|
||||||
|
if errors:
|
||||||
|
raise ValueError(f"Invalid spec: {'; '.join(errors)}")
|
||||||
|
|
||||||
|
# Stage 3: parse
|
||||||
|
job = _update(job, status="parsing")
|
||||||
|
_notify(on_progress, "parsing", job)
|
||||||
|
endpoints = parse_endpoints(spec_dict)
|
||||||
|
|
||||||
|
# Stage 4: classify
|
||||||
|
job = _update(job, status="classifying", total_endpoints=len(endpoints))
|
||||||
|
_notify(on_progress, "classifying", job)
|
||||||
|
classifications = await self._classifier.classify(endpoints)
|
||||||
|
|
||||||
|
# Done
|
||||||
|
job = _update(
|
||||||
|
job,
|
||||||
|
status="done",
|
||||||
|
total_endpoints=len(endpoints),
|
||||||
|
classified_count=len(classifications),
|
||||||
|
)
|
||||||
|
_notify(on_progress, "done", job)
|
||||||
|
return job
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Import pipeline failed for job %s", job_id)
|
||||||
|
job = _update(job, status="failed", error_message=str(exc))
|
||||||
|
_notify(on_progress, "failed", job)
|
||||||
|
return job
|
||||||
|
|
||||||
|
|
||||||
|
def _update(job: ImportJob, **kwargs: object) -> ImportJob:
|
||||||
|
"""Return a new ImportJob with updated fields (immutable update)."""
|
||||||
|
return replace(job, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None:
|
||||||
|
"""Call the progress callback if provided."""
|
||||||
|
if callback is not None:
|
||||||
|
callback(stage, job)
|
||||||
67
backend/app/openapi/models.py
Normal file
67
backend/app/openapi/models.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Data models for OpenAPI auto-discovery module.
|
||||||
|
|
||||||
|
Frozen dataclasses for all value objects to ensure immutability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ParameterInfo:
|
||||||
|
"""Describes a single endpoint parameter."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
location: str # "path", "query", "header", "cookie"
|
||||||
|
required: bool
|
||||||
|
schema_type: str # "string", "integer", "boolean", etc.
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EndpointInfo:
|
||||||
|
"""Describes a single API endpoint."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
method: str # uppercase: GET, POST, PUT, DELETE, PATCH
|
||||||
|
operation_id: str
|
||||||
|
summary: str
|
||||||
|
description: str
|
||||||
|
parameters: tuple[ParameterInfo, ...] = field(default_factory=tuple)
|
||||||
|
request_body_schema: dict | None = None
|
||||||
|
response_schema: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ClassificationResult:
|
||||||
|
"""Result of classifying an endpoint for agent routing."""
|
||||||
|
|
||||||
|
endpoint: EndpointInfo
|
||||||
|
access_type: str # "read" or "write"
|
||||||
|
customer_params: tuple[str, ...] # param names that identify the customer
|
||||||
|
agent_group: str # which agent group handles this endpoint
|
||||||
|
confidence: float # 0.0 to 1.0
|
||||||
|
needs_interrupt: bool # requires human approval before execution
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ImportJob:
|
||||||
|
"""Tracks the state of an OpenAPI import job."""
|
||||||
|
|
||||||
|
job_id: str
|
||||||
|
status: str # "pending", "fetching", "validating", "parsing", "classifying", "done", "failed"
|
||||||
|
spec_url: str
|
||||||
|
total_endpoints: int = 0
|
||||||
|
classified_count: int = 0
|
||||||
|
error_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GeneratedTool:
|
||||||
|
"""A generated LangChain tool from a classified endpoint."""
|
||||||
|
|
||||||
|
function_name: str
|
||||||
|
endpoint: EndpointInfo
|
||||||
|
classification: ClassificationResult
|
||||||
|
code: str
|
||||||
152
backend/app/openapi/parser.py
Normal file
152
backend/app/openapi/parser.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""OpenAPI spec endpoint parser.
|
||||||
|
|
||||||
|
Extracts all endpoint definitions from a parsed OpenAPI spec dict,
|
||||||
|
resolving $ref references from components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from app.openapi.models import EndpointInfo, ParameterInfo
|
||||||
|
|
||||||
|
_HTTP_METHODS = ("get", "post", "put", "patch", "delete", "head", "options")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_endpoints(spec_dict: dict) -> tuple[EndpointInfo, ...]:
|
||||||
|
"""Parse all endpoints from a validated OpenAPI spec dict.
|
||||||
|
|
||||||
|
Returns an immutable tuple of EndpointInfo instances.
|
||||||
|
"""
|
||||||
|
paths = spec_dict.get("paths", {})
|
||||||
|
if not isinstance(paths, dict) or not paths:
|
||||||
|
return ()
|
||||||
|
|
||||||
|
endpoints: list[EndpointInfo] = []
|
||||||
|
for path, path_item in paths.items():
|
||||||
|
if not isinstance(path_item, dict):
|
||||||
|
continue
|
||||||
|
for method in _HTTP_METHODS:
|
||||||
|
operation = path_item.get(method)
|
||||||
|
if operation is None:
|
||||||
|
continue
|
||||||
|
endpoint = _parse_operation(path, method.upper(), operation, spec_dict)
|
||||||
|
endpoints.append(endpoint)
|
||||||
|
|
||||||
|
return tuple(endpoints)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_operation(
|
||||||
|
path: str,
|
||||||
|
method: str,
|
||||||
|
operation: dict,
|
||||||
|
spec_dict: dict,
|
||||||
|
) -> EndpointInfo:
|
||||||
|
"""Parse a single operation dict into an EndpointInfo."""
|
||||||
|
operation_id = operation.get("operationId") or _generate_operation_id(path, method)
|
||||||
|
summary = operation.get("summary", "")
|
||||||
|
description = operation.get("description", "")
|
||||||
|
|
||||||
|
parameters = _parse_parameters(operation.get("parameters", []), spec_dict)
|
||||||
|
request_body_schema = _parse_request_body(operation.get("requestBody"), spec_dict)
|
||||||
|
response_schema = _parse_response_schema(operation.get("responses", {}), spec_dict)
|
||||||
|
|
||||||
|
return EndpointInfo(
|
||||||
|
path=path,
|
||||||
|
method=method,
|
||||||
|
operation_id=operation_id,
|
||||||
|
summary=summary,
|
||||||
|
description=description,
|
||||||
|
parameters=tuple(parameters),
|
||||||
|
request_body_schema=request_body_schema,
|
||||||
|
response_schema=response_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_parameters(
|
||||||
|
params_list: list,
|
||||||
|
spec_dict: dict,
|
||||||
|
) -> list[ParameterInfo]:
|
||||||
|
"""Parse list of parameter dicts into ParameterInfo instances."""
|
||||||
|
result: list[ParameterInfo] = []
|
||||||
|
for param in params_list:
|
||||||
|
if not isinstance(param, dict):
|
||||||
|
continue
|
||||||
|
schema = param.get("schema", {})
|
||||||
|
schema_type = schema.get("type", "string") if isinstance(schema, dict) else "string"
|
||||||
|
result.append(
|
||||||
|
ParameterInfo(
|
||||||
|
name=param.get("name", ""),
|
||||||
|
location=param.get("in", "query"),
|
||||||
|
required=bool(param.get("required", False)),
|
||||||
|
schema_type=schema_type,
|
||||||
|
description=param.get("description", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_request_body(request_body: dict | None, spec_dict: dict) -> dict | None:
|
||||||
|
"""Extract schema from requestBody, resolving $ref if present."""
|
||||||
|
if not isinstance(request_body, dict):
|
||||||
|
return None
|
||||||
|
content = request_body.get("content", {})
|
||||||
|
if not isinstance(content, dict):
|
||||||
|
return None
|
||||||
|
# Prefer application/json
|
||||||
|
for media_type in ("application/json", *content.keys()):
|
||||||
|
media = content.get(media_type)
|
||||||
|
if not isinstance(media, dict):
|
||||||
|
continue
|
||||||
|
schema = media.get("schema")
|
||||||
|
if schema:
|
||||||
|
return _resolve_ref(schema, spec_dict)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_response_schema(responses: dict, spec_dict: dict) -> dict | None:
|
||||||
|
"""Extract schema from the first 2xx response."""
|
||||||
|
if not isinstance(responses, dict):
|
||||||
|
return None
|
||||||
|
for status_code in ("200", "201", "202", "204"):
|
||||||
|
response = responses.get(status_code)
|
||||||
|
if not isinstance(response, dict):
|
||||||
|
continue
|
||||||
|
content = response.get("content", {})
|
||||||
|
if not isinstance(content, dict):
|
||||||
|
continue
|
||||||
|
for media_type in ("application/json", *content.keys()):
|
||||||
|
media = content.get(media_type)
|
||||||
|
if not isinstance(media, dict):
|
||||||
|
continue
|
||||||
|
schema = media.get("schema")
|
||||||
|
if schema:
|
||||||
|
return _resolve_ref(schema, spec_dict)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_ref(schema: object, spec_dict: dict) -> dict:
|
||||||
|
"""Resolve a $ref to its target schema, or return the schema as-is."""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return {}
|
||||||
|
ref = schema.get("$ref")
|
||||||
|
if not ref:
|
||||||
|
return schema
|
||||||
|
# Only handle local refs like #/components/schemas/Foo
|
||||||
|
if not isinstance(ref, str) or not ref.startswith("#/"):
|
||||||
|
return schema
|
||||||
|
parts = ref.lstrip("#/").split("/")
|
||||||
|
target: object = spec_dict
|
||||||
|
for part in parts:
|
||||||
|
if not isinstance(target, dict):
|
||||||
|
return schema
|
||||||
|
target = target.get(part)
|
||||||
|
return target if isinstance(target, dict) else schema
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_operation_id(path: str, method: str) -> str:
|
||||||
|
"""Generate a snake_case operation_id from path and method."""
|
||||||
|
# Remove path parameters braces and replace / with _
|
||||||
|
clean = re.sub(r"\{[^}]+\}", "by_param", path)
|
||||||
|
clean = re.sub(r"[^a-zA-Z0-9]+", "_", clean).strip("_")
|
||||||
|
return f"{method.lower()}_{clean}" if clean else method.lower()
|
||||||
282
backend/app/openapi/review_api.py
Normal file
282
backend/app/openapi/review_api.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
"""FastAPI router for OpenAPI import review workflow.
|
||||||
|
|
||||||
|
Exposes endpoints for:
|
||||||
|
- Starting an import job (triggers background pipeline)
|
||||||
|
- Querying job status
|
||||||
|
- Reviewing and editing classifications
|
||||||
|
- Approving a job to trigger tool generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
from app.auth import require_admin_api_key
|
||||||
|
from app.openapi.generator import generate_agent_yaml, generate_tool_code
|
||||||
|
from app.openapi.importer import ImportOrchestrator
|
||||||
|
from app.openapi.models import ClassificationResult, ImportJob
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/api/v1/openapi",
|
||||||
|
tags=["openapi"],
|
||||||
|
dependencies=[Depends(require_admin_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# In-memory store: job_id -> job dict, guarded by async lock
|
||||||
|
_job_store: dict[str, dict] = {}
|
||||||
|
_store_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Shared orchestrator instance
|
||||||
|
_orchestrator = ImportOrchestrator()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Request / Response schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class ImportRequest(BaseModel):
|
||||||
|
url: str
|
||||||
|
|
||||||
|
@field_validator("url")
|
||||||
|
@classmethod
|
||||||
|
def url_must_be_valid(cls, value: str) -> str:
|
||||||
|
stripped = value.strip()
|
||||||
|
if not stripped:
|
||||||
|
raise ValueError("url must not be empty")
|
||||||
|
if not stripped.startswith(("http://", "https://")):
|
||||||
|
raise ValueError("url must start with http:// or https://")
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
|
class JobResponse(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
status: str
|
||||||
|
spec_url: str
|
||||||
|
total_endpoints: int = 0
|
||||||
|
classified_count: int = 0
|
||||||
|
error_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationResponse(BaseModel):
|
||||||
|
index: int
|
||||||
|
access_type: str
|
||||||
|
needs_interrupt: bool
|
||||||
|
agent_group: str
|
||||||
|
confidence: float
|
||||||
|
customer_params: list[str]
|
||||||
|
endpoint: dict
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateClassificationRequest(BaseModel):
|
||||||
|
access_type: Literal["read", "write"]
|
||||||
|
needs_interrupt: bool
|
||||||
|
agent_group: str
|
||||||
|
|
||||||
|
@field_validator("agent_group")
|
||||||
|
@classmethod
|
||||||
|
def agent_group_must_be_safe(cls, value: str) -> str:
|
||||||
|
if not value.strip() or not re.fullmatch(r"[a-zA-Z0-9_\-]+", value):
|
||||||
|
raise ValueError(
|
||||||
|
"agent_group must be non-empty and contain only "
|
||||||
|
"alphanumeric characters, underscores, or hyphens"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helpers ---
|
||||||
|
|
||||||
|
|
||||||
|
def _job_to_response(job: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"job_id": job["job_id"],
|
||||||
|
"status": job["status"],
|
||||||
|
"spec_url": job["spec_url"],
|
||||||
|
"total_endpoints": job.get("total_endpoints", 0),
|
||||||
|
"classified_count": job.get("classified_count", 0),
|
||||||
|
"error_message": job.get("error_message"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _classification_to_response(idx: int, clf: ClassificationResult) -> dict:
|
||||||
|
ep = clf.endpoint
|
||||||
|
return {
|
||||||
|
"index": idx,
|
||||||
|
"access_type": clf.access_type,
|
||||||
|
"needs_interrupt": clf.needs_interrupt,
|
||||||
|
"agent_group": clf.agent_group,
|
||||||
|
"confidence": clf.confidence,
|
||||||
|
"customer_params": list(clf.customer_params),
|
||||||
|
"endpoint": {
|
||||||
|
"path": ep.path,
|
||||||
|
"method": ep.method,
|
||||||
|
"operation_id": ep.operation_id,
|
||||||
|
"summary": ep.summary,
|
||||||
|
"description": ep.description,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_import(job_id: str, url: str) -> None:
|
||||||
|
"""Run the import pipeline as a background task."""
|
||||||
|
|
||||||
|
def on_progress(stage: str, result_job: ImportJob) -> None:
|
||||||
|
if job_id in _job_store:
|
||||||
|
_job_store[job_id] = {
|
||||||
|
**_job_store[job_id],
|
||||||
|
"status": result_job.status,
|
||||||
|
"total_endpoints": result_job.total_endpoints,
|
||||||
|
"classified_count": result_job.classified_count,
|
||||||
|
"error_message": result_job.error_message,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await _orchestrator.start_import(
|
||||||
|
url=url, job_id=job_id, on_progress=on_progress,
|
||||||
|
)
|
||||||
|
if job_id in _job_store:
|
||||||
|
_job_store[job_id] = {
|
||||||
|
**_job_store[job_id],
|
||||||
|
"status": result.status,
|
||||||
|
"total_endpoints": result.total_endpoints,
|
||||||
|
"classified_count": result.classified_count,
|
||||||
|
"error_message": result.error_message,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Background import failed for job %s", job_id)
|
||||||
|
if job_id in _job_store:
|
||||||
|
_job_store[job_id] = {
|
||||||
|
**_job_store[job_id],
|
||||||
|
"status": "failed",
|
||||||
|
"error_message": "Import failed. Please check the URL and try again.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Endpoints ---
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import", status_code=202)
|
||||||
|
async def start_import(
|
||||||
|
request: ImportRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
) -> dict:
|
||||||
|
"""Start an OpenAPI import job for the given spec URL."""
|
||||||
|
job_id = str(uuid.uuid4())
|
||||||
|
job: dict = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": "pending",
|
||||||
|
"spec_url": request.url,
|
||||||
|
"total_endpoints": 0,
|
||||||
|
"classified_count": 0,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [],
|
||||||
|
}
|
||||||
|
_job_store[job_id] = job
|
||||||
|
background_tasks.add_task(_run_import, job_id, request.url)
|
||||||
|
return _job_to_response(job)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/jobs/{job_id}")
|
||||||
|
async def get_job(job_id: str) -> dict:
|
||||||
|
"""Get the status of an import job."""
|
||||||
|
job = _job_store.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||||
|
return _job_to_response(job)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/jobs/{job_id}/classifications")
|
||||||
|
async def get_classifications(job_id: str) -> list:
|
||||||
|
"""Get all classifications for an import job."""
|
||||||
|
job = _job_store.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||||
|
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||||
|
return [
|
||||||
|
_classification_to_response(i, clf)
|
||||||
|
for i, clf in enumerate(classifications)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/jobs/{job_id}/classifications/{idx}")
|
||||||
|
async def update_classification(
|
||||||
|
job_id: str,
|
||||||
|
idx: int,
|
||||||
|
request: UpdateClassificationRequest,
|
||||||
|
) -> dict:
|
||||||
|
"""Update a specific classification by index."""
|
||||||
|
job = _job_store.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||||
|
|
||||||
|
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||||
|
if idx < 0 or idx >= len(classifications):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Classification index {idx} out of range",
|
||||||
|
)
|
||||||
|
|
||||||
|
original = classifications[idx]
|
||||||
|
updated = ClassificationResult(
|
||||||
|
endpoint=original.endpoint,
|
||||||
|
access_type=request.access_type,
|
||||||
|
customer_params=original.customer_params,
|
||||||
|
agent_group=request.agent_group,
|
||||||
|
confidence=original.confidence,
|
||||||
|
needs_interrupt=request.needs_interrupt,
|
||||||
|
)
|
||||||
|
new_classifications = list(classifications)
|
||||||
|
new_classifications[idx] = updated
|
||||||
|
_job_store[job_id] = {**job, "classifications": new_classifications}
|
||||||
|
|
||||||
|
return _classification_to_response(idx, updated)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/jobs/{job_id}/approve")
|
||||||
|
async def approve_job(job_id: str) -> dict:
|
||||||
|
"""Approve a job's classifications and trigger tool generation.
|
||||||
|
|
||||||
|
Generates Python tool code for each classified endpoint and
|
||||||
|
produces an agent YAML configuration snippet.
|
||||||
|
"""
|
||||||
|
job = _job_store.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||||
|
|
||||||
|
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||||
|
if not classifications:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No classifications to approve. Import must complete first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
base_url = job["spec_url"].rsplit("/", 1)[0]
|
||||||
|
generated_tools = []
|
||||||
|
for clf in classifications:
|
||||||
|
tool = generate_tool_code(clf, base_url)
|
||||||
|
generated_tools.append({
|
||||||
|
"function_name": tool.function_name,
|
||||||
|
"agent_group": clf.agent_group,
|
||||||
|
"code": tool.code,
|
||||||
|
})
|
||||||
|
|
||||||
|
agent_yaml = generate_agent_yaml(tuple(classifications), base_url)
|
||||||
|
|
||||||
|
updated_job = {
|
||||||
|
**job,
|
||||||
|
"status": "approved",
|
||||||
|
"generated_tools": generated_tools,
|
||||||
|
"agent_yaml": agent_yaml,
|
||||||
|
}
|
||||||
|
_job_store[job_id] = updated_job
|
||||||
|
|
||||||
|
response = _job_to_response(updated_job)
|
||||||
|
response["generated_tools_count"] = len(generated_tools)
|
||||||
|
return response
|
||||||
167
backend/app/openapi/ssrf.py
Normal file
167
backend/app/openapi/ssrf.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""SSRF protection module.
|
||||||
|
|
||||||
|
Validates URLs before making external HTTP requests.
|
||||||
|
Blocks private IPs, loopback addresses, link-local, and DNS rebinding attacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class SSRFError(Exception):
|
||||||
|
"""Raised when a URL fails SSRF validation."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SSRFPolicy:
|
||||||
|
"""Configuration for SSRF protection."""
|
||||||
|
|
||||||
|
allowed_schemes: frozenset[str] = frozenset({"http", "https"})
|
||||||
|
allowed_hosts: frozenset[str] | None = None # None = all public hosts allowed
|
||||||
|
max_redirects: int = 5
|
||||||
|
timeout_seconds: float = 30.0
|
||||||
|
|
||||||
|
|
||||||
|
_BLOCKED_NETWORKS = (
|
||||||
|
ipaddress.ip_network("10.0.0.0/8"),
|
||||||
|
ipaddress.ip_network("172.16.0.0/12"),
|
||||||
|
ipaddress.ip_network("192.168.0.0/16"),
|
||||||
|
ipaddress.ip_network("127.0.0.0/8"),
|
||||||
|
ipaddress.ip_network("169.254.0.0/16"),
|
||||||
|
ipaddress.ip_network("0.0.0.0/32"),
|
||||||
|
ipaddress.ip_network("100.64.0.0/10"), # Carrier-Grade NAT
|
||||||
|
ipaddress.ip_network("198.18.0.0/15"), # Benchmarking
|
||||||
|
ipaddress.ip_network("240.0.0.0/4"), # Reserved
|
||||||
|
ipaddress.ip_network("255.255.255.255/32"), # Broadcast
|
||||||
|
# IPv6
|
||||||
|
ipaddress.ip_network("::1/128"),
|
||||||
|
ipaddress.ip_network("fe80::/10"),
|
||||||
|
ipaddress.ip_network("fc00::/7"),
|
||||||
|
ipaddress.ip_network("::/128"),
|
||||||
|
ipaddress.ip_network("::ffff:0:0/96"), # IPv4-mapped IPv6
|
||||||
|
ipaddress.ip_network("2001:db8::/32"), # Documentation
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_POLICY = SSRFPolicy()
|
||||||
|
|
||||||
|
|
||||||
|
def is_private_ip(ip_str: str) -> bool:
|
||||||
|
"""Check if an IP address is private/reserved."""
|
||||||
|
try:
|
||||||
|
addr = ipaddress.ip_address(ip_str)
|
||||||
|
except ValueError:
|
||||||
|
return True # Invalid IP treated as blocked
|
||||||
|
|
||||||
|
return any(addr in network for network in _BLOCKED_NETWORKS)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_url(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> str:
|
||||||
|
"""Validate a URL against SSRF policy.
|
||||||
|
|
||||||
|
Returns the validated URL string.
|
||||||
|
Raises SSRFError if the URL is blocked.
|
||||||
|
"""
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
# Check scheme
|
||||||
|
if parsed.scheme not in policy.allowed_schemes:
|
||||||
|
raise SSRFError(
|
||||||
|
f"URL scheme '{parsed.scheme}' is not allowed. "
|
||||||
|
f"Allowed: {', '.join(sorted(policy.allowed_schemes))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check hostname exists
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if not hostname:
|
||||||
|
raise SSRFError("URL has no hostname")
|
||||||
|
|
||||||
|
# Check allowed hosts whitelist
|
||||||
|
if policy.allowed_hosts is not None and hostname not in policy.allowed_hosts:
|
||||||
|
raise SSRFError(f"Host '{hostname}' is not in the allowed hosts list")
|
||||||
|
|
||||||
|
# DNS resolution -- resolve before making any request
|
||||||
|
resolved_ips = resolve_hostname(hostname)
|
||||||
|
if not resolved_ips:
|
||||||
|
raise SSRFError(f"Could not resolve hostname '{hostname}'")
|
||||||
|
|
||||||
|
# Check all resolved IPs against blocked networks
|
||||||
|
for ip_str in resolved_ips:
|
||||||
|
if is_private_ip(ip_str):
|
||||||
|
raise SSRFError(
|
||||||
|
f"Host '{hostname}' resolves to private/reserved IP {ip_str}. "
|
||||||
|
"Request blocked for SSRF protection."
|
||||||
|
)
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_hostname(hostname: str) -> list[str]:
|
||||||
|
"""Resolve hostname to IP addresses via DNS."""
|
||||||
|
try:
|
||||||
|
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||||
|
return list({result[4][0] for result in results})
|
||||||
|
except socket.gaierror:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def safe_fetch(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
policy: SSRFPolicy = DEFAULT_POLICY,
|
||||||
|
method: str = "GET",
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""Fetch a URL with SSRF protection.
|
||||||
|
|
||||||
|
Validates the URL, resolves DNS, checks IPs, then makes the request.
|
||||||
|
After receiving the response, verifies the actual connected IP
|
||||||
|
to guard against DNS rebinding.
|
||||||
|
"""
|
||||||
|
validate_url(url, policy)
|
||||||
|
|
||||||
|
# Make the request with redirect following disabled so we can check each hop
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
follow_redirects=False,
|
||||||
|
timeout=httpx.Timeout(policy.timeout_seconds),
|
||||||
|
) as client:
|
||||||
|
current_url = url
|
||||||
|
for _redirect_count in range(policy.max_redirects + 1):
|
||||||
|
response = await client.request(
|
||||||
|
method,
|
||||||
|
current_url,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.is_redirect:
|
||||||
|
redirect_url = str(response.next_request.url) if response.next_request else None
|
||||||
|
if not redirect_url:
|
||||||
|
raise SSRFError("Redirect with no target URL")
|
||||||
|
# Validate the redirect target
|
||||||
|
validate_url(redirect_url, policy)
|
||||||
|
current_url = redirect_url
|
||||||
|
continue
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
raise SSRFError(
|
||||||
|
f"Too many redirects (max {policy.max_redirects}). "
|
||||||
|
"Possible redirect loop or evasion attempt."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def safe_fetch_text(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
policy: SSRFPolicy = DEFAULT_POLICY,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Fetch a URL and return text content with SSRF protection."""
|
||||||
|
response = await safe_fetch(url, policy=policy, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
51
backend/app/openapi/validator.py
Normal file
51
backend/app/openapi/validator.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""OpenAPI spec validator.
|
||||||
|
|
||||||
|
Validates an OpenAPI spec dict for required fields and basic structural
|
||||||
|
correctness. Returns a list of human-readable error strings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
_SUPPORTED_VERSIONS = ("3.0.", "3.1.")
|
||||||
|
_REQUIRED_FIELDS = ("openapi", "info", "paths")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_spec(spec_dict: object) -> list[str]:
|
||||||
|
"""Validate an OpenAPI spec dict.
|
||||||
|
|
||||||
|
Returns a list of error strings. An empty list means the spec is valid.
|
||||||
|
Does not raise; all errors are captured and returned.
|
||||||
|
"""
|
||||||
|
if not isinstance(spec_dict, dict):
|
||||||
|
return [f"Spec must be a dict, got {type(spec_dict).__name__}"]
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
# Check required top-level fields
|
||||||
|
for field in _REQUIRED_FIELDS:
|
||||||
|
if field not in spec_dict:
|
||||||
|
errors.append(f"Missing required field: '{field}'")
|
||||||
|
|
||||||
|
# Validate openapi version if present
|
||||||
|
if "openapi" in spec_dict:
|
||||||
|
version = spec_dict["openapi"]
|
||||||
|
if not isinstance(version, str):
|
||||||
|
errors.append(f"'openapi' must be a string, got {type(version).__name__}")
|
||||||
|
elif not any(version.startswith(prefix) for prefix in _SUPPORTED_VERSIONS):
|
||||||
|
errors.append(
|
||||||
|
f"Unsupported OpenAPI version '{version}'. "
|
||||||
|
f"Supported versions start with: {', '.join(_SUPPORTED_VERSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate info object if present
|
||||||
|
if "info" in spec_dict and isinstance(spec_dict["info"], dict):
|
||||||
|
info = spec_dict["info"]
|
||||||
|
for sub_field in ("title", "version"):
|
||||||
|
if sub_field not in info:
|
||||||
|
errors.append(f"Missing required field in 'info': '{sub_field}'")
|
||||||
|
|
||||||
|
# Validate paths object if present
|
||||||
|
if "paths" in spec_dict and not isinstance(spec_dict["paths"], dict):
|
||||||
|
errors.append(f"'paths' must be a dict, got {type(spec_dict['paths']).__name__}")
|
||||||
|
|
||||||
|
return errors
|
||||||
@@ -100,5 +100,41 @@ class AgentRegistry:
|
|||||||
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
|
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
|
||||||
return tuple(a for a in self._agents.values() if a.permission == permission)
|
return tuple(a for a in self._agents.values() if a.permission == permission)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_template(
|
||||||
|
cls,
|
||||||
|
template_name: str,
|
||||||
|
templates_dir: str | Path | None = None,
|
||||||
|
) -> AgentRegistry:
|
||||||
|
"""Load agent configurations from a named template."""
|
||||||
|
if templates_dir is None:
|
||||||
|
templates_dir = Path(__file__).parent.parent / "templates"
|
||||||
|
templates_dir = Path(templates_dir)
|
||||||
|
|
||||||
|
yaml_path = templates_dir / f"{template_name}.yaml"
|
||||||
|
if not yaml_path.exists():
|
||||||
|
available = cls.list_templates(templates_dir)
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Template '{template_name}' not found. "
|
||||||
|
f"Available: {', '.join(available) if available else 'none'}"
|
||||||
|
)
|
||||||
|
return cls.load(yaml_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_templates(
|
||||||
|
cls,
|
||||||
|
templates_dir: str | Path | None = None,
|
||||||
|
) -> tuple[str, ...]:
|
||||||
|
"""List available template names from the templates directory."""
|
||||||
|
if templates_dir is None:
|
||||||
|
templates_dir = Path(__file__).parent.parent / "templates"
|
||||||
|
templates_dir = Path(templates_dir)
|
||||||
|
|
||||||
|
if not templates_dir.is_dir():
|
||||||
|
return ()
|
||||||
|
return tuple(
|
||||||
|
sorted(p.stem for p in templates_dir.glob("*.yaml"))
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._agents)
|
return len(self._agents)
|
||||||
|
|||||||
3
backend/app/replay/__init__.py
Normal file
3
backend/app/replay/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Replay module -- conversation replay API and transformer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
125
backend/app/replay/api.py
Normal file
125
backend/app/replay/api.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Replay API router -- conversation listing and step-by-step replay."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.auth import require_admin_api_key
|
||||||
|
|
||||||
|
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/api/v1",
|
||||||
|
tags=["replay"],
|
||||||
|
dependencies=[Depends(require_admin_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
|
_COUNT_CONVERSATIONS_SQL = """
|
||||||
|
SELECT COUNT(*) FROM conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
_LIST_CONVERSATIONS_SQL = """
|
||||||
|
SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd
|
||||||
|
FROM conversations
|
||||||
|
ORDER BY last_activity DESC
|
||||||
|
LIMIT %(limit)s OFFSET %(offset)s
|
||||||
|
"""
|
||||||
|
|
||||||
|
_GET_CHECKPOINTS_SQL = """
|
||||||
|
SELECT thread_id, checkpoint_id, checkpoint, metadata
|
||||||
|
FROM checkpoints
|
||||||
|
WHERE thread_id = %(thread_id)s
|
||||||
|
ORDER BY checkpoint_id ASC
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_pool(request: Request) -> AsyncConnectionPool:
|
||||||
|
"""Dependency: extract the shared pool from app state."""
|
||||||
|
return request.app.state.pool
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/conversations")
|
||||||
|
async def list_conversations(
|
||||||
|
request: Request,
|
||||||
|
page: Annotated[int, Query(ge=1)] = 1,
|
||||||
|
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||||
|
) -> dict:
|
||||||
|
"""List conversations with pagination."""
|
||||||
|
pool = await get_pool(request)
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
count_cursor = await conn.execute(_COUNT_CONVERSATIONS_SQL)
|
||||||
|
count_row = await count_cursor.fetchone()
|
||||||
|
total = count_row[0] if count_row else 0
|
||||||
|
|
||||||
|
cursor = await conn.execute(
|
||||||
|
_LIST_CONVERSATIONS_SQL,
|
||||||
|
{"limit": per_page, "offset": offset},
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
return envelope({
|
||||||
|
"conversations": [dict(row) for row in rows],
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/replay/{thread_id}")
|
||||||
|
async def get_replay(
|
||||||
|
thread_id: str,
|
||||||
|
request: Request,
|
||||||
|
page: Annotated[int, Query(ge=1)] = 1,
|
||||||
|
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||||
|
) -> dict:
|
||||||
|
"""Return paginated replay steps for a conversation thread."""
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
if not _THREAD_ID_PATTERN.match(thread_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid thread_id format")
|
||||||
|
|
||||||
|
pool = await get_pool(request)
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id})
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
all_steps = transform_checkpoints([dict(row) for row in rows])
|
||||||
|
total_steps = len(all_steps)
|
||||||
|
start = (page - 1) * per_page
|
||||||
|
end = start + per_page
|
||||||
|
page_steps = all_steps[start:end]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"total_steps": total_steps,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"step": s.step,
|
||||||
|
"type": s.type.value,
|
||||||
|
"timestamp": s.timestamp,
|
||||||
|
"content": s.content,
|
||||||
|
"agent": s.agent,
|
||||||
|
"tool": s.tool,
|
||||||
|
"params": s.params,
|
||||||
|
"result": s.result,
|
||||||
|
"reasoning": s.reasoning,
|
||||||
|
"tokens": s.tokens,
|
||||||
|
"duration_ms": s.duration_ms,
|
||||||
|
}
|
||||||
|
for s in page_steps
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return envelope(data)
|
||||||
52
backend/app/replay/models.py
Normal file
52
backend/app/replay/models.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Value objects for conversation replay."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class StepType(str, Enum):
|
||||||
|
"""Types of steps in a conversation replay."""
|
||||||
|
|
||||||
|
user_message = "user_message"
|
||||||
|
supervisor_routing = "supervisor_routing"
|
||||||
|
tool_call = "tool_call"
|
||||||
|
tool_result = "tool_result"
|
||||||
|
agent_response = "agent_response"
|
||||||
|
interrupt = "interrupt"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ReplayStep:
|
||||||
|
"""A single step in a conversation replay."""
|
||||||
|
|
||||||
|
step: int
|
||||||
|
type: StepType
|
||||||
|
timestamp: str
|
||||||
|
content: str = ""
|
||||||
|
agent: str | None = None
|
||||||
|
tool: str | None = None
|
||||||
|
params: dict | None = None
|
||||||
|
result: dict | None = None
|
||||||
|
reasoning: str | None = None
|
||||||
|
tokens: int | None = None
|
||||||
|
duration_ms: int | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# Store params as a frozen copy to prevent mutation from the outside
|
||||||
|
if self.params is not None:
|
||||||
|
object.__setattr__(self, "params", dict(self.params))
|
||||||
|
if self.result is not None:
|
||||||
|
object.__setattr__(self, "result", dict(self.result))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ReplayPage:
|
||||||
|
"""A paginated page of replay steps for a conversation thread."""
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
total_steps: int
|
||||||
|
page: int
|
||||||
|
per_page: int
|
||||||
|
steps: tuple[ReplayStep, ...]
|
||||||
116
backend/app/replay/transformer.py
Normal file
116
backend/app/replay/transformer.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""Transforms PostgresSaver checkpoint rows into ReplayStep list."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from app.replay.models import ReplayStep, StepType
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
_EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_messages(row: dict) -> list[dict]:
|
||||||
|
"""Safely extract messages list from a checkpoint row."""
|
||||||
|
checkpoint = row.get("checkpoint")
|
||||||
|
if not checkpoint or not isinstance(checkpoint, dict):
|
||||||
|
return []
|
||||||
|
channel_values = checkpoint.get("channel_values")
|
||||||
|
if not channel_values or not isinstance(channel_values, dict):
|
||||||
|
return []
|
||||||
|
messages = channel_values.get("messages")
|
||||||
|
if not messages or not isinstance(messages, list):
|
||||||
|
return []
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _step_from_message(msg: dict, step_number: int) -> ReplayStep | None:
|
||||||
|
"""Convert a single message dict to a ReplayStep. Returns None for unknown types."""
|
||||||
|
msg_type = msg.get("type", "")
|
||||||
|
timestamp = msg.get("created_at") or _EMPTY_TIMESTAMP
|
||||||
|
content = msg.get("content") or ""
|
||||||
|
if isinstance(content, list):
|
||||||
|
# LangChain may encode content as a list of parts
|
||||||
|
content = " ".join(
|
||||||
|
part.get("text", "") if isinstance(part, dict) else str(part)
|
||||||
|
for part in content
|
||||||
|
)
|
||||||
|
|
||||||
|
if msg_type == "human":
|
||||||
|
return ReplayStep(
|
||||||
|
step=step_number,
|
||||||
|
type=StepType.user_message,
|
||||||
|
timestamp=timestamp,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
if msg_type == "ai":
|
||||||
|
tool_calls = msg.get("tool_calls") or []
|
||||||
|
if tool_calls:
|
||||||
|
first = tool_calls[0]
|
||||||
|
return ReplayStep(
|
||||||
|
step=step_number,
|
||||||
|
type=StepType.tool_call,
|
||||||
|
timestamp=timestamp,
|
||||||
|
content=content,
|
||||||
|
tool=first.get("name"),
|
||||||
|
params=dict(first.get("args") or {}),
|
||||||
|
)
|
||||||
|
return ReplayStep(
|
||||||
|
step=step_number,
|
||||||
|
type=StepType.agent_response,
|
||||||
|
timestamp=timestamp,
|
||||||
|
content=content,
|
||||||
|
agent=msg.get("name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if msg_type == "tool":
|
||||||
|
raw = content
|
||||||
|
result: dict | None = None
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
result = json.loads(raw)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
result = {"raw": raw}
|
||||||
|
return ReplayStep(
|
||||||
|
step=step_number,
|
||||||
|
type=StepType.tool_result,
|
||||||
|
timestamp=timestamp,
|
||||||
|
tool=msg.get("name"),
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Skipping unknown message type: %s", msg_type)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def transform_checkpoints(rows: list[dict]) -> list[ReplayStep]:
|
||||||
|
"""Transform a list of checkpoint rows into an ordered list of ReplaySteps.
|
||||||
|
|
||||||
|
Steps are numbered sequentially starting from 1 across all rows.
|
||||||
|
Unknown or malformed messages are silently skipped.
|
||||||
|
"""
|
||||||
|
steps: list[ReplayStep] = []
|
||||||
|
step_number = 1
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
messages = _extract_messages(row)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.exception("Error extracting messages from checkpoint row")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
try:
|
||||||
|
step = _step_from_message(msg, step_number)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.exception("Error converting message to ReplayStep")
|
||||||
|
step = None
|
||||||
|
|
||||||
|
if step is not None:
|
||||||
|
steps.append(step)
|
||||||
|
step_number += 1
|
||||||
|
|
||||||
|
return steps
|
||||||
131
backend/app/safety.py
Normal file
131
backend/app/safety.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Safety policy for destructive-action confirmation rules.
|
||||||
|
|
||||||
|
This module makes the confirmation rules explicit and auditable. Every tool
|
||||||
|
call passes through ``requires_confirmation`` before execution to decide
|
||||||
|
whether human-in-the-loop approval is needed.
|
||||||
|
|
||||||
|
Policy summary
|
||||||
|
--------------
|
||||||
|
- ``read`` actions: execute immediately, no confirmation required.
|
||||||
|
- ``write`` actions: require human approval via interrupt gate.
|
||||||
|
- OpenAPI-imported endpoints: use ``needs_interrupt`` from classification.
|
||||||
|
- If both the agent permission AND the endpoint classification agree
|
||||||
|
the action is read-only, it executes without confirmation.
|
||||||
|
|
||||||
|
Multi-intent semantics
|
||||||
|
----------------------
|
||||||
|
When a user message contains multiple intents (e.g. "cancel my order and
|
||||||
|
apply a refund"), the supervisor routes them sequentially. Each action is
|
||||||
|
evaluated independently:
|
||||||
|
- If a write action is blocked by an interrupt, subsequent actions in the
|
||||||
|
same message are paused until the interrupt is resolved.
|
||||||
|
- Read actions that follow a blocked write are also paused (sequential,
|
||||||
|
not best-effort) to preserve causal ordering.
|
||||||
|
- If an interrupt is rejected, the remaining actions are skipped and the
|
||||||
|
agent informs the user.
|
||||||
|
|
||||||
|
MCP error taxonomy
|
||||||
|
------------------
|
||||||
|
Tool execution errors are classified into categories for retry decisions:
|
||||||
|
|
||||||
|
- ``transient``: network timeouts, rate limits, 5xx -- retryable up to 3 times.
|
||||||
|
- ``validation``: bad parameters, 4xx -- not retryable, report to user.
|
||||||
|
- ``auth``: 401/403 -- not retryable, escalate.
|
||||||
|
- ``unknown``: unclassified -- not retryable, log and escalate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ConfirmationPolicy:
|
||||||
|
"""Result of evaluating whether an action needs confirmation."""
|
||||||
|
|
||||||
|
requires_confirmation: bool
|
||||||
|
reason: str
|
||||||
|
|
||||||
|
|
||||||
|
def requires_confirmation(
|
||||||
|
*,
|
||||||
|
agent_permission: Literal["read", "write"],
|
||||||
|
needs_interrupt: bool | None = None,
|
||||||
|
) -> ConfirmationPolicy:
|
||||||
|
"""Determine whether an action requires human confirmation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
agent_permission:
|
||||||
|
The permission level of the agent executing the action.
|
||||||
|
needs_interrupt:
|
||||||
|
Override from OpenAPI classification. When ``None``, the decision
|
||||||
|
is based solely on ``agent_permission``.
|
||||||
|
"""
|
||||||
|
if needs_interrupt is not None:
|
||||||
|
if needs_interrupt:
|
||||||
|
return ConfirmationPolicy(
|
||||||
|
requires_confirmation=True,
|
||||||
|
reason="Endpoint classified as requiring human approval",
|
||||||
|
)
|
||||||
|
return ConfirmationPolicy(
|
||||||
|
requires_confirmation=False,
|
||||||
|
reason="Endpoint classified as safe (no interrupt needed)",
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_permission == "write":
|
||||||
|
return ConfirmationPolicy(
|
||||||
|
requires_confirmation=True,
|
||||||
|
reason="Write-permission agent actions require confirmation",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ConfirmationPolicy(
|
||||||
|
requires_confirmation=False,
|
||||||
|
reason="Read-only agent actions execute immediately",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- MCP Error Taxonomy ---
|
||||||
|
|
||||||
|
|
||||||
|
MCP_ERROR_CATEGORY = Literal["transient", "validation", "auth", "unknown"]
|
||||||
|
|
||||||
|
_TRANSIENT_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504})
|
||||||
|
_AUTH_STATUS_CODES = frozenset({401, 403})
|
||||||
|
_MAX_RETRIES = 3
|
||||||
|
|
||||||
|
|
||||||
|
def classify_mcp_error(
|
||||||
|
*,
|
||||||
|
status_code: int | None = None,
|
||||||
|
error_message: str = "",
|
||||||
|
) -> MCP_ERROR_CATEGORY:
|
||||||
|
"""Classify an MCP tool error for retry decisions."""
|
||||||
|
if status_code is not None:
|
||||||
|
if status_code in _TRANSIENT_STATUS_CODES:
|
||||||
|
return "transient"
|
||||||
|
if status_code in _AUTH_STATUS_CODES:
|
||||||
|
return "auth"
|
||||||
|
if 400 <= status_code < 500:
|
||||||
|
return "validation"
|
||||||
|
|
||||||
|
lower_msg = error_message.lower()
|
||||||
|
if any(kw in lower_msg for kw in ("timeout", "timed out", "rate limit")):
|
||||||
|
return "transient"
|
||||||
|
if any(kw in lower_msg for kw in ("unauthorized", "forbidden")):
|
||||||
|
return "auth"
|
||||||
|
if any(kw in lower_msg for kw in ("invalid", "missing", "bad request")):
|
||||||
|
return "validation"
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def is_retryable(category: MCP_ERROR_CATEGORY) -> bool:
|
||||||
|
"""Return whether a given error category is retryable."""
|
||||||
|
return category == "transient"
|
||||||
|
|
||||||
|
|
||||||
|
def max_retries() -> int:
|
||||||
|
"""Maximum retry attempts for transient errors."""
|
||||||
|
return _MAX_RETRIES
|
||||||
@@ -1,9 +1,18 @@
|
|||||||
"""Session TTL management with sliding window and interrupt extension."""
|
"""Session TTL management with sliding window and interrupt extension.
|
||||||
|
|
||||||
|
Provides both in-memory (SessionManager) and PostgreSQL-backed
|
||||||
|
(PgSessionManager) implementations behind a common Protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -13,8 +22,19 @@ class SessionState:
|
|||||||
has_pending_interrupt: bool
|
has_pending_interrupt: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManagerProtocol(Protocol):
|
||||||
|
"""Protocol for session TTL management."""
|
||||||
|
|
||||||
|
def touch(self, thread_id: str) -> SessionState: ...
|
||||||
|
def is_expired(self, thread_id: str) -> bool: ...
|
||||||
|
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
|
||||||
|
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
|
||||||
|
def get_state(self, thread_id: str) -> SessionState | None: ...
|
||||||
|
def remove(self, thread_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
"""Manages session TTL with sliding window and interrupt extensions.
|
"""In-memory session manager for single-worker development.
|
||||||
|
|
||||||
- Each message resets the TTL (sliding window).
|
- Each message resets the TTL (sliding window).
|
||||||
- A pending interrupt suspends expiration until resolved.
|
- A pending interrupt suspends expiration until resolved.
|
||||||
@@ -40,10 +60,8 @@ class SessionManager:
|
|||||||
state = self._sessions.get(thread_id)
|
state = self._sessions.get(thread_id)
|
||||||
if state is None:
|
if state is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if state.has_pending_interrupt:
|
if state.has_pending_interrupt:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
elapsed = time.time() - state.last_activity
|
elapsed = time.time() - state.last_activity
|
||||||
return elapsed > self._session_ttl
|
return elapsed > self._session_ttl
|
||||||
|
|
||||||
@@ -52,7 +70,6 @@ class SessionManager:
|
|||||||
existing = self._sessions.get(thread_id)
|
existing = self._sessions.get(thread_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
return self.touch(thread_id)
|
return self.touch(thread_id)
|
||||||
|
|
||||||
new_state = SessionState(
|
new_state = SessionState(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
last_activity=existing.last_activity,
|
last_activity=existing.last_activity,
|
||||||
@@ -76,3 +93,120 @@ class SessionManager:
|
|||||||
|
|
||||||
def remove(self, thread_id: str) -> None:
|
def remove(self, thread_id: str) -> None:
|
||||||
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for explicit naming
|
||||||
|
InMemorySessionManager = SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
class PgSessionManager:
|
||||||
|
"""PostgreSQL-backed session manager for multi-worker production."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pool: AsyncConnectionPool,
|
||||||
|
session_ttl_seconds: int = 1800,
|
||||||
|
) -> None:
|
||||||
|
self._pool = pool
|
||||||
|
self._session_ttl = session_ttl_seconds
|
||||||
|
|
||||||
|
def touch(self, thread_id: str) -> SessionState:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
|
||||||
|
|
||||||
|
async def _touch(self, thread_id: str) -> SessionState:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||||
|
VALUES (%(tid)s, %(now)s, FALSE)
|
||||||
|
ON CONFLICT (thread_id) DO UPDATE
|
||||||
|
SET last_activity = %(now)s
|
||||||
|
""",
|
||||||
|
{"tid": thread_id, "now": now},
|
||||||
|
)
|
||||||
|
return SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=now.timestamp(),
|
||||||
|
has_pending_interrupt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_expired(self, thread_id: str) -> bool:
|
||||||
|
state = self.get_state(thread_id)
|
||||||
|
if state is None:
|
||||||
|
return True
|
||||||
|
if state.has_pending_interrupt:
|
||||||
|
return False
|
||||||
|
elapsed = time.time() - state.last_activity
|
||||||
|
return elapsed > self._session_ttl
|
||||||
|
|
||||||
|
def extend_for_interrupt(self, thread_id: str) -> SessionState:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._set_interrupt(thread_id, True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve_interrupt(self, thread_id: str) -> SessionState:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._set_interrupt(thread_id, False)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _set_interrupt(
|
||||||
|
self, thread_id: str, has_interrupt: bool
|
||||||
|
) -> SessionState:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||||
|
VALUES (%(tid)s, %(now)s, %(interrupt)s)
|
||||||
|
ON CONFLICT (thread_id) DO UPDATE
|
||||||
|
SET last_activity = %(now)s,
|
||||||
|
has_pending_interrupt = %(interrupt)s
|
||||||
|
""",
|
||||||
|
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
|
||||||
|
)
|
||||||
|
return SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=now.timestamp(),
|
||||||
|
has_pending_interrupt=has_interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_state(self, thread_id: str) -> SessionState | None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self._get_state(thread_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_state(self, thread_id: str) -> SessionState | None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
cursor = await conn.execute(
|
||||||
|
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return SessionState(
|
||||||
|
thread_id=thread_id,
|
||||||
|
last_activity=row["last_activity"].timestamp(),
|
||||||
|
has_pending_interrupt=row["has_pending_interrupt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove(self, thread_id: str) -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
|
||||||
|
|
||||||
|
async def _remove(self, thread_id: str) -> None:
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"DELETE FROM sessions WHERE thread_id = %(tid)s",
|
||||||
|
{"tid": thread_id},
|
||||||
|
)
|
||||||
|
|||||||
3
backend/app/tools/__init__.py
Normal file
3
backend/app/tools/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Tools package for smart-support backend."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
72
backend/app/tools/error_handler.py
Normal file
72
backend/app/tools/error_handler.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""Error classification and retry logic for tool calls."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorCategory(Enum):
|
||||||
|
"""Categories for error classification to guide retry decisions."""
|
||||||
|
|
||||||
|
RETRYABLE = "retryable"
|
||||||
|
NON_RETRYABLE = "non_retryable"
|
||||||
|
AUTH_FAILURE = "auth_failure"
|
||||||
|
TIMEOUT = "timeout"
|
||||||
|
NETWORK = "network"
|
||||||
|
|
||||||
|
|
||||||
|
def classify_error(exc: Exception) -> ErrorCategory:
|
||||||
|
"""Classify an exception into an ErrorCategory.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- httpx.TimeoutException -> TIMEOUT
|
||||||
|
- httpx.ConnectError -> NETWORK
|
||||||
|
- httpx.HTTPStatusError 401/403 -> AUTH_FAILURE
|
||||||
|
- httpx.HTTPStatusError 429/500/502/503 -> RETRYABLE
|
||||||
|
- anything else -> NON_RETRYABLE
|
||||||
|
"""
|
||||||
|
if isinstance(exc, httpx.TimeoutException):
|
||||||
|
return ErrorCategory.TIMEOUT
|
||||||
|
if isinstance(exc, httpx.ConnectError):
|
||||||
|
return ErrorCategory.NETWORK
|
||||||
|
if isinstance(exc, httpx.HTTPStatusError):
|
||||||
|
code = exc.response.status_code
|
||||||
|
if code in (401, 403):
|
||||||
|
return ErrorCategory.AUTH_FAILURE
|
||||||
|
if code in (429, 500, 502, 503):
|
||||||
|
return ErrorCategory.RETRYABLE
|
||||||
|
return ErrorCategory.NON_RETRYABLE
|
||||||
|
return ErrorCategory.NON_RETRYABLE
|
||||||
|
|
||||||
|
|
||||||
|
async def with_retry(
|
||||||
|
fn: Callable[..., Any],
|
||||||
|
max_retries: int = 3,
|
||||||
|
base_delay: float = 1.0,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute an async callable with exponential backoff for RETRYABLE errors.
|
||||||
|
|
||||||
|
Only ErrorCategory.RETRYABLE errors trigger retries. All other error
|
||||||
|
categories raise immediately after the first attempt.
|
||||||
|
"""
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for attempt in range(1, max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await fn()
|
||||||
|
except Exception as exc:
|
||||||
|
category = classify_error(exc)
|
||||||
|
if category != ErrorCategory.RETRYABLE:
|
||||||
|
raise
|
||||||
|
last_exc = exc
|
||||||
|
if attempt < max_retries:
|
||||||
|
delay = base_delay * (2 ** (attempt - 1))
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
raise last_exc # type: ignore[misc]
|
||||||
30
backend/app/ws_context.py
Normal file
30
backend/app/ws_context.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""WebSocketContext -- bundles all dependencies needed by dispatch_message."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.analytics.event_recorder import AnalyticsRecorder
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.conversation_tracker import ConversationTrackerProtocol
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WebSocketContext:
|
||||||
|
"""All dependencies required for WebSocket message processing.
|
||||||
|
|
||||||
|
Replaces the previous 9-parameter function signature in dispatch_message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_ctx: GraphContext
|
||||||
|
session_manager: SessionManager
|
||||||
|
callback_handler: TokenUsageCallbackHandler
|
||||||
|
interrupt_manager: InterruptManager | None = None
|
||||||
|
analytics_recorder: AnalyticsRecorder | None = None
|
||||||
|
conversation_tracker: ConversationTrackerProtocol | None = None
|
||||||
|
pool: Any = None
|
||||||
@@ -3,47 +3,98 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
|
||||||
|
|
||||||
from app.callbacks import TokenUsageCallbackHandler
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
from app.session_manager import SessionManager
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import structlog
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
||||||
MAX_CONTENT_LENGTH = 8_000 # characters
|
MAX_CONTENT_LENGTH = 10_000 # characters
|
||||||
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||||
|
|
||||||
|
# Rate limiting: max 10 messages per 10-second window, per thread
|
||||||
|
_RATE_LIMIT_MAX = 10
|
||||||
|
_RATE_LIMIT_WINDOW = 10.0
|
||||||
|
_MAX_TRACKED_THREADS = 10_000
|
||||||
|
_thread_timestamps: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
|
||||||
|
def _evict_stale_threads(cutoff: float) -> None:
|
||||||
|
"""Remove thread entries with no recent timestamps to prevent memory leak."""
|
||||||
|
stale = [tid for tid, ts in _thread_timestamps.items() if not ts or ts[-1] < cutoff]
|
||||||
|
for tid in stale:
|
||||||
|
del _thread_timestamps[tid]
|
||||||
|
|
||||||
|
|
||||||
async def handle_user_message(
|
async def handle_user_message(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
graph: CompiledStateGraph,
|
ctx: GraphContext,
|
||||||
session_manager: SessionManager,
|
session_manager: SessionManager,
|
||||||
callback_handler: TokenUsageCallbackHandler,
|
callback_handler: TokenUsageCallbackHandler,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
|
interrupt_manager: InterruptManager | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process a user message through the graph and stream results back."""
|
"""Process a user message through the graph and stream results back."""
|
||||||
if session_manager.is_expired(thread_id):
|
existing = session_manager.get_state(thread_id)
|
||||||
|
if existing is not None and session_manager.is_expired(thread_id):
|
||||||
msg = "Session expired. Please start a new conversation."
|
msg = "Session expired. Please start a new conversation."
|
||||||
await _send_json(ws, {"type": "error", "message": msg})
|
await _send_json(ws, {"type": "error", "message": msg})
|
||||||
return
|
return
|
||||||
|
|
||||||
session_manager.touch(thread_id)
|
session_manager.touch(thread_id)
|
||||||
|
|
||||||
|
classification = await ctx.classify_intent(content)
|
||||||
|
if classification is not None:
|
||||||
|
logger.info(
|
||||||
|
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
||||||
|
thread_id,
|
||||||
|
classification.is_ambiguous,
|
||||||
|
[i.agent_name for i in classification.intents],
|
||||||
|
)
|
||||||
|
|
||||||
|
if classification.is_ambiguous and classification.clarification_question:
|
||||||
|
await _send_json(
|
||||||
|
ws,
|
||||||
|
{
|
||||||
|
"type": "clarification",
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"message": classification.clarification_question,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||||
|
return
|
||||||
|
|
||||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||||
|
|
||||||
|
if classification and len(classification.intents) > 1:
|
||||||
|
agent_names = [i.agent_name for i in classification.intents]
|
||||||
|
hint = (
|
||||||
|
f"\n[System: This request involves multiple actions. "
|
||||||
|
f"Execute in order: {', '.join(agent_names)}]"
|
||||||
|
)
|
||||||
|
input_msg = {"messages": [HumanMessage(content=content + hint)]}
|
||||||
|
else:
|
||||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||||
msg_chunk, metadata = chunk
|
msg_chunk, metadata = chunk
|
||||||
node = metadata.get("langgraph_node", "")
|
node = metadata.get("langgraph_node", "")
|
||||||
|
|
||||||
@@ -68,10 +119,18 @@ async def handle_user_message(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
state = await graph.aget_state(config)
|
state = await ctx.graph.aget_state(config)
|
||||||
if _has_interrupt(state):
|
if _has_interrupt(state):
|
||||||
interrupt_data = _extract_interrupt(state)
|
interrupt_data = _extract_interrupt(state)
|
||||||
session_manager.extend_for_interrupt(thread_id)
|
session_manager.extend_for_interrupt(thread_id)
|
||||||
|
|
||||||
|
if interrupt_manager is not None:
|
||||||
|
interrupt_manager.register(
|
||||||
|
thread_id=thread_id,
|
||||||
|
action=interrupt_data.get("action", "unknown"),
|
||||||
|
params=interrupt_data.get("params", {}),
|
||||||
|
)
|
||||||
|
|
||||||
await _send_json(
|
await _send_json(
|
||||||
ws,
|
ws,
|
||||||
{
|
{
|
||||||
@@ -91,20 +150,32 @@ async def handle_user_message(
|
|||||||
|
|
||||||
async def handle_interrupt_response(
|
async def handle_interrupt_response(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
graph: CompiledStateGraph,
|
ctx: GraphContext,
|
||||||
session_manager: SessionManager,
|
session_manager: SessionManager,
|
||||||
callback_handler: TokenUsageCallbackHandler,
|
callback_handler: TokenUsageCallbackHandler,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
approved: bool,
|
approved: bool,
|
||||||
|
interrupt_manager: InterruptManager | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Resume graph execution after interrupt approval/rejection."""
|
"""Resume graph execution after interrupt approval/rejection."""
|
||||||
|
if interrupt_manager is not None:
|
||||||
|
status = interrupt_manager.check_status(thread_id)
|
||||||
|
if status is not None and status.is_expired:
|
||||||
|
retry_prompt = interrupt_manager.generate_retry_prompt(status.record)
|
||||||
|
interrupt_manager.resolve(thread_id)
|
||||||
|
session_manager.resolve_interrupt(thread_id)
|
||||||
|
await _send_json(ws, retry_prompt)
|
||||||
|
return
|
||||||
|
|
||||||
|
interrupt_manager.resolve(thread_id)
|
||||||
|
|
||||||
session_manager.resolve_interrupt(thread_id)
|
session_manager.resolve_interrupt(thread_id)
|
||||||
session_manager.touch(thread_id)
|
session_manager.touch(thread_id)
|
||||||
|
|
||||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in graph.astream(
|
async for chunk in ctx.graph.astream(
|
||||||
Command(resume=approved),
|
Command(resume=approved),
|
||||||
config=config,
|
config=config,
|
||||||
stream_mode="messages",
|
stream_mode="messages",
|
||||||
@@ -132,9 +203,7 @@ async def handle_interrupt_response(
|
|||||||
|
|
||||||
async def dispatch_message(
|
async def dispatch_message(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
graph: CompiledStateGraph,
|
ctx: WebSocketContext,
|
||||||
session_manager: SessionManager,
|
|
||||||
callback_handler: TokenUsageCallbackHandler,
|
|
||||||
raw_data: str,
|
raw_data: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Parse and route an incoming WebSocket message."""
|
"""Parse and route an incoming WebSocket message."""
|
||||||
@@ -144,10 +213,14 @@ async def dispatch_message(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(raw_data)
|
data = json.loads(raw_data)
|
||||||
except json.JSONDecodeError:
|
except (json.JSONDecodeError, ValueError):
|
||||||
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
|
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
await _send_json(ws, {"type": "error", "message": "Invalid JSON: expected object"})
|
||||||
|
return
|
||||||
|
|
||||||
msg_type = data.get("type")
|
msg_type = data.get("type")
|
||||||
thread_id = data.get("thread_id", "")
|
thread_id = data.get("thread_id", "")
|
||||||
|
|
||||||
@@ -161,24 +234,81 @@ async def dispatch_message(
|
|||||||
|
|
||||||
if msg_type == "message":
|
if msg_type == "message":
|
||||||
content = data.get("content", "")
|
content = data.get("content", "")
|
||||||
if not content:
|
if not content or not content.strip():
|
||||||
await _send_json(ws, {"type": "error", "message": "Missing message content"})
|
await _send_json(ws, {"type": "error", "message": "Missing message content"})
|
||||||
return
|
return
|
||||||
if len(content) > MAX_CONTENT_LENGTH:
|
if len(content) > MAX_CONTENT_LENGTH:
|
||||||
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
||||||
return
|
return
|
||||||
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
|
|
||||||
|
# Rate limiting check (per-thread, with bounded memory)
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - _RATE_LIMIT_WINDOW
|
||||||
|
if len(_thread_timestamps) > _MAX_TRACKED_THREADS:
|
||||||
|
_evict_stale_threads(cutoff)
|
||||||
|
recent = [t for t in _thread_timestamps[thread_id] if t >= cutoff]
|
||||||
|
if len(recent) >= _RATE_LIMIT_MAX:
|
||||||
|
await _send_json(ws, {"type": "error", "message": "Rate limit exceeded"})
|
||||||
|
return
|
||||||
|
_thread_timestamps[thread_id] = [*recent, now]
|
||||||
|
|
||||||
|
await handle_user_message(
|
||||||
|
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||||
|
thread_id, content,
|
||||||
|
interrupt_manager=ctx.interrupt_manager,
|
||||||
|
)
|
||||||
|
await _fire_and_forget_tracking(
|
||||||
|
thread_id=thread_id,
|
||||||
|
pool=ctx.pool,
|
||||||
|
analytics_recorder=ctx.analytics_recorder,
|
||||||
|
conversation_tracker=ctx.conversation_tracker,
|
||||||
|
agent_name=None,
|
||||||
|
tokens=0,
|
||||||
|
cost=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
elif msg_type == "interrupt_response":
|
elif msg_type == "interrupt_response":
|
||||||
approved = data.get("approved", False)
|
approved = data.get("approved", False)
|
||||||
await handle_interrupt_response(
|
await handle_interrupt_response(
|
||||||
ws, graph, session_manager, callback_handler, thread_id, approved
|
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||||
|
thread_id, approved,
|
||||||
|
interrupt_manager=ctx.interrupt_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
|
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
|
||||||
|
|
||||||
|
|
||||||
|
async def _fire_and_forget_tracking(
|
||||||
|
thread_id: str,
|
||||||
|
pool: object,
|
||||||
|
analytics_recorder: object | None,
|
||||||
|
conversation_tracker: object | None,
|
||||||
|
agent_name: str | None,
|
||||||
|
tokens: int,
|
||||||
|
cost: float,
|
||||||
|
) -> None:
|
||||||
|
"""Fire-and-forget analytics/tracking; failures must NOT break chat."""
|
||||||
|
try:
|
||||||
|
if conversation_tracker is not None and pool is not None:
|
||||||
|
await conversation_tracker.ensure_conversation(pool, thread_id)
|
||||||
|
await conversation_tracker.record_turn(pool, thread_id, agent_name, tokens, cost)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Conversation tracker error for thread %s (suppressed)", thread_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if analytics_recorder is not None:
|
||||||
|
await analytics_recorder.record(
|
||||||
|
thread_id=thread_id,
|
||||||
|
event_type="message",
|
||||||
|
agent_name=agent_name,
|
||||||
|
tokens_used=tokens,
|
||||||
|
cost_usd=cost,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Analytics recorder error for thread %s (suppressed)", thread_id)
|
||||||
|
|
||||||
|
|
||||||
def _has_interrupt(state: Any) -> bool:
|
def _has_interrupt(state: Any) -> bool:
|
||||||
"""Check if the graph state has a pending interrupt."""
|
"""Check if the graph state has a pending interrupt."""
|
||||||
tasks = getattr(state, "tasks", ())
|
tasks = getattr(state, "tasks", ())
|
||||||
|
|||||||
153
backend/fixtures/demo_data.py
Normal file
153
backend/fixtures/demo_data.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Seed script -- inserts sample conversations and analytics events for demo purposes.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
cd backend
|
||||||
|
python fixtures/demo_data.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
|
||||||
|
DATABASE_URL = os.environ.get(
|
||||||
|
"DATABASE_URL",
|
||||||
|
"postgresql://smart_support:dev_password@localhost:5432/smart_support",
|
||||||
|
)
|
||||||
|
|
||||||
|
SAMPLE_CONVERSATIONS = [
|
||||||
|
{
|
||||||
|
"thread_id": "demo-thread-001",
|
||||||
|
"agents_used": ["order_agent"],
|
||||||
|
"turn_count": 3,
|
||||||
|
"total_tokens": 1250,
|
||||||
|
"total_cost_usd": 0.00375,
|
||||||
|
"resolution_type": "resolved",
|
||||||
|
"minutes_ago": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"thread_id": "demo-thread-002",
|
||||||
|
"agents_used": ["order_agent", "refund_agent"],
|
||||||
|
"turn_count": 6,
|
||||||
|
"total_tokens": 3200,
|
||||||
|
"total_cost_usd": 0.0096,
|
||||||
|
"resolution_type": "resolved",
|
||||||
|
"minutes_ago": 30,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"thread_id": "demo-thread-003",
|
||||||
|
"agents_used": ["general_agent"],
|
||||||
|
"turn_count": 2,
|
||||||
|
"total_tokens": 800,
|
||||||
|
"total_cost_usd": 0.0024,
|
||||||
|
"resolution_type": None,
|
||||||
|
"minutes_ago": 60,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"thread_id": "demo-thread-004",
|
||||||
|
"agents_used": ["order_agent", "general_agent"],
|
||||||
|
"turn_count": 8,
|
||||||
|
"total_tokens": 4500,
|
||||||
|
"total_cost_usd": 0.0135,
|
||||||
|
"resolution_type": "escalated",
|
||||||
|
"minutes_ago": 120,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"thread_id": "demo-thread-005",
|
||||||
|
"agents_used": ["refund_agent"],
|
||||||
|
"turn_count": 4,
|
||||||
|
"total_tokens": 2100,
|
||||||
|
"total_cost_usd": 0.0063,
|
||||||
|
"resolution_type": "resolved",
|
||||||
|
"minutes_ago": 240,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
SAMPLE_EVENTS = [
|
||||||
|
{"thread_id": "demo-thread-001", "event_type": "message", "agent_name": "order_agent", "tokens_used": 400, "cost_usd": 0.0012, "success": True},
|
||||||
|
{"thread_id": "demo-thread-001", "event_type": "tool_call", "agent_name": "order_agent", "tool_name": "get_order_status", "tokens_used": 0, "cost_usd": 0.0, "success": True},
|
||||||
|
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "order_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
|
||||||
|
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
|
||||||
|
{"thread_id": "demo-thread-002", "event_type": "tool_call", "agent_name": "refund_agent", "tool_name": "process_refund", "tokens_used": 0, "cost_usd": 0.0, "success": True},
|
||||||
|
{"thread_id": "demo-thread-003", "event_type": "message", "agent_name": "general_agent", "tokens_used": 800, "cost_usd": 0.0024, "success": True},
|
||||||
|
{"thread_id": "demo-thread-004", "event_type": "message", "agent_name": "order_agent", "tokens_used": 2000, "cost_usd": 0.006, "success": True},
|
||||||
|
{"thread_id": "demo-thread-004", "event_type": "escalation", "agent_name": "general_agent", "tokens_used": 2500, "cost_usd": 0.0075, "success": False},
|
||||||
|
{"thread_id": "demo-thread-005", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 2100, "cost_usd": 0.0063, "success": True},
|
||||||
|
]
|
||||||
|
|
||||||
|
_INSERT_CONVERSATION = """
|
||||||
|
INSERT INTO conversations
|
||||||
|
(thread_id, started_at, last_activity, turn_count, agents_used,
|
||||||
|
total_tokens, total_cost_usd, resolution_type, ended_at)
|
||||||
|
VALUES
|
||||||
|
(%(thread_id)s, %(started_at)s, %(last_activity)s, %(turn_count)s,
|
||||||
|
%(agents_used)s, %(total_tokens)s, %(total_cost_usd)s,
|
||||||
|
%(resolution_type)s, %(ended_at)s)
|
||||||
|
ON CONFLICT (thread_id) DO NOTHING
|
||||||
|
"""
|
||||||
|
|
||||||
|
_INSERT_EVENT = """
|
||||||
|
INSERT INTO analytics_events
|
||||||
|
(thread_id, event_type, agent_name, tool_name, tokens_used, cost_usd, success)
|
||||||
|
VALUES
|
||||||
|
(%(thread_id)s, %(event_type)s, %(agent_name)s, %(tool_name)s,
|
||||||
|
%(tokens_used)s, %(cost_usd)s, %(success)s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def seed() -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
async with await psycopg.AsyncConnection.connect(DATABASE_URL) as conn:
|
||||||
|
print("Seeding conversations...")
|
||||||
|
for conv in SAMPLE_CONVERSATIONS:
|
||||||
|
started_at = now - timedelta(minutes=conv["minutes_ago"])
|
||||||
|
last_activity = started_at + timedelta(minutes=conv["turn_count"] * 2)
|
||||||
|
ended_at = last_activity if conv["resolution_type"] else None
|
||||||
|
|
||||||
|
await conn.execute(
|
||||||
|
_INSERT_CONVERSATION,
|
||||||
|
{
|
||||||
|
"thread_id": conv["thread_id"],
|
||||||
|
"started_at": started_at,
|
||||||
|
"last_activity": last_activity,
|
||||||
|
"turn_count": conv["turn_count"],
|
||||||
|
"agents_used": conv["agents_used"],
|
||||||
|
"total_tokens": conv["total_tokens"],
|
||||||
|
"total_cost_usd": conv["total_cost_usd"],
|
||||||
|
"resolution_type": conv["resolution_type"],
|
||||||
|
"ended_at": ended_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(f" Inserted conversation {conv['thread_id']}")
|
||||||
|
|
||||||
|
print("Seeding analytics events...")
|
||||||
|
for event in SAMPLE_EVENTS:
|
||||||
|
await conn.execute(
|
||||||
|
_INSERT_EVENT,
|
||||||
|
{
|
||||||
|
"thread_id": event["thread_id"],
|
||||||
|
"event_type": event["event_type"],
|
||||||
|
"agent_name": event.get("agent_name"),
|
||||||
|
"tool_name": event.get("tool_name"),
|
||||||
|
"tokens_used": event.get("tokens_used", 0),
|
||||||
|
"cost_usd": event.get("cost_usd", 0.0),
|
||||||
|
"success": event.get("success"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(f" Inserted event {event['event_type']} for {event['thread_id']}")
|
||||||
|
|
||||||
|
await conn.commit()
|
||||||
|
|
||||||
|
print("Done. Demo data seeded successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(seed())
|
||||||
238
backend/fixtures/sample_openapi.yaml
Normal file
238
backend/fixtures/sample_openapi.yaml
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
openapi: "3.0.3"
|
||||||
|
info:
|
||||||
|
title: "E-Commerce API"
|
||||||
|
description: "Sample e-commerce API for Smart Support demo."
|
||||||
|
version: "1.0.0"
|
||||||
|
|
||||||
|
servers:
|
||||||
|
- url: "https://api.example-shop.com/v1"
|
||||||
|
description: "Production server"
|
||||||
|
|
||||||
|
paths:
|
||||||
|
/orders/{order_id}:
|
||||||
|
get:
|
||||||
|
operationId: getOrder
|
||||||
|
summary: "Get order details"
|
||||||
|
description: "Retrieve the full details of a specific order."
|
||||||
|
parameters:
|
||||||
|
- name: order_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: "Order details"
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/Order"
|
||||||
|
|
||||||
|
/orders/{order_id}/cancel:
|
||||||
|
post:
|
||||||
|
operationId: cancelOrder
|
||||||
|
summary: "Cancel an order"
|
||||||
|
description: "Cancel an order that has not yet been shipped."
|
||||||
|
parameters:
|
||||||
|
- name: order_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
reason:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: "Order cancelled"
|
||||||
|
"400":
|
||||||
|
description: "Order cannot be cancelled (already shipped)"
|
||||||
|
|
||||||
|
/orders/{order_id}/refund:
|
||||||
|
post:
|
||||||
|
operationId: refundOrder
|
||||||
|
summary: "Request a refund"
|
||||||
|
description: "Submit a refund request for a completed order."
|
||||||
|
parameters:
|
||||||
|
- name: order_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
amount:
|
||||||
|
type: number
|
||||||
|
description: "Refund amount in USD. Leave null for full refund."
|
||||||
|
reason:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: "Refund submitted"
|
||||||
|
"400":
|
||||||
|
description: "Invalid refund request"
|
||||||
|
|
||||||
|
/customers/{customer_id}:
|
||||||
|
get:
|
||||||
|
operationId: getCustomer
|
||||||
|
summary: "Get customer profile"
|
||||||
|
description: "Retrieve customer profile and account information."
|
||||||
|
parameters:
|
||||||
|
- name: customer_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: "Customer profile"
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/Customer"
|
||||||
|
|
||||||
|
/customers/{customer_id}/orders:
|
||||||
|
get:
|
||||||
|
operationId: listCustomerOrders
|
||||||
|
summary: "List customer orders"
|
||||||
|
description: "Get a paginated list of orders for a customer."
|
||||||
|
parameters:
|
||||||
|
- name: customer_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: page
|
||||||
|
in: query
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
default: 1
|
||||||
|
- name: per_page
|
||||||
|
in: query
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
default: 20
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: "List of orders"
|
||||||
|
|
||||||
|
/products/{product_id}:
|
||||||
|
get:
|
||||||
|
operationId: getProduct
|
||||||
|
summary: "Get product details"
|
||||||
|
description: "Retrieve product information including inventory status."
|
||||||
|
parameters:
|
||||||
|
- name: product_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: "Product details"
|
||||||
|
|
||||||
|
/support/tickets:
|
||||||
|
post:
|
||||||
|
operationId: createSupportTicket
|
||||||
|
summary: "Create support ticket"
|
||||||
|
description: "Open a new support ticket for a customer issue."
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/CreateTicketRequest"
|
||||||
|
responses:
|
||||||
|
"201":
|
||||||
|
description: "Ticket created"
|
||||||
|
|
||||||
|
/support/tickets/{ticket_id}:
|
||||||
|
get:
|
||||||
|
operationId: getSupportTicket
|
||||||
|
summary: "Get support ticket"
|
||||||
|
description: "Retrieve a support ticket and its conversation history."
|
||||||
|
parameters:
|
||||||
|
- name: ticket_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: "Ticket details"
|
||||||
|
|
||||||
|
components:
|
||||||
|
schemas:
|
||||||
|
Order:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
order_id:
|
||||||
|
type: string
|
||||||
|
customer_id:
|
||||||
|
type: string
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
enum: [pending, processing, shipped, delivered, cancelled, refunded]
|
||||||
|
items:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/components/schemas/OrderItem"
|
||||||
|
total_usd:
|
||||||
|
type: number
|
||||||
|
created_at:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
|
||||||
|
OrderItem:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
product_id:
|
||||||
|
type: string
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
quantity:
|
||||||
|
type: integer
|
||||||
|
unit_price_usd:
|
||||||
|
type: number
|
||||||
|
|
||||||
|
Customer:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
customer_id:
|
||||||
|
type: string
|
||||||
|
email:
|
||||||
|
type: string
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
tier:
|
||||||
|
type: string
|
||||||
|
enum: [standard, premium, vip]
|
||||||
|
created_at:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
|
||||||
|
CreateTicketRequest:
|
||||||
|
type: object
|
||||||
|
required: [customer_id, subject, description]
|
||||||
|
properties:
|
||||||
|
customer_id:
|
||||||
|
type: string
|
||||||
|
subject:
|
||||||
|
type: string
|
||||||
|
description:
|
||||||
|
type: string
|
||||||
|
priority:
|
||||||
|
type: string
|
||||||
|
enum: [low, medium, high, urgent]
|
||||||
|
default: medium
|
||||||
@@ -6,18 +6,23 @@ requires-python = ">=3.11"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi>=0.115,<1.0",
|
"fastapi>=0.115,<1.0",
|
||||||
"uvicorn[standard]>=0.34,<1.0",
|
"uvicorn[standard]>=0.34,<1.0",
|
||||||
"langgraph>=0.4,<1.0",
|
"langgraph>=1.0,<2.0",
|
||||||
"langgraph-supervisor>=0.0.12,<1.0",
|
"langgraph-supervisor>=0.0.30,<1.0",
|
||||||
"langgraph-checkpoint-postgres>=3.0,<4.0",
|
"langgraph-checkpoint-postgres>=3.0,<4.0",
|
||||||
"langchain-core>=0.3,<1.0",
|
"langchain>=1.0,<2.0",
|
||||||
"langchain-anthropic>=0.3,<2.0",
|
"langchain-core>=1.0,<2.0",
|
||||||
"langchain-openai>=0.3,<1.0",
|
"langchain-anthropic>=1.0,<2.0",
|
||||||
|
"langchain-openai>=1.0,<2.0",
|
||||||
"langchain-google-genai>=2.1,<3.0",
|
"langchain-google-genai>=2.1,<3.0",
|
||||||
"psycopg[binary,pool]>=3.2,<4.0",
|
"psycopg[binary,pool]>=3.2,<4.0",
|
||||||
"pydantic>=2.10,<3.0",
|
"pydantic>=2.10,<3.0",
|
||||||
"pydantic-settings>=2.7,<3.0",
|
"pydantic-settings>=2.7,<3.0",
|
||||||
"pyyaml>=6.0,<7.0",
|
"pyyaml>=6.0,<7.0",
|
||||||
"python-dotenv>=1.0,<2.0",
|
"python-dotenv>=1.0,<2.0",
|
||||||
|
"httpx>=0.28,<1.0",
|
||||||
|
"openapi-spec-validator>=0.7,<1.0",
|
||||||
|
"alembic>=1.13,<2.0",
|
||||||
|
"structlog>=24.0,<26.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -27,6 +32,7 @@ dev = [
|
|||||||
"pytest-cov>=6.0,<7.0",
|
"pytest-cov>=6.0,<7.0",
|
||||||
"httpx>=0.28,<1.0",
|
"httpx>=0.28,<1.0",
|
||||||
"ruff>=0.9,<1.0",
|
"ruff>=0.9,<1.0",
|
||||||
|
"pytest-httpx>=0.35,<1.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
42
backend/templates/e-commerce.yaml
Normal file
42
backend/templates/e-commerce.yaml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
agents:
|
||||||
|
- name: order_lookup
|
||||||
|
description: "Looks up order status and tracking information. Use for queries about order status, shipping, and delivery."
|
||||||
|
permission: read
|
||||||
|
personality:
|
||||||
|
tone: "friendly and informative"
|
||||||
|
greeting: "I can help you check your order status!"
|
||||||
|
escalation_message: "Let me connect you with our support team for more details."
|
||||||
|
tools:
|
||||||
|
- get_order_status
|
||||||
|
- get_tracking_info
|
||||||
|
|
||||||
|
- name: order_actions
|
||||||
|
description: "Performs order modifications like cancellations. Use when the customer wants to cancel, modify, or change an order."
|
||||||
|
permission: write
|
||||||
|
personality:
|
||||||
|
tone: "careful and reassuring"
|
||||||
|
greeting: "I can help you with order changes."
|
||||||
|
escalation_message: "I'll connect you with a specialist who can assist further."
|
||||||
|
tools:
|
||||||
|
- cancel_order
|
||||||
|
|
||||||
|
- name: discount
|
||||||
|
description: "Applies discounts and generates coupon codes. Use when the customer asks for discounts, promotions, or coupon codes."
|
||||||
|
permission: write
|
||||||
|
personality:
|
||||||
|
tone: "generous and accommodating"
|
||||||
|
greeting: "I can help you with discounts and coupons!"
|
||||||
|
escalation_message: "Let me connect you with our promotions team."
|
||||||
|
tools:
|
||||||
|
- apply_discount
|
||||||
|
- generate_coupon
|
||||||
|
|
||||||
|
- name: fallback
|
||||||
|
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
||||||
|
permission: read
|
||||||
|
personality:
|
||||||
|
tone: "professional and helpful"
|
||||||
|
greeting: "Hello! How can I help you today?"
|
||||||
|
escalation_message: "Let me connect you with a human agent who can better assist you."
|
||||||
|
tools:
|
||||||
|
- fallback_respond
|
||||||
31
backend/templates/fintech.yaml
Normal file
31
backend/templates/fintech.yaml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
agents:
|
||||||
|
- name: transaction_lookup
|
||||||
|
description: "Looks up transaction history, balances, and payment details. Use for queries about transactions and account activity."
|
||||||
|
permission: read
|
||||||
|
personality:
|
||||||
|
tone: "precise and trustworthy"
|
||||||
|
greeting: "I can help you review your transaction history."
|
||||||
|
escalation_message: "Let me connect you with our financial support team."
|
||||||
|
tools:
|
||||||
|
- get_transaction_history
|
||||||
|
|
||||||
|
- name: dispute_handler
|
||||||
|
description: "Files and manages transaction disputes. Use when the customer wants to dispute a charge or check dispute status."
|
||||||
|
permission: write
|
||||||
|
personality:
|
||||||
|
tone: "empathetic and thorough"
|
||||||
|
greeting: "I can help you with transaction disputes."
|
||||||
|
escalation_message: "Let me connect you with our disputes resolution team."
|
||||||
|
tools:
|
||||||
|
- file_dispute
|
||||||
|
- check_dispute_status
|
||||||
|
|
||||||
|
- name: fallback
|
||||||
|
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
||||||
|
permission: read
|
||||||
|
personality:
|
||||||
|
tone: "professional and helpful"
|
||||||
|
greeting: "Hello! How can I help you today?"
|
||||||
|
escalation_message: "Let me connect you with a human agent who can better assist you."
|
||||||
|
tools:
|
||||||
|
- fallback_respond
|
||||||
31
backend/templates/saas.yaml
Normal file
31
backend/templates/saas.yaml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
agents:
|
||||||
|
- name: account_lookup
|
||||||
|
description: "Looks up account status, subscription details, and billing history. Use for queries about account information."
|
||||||
|
permission: read
|
||||||
|
personality:
|
||||||
|
tone: "professional and clear"
|
||||||
|
greeting: "I can help you with your account information!"
|
||||||
|
escalation_message: "Let me connect you with our account support team."
|
||||||
|
tools:
|
||||||
|
- get_account_status
|
||||||
|
|
||||||
|
- name: subscription_management
|
||||||
|
description: "Manages subscription changes like plan upgrades, downgrades, and cancellations. Use when the customer wants to change their subscription."
|
||||||
|
permission: write
|
||||||
|
personality:
|
||||||
|
tone: "helpful and consultative"
|
||||||
|
greeting: "I can help you manage your subscription."
|
||||||
|
escalation_message: "Let me connect you with our billing specialist."
|
||||||
|
tools:
|
||||||
|
- change_plan
|
||||||
|
- cancel_subscription
|
||||||
|
|
||||||
|
- name: fallback
|
||||||
|
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
||||||
|
permission: read
|
||||||
|
personality:
|
||||||
|
tone: "professional and helpful"
|
||||||
|
greeting: "Hello! How can I help you today?"
|
||||||
|
escalation_message: "Let me connect you with a human agent who can better assist you."
|
||||||
|
tools:
|
||||||
|
- fallback_respond
|
||||||
@@ -15,6 +15,16 @@ if TYPE_CHECKING:
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_rate_limit_state() -> None:
|
||||||
|
"""Clear module-level rate limit state between tests to prevent leakage."""
|
||||||
|
import app.ws_handler as ws_handler
|
||||||
|
|
||||||
|
ws_handler._thread_timestamps.clear()
|
||||||
|
yield
|
||||||
|
ws_handler._thread_timestamps.clear()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_settings() -> Settings:
|
def test_settings() -> Settings:
|
||||||
return Settings(
|
return Settings(
|
||||||
|
|||||||
230
backend/tests/e2e/conftest.py
Normal file
230
backend/tests/e2e/conftest.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""E2E test fixtures -- full FastAPI app with mocked LLM and database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.analytics.api import router as analytics_router
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.openapi.review_api import _job_store, router as openapi_router
|
||||||
|
from app.replay.api import router as replay_router
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Graph helpers -- simulate LangGraph streaming behaviour
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncIterHelper:
|
||||||
|
"""Make a list behave as an async iterator."""
|
||||||
|
|
||||||
|
def __init__(self, items: list) -> None:
|
||||||
|
self._items = list(items)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if not self._items:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return self._items.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def make_chunk(content: str, node: str = "order_lookup") -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = content
|
||||||
|
c.tool_calls = []
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def make_tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = ""
|
||||||
|
c.tool_calls = [{"name": name, "args": args}]
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def make_state(*, interrupt: bool = False, data: dict | None = None) -> Any:
|
||||||
|
s = MagicMock()
|
||||||
|
if interrupt:
|
||||||
|
obj = MagicMock()
|
||||||
|
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
||||||
|
t = MagicMock()
|
||||||
|
t.interrupts = (obj,)
|
||||||
|
s.tasks = (t,)
|
||||||
|
else:
|
||||||
|
s.tasks = ()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def make_graph(
|
||||||
|
chunks: list | None = None,
|
||||||
|
state: Any = None,
|
||||||
|
resume_chunks: list | None = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Build a mock LangGraph CompiledStateGraph."""
|
||||||
|
g = MagicMock()
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
state = make_state()
|
||||||
|
|
||||||
|
streams = [chunks or [], resume_chunks or []]
|
||||||
|
idx = {"n": 0}
|
||||||
|
|
||||||
|
def astream_side_effect(*a, **kw):
|
||||||
|
i = min(idx["n"], len(streams) - 1)
|
||||||
|
idx["n"] += 1
|
||||||
|
return AsyncIterHelper(list(streams[i]))
|
||||||
|
|
||||||
|
g.astream = MagicMock(side_effect=astream_side_effect)
|
||||||
|
g.aget_state = AsyncMock(return_value=state)
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||||
|
"""Build a GraphContext wrapping a mock graph."""
|
||||||
|
g = graph or make_graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fake database pool
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCursor:
|
||||||
|
"""Minimal async cursor returning pre-configured rows."""
|
||||||
|
|
||||||
|
def __init__(self, rows: list[dict]) -> None:
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
async def fetchall(self) -> list[dict]:
|
||||||
|
return self._rows
|
||||||
|
|
||||||
|
async def fetchone(self) -> tuple | dict | None:
|
||||||
|
return self._rows[0] if self._rows else None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeConnection:
|
||||||
|
"""Fake async connection that returns a FakeCursor."""
|
||||||
|
|
||||||
|
def __init__(self, rows: list[dict]) -> None:
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
async def execute(self, query: str, params: dict | None = None) -> FakeCursor:
|
||||||
|
return FakeCursor(self._rows)
|
||||||
|
|
||||||
|
|
||||||
|
class FakePool:
|
||||||
|
"""Minimal pool that yields a fake connection."""
|
||||||
|
|
||||||
|
def __init__(self, rows: list[dict] | None = None) -> None:
|
||||||
|
self._rows = rows or []
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connection(self):
|
||||||
|
yield FakeConnection(self._rows)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# App factory
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def create_e2e_app(
|
||||||
|
graph: MagicMock | None = None,
|
||||||
|
pool: FakePool | None = None,
|
||||||
|
session_ttl: int = 3600,
|
||||||
|
interrupt_ttl: int = 1800,
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Create a FastAPI app wired with mocked dependencies for E2E testing."""
|
||||||
|
g = graph or make_graph()
|
||||||
|
graph_ctx = make_graph_ctx(g)
|
||||||
|
p = pool or FakePool()
|
||||||
|
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||||
|
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||||
|
|
||||||
|
app = FastAPI(title="Smart Support E2E Test")
|
||||||
|
app.include_router(openapi_router)
|
||||||
|
app.include_router(replay_router)
|
||||||
|
app.include_router(analytics_router)
|
||||||
|
|
||||||
|
app.state.graph_ctx = graph_ctx
|
||||||
|
app.state.session_manager = sm
|
||||||
|
app.state.interrupt_manager = im
|
||||||
|
app.state.pool = p
|
||||||
|
app.state.settings = MagicMock(llm_model="test-model")
|
||||||
|
app.state.analytics_recorder = AsyncMock()
|
||||||
|
app.state.conversation_tracker = AsyncMock()
|
||||||
|
|
||||||
|
@app.get("/api/v1/health")
|
||||||
|
def health_check() -> dict:
|
||||||
|
return {"status": "ok", "version": "test"}
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(ws: WebSocket) -> None:
|
||||||
|
await ws.accept()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw_data = await ws.receive_text()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=app.state.graph_ctx,
|
||||||
|
session_manager=app.state.session_manager,
|
||||||
|
callback_handler=TokenUsageCallbackHandler(model_name="test-model"),
|
||||||
|
interrupt_manager=app.state.interrupt_manager,
|
||||||
|
analytics_recorder=app.state.analytics_recorder,
|
||||||
|
conversation_tracker=app.state.conversation_tracker,
|
||||||
|
pool=app.state.pool,
|
||||||
|
)
|
||||||
|
await dispatch_message(ws, ws_ctx, raw_data)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def e2e_graph():
|
||||||
|
"""Default graph fixture -- returns tokens and message_complete."""
|
||||||
|
return make_graph(
|
||||||
|
chunks=[make_chunk("Order 1042 is "), make_chunk("shipped.")]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def e2e_app(e2e_graph):
|
||||||
|
"""Default E2E app fixture."""
|
||||||
|
return create_e2e_app(graph=e2e_graph)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def e2e_client(e2e_app):
|
||||||
|
"""Async HTTP client for E2E tests."""
|
||||||
|
transport = ASGITransport(app=e2e_app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_openapi_job_store():
|
||||||
|
"""Clear the in-memory job store between tests."""
|
||||||
|
_job_store.clear()
|
||||||
|
yield
|
||||||
|
_job_store.clear()
|
||||||
384
backend/tests/e2e/test_chat_flows.py
Normal file
384
backend/tests/e2e/test_chat_flows.py
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
"""E2E tests for critical chat user flows (flows 1-4).
|
||||||
|
|
||||||
|
Flow 1: Happy path -- query order, get answer
|
||||||
|
Flow 2: Approval flow -- write operation, interrupt, approve, execute
|
||||||
|
Flow 3: Rejection flow -- write operation, interrupt, reject, no execution
|
||||||
|
Flow 4: Multi-turn context -- sequential messages in same session
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from tests.e2e.conftest import (
|
||||||
|
create_e2e_app,
|
||||||
|
make_chunk,
|
||||||
|
make_graph,
|
||||||
|
make_state,
|
||||||
|
make_tool_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.e2e
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow1HappyPath:
|
||||||
|
"""Flow 1: query order -> get answer with streaming tokens."""
|
||||||
|
|
||||||
|
def test_websocket_happy_path_order_query(self) -> None:
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[
|
||||||
|
make_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||||
|
make_chunk("Order 1042 has been shipped and is on its way."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-happy-1",
|
||||||
|
"content": "What is the status of order 1042?",
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
while True:
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] in ("message_complete", "error"):
|
||||||
|
break
|
||||||
|
|
||||||
|
tool_calls = [m for m in messages if m["type"] == "tool_call"]
|
||||||
|
assert len(tool_calls) == 1
|
||||||
|
assert tool_calls[0]["tool"] == "get_order_status"
|
||||||
|
assert tool_calls[0]["args"] == {"order_id": "1042"}
|
||||||
|
|
||||||
|
tokens = [m for m in messages if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 1
|
||||||
|
assert "shipped" in tokens[0]["content"]
|
||||||
|
|
||||||
|
completes = [m for m in messages if m["type"] == "message_complete"]
|
||||||
|
assert len(completes) == 1
|
||||||
|
assert completes[0]["thread_id"] == "e2e-happy-1"
|
||||||
|
|
||||||
|
def test_websocket_multiple_token_stream(self) -> None:
|
||||||
|
"""Verify streaming returns multiple token chunks."""
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[
|
||||||
|
make_chunk("Your order "),
|
||||||
|
make_chunk("1042 "),
|
||||||
|
make_chunk("was delivered "),
|
||||||
|
make_chunk("yesterday."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-stream-1",
|
||||||
|
"content": "Where is my order?",
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = _collect_until_complete(ws)
|
||||||
|
|
||||||
|
tokens = [m for m in messages if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 4
|
||||||
|
full_text = "".join(t["content"] for t in tokens)
|
||||||
|
assert "1042" in full_text
|
||||||
|
assert "delivered" in full_text
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow2ApprovalFlow:
|
||||||
|
"""Flow 2: write operation -> interrupt -> approve -> execute."""
|
||||||
|
|
||||||
|
def test_interrupt_approve_executes_action(self) -> None:
|
||||||
|
interrupt_state = make_state(
|
||||||
|
interrupt=True,
|
||||||
|
data={"action": "cancel_order", "order_id": "1042"},
|
||||||
|
)
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[],
|
||||||
|
state=interrupt_state,
|
||||||
|
resume_chunks=[
|
||||||
|
make_chunk("Order 1042 has been cancelled successfully.", "order_actions"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# Step 1: Send cancel request
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-approve-1",
|
||||||
|
"content": "Cancel order 1042",
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = _collect_until_type(ws, "interrupt")
|
||||||
|
|
||||||
|
interrupts = [m for m in messages if m["type"] == "interrupt"]
|
||||||
|
assert len(interrupts) == 1
|
||||||
|
assert interrupts[0]["action"] == "cancel_order"
|
||||||
|
assert interrupts[0]["thread_id"] == "e2e-approve-1"
|
||||||
|
|
||||||
|
# Step 2: Approve the interrupt
|
||||||
|
ws.send_json({
|
||||||
|
"type": "interrupt_response",
|
||||||
|
"thread_id": "e2e-approve-1",
|
||||||
|
"approved": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
resume_messages = _collect_until_complete(ws)
|
||||||
|
|
||||||
|
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 1
|
||||||
|
assert "cancelled" in tokens[0]["content"]
|
||||||
|
assert tokens[0]["agent"] == "order_actions"
|
||||||
|
|
||||||
|
completes = [m for m in resume_messages if m["type"] == "message_complete"]
|
||||||
|
assert len(completes) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow3RejectionFlow:
|
||||||
|
"""Flow 3: write operation -> interrupt -> reject -> no execution."""
|
||||||
|
|
||||||
|
def test_interrupt_reject_does_not_execute(self) -> None:
|
||||||
|
interrupt_state = make_state(
|
||||||
|
interrupt=True,
|
||||||
|
data={"action": "cancel_order", "order_id": "1042"},
|
||||||
|
)
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[],
|
||||||
|
state=interrupt_state,
|
||||||
|
resume_chunks=[
|
||||||
|
make_chunk("Understood. Order 1042 will remain active.", "order_actions"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# Step 1: Trigger interrupt
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-reject-1",
|
||||||
|
"content": "Cancel order 1042",
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = _collect_until_type(ws, "interrupt")
|
||||||
|
assert any(m["type"] == "interrupt" for m in messages)
|
||||||
|
|
||||||
|
# Step 2: Reject
|
||||||
|
ws.send_json({
|
||||||
|
"type": "interrupt_response",
|
||||||
|
"thread_id": "e2e-reject-1",
|
||||||
|
"approved": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
resume_messages = _collect_until_complete(ws)
|
||||||
|
|
||||||
|
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 1
|
||||||
|
assert "remain active" in tokens[0]["content"]
|
||||||
|
|
||||||
|
# Verify graph.astream was called with resume=False
|
||||||
|
resume_call = graph.astream.call_args_list[-1]
|
||||||
|
command = resume_call[0][0]
|
||||||
|
assert command.resume is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow4MultiTurnContext:
|
||||||
|
"""Flow 4: multi-turn conversation in the same session."""
|
||||||
|
|
||||||
|
def test_multi_turn_messages_share_session(self) -> None:
|
||||||
|
"""Multiple messages in the same thread_id maintain session context."""
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[make_chunk("Order 1042 status: shipped.")],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# Turn 1: Query order
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-multi-1",
|
||||||
|
"content": "What is the status of order 1042?",
|
||||||
|
})
|
||||||
|
turn1 = _collect_until_complete(ws)
|
||||||
|
assert any(m["type"] == "message_complete" for m in turn1)
|
||||||
|
|
||||||
|
# Turn 2: Follow-up in same thread
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-multi-1",
|
||||||
|
"content": "When will it arrive?",
|
||||||
|
})
|
||||||
|
turn2 = _collect_until_complete(ws)
|
||||||
|
assert any(m["type"] == "message_complete" for m in turn2)
|
||||||
|
|
||||||
|
# Turn 3: Another follow-up
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-multi-1",
|
||||||
|
"content": "Can you track it?",
|
||||||
|
})
|
||||||
|
turn3 = _collect_until_complete(ws)
|
||||||
|
assert any(m["type"] == "message_complete" for m in turn3)
|
||||||
|
|
||||||
|
# Verify all turns used the same thread_id in graph calls
|
||||||
|
for call in graph.astream.call_args_list:
|
||||||
|
config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {})
|
||||||
|
assert config["configurable"]["thread_id"] == "e2e-multi-1"
|
||||||
|
|
||||||
|
def test_separate_threads_are_independent(self) -> None:
|
||||||
|
"""Different thread_ids have independent sessions."""
|
||||||
|
graph = make_graph(
|
||||||
|
chunks=[make_chunk("Response.")],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# Thread A
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-thread-a",
|
||||||
|
"content": "Hello from thread A",
|
||||||
|
})
|
||||||
|
_collect_until_complete(ws)
|
||||||
|
|
||||||
|
# Thread B
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-thread-b",
|
||||||
|
"content": "Hello from thread B",
|
||||||
|
})
|
||||||
|
_collect_until_complete(ws)
|
||||||
|
|
||||||
|
# Both threads should exist as separate sessions
|
||||||
|
sm = app.state.session_manager
|
||||||
|
assert sm.get_state("e2e-thread-a") is not None
|
||||||
|
assert sm.get_state("e2e-thread-b") is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatEdgeCases:
|
||||||
|
"""Edge cases and error handling for the chat WebSocket."""
|
||||||
|
|
||||||
|
def test_invalid_json_returns_error(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_text("not valid json")
|
||||||
|
msg = ws.receive_json()
|
||||||
|
assert msg["type"] == "error"
|
||||||
|
assert "Invalid JSON" in msg["message"]
|
||||||
|
|
||||||
|
def test_missing_thread_id_returns_error(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({"type": "message", "content": "hello"})
|
||||||
|
msg = ws.receive_json()
|
||||||
|
assert msg["type"] == "error"
|
||||||
|
assert "thread_id" in msg["message"]
|
||||||
|
|
||||||
|
def test_empty_content_returns_error(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-err-1",
|
||||||
|
"content": "",
|
||||||
|
})
|
||||||
|
msg = ws.receive_json()
|
||||||
|
assert msg["type"] == "error"
|
||||||
|
|
||||||
|
def test_expired_session_returns_error(self) -> None:
|
||||||
|
graph = make_graph(chunks=[make_chunk("Response.")])
|
||||||
|
app = create_e2e_app(graph=graph, session_ttl=0)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
# First message creates the session (TTL=0)
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-expired-1",
|
||||||
|
"content": "hello",
|
||||||
|
})
|
||||||
|
_collect_until_complete_or_error(ws)
|
||||||
|
|
||||||
|
# Second message finds the session expired (TTL=0)
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-expired-1",
|
||||||
|
"content": "hello again",
|
||||||
|
})
|
||||||
|
messages = _collect_until_complete_or_error(ws)
|
||||||
|
errors = [m for m in messages if m["type"] == "error"]
|
||||||
|
assert len(errors) >= 1
|
||||||
|
assert "expired" in errors[0]["message"].lower()
|
||||||
|
|
||||||
|
def test_oversized_message_returns_error(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_text("x" * 40_000)
|
||||||
|
msg = ws.receive_json()
|
||||||
|
assert msg["type"] == "error"
|
||||||
|
assert "too large" in msg["message"].lower()
|
||||||
|
|
||||||
|
def test_health_endpoint(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]:
|
||||||
|
"""Receive WebSocket messages until message_complete or error."""
|
||||||
|
messages = []
|
||||||
|
for _ in range(max_messages):
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] in ("message_complete", "error"):
|
||||||
|
break
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]:
|
||||||
|
"""Receive until a specific message type is received."""
|
||||||
|
messages = []
|
||||||
|
for _ in range(max_messages):
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] == msg_type:
|
||||||
|
break
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]:
|
||||||
|
"""Receive until message_complete or error."""
|
||||||
|
messages = []
|
||||||
|
for _ in range(max_messages):
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] in ("message_complete", "error"):
|
||||||
|
break
|
||||||
|
return messages
|
||||||
201
backend/tests/e2e/test_openapi_import.py
Normal file
201
backend/tests/e2e/test_openapi_import.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""E2E tests for OpenAPI import flow (flow 5).
|
||||||
|
|
||||||
|
Flow 5: paste OpenAPI spec URL -> import job -> classify endpoints ->
|
||||||
|
review classifications -> approve -> tool generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||||
|
from app.openapi.review_api import _job_store
|
||||||
|
from tests.e2e.conftest import create_e2e_app
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.e2e
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_endpoint(
|
||||||
|
path: str = "/orders/{id}",
|
||||||
|
method: str = "GET",
|
||||||
|
operation_id: str = "getOrder",
|
||||||
|
summary: str = "Get order details",
|
||||||
|
) -> EndpointInfo:
|
||||||
|
return EndpointInfo(
|
||||||
|
path=path,
|
||||||
|
method=method,
|
||||||
|
operation_id=operation_id,
|
||||||
|
summary=summary,
|
||||||
|
description="",
|
||||||
|
parameters=(),
|
||||||
|
request_body_schema=None,
|
||||||
|
response_schema=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_classification(
|
||||||
|
endpoint: EndpointInfo | None = None,
|
||||||
|
access_type: str = "read",
|
||||||
|
needs_interrupt: bool = False,
|
||||||
|
agent_group: str = "order_lookup",
|
||||||
|
) -> ClassificationResult:
|
||||||
|
return ClassificationResult(
|
||||||
|
endpoint=endpoint or _fake_endpoint(),
|
||||||
|
access_type=access_type,
|
||||||
|
customer_params=["order_id"],
|
||||||
|
agent_group=agent_group,
|
||||||
|
confidence=0.95,
|
||||||
|
needs_interrupt=needs_interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow5OpenAPIImport:
|
||||||
|
"""Flow 5: full OpenAPI import lifecycle."""
|
||||||
|
|
||||||
|
def test_import_job_lifecycle(self) -> None:
|
||||||
|
"""Start import -> check status -> review classifications -> approve."""
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Step 1: Start import job
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/openapi/import",
|
||||||
|
json={"url": "https://api.example.com/openapi.json"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 202
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "pending"
|
||||||
|
job_id = body["job_id"]
|
||||||
|
|
||||||
|
# Step 2: Check job status (still pending since background task hasn't run)
|
||||||
|
resp = client.get(f"/api/v1/openapi/jobs/{job_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["job_id"] == job_id
|
||||||
|
|
||||||
|
def test_import_job_with_classifications(self) -> None:
|
||||||
|
"""Simulate completed import and review classified endpoints."""
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
# Seed a completed job directly
|
||||||
|
ep_read = _fake_endpoint("/orders/{id}", "GET", "getOrder", "Get order")
|
||||||
|
ep_write = _fake_endpoint("/orders/{id}/cancel", "POST", "cancelOrder", "Cancel order")
|
||||||
|
|
||||||
|
clf_read = _fake_classification(ep_read, "read", False, "order_lookup")
|
||||||
|
clf_write = _fake_classification(ep_write, "write", True, "order_actions")
|
||||||
|
|
||||||
|
job_id = "test-job-001"
|
||||||
|
_job_store[job_id] = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://api.example.com/openapi.json",
|
||||||
|
"total_endpoints": 2,
|
||||||
|
"classified_count": 2,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [clf_read, clf_write],
|
||||||
|
}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Step 1: Get classifications
|
||||||
|
resp = client.get(f"/api/v1/openapi/jobs/{job_id}/classifications")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
classifications = resp.json()
|
||||||
|
assert len(classifications) == 2
|
||||||
|
|
||||||
|
# Verify read endpoint
|
||||||
|
read_clf = classifications[0]
|
||||||
|
assert read_clf["access_type"] == "read"
|
||||||
|
assert read_clf["needs_interrupt"] is False
|
||||||
|
assert read_clf["endpoint"]["path"] == "/orders/{id}"
|
||||||
|
|
||||||
|
# Verify write endpoint
|
||||||
|
write_clf = classifications[1]
|
||||||
|
assert write_clf["access_type"] == "write"
|
||||||
|
assert write_clf["needs_interrupt"] is True
|
||||||
|
assert write_clf["endpoint"]["path"] == "/orders/{id}/cancel"
|
||||||
|
|
||||||
|
# Step 2: Update a classification
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
|
||||||
|
json={
|
||||||
|
"access_type": "write",
|
||||||
|
"needs_interrupt": True,
|
||||||
|
"agent_group": "order_actions",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
updated = resp.json()
|
||||||
|
assert updated["access_type"] == "write"
|
||||||
|
assert updated["needs_interrupt"] is True
|
||||||
|
assert updated["agent_group"] == "order_actions"
|
||||||
|
|
||||||
|
# Step 3: Approve the job
|
||||||
|
resp = client.post(f"/api/v1/openapi/jobs/{job_id}/approve")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "approved"
|
||||||
|
|
||||||
|
def test_import_nonexistent_job_returns_404(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/openapi/jobs/nonexistent")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_import_invalid_url_returns_422(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.post("/api/v1/openapi/import", json={"url": "not-a-url"})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_classification_index_out_of_range(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
job_id = "test-job-range"
|
||||||
|
_job_store[job_id] = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://example.com/spec.json",
|
||||||
|
"total_endpoints": 1,
|
||||||
|
"classified_count": 1,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [_fake_classification()],
|
||||||
|
}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_id}/classifications/99",
|
||||||
|
json={
|
||||||
|
"access_type": "read",
|
||||||
|
"needs_interrupt": False,
|
||||||
|
"agent_group": "order_lookup",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_update_classification_invalid_agent_group(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
job_id = "test-job-invalid"
|
||||||
|
_job_store[job_id] = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://example.com/spec.json",
|
||||||
|
"total_endpoints": 1,
|
||||||
|
"classified_count": 1,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [_fake_classification()],
|
||||||
|
}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
|
||||||
|
json={
|
||||||
|
"access_type": "read",
|
||||||
|
"needs_interrupt": False,
|
||||||
|
"agent_group": "invalid group!", # spaces and special chars
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
230
backend/tests/e2e/test_replay_analytics.py
Normal file
230
backend/tests/e2e/test_replay_analytics.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""E2E tests for replay and analytics flows (flow 6).
|
||||||
|
|
||||||
|
Flow 6: list conversations -> select one -> step-by-step replay.
|
||||||
|
Also tests the analytics dashboard endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from tests.e2e.conftest import FakePool, create_e2e_app
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.e2e
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Custom pool that returns specific data per query
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayPool(FakePool):
|
||||||
|
"""Pool that returns different data depending on the SQL query."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
conversations: list[dict] | None = None,
|
||||||
|
checkpoints: list[dict] | None = None,
|
||||||
|
analytics_rows: list[dict] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._conversations = conversations or []
|
||||||
|
self._checkpoints = checkpoints or []
|
||||||
|
self._analytics = analytics_rows or []
|
||||||
|
|
||||||
|
class _Conn:
|
||||||
|
def __init__(self, convos, checkpoints, analytics):
|
||||||
|
self._convos = convos
|
||||||
|
self._checkpoints = checkpoints
|
||||||
|
self._analytics = analytics
|
||||||
|
|
||||||
|
async def execute(self, query: str, params=None):
|
||||||
|
from tests.e2e.conftest import FakeCursor
|
||||||
|
|
||||||
|
if "COUNT" in query and "conversations" in query:
|
||||||
|
return FakeCursor([(len(self._convos),)])
|
||||||
|
if "conversations" in query and "SELECT" in query:
|
||||||
|
# Respect LIMIT/OFFSET from params if provided
|
||||||
|
rows = self._convos
|
||||||
|
if params:
|
||||||
|
offset = params.get("offset", 0)
|
||||||
|
limit = params.get("limit", len(rows))
|
||||||
|
rows = rows[offset : offset + limit]
|
||||||
|
return FakeCursor(rows)
|
||||||
|
if "checkpoints" in query:
|
||||||
|
return FakeCursor(self._checkpoints)
|
||||||
|
# Analytics queries
|
||||||
|
return FakeCursor(self._analytics)
|
||||||
|
|
||||||
|
def connection(self):
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
conn = self._Conn(self._conversations, self._checkpoints, self._analytics)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _ctx():
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
return _ctx()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow6ReplayConversation:
|
||||||
|
"""Flow 6: list conversations -> select one -> step replay."""
|
||||||
|
|
||||||
|
def test_list_conversations(self) -> None:
|
||||||
|
now = datetime.now(tz=timezone.utc).isoformat()
|
||||||
|
conversations = [
|
||||||
|
{
|
||||||
|
"thread_id": "conv-001",
|
||||||
|
"created_at": now,
|
||||||
|
"last_activity": now,
|
||||||
|
"status": "active",
|
||||||
|
"total_tokens": 150,
|
||||||
|
"total_cost_usd": 0.003,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"thread_id": "conv-002",
|
||||||
|
"created_at": now,
|
||||||
|
"last_activity": now,
|
||||||
|
"status": "completed",
|
||||||
|
"total_tokens": 300,
|
||||||
|
"total_cost_usd": 0.006,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
pool = ReplayPool(conversations=conversations)
|
||||||
|
app = create_e2e_app(pool=pool)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/conversations")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
data = body["data"]
|
||||||
|
assert len(data["conversations"]) == 2
|
||||||
|
assert data["conversations"][0]["thread_id"] == "conv-001"
|
||||||
|
assert data["conversations"][1]["thread_id"] == "conv-002"
|
||||||
|
assert data["total"] == 2
|
||||||
|
|
||||||
|
def test_list_conversations_pagination(self) -> None:
|
||||||
|
conversations = [
|
||||||
|
{
|
||||||
|
"thread_id": f"conv-{i:03d}",
|
||||||
|
"created_at": "2026-04-01T00:00:00Z",
|
||||||
|
"last_activity": "2026-04-01T00:00:00Z",
|
||||||
|
"status": "active",
|
||||||
|
"total_tokens": 100,
|
||||||
|
"total_cost_usd": 0.001,
|
||||||
|
}
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
pool = ReplayPool(conversations=conversations)
|
||||||
|
app = create_e2e_app(pool=pool)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/conversations", params={"page": 1, "per_page": 2})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
data = body["data"]
|
||||||
|
assert data["total"] == 5
|
||||||
|
assert data["page"] == 1
|
||||||
|
assert data["per_page"] == 2
|
||||||
|
assert len(data["conversations"]) == 2
|
||||||
|
|
||||||
|
def test_replay_thread_not_found(self) -> None:
|
||||||
|
pool = ReplayPool(checkpoints=[])
|
||||||
|
app = create_e2e_app(pool=pool)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/replay/nonexistent-thread")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_replay_invalid_thread_id_format(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Thread ID with special chars fails regex validation
|
||||||
|
resp = client.get("/api/v1/replay/invalid%20thread%21%40")
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsDashboard:
|
||||||
|
"""Analytics endpoint tests."""
|
||||||
|
|
||||||
|
def test_analytics_invalid_range_format(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/analytics", params={"range": "invalid"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_analytics_range_too_large(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/analytics", params={"range": "999d"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_analytics_range_zero_rejected(self) -> None:
|
||||||
|
app = create_e2e_app()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/analytics", params={"range": "0d"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullUserJourney:
|
||||||
|
"""End-to-end journey: chat -> then check replay list shows the conversation."""
|
||||||
|
|
||||||
|
def test_chat_then_check_conversations_endpoint(self) -> None:
|
||||||
|
"""After chatting via WebSocket, the conversations endpoint is reachable."""
|
||||||
|
from tests.e2e.conftest import make_chunk, make_graph
|
||||||
|
|
||||||
|
graph = make_graph(chunks=[make_chunk("Your order is shipped.")])
|
||||||
|
now = datetime.now(tz=timezone.utc).isoformat()
|
||||||
|
pool = ReplayPool(
|
||||||
|
conversations=[
|
||||||
|
{
|
||||||
|
"thread_id": "e2e-journey-1",
|
||||||
|
"created_at": now,
|
||||||
|
"last_activity": now,
|
||||||
|
"status": "active",
|
||||||
|
"total_tokens": 50,
|
||||||
|
"total_cost_usd": 0.001,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
app = create_e2e_app(graph=graph, pool=pool)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Step 1: Chat via WebSocket
|
||||||
|
with client.websocket_connect("/ws") as ws:
|
||||||
|
ws.send_json({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "e2e-journey-1",
|
||||||
|
"content": "Where is my order?",
|
||||||
|
})
|
||||||
|
messages = []
|
||||||
|
for _ in range(20):
|
||||||
|
msg = ws.receive_json()
|
||||||
|
messages.append(msg)
|
||||||
|
if msg["type"] in ("message_complete", "error"):
|
||||||
|
break
|
||||||
|
assert any(m["type"] == "message_complete" for m in messages)
|
||||||
|
|
||||||
|
# Step 2: Check conversations endpoint
|
||||||
|
resp = client.get("/api/v1/conversations")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert any(
|
||||||
|
c["thread_id"] == "e2e-journey-1"
|
||||||
|
for c in body["data"]["conversations"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Health check still works
|
||||||
|
resp = client.get("/api/v1/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
183
backend/tests/integration/test_analytics_api.py
Normal file
183
backend/tests/integration/test_analytics_api.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Integration tests for the /api/v1/analytics endpoint.
|
||||||
|
|
||||||
|
Tests the full API layer (routing, parameter validation, serialization,
|
||||||
|
error handling) with a mocked database pool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import asdict
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
_SAMPLE_RESULT = AnalyticsResult(
|
||||||
|
range="7d",
|
||||||
|
total_conversations=42,
|
||||||
|
resolution_rate=0.85,
|
||||||
|
escalation_rate=0.05,
|
||||||
|
avg_turns_per_conversation=3.2,
|
||||||
|
avg_cost_per_conversation_usd=0.012,
|
||||||
|
agent_usage=(),
|
||||||
|
interrupt_stats=InterruptStats(total=10, approved=7, rejected=2, expired=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app():
|
||||||
|
"""Build a minimal FastAPI app with the analytics router and mocked deps."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.analytics.api import router as analytics_router
|
||||||
|
from app.api_utils import envelope
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(analytics_router)
|
||||||
|
|
||||||
|
@test_app.exception_handler(Exception)
|
||||||
|
async def _catch_all(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=envelope(None, success=False, error="Internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
@test_app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(RequestValidationError)
|
||||||
|
async def _validation_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# No admin_api_key set -> auth is skipped
|
||||||
|
test_app.state.settings = MagicMock(admin_api_key="")
|
||||||
|
test_app.state.pool = MagicMock()
|
||||||
|
|
||||||
|
return test_app
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsValidRange:
|
||||||
|
"""Test analytics endpoint with valid range parameters."""
|
||||||
|
|
||||||
|
async def test_valid_range_7d_returns_envelope(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=7d returns success envelope with data."""
|
||||||
|
test_app = _build_app()
|
||||||
|
with patch(
|
||||||
|
"app.analytics.api.get_analytics",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=_SAMPLE_RESULT,
|
||||||
|
):
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "7d"})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["error"] is None
|
||||||
|
assert body["data"]["total_conversations"] == 42
|
||||||
|
assert body["data"]["resolution_rate"] == 0.85
|
||||||
|
|
||||||
|
async def test_default_range_returns_success(self) -> None:
|
||||||
|
"""GET /api/v1/analytics with no range param defaults to 7d."""
|
||||||
|
test_app = _build_app()
|
||||||
|
with patch(
|
||||||
|
"app.analytics.api.get_analytics",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=_SAMPLE_RESULT,
|
||||||
|
) as mock_get:
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
# Verify default range of 7 days was passed
|
||||||
|
mock_get.assert_called_once()
|
||||||
|
call_args = mock_get.call_args
|
||||||
|
assert call_args[1].get("range_days", call_args[0][1] if len(call_args[0]) > 1 else None) in (7, None) or call_args[0][1] == 7
|
||||||
|
|
||||||
|
async def test_large_range_365d_works(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=365d is accepted (max boundary)."""
|
||||||
|
test_app = _build_app()
|
||||||
|
result = AnalyticsResult(
|
||||||
|
range="365d",
|
||||||
|
total_conversations=1000,
|
||||||
|
resolution_rate=0.9,
|
||||||
|
escalation_rate=0.02,
|
||||||
|
avg_turns_per_conversation=4.0,
|
||||||
|
avg_cost_per_conversation_usd=0.01,
|
||||||
|
agent_usage=(),
|
||||||
|
interrupt_stats=InterruptStats(),
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"app.analytics.api.get_analytics",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=result,
|
||||||
|
):
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "365d"})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["success"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsInvalidRange:
|
||||||
|
"""Test analytics endpoint with invalid range parameters."""
|
||||||
|
|
||||||
|
async def test_invalid_range_format_returns_400(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=abc returns 400 error envelope."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "abc"})
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert "Invalid range format" in body["error"]
|
||||||
|
|
||||||
|
async def test_zero_day_range_returns_400(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=0d returns 400 because 0 is below minimum."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "0d"})
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "between 1 and 365" in body["error"]
|
||||||
|
|
||||||
|
async def test_range_exceeding_max_returns_400(self) -> None:
|
||||||
|
"""GET /api/v1/analytics?range=999d returns 400 because it exceeds 365."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "999d"})
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "between 1 and 365" in body["error"]
|
||||||
128
backend/tests/integration/test_error_responses.py
Normal file
128
backend/tests/integration/test_error_responses.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Integration tests for global error handling and envelope format consistency.
|
||||||
|
|
||||||
|
Tests that all error responses from the FastAPI app conform to the
|
||||||
|
standard envelope: {"success": false, "data": null, "error": "..."}.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app():
|
||||||
|
"""Build the actual FastAPI app with exception handlers but mocked state."""
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.analytics.api import router as analytics_router
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.replay.api import router as replay_router
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(analytics_router)
|
||||||
|
test_app.include_router(replay_router)
|
||||||
|
|
||||||
|
@test_app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(RequestValidationError)
|
||||||
|
async def _validation_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(Exception)
|
||||||
|
async def _catch_all(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=envelope(None, success=False, error="Internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.get("/api/v1/health")
|
||||||
|
def health_check():
|
||||||
|
return {"status": "ok", "version": "0.6.0"}
|
||||||
|
|
||||||
|
test_app.state.settings = MagicMock(admin_api_key="")
|
||||||
|
test_app.state.pool = MagicMock()
|
||||||
|
|
||||||
|
return test_app
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvelopeFormat:
|
||||||
|
"""Tests that error responses consistently follow envelope format."""
|
||||||
|
|
||||||
|
async def test_http_400_produces_envelope(self) -> None:
|
||||||
|
"""A 400 error returns standard envelope with success=false."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "invalid"})
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert isinstance(body["error"], str)
|
||||||
|
assert len(body["error"]) > 0
|
||||||
|
|
||||||
|
async def test_validation_error_produces_422_envelope(self) -> None:
|
||||||
|
"""Invalid query param type returns 422 with envelope format."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
# page must be >= 1; passing 0 triggers validation error
|
||||||
|
resp = await client.get("/api/v1/conversations", params={"page": 0})
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert isinstance(body["error"], str)
|
||||||
|
|
||||||
|
async def test_all_error_fields_present(self) -> None:
|
||||||
|
"""Error envelope contains exactly success, data, and error keys."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/analytics", params={"range": "bad"})
|
||||||
|
|
||||||
|
body = resp.json()
|
||||||
|
assert set(body.keys()) == {"success", "data", "error"}
|
||||||
|
|
||||||
|
async def test_health_endpoint_returns_200(self) -> None:
|
||||||
|
"""Health check returns 200 with status ok."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/health")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "ok"
|
||||||
|
assert "version" in body
|
||||||
|
|
||||||
|
async def test_unknown_endpoint_returns_404(self) -> None:
|
||||||
|
"""Requesting a non-existent path returns 404."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/nonexistent-path")
|
||||||
|
|
||||||
|
# FastAPI returns 404 for unknown routes; may or may not be wrapped
|
||||||
|
assert resp.status_code == 404
|
||||||
203
backend/tests/integration/test_import_pipeline.py
Normal file
203
backend/tests/integration/test_import_pipeline.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""Integration tests for the OpenAPI import pipeline orchestrator.
|
||||||
|
|
||||||
|
Tests the full pipeline: fetch -> validate -> parse -> classify.
|
||||||
|
Uses mocked HTTP and mocked LLM classifier.
|
||||||
|
|
||||||
|
RED phase: written before implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.openapi.models import ImportJob
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
_VALID_SPEC_JSON = """{
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Test API", "version": "1.0.0"},
|
||||||
|
"paths": {
|
||||||
|
"/orders": {
|
||||||
|
"get": {
|
||||||
|
"operationId": "list_orders",
|
||||||
|
"summary": "List orders",
|
||||||
|
"description": "Returns all orders",
|
||||||
|
"responses": {"200": {"description": "OK"}}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/orders/{id}": {
|
||||||
|
"delete": {
|
||||||
|
"operationId": "delete_order",
|
||||||
|
"summary": "Delete order",
|
||||||
|
"description": "Deletes an order",
|
||||||
|
"parameters": [
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||||
|
],
|
||||||
|
"responses": {"204": {"description": "Deleted"}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"""
|
||||||
|
|
||||||
|
_INVALID_SPEC_JSON = '{"not": "a valid openapi spec"}'
|
||||||
|
|
||||||
|
_PUBLIC_IP = "93.184.216.34"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_classifier():
|
||||||
|
"""A mock classifier that classifies using heuristics."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
return HeuristicClassifier()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def orchestrator(mock_classifier):
|
||||||
|
"""Create an ImportOrchestrator with the mock classifier."""
|
||||||
|
from app.openapi.importer import ImportOrchestrator
|
||||||
|
|
||||||
|
return ImportOrchestrator(classifier=mock_classifier)
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportOrchestratorSuccess:
|
||||||
|
"""Tests for successful import pipeline flows."""
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_full_pipeline_succeeds(self, orchestrator, httpx_mock) -> None:
|
||||||
|
"""Full pipeline with valid spec and mocked HTTP succeeds."""
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/api/spec.json",
|
||||||
|
text=_VALID_SPEC_JSON,
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
job = await orchestrator.start_import(
|
||||||
|
url="https://example.com/api/spec.json",
|
||||||
|
job_id="test-job-1",
|
||||||
|
on_progress=None,
|
||||||
|
)
|
||||||
|
assert isinstance(job, ImportJob)
|
||||||
|
assert job.status == "done"
|
||||||
|
assert job.job_id == "test-job-1"
|
||||||
|
assert job.total_endpoints == 2
|
||||||
|
assert job.classified_count == 2
|
||||||
|
assert job.error_message is None
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_progress_callback_called_at_stages(self, orchestrator, httpx_mock) -> None:
|
||||||
|
"""on_progress callback is called at each pipeline stage."""
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/api/spec.json",
|
||||||
|
text=_VALID_SPEC_JSON,
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
stages_seen: list[str] = []
|
||||||
|
|
||||||
|
def on_progress(stage: str, job: ImportJob) -> None:
|
||||||
|
stages_seen.append(stage)
|
||||||
|
|
||||||
|
await orchestrator.start_import(
|
||||||
|
url="https://example.com/api/spec.json",
|
||||||
|
job_id="test-job-2",
|
||||||
|
on_progress=on_progress,
|
||||||
|
)
|
||||||
|
assert "fetching" in stages_seen
|
||||||
|
assert "validating" in stages_seen
|
||||||
|
assert "parsing" in stages_seen
|
||||||
|
assert "classifying" in stages_seen
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_none_progress_callback_does_not_raise(
|
||||||
|
self, orchestrator, httpx_mock
|
||||||
|
) -> None:
|
||||||
|
"""Passing None as on_progress does not raise."""
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/api/spec.json",
|
||||||
|
text=_VALID_SPEC_JSON,
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
job = await orchestrator.start_import(
|
||||||
|
url="https://example.com/api/spec.json",
|
||||||
|
job_id="test-job-3",
|
||||||
|
on_progress=None,
|
||||||
|
)
|
||||||
|
assert job.status == "done"
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportOrchestratorFailures:
|
||||||
|
"""Tests for error handling in the import pipeline."""
|
||||||
|
|
||||||
|
async def test_fetch_failure_sets_failed_status(self, orchestrator) -> None:
|
||||||
|
"""When HTTP fetch fails, job status is 'failed'."""
|
||||||
|
with patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]):
|
||||||
|
job = await orchestrator.start_import(
|
||||||
|
url="http://internal.corp/spec.json",
|
||||||
|
job_id="test-job-fail-1",
|
||||||
|
on_progress=None,
|
||||||
|
)
|
||||||
|
assert job.status == "failed"
|
||||||
|
assert job.error_message is not None
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_validation_failure_sets_failed_status(
|
||||||
|
self, orchestrator, httpx_mock
|
||||||
|
) -> None:
|
||||||
|
"""When spec validation fails, job status is 'failed'."""
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/api/bad.json",
|
||||||
|
text=_INVALID_SPEC_JSON,
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
job = await orchestrator.start_import(
|
||||||
|
url="https://example.com/api/bad.json",
|
||||||
|
job_id="test-job-fail-2",
|
||||||
|
on_progress=None,
|
||||||
|
)
|
||||||
|
assert job.status == "failed"
|
||||||
|
assert job.error_message is not None
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_error_message_is_descriptive(self, orchestrator, httpx_mock) -> None:
|
||||||
|
"""Error message contains useful context."""
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/api/bad.json",
|
||||||
|
text=_INVALID_SPEC_JSON,
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
job = await orchestrator.start_import(
|
||||||
|
url="https://example.com/api/bad.json",
|
||||||
|
job_id="test-job-fail-3",
|
||||||
|
on_progress=None,
|
||||||
|
)
|
||||||
|
assert len(job.error_message) > 0
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_failed_status_progress_called_with_failed(
|
||||||
|
self, orchestrator, httpx_mock
|
||||||
|
) -> None:
|
||||||
|
"""When pipeline fails, on_progress is called with 'failed' stage."""
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/api/bad.json",
|
||||||
|
text=_INVALID_SPEC_JSON,
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
stages_seen: list[str] = []
|
||||||
|
|
||||||
|
def on_progress(stage: str, job: ImportJob) -> None:
|
||||||
|
stages_seen.append(stage)
|
||||||
|
|
||||||
|
await orchestrator.start_import(
|
||||||
|
url="https://example.com/api/bad.json",
|
||||||
|
job_id="test-job-fail-4",
|
||||||
|
on_progress=on_progress,
|
||||||
|
)
|
||||||
|
assert "failed" in stages_seen
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _mock_public_dns():
|
||||||
|
with patch("app.openapi.ssrf.resolve_hostname", return_value=[_PUBLIC_IP]):
|
||||||
|
yield
|
||||||
164
backend/tests/integration/test_openapi_api.py
Normal file
164
backend/tests/integration/test_openapi_api.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Integration tests for /api/v1/openapi/ endpoints.
|
||||||
|
|
||||||
|
Tests the full API layer for the OpenAPI import review workflow,
|
||||||
|
including job creation, status retrieval, classification updates,
|
||||||
|
and approval triggering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app():
|
||||||
|
"""Build a minimal FastAPI app with the openapi router and mocked deps."""
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.openapi.review_api import router as openapi_router
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(openapi_router)
|
||||||
|
|
||||||
|
@test_app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(RequestValidationError)
|
||||||
|
async def _validation_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
test_app.state.settings = MagicMock(admin_api_key="")
|
||||||
|
|
||||||
|
return test_app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_job_store():
|
||||||
|
"""Clear the in-memory job store between tests."""
|
||||||
|
from app.openapi.review_api import _job_store
|
||||||
|
|
||||||
|
_job_store.clear()
|
||||||
|
yield
|
||||||
|
_job_store.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportEndpoint:
|
||||||
|
"""Tests for POST /api/v1/openapi/import."""
|
||||||
|
|
||||||
|
async def test_import_returns_202_with_job_id(self) -> None:
|
||||||
|
"""Starting an import returns 202 with a job_id."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/openapi/import",
|
||||||
|
json={"url": "https://example.com/api/spec.json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 202
|
||||||
|
body = resp.json()
|
||||||
|
assert "job_id" in body
|
||||||
|
assert body["status"] == "pending"
|
||||||
|
assert body["spec_url"] == "https://example.com/api/spec.json"
|
||||||
|
|
||||||
|
async def test_import_invalid_url_returns_422(self) -> None:
|
||||||
|
"""POST with invalid URL (no http/https) returns 422."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/openapi/import",
|
||||||
|
json={"url": "ftp://example.com/spec.json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestJobStatusEndpoint:
|
||||||
|
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
|
||||||
|
|
||||||
|
async def test_get_existing_job_returns_status(self) -> None:
|
||||||
|
"""Retrieving an existing job returns its status."""
|
||||||
|
from app.openapi.review_api import _job_store
|
||||||
|
|
||||||
|
_job_store["test-job-1"] = {
|
||||||
|
"job_id": "test-job-1",
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://example.com/spec.json",
|
||||||
|
"total_endpoints": 5,
|
||||||
|
"classified_count": 5,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/openapi/jobs/test-job-1")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["job_id"] == "test-job-1"
|
||||||
|
assert body["status"] == "done"
|
||||||
|
assert body["total_endpoints"] == 5
|
||||||
|
|
||||||
|
async def test_get_unknown_job_returns_404(self) -> None:
|
||||||
|
"""Retrieving a non-existent job returns 404 error envelope."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/openapi/jobs/unknown-id-999")
|
||||||
|
|
||||||
|
assert resp.status_code == 404
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "not found" in body["error"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestApproveEndpoint:
|
||||||
|
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
|
||||||
|
|
||||||
|
async def test_approve_with_no_classifications_returns_400(self) -> None:
|
||||||
|
"""Approving a job with no classifications returns 400."""
|
||||||
|
from app.openapi.review_api import _job_store
|
||||||
|
|
||||||
|
_job_store["empty-job"] = {
|
||||||
|
"job_id": "empty-job",
|
||||||
|
"status": "done",
|
||||||
|
"spec_url": "https://example.com/spec.json",
|
||||||
|
"total_endpoints": 0,
|
||||||
|
"classified_count": 0,
|
||||||
|
"error_message": None,
|
||||||
|
"classifications": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.post("/api/v1/openapi/jobs/empty-job/approve")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "no classifications" in body["error"].lower()
|
||||||
512
backend/tests/integration/test_phase2_checkpoints.py
Normal file
512
backend/tests/integration/test_phase2_checkpoints.py
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
"""Phase 2 checkpoint acceptance tests.
|
||||||
|
|
||||||
|
Each test maps to one checkpoint criterion from DEVELOPMENT-PLAN.md:
|
||||||
|
1. "查询订单 1042" -> routes to order_lookup agent
|
||||||
|
2. "取消订单 1042 并给我一个 10% 折扣" -> sequential multi-agent execution
|
||||||
|
3. Ambiguous message -> fallback asks for clarification
|
||||||
|
4. Interrupt > 30 min TTL -> auto-cancel + retry prompt
|
||||||
|
5. Agent escalation -> Webhook POST succeeds (or logs after retries)
|
||||||
|
6. E-commerce template -> 4 pre-configured agents work
|
||||||
|
7. pytest --cov >= 80% (verified separately)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.registry import AgentConfig, AgentRegistry
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
|
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class AsyncIterHelper:
|
||||||
|
def __init__(self, items: list) -> None:
|
||||||
|
self._items = list(items)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if not self._items:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return self._items.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeWS:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent: list[dict] = []
|
||||||
|
|
||||||
|
async def send_json(self, data: dict) -> None:
|
||||||
|
self.sent.append(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk(content: str, node: str) -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = content
|
||||||
|
c.tool_calls = []
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_chunk(name: str, args: dict, node: str) -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = ""
|
||||||
|
c.tool_calls = [{"name": name, "args": args}]
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def _state(*, interrupt: bool = False, data: dict | None = None):
|
||||||
|
s = MagicMock()
|
||||||
|
if interrupt:
|
||||||
|
obj = MagicMock()
|
||||||
|
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
||||||
|
t = MagicMock()
|
||||||
|
t.interrupts = (obj,)
|
||||||
|
s.tasks = (t,)
|
||||||
|
else:
|
||||||
|
s.tasks = ()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def _agent(name: str, desc: str, perm: str = "read") -> AgentConfig:
|
||||||
|
return AgentConfig(name=name, description=desc, permission=perm, tools=["fallback_respond"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Checkpoint 1: "查询订单 1042" -> 路由到订单查询 Agent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestCheckpoint1OrderQueryRouting:
|
||||||
|
"""Verify intent classifier routes order queries to order_lookup."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_order_query_classified_to_order_lookup(self) -> None:
|
||||||
|
expected = ClassificationResult(
|
||||||
|
intents=(
|
||||||
|
IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="order query"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
agents = (
|
||||||
|
_agent("order_lookup", "Looks up order status and tracking"),
|
||||||
|
_agent("order_actions", "Modifies orders", "write"),
|
||||||
|
_agent("discount", "Applies discounts", "write"),
|
||||||
|
_agent("fallback", "Handles unclear requests"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await classifier.classify("查询订单 1042", agents)
|
||||||
|
assert len(result.intents) == 1
|
||||||
|
assert result.intents[0].agent_name == "order_lookup"
|
||||||
|
assert result.intents[0].confidence >= 0.9
|
||||||
|
assert not result.is_ambiguous
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_order_query_streams_from_order_lookup_agent(self) -> None:
|
||||||
|
"""Full dispatch: classify -> route -> stream from order_lookup."""
|
||||||
|
graph = MagicMock()
|
||||||
|
# Classifier returns order_lookup
|
||||||
|
mock_classifier = AsyncMock()
|
||||||
|
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||||
|
))
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
|
||||||
|
# Graph streams order_lookup response
|
||||||
|
graph.astream = MagicMock(return_value=AsyncIterHelper([
|
||||||
|
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
||||||
|
_chunk("Order 1042 is shipped.", "order_lookup"),
|
||||||
|
]))
|
||||||
|
graph.aget_state = AsyncMock(return_value=_state())
|
||||||
|
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
|
sm = SessionManager()
|
||||||
|
sm.touch("t1")
|
||||||
|
im = InterruptManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||||
|
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
|
||||||
|
|
||||||
|
token_msgs = [m for m in ws.sent if m["type"] == "token"]
|
||||||
|
assert any(m["agent"] == "order_lookup" for m in token_msgs)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Checkpoint 2: Multi-intent -> sequential execution
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestCheckpoint2MultiIntentSequential:
|
||||||
|
"""Verify multi-intent classified and hint injected for sequential execution."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_intent_classification(self) -> None:
|
||||||
|
expected = ClassificationResult(
|
||||||
|
intents=(
|
||||||
|
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||||
|
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
agents = (
|
||||||
|
_agent("order_actions", "Modifies orders", "write"),
|
||||||
|
_agent("discount", "Applies discounts", "write"),
|
||||||
|
_agent("fallback", "Handles unclear requests"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await classifier.classify("取消订单 1042 并给我一个 10% 折扣", agents)
|
||||||
|
assert len(result.intents) == 2
|
||||||
|
assert result.intents[0].agent_name == "order_actions"
|
||||||
|
assert result.intents[1].agent_name == "discount"
|
||||||
|
assert not result.is_ambiguous
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_intent_injects_routing_hint(self) -> None:
|
||||||
|
"""When multi-intent detected, a [System: ...] hint is appended to the message."""
|
||||||
|
graph = MagicMock()
|
||||||
|
mock_classifier = AsyncMock()
|
||||||
|
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||||
|
intents=(
|
||||||
|
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||||
|
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||||
|
),
|
||||||
|
))
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
|
graph.aget_state = AsyncMock(return_value=_state())
|
||||||
|
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
|
sm = SessionManager()
|
||||||
|
sm.touch("t1")
|
||||||
|
im = InterruptManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = json.dumps({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "t1",
|
||||||
|
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||||
|
})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
# Verify the graph was called with the routing hint in the message
|
||||||
|
call_args = graph.astream.call_args
|
||||||
|
input_msg = call_args[0][0]
|
||||||
|
msg_content = input_msg["messages"][0].content
|
||||||
|
assert "[System:" in msg_content
|
||||||
|
assert "order_actions" in msg_content
|
||||||
|
assert "discount" in msg_content
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Checkpoint 3: Ambiguous message -> clarification
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestCheckpoint3AmbiguousClarification:
|
||||||
|
"""Verify ambiguous messages trigger clarification prompt."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ambiguous_intent_returns_clarification(self) -> None:
|
||||||
|
expected = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
|
||||||
|
is_ambiguous=False, # low confidence will trigger ambiguity threshold
|
||||||
|
)
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
agents = (_agent("order_lookup", "Orders"), _agent("fallback", "Fallback"))
|
||||||
|
|
||||||
|
result = await classifier.classify("嗯...", agents)
|
||||||
|
assert result.is_ambiguous
|
||||||
|
assert result.clarification_question is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ambiguous_sends_clarification_via_websocket(self) -> None:
|
||||||
|
graph = MagicMock()
|
||||||
|
mock_classifier = AsyncMock()
|
||||||
|
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||||
|
intents=(),
|
||||||
|
is_ambiguous=True,
|
||||||
|
clarification_question=(
|
||||||
|
"Could you please provide more details about what you need help with?"
|
||||||
|
),
|
||||||
|
))
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
|
graph.aget_state = AsyncMock(return_value=_state())
|
||||||
|
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
|
sm = SessionManager()
|
||||||
|
sm.touch("t1")
|
||||||
|
im = InterruptManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
|
||||||
|
assert len(clarifications) == 1
|
||||||
|
assert "more details" in clarifications[0]["message"]
|
||||||
|
|
||||||
|
# Should NOT call graph.astream since we returned early
|
||||||
|
graph.astream.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Checkpoint 4: Interrupt > 30 min -> auto-cancel + retry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestCheckpoint4InterruptTTLAutoCancel:
|
||||||
|
"""Verify interrupt TTL expiration triggers auto-cancel and retry prompt."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
|
||||||
|
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||||
|
graph = MagicMock()
|
||||||
|
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||||
|
graph.aget_state = AsyncMock(return_value=st)
|
||||||
|
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
|
||||||
|
|
||||||
|
sm = SessionManager()
|
||||||
|
sm.touch("t1")
|
||||||
|
im = InterruptManager(ttl_seconds=1800) # 30 minutes
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger interrupt
|
||||||
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||||
|
assert len(interrupts) == 1
|
||||||
|
|
||||||
|
# Simulate 31 minutes passing
|
||||||
|
record = im._interrupts["t1"]
|
||||||
|
ws.sent.clear()
|
||||||
|
|
||||||
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
|
mock_time.time.return_value = record.created_at + 1860 # 31 min
|
||||||
|
|
||||||
|
raw = json.dumps({
|
||||||
|
"type": "interrupt_response",
|
||||||
|
"thread_id": "t1",
|
||||||
|
"approved": True,
|
||||||
|
})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
# Should get retry prompt, NOT resume the graph
|
||||||
|
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
|
||||||
|
assert len(expired_msgs) == 1
|
||||||
|
assert "30 minutes" in expired_msgs[0]["message"]
|
||||||
|
assert expired_msgs[0]["action"] == "cancel_order"
|
||||||
|
assert expired_msgs[0]["thread_id"] == "t1"
|
||||||
|
|
||||||
|
def test_cleanup_expired_returns_records(self) -> None:
|
||||||
|
im = InterruptManager(ttl_seconds=1800)
|
||||||
|
im.register("t1", "cancel_order", {"order_id": "1042"})
|
||||||
|
im.register("t2", "apply_discount", {"order_id": "1043"})
|
||||||
|
|
||||||
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
|
record = im._interrupts["t1"]
|
||||||
|
mock_time.time.return_value = record.created_at + 1860
|
||||||
|
expired = im.cleanup_expired()
|
||||||
|
|
||||||
|
assert len(expired) == 2
|
||||||
|
actions = {r.action for r in expired}
|
||||||
|
assert "cancel_order" in actions
|
||||||
|
assert "apply_discount" in actions
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Checkpoint 5: Agent escalation -> Webhook POST
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestCheckpoint5WebhookEscalation:
|
||||||
|
"""Verify webhook escalation sends POST and retries on failure."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_post_success(self) -> None:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
|
||||||
|
escalator = WebhookEscalator(url="https://support.example.com/escalate")
|
||||||
|
payload = EscalationPayload(
|
||||||
|
thread_id="t1",
|
||||||
|
reason="Agent cannot resolve customer issue",
|
||||||
|
conversation_summary="Customer asked about refund policy",
|
||||||
|
metadata={"customer_id": "C-123"},
|
||||||
|
)
|
||||||
|
result = await escalator.escalate(payload)
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.status_code == 200
|
||||||
|
assert result.attempts == 1
|
||||||
|
|
||||||
|
# Verify POST was called with correct payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[0][0] == "https://support.example.com/escalate"
|
||||||
|
posted_data = call_args[1]["json"]
|
||||||
|
assert posted_data["thread_id"] == "t1"
|
||||||
|
assert posted_data["reason"] == "Agent cannot resolve customer issue"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_retries_then_logs(self) -> None:
|
||||||
|
fail_response = AsyncMock()
|
||||||
|
fail_response.status_code = 503
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=fail_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||||
|
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
escalator = WebhookEscalator(
|
||||||
|
url="https://support.example.com/escalate",
|
||||||
|
max_retries=3,
|
||||||
|
)
|
||||||
|
payload = EscalationPayload(
|
||||||
|
thread_id="t1",
|
||||||
|
reason="Escalation needed",
|
||||||
|
conversation_summary="Summary",
|
||||||
|
)
|
||||||
|
result = await escalator.escalate(payload)
|
||||||
|
|
||||||
|
assert not result.success
|
||||||
|
assert result.attempts == 3
|
||||||
|
assert "503" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_noop_escalator_when_disabled(self) -> None:
|
||||||
|
escalator = NoOpEscalator()
|
||||||
|
payload = EscalationPayload(
|
||||||
|
thread_id="t1",
|
||||||
|
reason="Test",
|
||||||
|
conversation_summary="Test",
|
||||||
|
)
|
||||||
|
result = await escalator.escalate(payload)
|
||||||
|
assert not result.success
|
||||||
|
assert "disabled" in result.error.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Checkpoint 6: E-commerce template -> pre-configured agents
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestCheckpoint6EcommerceTemplate:
|
||||||
|
"""Verify e-commerce template loads with correct agents."""
|
||||||
|
|
||||||
|
def test_ecommerce_template_loads_4_agents(self) -> None:
|
||||||
|
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||||
|
assert len(registry) == 4
|
||||||
|
|
||||||
|
def test_ecommerce_template_has_correct_agents(self) -> None:
|
||||||
|
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||||
|
agents = registry.list_agents()
|
||||||
|
names = {a.name for a in agents}
|
||||||
|
assert names == {"order_lookup", "order_actions", "discount", "fallback"}
|
||||||
|
|
||||||
|
def test_ecommerce_order_lookup_is_read(self) -> None:
|
||||||
|
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||||
|
agent = registry.get_agent("order_lookup")
|
||||||
|
assert agent.permission == "read"
|
||||||
|
assert "get_order_status" in agent.tools
|
||||||
|
assert "get_tracking_info" in agent.tools
|
||||||
|
|
||||||
|
def test_ecommerce_order_actions_is_write(self) -> None:
|
||||||
|
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||||
|
agent = registry.get_agent("order_actions")
|
||||||
|
assert agent.permission == "write"
|
||||||
|
assert "cancel_order" in agent.tools
|
||||||
|
|
||||||
|
def test_ecommerce_discount_is_write(self) -> None:
|
||||||
|
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||||
|
agent = registry.get_agent("discount")
|
||||||
|
assert agent.permission == "write"
|
||||||
|
assert "apply_discount" in agent.tools
|
||||||
|
assert "generate_coupon" in agent.tools
|
||||||
|
|
||||||
|
def test_ecommerce_fallback_is_read(self) -> None:
|
||||||
|
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||||
|
agent = registry.get_agent("fallback")
|
||||||
|
assert agent.permission == "read"
|
||||||
|
|
||||||
|
def test_all_three_templates_available(self) -> None:
|
||||||
|
templates = AgentRegistry.list_templates(TEMPLATES_DIR)
|
||||||
|
assert "e-commerce" in templates
|
||||||
|
assert "saas" in templates
|
||||||
|
assert "fintech" in templates
|
||||||
213
backend/tests/integration/test_replay_api.py
Normal file
213
backend/tests/integration/test_replay_api.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""Integration tests for /api/v1/conversations and /api/v1/replay/{thread_id}.
|
||||||
|
|
||||||
|
Tests the full API layer with a mocked database pool, verifying routing,
|
||||||
|
serialization, pagination, and error handling in envelope format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fake_cursor(rows, *, fetchone_value=None):
|
||||||
|
"""Build a fake async cursor returning the given rows on fetchall."""
|
||||||
|
cursor = AsyncMock()
|
||||||
|
cursor.fetchall = AsyncMock(return_value=rows)
|
||||||
|
if fetchone_value is not None:
|
||||||
|
cursor.fetchone = AsyncMock(return_value=fetchone_value)
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeConnection:
|
||||||
|
"""Fake async connection that returns pre-configured cursors in order."""
|
||||||
|
|
||||||
|
def __init__(self, cursors: list) -> None:
|
||||||
|
self._cursors = list(cursors)
|
||||||
|
self._idx = 0
|
||||||
|
|
||||||
|
async def execute(self, sql, params=None):
|
||||||
|
cursor = self._cursors[self._idx]
|
||||||
|
self._idx += 1
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _FakePool:
|
||||||
|
"""Fake connection pool that yields a fake connection."""
|
||||||
|
|
||||||
|
def __init__(self, conn: _FakeConnection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def connection(self):
|
||||||
|
return self._conn
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app(pool=None):
|
||||||
|
"""Build a minimal FastAPI app with the replay router and mocked deps."""
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
from app.replay.api import router as replay_router
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(replay_router)
|
||||||
|
|
||||||
|
@test_app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_app.exception_handler(RequestValidationError)
|
||||||
|
async def _validation_exc(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
test_app.state.settings = MagicMock(admin_api_key="")
|
||||||
|
test_app.state.pool = pool or MagicMock()
|
||||||
|
|
||||||
|
return test_app
|
||||||
|
|
||||||
|
|
||||||
|
class TestListConversations:
|
||||||
|
"""Tests for GET /api/v1/conversations endpoint."""
|
||||||
|
|
||||||
|
async def test_returns_paginated_envelope(self) -> None:
|
||||||
|
"""Conversations list returns envelope with pagination metadata."""
|
||||||
|
count_cursor = _make_fake_cursor([], fetchone_value=(3,))
|
||||||
|
rows = [
|
||||||
|
{"thread_id": "t1", "created_at": "2026-01-01", "last_activity": "2026-01-01",
|
||||||
|
"status": "active", "total_tokens": 100, "total_cost_usd": 0.01},
|
||||||
|
{"thread_id": "t2", "created_at": "2026-01-02", "last_activity": "2026-01-02",
|
||||||
|
"status": "resolved", "total_tokens": 200, "total_cost_usd": 0.02},
|
||||||
|
]
|
||||||
|
list_cursor = _make_fake_cursor(rows)
|
||||||
|
conn = _FakeConnection([count_cursor, list_cursor])
|
||||||
|
pool = _FakePool(conn)
|
||||||
|
test_app = _build_app(pool)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/conversations")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["data"]["total"] == 3
|
||||||
|
assert len(body["data"]["conversations"]) == 2
|
||||||
|
assert body["data"]["page"] == 1
|
||||||
|
assert body["data"]["per_page"] == 20
|
||||||
|
|
||||||
|
async def test_custom_page_and_per_page(self) -> None:
|
||||||
|
"""Custom page/per_page params are reflected in the response."""
|
||||||
|
count_cursor = _make_fake_cursor([], fetchone_value=(50,))
|
||||||
|
list_cursor = _make_fake_cursor([])
|
||||||
|
conn = _FakeConnection([count_cursor, list_cursor])
|
||||||
|
pool = _FakePool(conn)
|
||||||
|
test_app = _build_app(pool)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/conversations", params={"page": 3, "per_page": 10})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["data"]["page"] == 3
|
||||||
|
assert body["data"]["per_page"] == 10
|
||||||
|
|
||||||
|
async def test_invalid_page_returns_422(self) -> None:
|
||||||
|
"""page=0 violates ge=1 constraint and returns 422 error envelope."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/conversations", params={"page": 0})
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplayEndpoint:
|
||||||
|
"""Tests for GET /api/v1/replay/{thread_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_valid_thread_returns_timeline(self) -> None:
|
||||||
|
"""Replay with valid thread_id returns steps in envelope format."""
|
||||||
|
checkpoint_rows = [
|
||||||
|
{
|
||||||
|
"thread_id": "abc123",
|
||||||
|
"checkpoint_id": "cp1",
|
||||||
|
"checkpoint": {
|
||||||
|
"channel_values": {
|
||||||
|
"messages": [
|
||||||
|
{"type": "human", "content": "Hello", "created_at": "2026-01-01T00:00:00Z"},
|
||||||
|
{"type": "ai", "content": "Hi there!", "created_at": "2026-01-01T00:00:01Z"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
cursor = _make_fake_cursor(checkpoint_rows)
|
||||||
|
conn = _FakeConnection([cursor])
|
||||||
|
pool = _FakePool(conn)
|
||||||
|
test_app = _build_app(pool)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/replay/abc123")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["data"]["thread_id"] == "abc123"
|
||||||
|
assert body["data"]["total_steps"] == 2
|
||||||
|
assert len(body["data"]["steps"]) == 2
|
||||||
|
assert body["data"]["steps"][0]["type"] == "user_message"
|
||||||
|
assert body["data"]["steps"][1]["type"] == "agent_response"
|
||||||
|
|
||||||
|
async def test_invalid_thread_id_format_returns_400(self) -> None:
|
||||||
|
"""Thread IDs with path traversal characters are rejected with 400."""
|
||||||
|
test_app = _build_app()
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/replay/../../etc/passwd")
|
||||||
|
|
||||||
|
# FastAPI may return 400 from our handler or 404 from routing
|
||||||
|
assert resp.status_code in (400, 404, 422)
|
||||||
|
|
||||||
|
async def test_nonexistent_thread_returns_404(self) -> None:
|
||||||
|
"""Replay with a thread_id that has no checkpoints returns 404."""
|
||||||
|
cursor = _make_fake_cursor([])
|
||||||
|
conn = _FakeConnection([cursor])
|
||||||
|
pool = _FakePool(conn)
|
||||||
|
test_app = _build_app(pool)
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.get("/api/v1/replay/nonexistent-thread")
|
||||||
|
|
||||||
|
assert resp.status_code == 404
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "not found" in body["error"].lower()
|
||||||
371
backend/tests/integration/test_routing.py
Normal file
371
backend/tests/integration/test_routing.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
"""Integration tests for multi-agent routing flow.
|
||||||
|
|
||||||
|
Tests the full pipeline: intent classification -> supervisor routing ->
|
||||||
|
agent execution -> response streaming, exercising cross-module integration
|
||||||
|
with mocked LLM.
|
||||||
|
|
||||||
|
Required by Phase 2 test plan:
|
||||||
|
- Unit: intent classification accuracy
|
||||||
|
- Unit: multi-intent sequential execution
|
||||||
|
- Integration: complete multi-agent routing flow
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.registry import AgentConfig
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class AsyncIterHelper:
|
||||||
|
def __init__(self, items: list) -> None:
|
||||||
|
self._items = list(items)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if not self._items:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return self._items.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeWS:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent: list[dict] = []
|
||||||
|
|
||||||
|
async def send_json(self, data: dict) -> None:
|
||||||
|
self.sent.append(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk(content: str, node: str) -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = content
|
||||||
|
c.tool_calls = []
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_chunk(name: str, args: dict, node: str) -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = ""
|
||||||
|
c.tool_calls = [{"name": name, "args": args}]
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def _state(*, interrupt: bool = False, data: dict | None = None):
|
||||||
|
s = MagicMock()
|
||||||
|
if interrupt:
|
||||||
|
obj = MagicMock()
|
||||||
|
obj.value = data or {}
|
||||||
|
t = MagicMock()
|
||||||
|
t.interrupts = (obj,)
|
||||||
|
s.tasks = (t,)
|
||||||
|
else:
|
||||||
|
s.tasks = ()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
AGENTS = (
|
||||||
|
AgentConfig(
|
||||||
|
name="order_lookup", description="Looks up orders",
|
||||||
|
permission="read", tools=["get_order_status", "get_tracking_info"],
|
||||||
|
),
|
||||||
|
AgentConfig(
|
||||||
|
name="order_actions", description="Modifies orders",
|
||||||
|
permission="write", tools=["cancel_order"],
|
||||||
|
),
|
||||||
|
AgentConfig(
|
||||||
|
name="discount", description="Applies discounts",
|
||||||
|
permission="write", tools=["apply_discount", "generate_coupon"],
|
||||||
|
),
|
||||||
|
AgentConfig(
|
||||||
|
name="fallback", description="Handles unclear requests",
|
||||||
|
permission="read", tools=["fallback_respond"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_classifier(result: ClassificationResult) -> AsyncMock:
|
||||||
|
"""Create a mock classifier returning the given result."""
|
||||||
|
classifier = AsyncMock()
|
||||||
|
classifier.classify = AsyncMock(return_value=result)
|
||||||
|
return classifier
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_and_ctx(
|
||||||
|
classifier_result: ClassificationResult | None,
|
||||||
|
chunks: list,
|
||||||
|
state=None,
|
||||||
|
) -> tuple[MagicMock, GraphContext]:
|
||||||
|
"""Build a graph mock and GraphContext with optional intent classifier."""
|
||||||
|
graph = MagicMock()
|
||||||
|
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
|
||||||
|
graph.aget_state = AsyncMock(return_value=state or _state())
|
||||||
|
|
||||||
|
if classifier_result is not None:
|
||||||
|
classifier = _make_classifier(classifier_result)
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=classifier,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=graph, registry=mock_registry, intent_classifier=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return graph, graph_ctx
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
|
||||||
|
sm = SessionManager()
|
||||||
|
sm.touch(thread_id)
|
||||||
|
im = InterruptManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
return ws.sent
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Single-intent routing to each agent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestSingleIntentRouting:
|
||||||
|
"""Verify single-intent messages route to the correct agent."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routes_to_order_lookup(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(IntentTarget(
|
||||||
|
agent_name="order_lookup", confidence=0.95, reasoning="status query",
|
||||||
|
),),
|
||||||
|
)
|
||||||
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
|
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
||||||
|
_chunk("Order 1042 is shipped.", "order_lookup"),
|
||||||
|
])
|
||||||
|
|
||||||
|
msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
|
||||||
|
|
||||||
|
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||||
|
assert len(tools) == 1
|
||||||
|
assert tools[0]["tool"] == "get_order_status"
|
||||||
|
assert tools[0]["agent"] == "order_lookup"
|
||||||
|
|
||||||
|
tokens = [m for m in msgs if m["type"] == "token"]
|
||||||
|
assert any("shipped" in t["content"] for t in tokens)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routes_to_order_actions(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
|
||||||
|
)
|
||||||
|
graph, graph_ctx = _make_graph_and_ctx(
|
||||||
|
result,
|
||||||
|
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
|
||||||
|
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
msgs = await _dispatch(graph_ctx, "Cancel order 1042")
|
||||||
|
|
||||||
|
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||||
|
assert tools[0]["tool"] == "cancel_order"
|
||||||
|
assert tools[0]["agent"] == "order_actions"
|
||||||
|
|
||||||
|
interrupts = [m for m in msgs if m["type"] == "interrupt"]
|
||||||
|
assert len(interrupts) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routes_to_discount(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
|
||||||
|
)
|
||||||
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
|
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
|
||||||
|
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
|
||||||
|
])
|
||||||
|
|
||||||
|
msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
|
||||||
|
|
||||||
|
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||||
|
assert tools[0]["tool"] == "generate_coupon"
|
||||||
|
assert tools[0]["agent"] == "discount"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routes_to_fallback(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
|
||||||
|
)
|
||||||
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
|
_chunk("I can help with order inquiries.", "fallback"),
|
||||||
|
])
|
||||||
|
|
||||||
|
msgs = await _dispatch(graph_ctx, "What can you do?")
|
||||||
|
|
||||||
|
tokens = [m for m in msgs if m["type"] == "token"]
|
||||||
|
assert tokens[0]["agent"] == "fallback"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Multi-intent routing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestMultiIntentRouting:
|
||||||
|
"""Verify multi-intent triggers sequential execution hint."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_two_intents_inject_routing_hint(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(
|
||||||
|
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||||
|
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||||
|
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
|
||||||
|
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
|
||||||
|
])
|
||||||
|
|
||||||
|
sm = SessionManager()
|
||||||
|
sm.touch("t1")
|
||||||
|
im = InterruptManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = json.dumps({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "t1",
|
||||||
|
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||||
|
})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
# Verify routing hint was injected
|
||||||
|
call_args = graph.astream.call_args[0][0]
|
||||||
|
msg_content = call_args["messages"][0].content
|
||||||
|
assert "[System:" in msg_content
|
||||||
|
assert "order_actions" in msg_content
|
||||||
|
assert "discount" in msg_content
|
||||||
|
|
||||||
|
# Both tool calls should appear
|
||||||
|
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||||
|
tool_names = {t["tool"] for t in tools}
|
||||||
|
assert "cancel_order" in tool_names
|
||||||
|
assert "apply_discount" in tool_names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_intent_no_routing_hint(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||||
|
)
|
||||||
|
graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
|
||||||
|
|
||||||
|
sm = SessionManager()
|
||||||
|
sm.touch("t1")
|
||||||
|
im = InterruptManager()
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
msg_content = graph.astream.call_args[0][0]["messages"][0].content
|
||||||
|
assert "[System:" not in msg_content
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Ambiguity routing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestAmbiguityRouting:
|
||||||
|
"""Verify ambiguous intents produce clarification, not agent calls."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ambiguous_skips_graph_returns_clarification(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(),
|
||||||
|
is_ambiguous=True,
|
||||||
|
clarification_question="Could you please clarify what you need?",
|
||||||
|
)
|
||||||
|
graph, graph_ctx = _make_graph_and_ctx(result, [])
|
||||||
|
|
||||||
|
msgs = await _dispatch(graph_ctx, "嗯...")
|
||||||
|
|
||||||
|
clarifications = [m for m in msgs if m["type"] == "clarification"]
|
||||||
|
assert len(clarifications) == 1
|
||||||
|
assert "clarify" in clarifications[0]["message"]
|
||||||
|
|
||||||
|
# Graph should NOT have been called
|
||||||
|
graph.astream.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_low_confidence_triggers_ambiguity(self) -> None:
|
||||||
|
"""LLMIntentClassifier applies threshold -- low confidence -> ambiguous."""
|
||||||
|
raw_result = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="fallback", confidence=0.2, reasoning="unclear"),),
|
||||||
|
is_ambiguous=False,
|
||||||
|
)
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(return_value=raw_result)
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
result = await classifier.classify("hmm", AGENTS)
|
||||||
|
|
||||||
|
assert result.is_ambiguous
|
||||||
|
assert result.clarification_question is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# No classifier fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestNoClassifierFallback:
|
||||||
|
"""Verify system works without intent classifier (falls back to supervisor prompt)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_classifier_routes_via_supervisor(self) -> None:
|
||||||
|
graph, graph_ctx = _make_graph_and_ctx(
|
||||||
|
classifier_result=None,
|
||||||
|
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
|
||||||
|
)
|
||||||
|
|
||||||
|
msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
|
||||||
|
|
||||||
|
tokens = [m for m in msgs if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 1
|
||||||
|
completes = [m for m in msgs if m["type"] == "message_complete"]
|
||||||
|
assert len(completes) == 1
|
||||||
159
backend/tests/integration/test_session_interrupt_lifecycle.py
Normal file
159
backend/tests/integration/test_session_interrupt_lifecycle.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""Integration tests for SessionManager + InterruptManager lifecycle.
|
||||||
|
|
||||||
|
These tests exercise the in-memory managers together, verifying the full
|
||||||
|
lifecycle of sessions and interrupts: creation, TTL sliding, interrupt
|
||||||
|
registration/resolution, and expired-interrupt cleanup.
|
||||||
|
|
||||||
|
No database required -- both managers are in-memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionInterruptLifecycle:
|
||||||
|
"""Tests for the combined session + interrupt lifecycle."""
|
||||||
|
|
||||||
|
def test_create_session_register_interrupt_check_status(self) -> None:
|
||||||
|
"""Full lifecycle: create session, register interrupt, verify both states."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=3600)
|
||||||
|
im = InterruptManager(ttl_seconds=300)
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
state = sm.touch("thread-1")
|
||||||
|
assert state.thread_id == "thread-1"
|
||||||
|
assert not state.has_pending_interrupt
|
||||||
|
assert not sm.is_expired("thread-1")
|
||||||
|
|
||||||
|
# Register an interrupt
|
||||||
|
record = im.register("thread-1", "cancel_order", {"order_id": "1042"})
|
||||||
|
sm.extend_for_interrupt("thread-1")
|
||||||
|
|
||||||
|
assert im.has_pending("thread-1")
|
||||||
|
session_state = sm.get_state("thread-1")
|
||||||
|
assert session_state is not None
|
||||||
|
assert session_state.has_pending_interrupt
|
||||||
|
|
||||||
|
# Session should not expire while interrupt is pending
|
||||||
|
assert not sm.is_expired("thread-1")
|
||||||
|
|
||||||
|
def test_interrupt_expiry_after_ttl(self) -> None:
|
||||||
|
"""Interrupt expires when TTL elapses, even if session is alive."""
|
||||||
|
im = InterruptManager(ttl_seconds=5)
|
||||||
|
|
||||||
|
record = im.register("thread-2", "refund", {"amount": 50})
|
||||||
|
assert im.has_pending("thread-2")
|
||||||
|
|
||||||
|
# Simulate time passing beyond TTL
|
||||||
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
|
mock_time.time.return_value = record.created_at + 10
|
||||||
|
assert not im.has_pending("thread-2")
|
||||||
|
|
||||||
|
status = im.check_status("thread-2")
|
||||||
|
assert status is not None
|
||||||
|
assert status.is_expired
|
||||||
|
assert status.remaining_seconds == 0.0
|
||||||
|
|
||||||
|
def test_interrupt_resolve_flow(self) -> None:
|
||||||
|
"""Resolving an interrupt removes it from pending and resets session."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=3600)
|
||||||
|
im = InterruptManager(ttl_seconds=300)
|
||||||
|
|
||||||
|
sm.touch("thread-3")
|
||||||
|
im.register("thread-3", "delete_account", {"user_id": "u1"})
|
||||||
|
sm.extend_for_interrupt("thread-3")
|
||||||
|
|
||||||
|
# Verify pending state
|
||||||
|
assert im.has_pending("thread-3")
|
||||||
|
assert sm.get_state("thread-3").has_pending_interrupt
|
||||||
|
|
||||||
|
# Resolve
|
||||||
|
im.resolve("thread-3")
|
||||||
|
sm.resolve_interrupt("thread-3")
|
||||||
|
|
||||||
|
assert not im.has_pending("thread-3")
|
||||||
|
session_state = sm.get_state("thread-3")
|
||||||
|
assert session_state is not None
|
||||||
|
assert not session_state.has_pending_interrupt
|
||||||
|
|
||||||
|
def test_cleanup_expired_removes_old_interrupts(self) -> None:
|
||||||
|
"""cleanup_expired removes only expired interrupts, keeping active ones."""
|
||||||
|
im = InterruptManager(ttl_seconds=10)
|
||||||
|
|
||||||
|
# Register two interrupts at different times
|
||||||
|
old_record = im.register("thread-old", "action_old", {})
|
||||||
|
new_record = im.register("thread-new", "action_new", {})
|
||||||
|
|
||||||
|
# Simulate time where only old one expired
|
||||||
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
|
# Move old record's creation to the past
|
||||||
|
im._interrupts["thread-old"] = old_record.__class__(
|
||||||
|
interrupt_id=old_record.interrupt_id,
|
||||||
|
thread_id=old_record.thread_id,
|
||||||
|
action=old_record.action,
|
||||||
|
params=old_record.params,
|
||||||
|
created_at=time.time() - 20,
|
||||||
|
ttl_seconds=old_record.ttl_seconds,
|
||||||
|
)
|
||||||
|
mock_time.time.return_value = time.time()
|
||||||
|
|
||||||
|
expired = im.cleanup_expired()
|
||||||
|
assert len(expired) == 1
|
||||||
|
assert expired[0].thread_id == "thread-old"
|
||||||
|
|
||||||
|
# New one should still be pending
|
||||||
|
assert im.has_pending("thread-new")
|
||||||
|
assert not im.has_pending("thread-old")
|
||||||
|
|
||||||
|
def test_session_ttl_sliding_window(self) -> None:
|
||||||
|
"""Touching a session resets the sliding window TTL."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=3600)
|
||||||
|
|
||||||
|
state1 = sm.touch("thread-5")
|
||||||
|
first_activity = state1.last_activity
|
||||||
|
|
||||||
|
time.sleep(0.01)
|
||||||
|
state2 = sm.touch("thread-5")
|
||||||
|
second_activity = state2.last_activity
|
||||||
|
|
||||||
|
assert second_activity > first_activity
|
||||||
|
assert not sm.is_expired("thread-5")
|
||||||
|
|
||||||
|
def test_session_expires_after_ttl_without_activity(self) -> None:
|
||||||
|
"""Session expires when TTL passes without a touch or interrupt."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=0)
|
||||||
|
sm.touch("thread-6")
|
||||||
|
|
||||||
|
# TTL is 0 so session is immediately expired
|
||||||
|
assert sm.is_expired("thread-6")
|
||||||
|
|
||||||
|
def test_pending_interrupt_prevents_session_expiry(self) -> None:
|
||||||
|
"""A session with pending interrupt does not expire even with TTL=0."""
|
||||||
|
sm = SessionManager(session_ttl_seconds=0)
|
||||||
|
sm.touch("thread-7")
|
||||||
|
sm.extend_for_interrupt("thread-7")
|
||||||
|
|
||||||
|
# Even with TTL=0, session should not expire because of pending interrupt
|
||||||
|
assert not sm.is_expired("thread-7")
|
||||||
|
|
||||||
|
def test_retry_prompt_for_expired_interrupt(self) -> None:
|
||||||
|
"""InterruptManager generates a retry prompt for expired interrupts."""
|
||||||
|
im = InterruptManager(ttl_seconds=300)
|
||||||
|
record = im.register("thread-8", "cancel_order", {"order_id": "1042"})
|
||||||
|
|
||||||
|
prompt = im.generate_retry_prompt(record)
|
||||||
|
|
||||||
|
assert prompt["type"] == "interrupt_expired"
|
||||||
|
assert prompt["thread_id"] == "thread-8"
|
||||||
|
assert "cancel_order" in prompt["action"]
|
||||||
|
assert "cancel_order" in prompt["message"]
|
||||||
|
assert "expired" in prompt["message"].lower()
|
||||||
360
backend/tests/integration/test_websocket.py
Normal file
360
backend/tests/integration/test_websocket.py
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
"""Integration tests for WebSocket message flow.
|
||||||
|
|
||||||
|
These tests exercise dispatch_message end-to-end with a mocked LangGraph
|
||||||
|
graph, verifying streaming, interrupt approval/rejection, session TTL,
|
||||||
|
and interrupt TTL expiration through the full message handling pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.interrupt_manager import InterruptManager
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class AsyncIterHelper:
|
||||||
|
"""Make a list behave as an async iterator."""
|
||||||
|
|
||||||
|
def __init__(self, items: list) -> None:
|
||||||
|
self._items = list(items)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if not self._items:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return self._items.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeWS:
|
||||||
|
"""Fake WebSocket that records sent messages."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent: list[dict] = []
|
||||||
|
|
||||||
|
async def send_json(self, data: dict) -> None:
|
||||||
|
self.sent.append(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk(content: str, node: str = "order_lookup") -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = content
|
||||||
|
c.tool_calls = []
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = ""
|
||||||
|
c.tool_calls = [{"name": name, "args": args}]
|
||||||
|
return (c, {"langgraph_node": node})
|
||||||
|
|
||||||
|
|
||||||
|
def _state(*, interrupt: bool = False, data: dict | None = None) -> Any:
|
||||||
|
s = MagicMock()
|
||||||
|
if interrupt:
|
||||||
|
obj = MagicMock()
|
||||||
|
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
||||||
|
t = MagicMock()
|
||||||
|
t.interrupts = (obj,)
|
||||||
|
s.tasks = (t,)
|
||||||
|
else:
|
||||||
|
s.tasks = ()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def _graph(
|
||||||
|
chunks: list | None = None,
|
||||||
|
st: Any = None,
|
||||||
|
resume_chunks: list | None = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
g = MagicMock()
|
||||||
|
|
||||||
|
if st is None:
|
||||||
|
st = _state()
|
||||||
|
|
||||||
|
streams = [chunks or [], resume_chunks or []]
|
||||||
|
idx = {"n": 0}
|
||||||
|
|
||||||
|
def make_stream(*a, **kw):
|
||||||
|
i = min(idx["n"], len(streams) - 1)
|
||||||
|
idx["n"] += 1
|
||||||
|
return AsyncIterHelper(list(streams[i]))
|
||||||
|
|
||||||
|
g.astream = MagicMock(side_effect=make_stream)
|
||||||
|
g.aget_state = AsyncMock(return_value=st)
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||||
|
g = graph or _graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||||
|
|
||||||
|
|
||||||
|
def _setup(
|
||||||
|
graph=None,
|
||||||
|
session_ttl: int = 1800,
|
||||||
|
interrupt_ttl: int = 1800,
|
||||||
|
thread_id: str = "t1",
|
||||||
|
touch: bool = True,
|
||||||
|
):
|
||||||
|
"""Create test dependencies. Pre-touches session by default."""
|
||||||
|
g = graph or _graph()
|
||||||
|
graph_ctx = _make_graph_ctx(g)
|
||||||
|
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||||
|
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||||
|
cb = TokenUsageCallbackHandler()
|
||||||
|
ws = FakeWS()
|
||||||
|
ws_ctx = WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx, session_manager=sm,
|
||||||
|
callback_handler=cb, interrupt_manager=im,
|
||||||
|
)
|
||||||
|
if touch:
|
||||||
|
sm.touch(thread_id)
|
||||||
|
return g, sm, im, cb, ws, ws_ctx
|
||||||
|
|
||||||
|
|
||||||
|
async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
|
||||||
|
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
|
||||||
|
async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
|
||||||
|
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestWebSocketHappyPath:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_message_receives_tokens_and_complete(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||||
|
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
|
||||||
|
)
|
||||||
|
await _send(ws, ws_ctx, content="What is the status of order 1042?")
|
||||||
|
|
||||||
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 2
|
||||||
|
assert tokens[0]["content"] == "Order 1042 is "
|
||||||
|
assert tokens[0]["agent"] == "order_lookup"
|
||||||
|
assert tokens[1]["content"] == "shipped."
|
||||||
|
|
||||||
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||||
|
assert len(completes) == 1
|
||||||
|
assert completes[0]["thread_id"] == "t1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_streamed(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||||
|
graph=_graph(chunks=[
|
||||||
|
_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||||
|
_chunk("Order shipped."),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
await _send(ws, ws_ctx, content="Check order 1042")
|
||||||
|
|
||||||
|
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||||
|
assert len(tools) == 1
|
||||||
|
assert tools[0]["tool"] == "get_order_status"
|
||||||
|
assert tools[0]["args"] == {"order_id": "1042"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_messages_same_session(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
|
for i in range(3):
|
||||||
|
await _send(ws, ws_ctx, content=f"msg {i}")
|
||||||
|
|
||||||
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||||
|
assert len(completes) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestWebSocketInterruptApproval:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_interrupt_then_approve(self) -> None:
|
||||||
|
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||||
|
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
|
||||||
|
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||||
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||||
|
|
||||||
|
# Send message -> triggers interrupt
|
||||||
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||||
|
|
||||||
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||||
|
assert len(interrupts) == 1
|
||||||
|
assert interrupts[0]["action"] == "cancel_order"
|
||||||
|
assert interrupts[0]["thread_id"] == "t1"
|
||||||
|
assert im.has_pending("t1")
|
||||||
|
|
||||||
|
# Approve
|
||||||
|
ws.sent.clear()
|
||||||
|
await _respond(ws, ws_ctx, approved=True)
|
||||||
|
|
||||||
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||||
|
assert len(tokens) == 1
|
||||||
|
assert "cancelled" in tokens[0]["content"]
|
||||||
|
|
||||||
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||||
|
assert len(completes) == 1
|
||||||
|
assert not im.has_pending("t1")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_interrupt_then_reject(self) -> None:
|
||||||
|
st_int = _state(interrupt=True)
|
||||||
|
resume = [_chunk("Order remains active.", "order_actions")]
|
||||||
|
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||||
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||||
|
|
||||||
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||||
|
ws.sent.clear()
|
||||||
|
|
||||||
|
await _respond(ws, ws_ctx, approved=False)
|
||||||
|
|
||||||
|
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||||
|
assert "remains active" in tokens[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestWebSocketSessionTTL:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_expired_session_returns_error(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
|
||||||
|
# Session was touched in _setup, but TTL is 0 so it's already expired
|
||||||
|
await _send(ws, ws_ctx, content="hello")
|
||||||
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
assert "expired" in ws.sent[0]["message"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_session_not_expired(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||||
|
await _send(ws, ws_ctx, content="hello")
|
||||||
|
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||||
|
assert len(completes) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sliding_window_resets_on_message(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||||
|
|
||||||
|
await _send(ws, ws_ctx, content="hello")
|
||||||
|
first_activity = sm.get_state("t1").last_activity
|
||||||
|
|
||||||
|
time.sleep(0.01)
|
||||||
|
await _send(ws, ws_ctx, content="hello again")
|
||||||
|
second_activity = sm.get_state("t1").last_activity
|
||||||
|
|
||||||
|
assert second_activity > first_activity
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_interrupt_extends_session_ttl(self) -> None:
|
||||||
|
st_int = _state(interrupt=True)
|
||||||
|
g = _graph(chunks=[], st=st_int)
|
||||||
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
|
||||||
|
|
||||||
|
await _send(ws, ws_ctx, content="cancel order")
|
||||||
|
|
||||||
|
state = sm.get_state("t1")
|
||||||
|
assert state is not None
|
||||||
|
assert state.has_pending_interrupt
|
||||||
|
assert not sm.is_expired("t1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestWebSocketValidation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_json(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
|
await dispatch_message(ws, ws_ctx, "not json")
|
||||||
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
assert "Invalid JSON" in ws.sent[0]["message"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_thread_id(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
|
raw = json.dumps({"type": "message", "content": "hi"})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
assert "thread_id" in ws.sent[0]["message"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_thread_id_format(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
|
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_content(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
|
raw = json.dumps({"type": "message", "thread_id": "t1"})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_message_type(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
|
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
assert "Unknown" in ws.sent[0]["message"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_too_large(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
|
await dispatch_message(ws, ws_ctx, "x" * 40_000)
|
||||||
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
assert "too large" in ws.sent[0]["message"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_content_too_long(self) -> None:
|
||||||
|
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||||
|
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||||
|
await dispatch_message(ws, ws_ctx, raw)
|
||||||
|
assert ws.sent[0]["type"] == "error"
|
||||||
|
assert "too long" in ws.sent[0]["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestWebSocketInterruptTTL:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
||||||
|
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||||
|
g = _graph(chunks=[], st=st_int)
|
||||||
|
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
|
||||||
|
|
||||||
|
# Trigger interrupt
|
||||||
|
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||||
|
|
||||||
|
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||||
|
assert len(interrupts) == 1
|
||||||
|
|
||||||
|
# Expire the interrupt
|
||||||
|
record = im._interrupts["t1"]
|
||||||
|
ws.sent.clear()
|
||||||
|
|
||||||
|
with patch("app.interrupt_manager.time") as mock_time:
|
||||||
|
mock_time.time.return_value = record.created_at + 10
|
||||||
|
await _respond(ws, ws_ctx, approved=True)
|
||||||
|
|
||||||
|
assert ws.sent[0]["type"] == "interrupt_expired"
|
||||||
|
assert "cancel_order" in ws.sent[0]["message"]
|
||||||
|
assert ws.sent[0]["thread_id"] == "t1"
|
||||||
1
backend/tests/unit/analytics/__init__.py
Normal file
1
backend/tests/unit/analytics/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Unit tests for app.analytics module."""
|
||||||
149
backend/tests/unit/analytics/test_api.py
Normal file
149
backend/tests/unit/analytics/test_api.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""Unit tests for app.analytics.api."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app() -> FastAPI:
|
||||||
|
from app.analytics.api import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_pool() -> MagicMock:
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection.return_value = mock_ctx
|
||||||
|
return mock_pool
|
||||||
|
|
||||||
|
|
||||||
|
def _make_analytics_result() -> object:
|
||||||
|
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
|
||||||
|
|
||||||
|
return AnalyticsResult(
|
||||||
|
range="7d",
|
||||||
|
total_conversations=50,
|
||||||
|
resolution_rate=0.8,
|
||||||
|
escalation_rate=0.1,
|
||||||
|
avg_turns_per_conversation=3.5,
|
||||||
|
avg_cost_per_conversation_usd=0.02,
|
||||||
|
agent_usage=(AgentUsage(agent="order_agent", count=30, percentage=60.0),),
|
||||||
|
interrupt_stats=InterruptStats(total=5, approved=4, rejected=1, expired=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_analytics(app: FastAPI, path: str = "/api/v1/analytics", **patch_kwargs: object) -> object:
|
||||||
|
"""Helper: patch get_analytics, make request, return (response, mock)."""
|
||||||
|
analytics_result = _make_analytics_result()
|
||||||
|
with (
|
||||||
|
patch("app.analytics.api.get_analytics", return_value=analytics_result) as mock_ga,
|
||||||
|
TestClient(app) as client,
|
||||||
|
):
|
||||||
|
resp = client.get(path)
|
||||||
|
return resp, mock_ga
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsEndpoint:
|
||||||
|
def test_returns_200_with_default_range(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool()
|
||||||
|
resp, _ = _get_analytics(app)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["error"] is None
|
||||||
|
assert body["data"]["range"] == "7d"
|
||||||
|
|
||||||
|
def test_returns_correct_analytics_structure(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool()
|
||||||
|
resp, _ = _get_analytics(app)
|
||||||
|
|
||||||
|
data = resp.json()["data"]
|
||||||
|
assert "total_conversations" in data
|
||||||
|
assert "resolution_rate" in data
|
||||||
|
assert "escalation_rate" in data
|
||||||
|
assert "avg_turns_per_conversation" in data
|
||||||
|
assert "avg_cost_per_conversation_usd" in data
|
||||||
|
assert "agent_usage" in data
|
||||||
|
assert "interrupt_stats" in data
|
||||||
|
|
||||||
|
def test_custom_range_7d(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool()
|
||||||
|
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=7d")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
mock_ga.assert_called_once()
|
||||||
|
call_kwargs = mock_ga.call_args
|
||||||
|
assert call_kwargs[1]["range_days"] == 7 or call_kwargs[0][1] == 7
|
||||||
|
|
||||||
|
def test_custom_range_30d(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool()
|
||||||
|
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=30d")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
call_kwargs = mock_ga.call_args
|
||||||
|
assert call_kwargs[1].get("range_days") == 30 or (
|
||||||
|
len(call_kwargs[0]) > 1 and call_kwargs[0][1] == 30
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_invalid_range_format_returns_400(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/analytics?range=invalid")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_range_without_d_suffix_returns_400(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool()
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/analytics?range=7")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_agent_usage_in_response(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool()
|
||||||
|
resp, _ = _get_analytics(app)
|
||||||
|
|
||||||
|
data = resp.json()["data"]
|
||||||
|
assert len(data["agent_usage"]) == 1
|
||||||
|
assert data["agent_usage"][0]["agent"] == "order_agent"
|
||||||
|
|
||||||
|
def test_interrupt_stats_in_response(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool()
|
||||||
|
resp, _ = _get_analytics(app)
|
||||||
|
|
||||||
|
data = resp.json()["data"]
|
||||||
|
assert data["interrupt_stats"]["total"] == 5
|
||||||
|
assert data["interrupt_stats"]["approved"] == 4
|
||||||
|
|
||||||
|
def test_envelope_format(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool()
|
||||||
|
resp, _ = _get_analytics(app)
|
||||||
|
|
||||||
|
body = resp.json()
|
||||||
|
assert "success" in body
|
||||||
|
assert "data" in body
|
||||||
|
assert "error" in body
|
||||||
155
backend/tests/unit/analytics/test_event_recorder.py
Normal file
155
backend/tests/unit/analytics/test_event_recorder.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""Unit tests for app.analytics.event_recorder."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsRecorderProtocol:
|
||||||
|
def test_postgres_recorder_implements_protocol(self) -> None:
|
||||||
|
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
|
||||||
|
# Runtime check: has record method
|
||||||
|
assert hasattr(recorder, "record")
|
||||||
|
assert callable(recorder.record)
|
||||||
|
|
||||||
|
def test_noop_recorder_implements_protocol(self) -> None:
|
||||||
|
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||||
|
|
||||||
|
recorder = NoOpAnalyticsRecorder()
|
||||||
|
assert hasattr(recorder, "record")
|
||||||
|
assert callable(recorder.record)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoOpAnalyticsRecorder:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_does_nothing(self) -> None:
|
||||||
|
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||||
|
|
||||||
|
recorder = NoOpAnalyticsRecorder()
|
||||||
|
# Should not raise
|
||||||
|
await recorder.record(
|
||||||
|
thread_id="t1",
|
||||||
|
event_type="tool_call",
|
||||||
|
agent_name="order_agent",
|
||||||
|
tool_name="get_order",
|
||||||
|
tokens_used=50,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_with_all_params(self) -> None:
|
||||||
|
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||||
|
|
||||||
|
recorder = NoOpAnalyticsRecorder()
|
||||||
|
await recorder.record(
|
||||||
|
thread_id="t1",
|
||||||
|
event_type="agent_response",
|
||||||
|
agent_name="fallback",
|
||||||
|
tool_name=None,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_usd=0.002,
|
||||||
|
duration_ms=150,
|
||||||
|
success=True,
|
||||||
|
error_message=None,
|
||||||
|
metadata={"extra": "data"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_minimal_params(self) -> None:
|
||||||
|
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||||
|
|
||||||
|
recorder = NoOpAnalyticsRecorder()
|
||||||
|
# Only required params
|
||||||
|
await recorder.record(thread_id="t1", event_type="conversation_start")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPostgresAnalyticsRecorder:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_executes_insert(self) -> None:
|
||||||
|
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection.return_value = mock_ctx
|
||||||
|
|
||||||
|
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
|
||||||
|
await recorder.record(
|
||||||
|
thread_id="t1",
|
||||||
|
event_type="tool_call",
|
||||||
|
agent_name="order_agent",
|
||||||
|
tokens_used=50,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
mock_conn.execute.assert_awaited_once()
|
||||||
|
call_args = mock_conn.execute.call_args
|
||||||
|
sql = call_args[0][0]
|
||||||
|
assert "INSERT INTO analytics_events" in sql
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_passes_correct_params(self) -> None:
|
||||||
|
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection.return_value = mock_ctx
|
||||||
|
|
||||||
|
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
|
||||||
|
await recorder.record(
|
||||||
|
thread_id="thread-xyz",
|
||||||
|
event_type="agent_response",
|
||||||
|
agent_name="discount_agent",
|
||||||
|
tool_name="apply_discount",
|
||||||
|
tokens_used=75,
|
||||||
|
cost_usd=0.002,
|
||||||
|
duration_ms=300,
|
||||||
|
success=True,
|
||||||
|
error_message=None,
|
||||||
|
metadata={"promo": "10PCT"},
|
||||||
|
)
|
||||||
|
call_args = mock_conn.execute.call_args
|
||||||
|
params = call_args[0][1]
|
||||||
|
assert params["thread_id"] == "thread-xyz"
|
||||||
|
assert params["event_type"] == "agent_response"
|
||||||
|
assert params["agent_name"] == "discount_agent"
|
||||||
|
assert params["tokens_used"] == 75
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_stores_metadata_as_dict(self) -> None:
|
||||||
|
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection.return_value = mock_ctx
|
||||||
|
|
||||||
|
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
|
||||||
|
await recorder.record(
|
||||||
|
thread_id="t1",
|
||||||
|
event_type="tool_call",
|
||||||
|
metadata={"key": "val"},
|
||||||
|
)
|
||||||
|
call_args = mock_conn.execute.call_args
|
||||||
|
params = call_args[0][1]
|
||||||
|
# PostgresAnalyticsRecorder wraps metadata with psycopg Json() adapter.
|
||||||
|
# Unwrap to compare the inner dict.
|
||||||
|
from psycopg.types.json import Json
|
||||||
|
|
||||||
|
meta = params["metadata"]
|
||||||
|
if isinstance(meta, Json):
|
||||||
|
meta = meta.obj
|
||||||
|
assert meta == {"key": "val"}
|
||||||
106
backend/tests/unit/analytics/test_models.py
Normal file
106
backend/tests/unit/analytics/test_models.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""Unit tests for app.analytics.models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentUsage:
|
||||||
|
def test_agent_usage_construction(self) -> None:
|
||||||
|
from app.analytics.models import AgentUsage
|
||||||
|
|
||||||
|
au = AgentUsage(agent="order_agent", count=10, percentage=50.0)
|
||||||
|
assert au.agent == "order_agent"
|
||||||
|
assert au.count == 10
|
||||||
|
assert au.percentage == 50.0
|
||||||
|
|
||||||
|
def test_agent_usage_is_frozen(self) -> None:
|
||||||
|
from app.analytics.models import AgentUsage
|
||||||
|
|
||||||
|
au = AgentUsage(agent="a", count=1, percentage=100.0)
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
au.count = 2 # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
class TestInterruptStats:
|
||||||
|
def test_interrupt_stats_defaults(self) -> None:
|
||||||
|
from app.analytics.models import InterruptStats
|
||||||
|
|
||||||
|
stats = InterruptStats()
|
||||||
|
assert stats.total == 0
|
||||||
|
assert stats.approved == 0
|
||||||
|
assert stats.rejected == 0
|
||||||
|
assert stats.expired == 0
|
||||||
|
|
||||||
|
def test_interrupt_stats_custom_values(self) -> None:
|
||||||
|
from app.analytics.models import InterruptStats
|
||||||
|
|
||||||
|
stats = InterruptStats(total=10, approved=7, rejected=2, expired=1)
|
||||||
|
assert stats.total == 10
|
||||||
|
assert stats.approved == 7
|
||||||
|
assert stats.rejected == 2
|
||||||
|
assert stats.expired == 1
|
||||||
|
|
||||||
|
def test_interrupt_stats_is_frozen(self) -> None:
|
||||||
|
from app.analytics.models import InterruptStats
|
||||||
|
|
||||||
|
stats = InterruptStats()
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
stats.total = 5 # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsResult:
|
||||||
|
def test_analytics_result_construction(self) -> None:
|
||||||
|
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
|
||||||
|
|
||||||
|
result = AnalyticsResult(
|
||||||
|
range="7d",
|
||||||
|
total_conversations=100,
|
||||||
|
resolution_rate=0.85,
|
||||||
|
escalation_rate=0.05,
|
||||||
|
avg_turns_per_conversation=4.2,
|
||||||
|
avg_cost_per_conversation_usd=0.03,
|
||||||
|
agent_usage=(AgentUsage(agent="order_agent", count=60, percentage=60.0),),
|
||||||
|
interrupt_stats=InterruptStats(total=5, approved=4, rejected=1, expired=0),
|
||||||
|
)
|
||||||
|
assert result.range == "7d"
|
||||||
|
assert result.total_conversations == 100
|
||||||
|
assert result.resolution_rate == 0.85
|
||||||
|
assert result.escalation_rate == 0.05
|
||||||
|
assert result.avg_turns_per_conversation == 4.2
|
||||||
|
assert result.avg_cost_per_conversation_usd == 0.03
|
||||||
|
assert len(result.agent_usage) == 1
|
||||||
|
assert result.interrupt_stats.total == 5
|
||||||
|
|
||||||
|
def test_analytics_result_is_frozen(self) -> None:
|
||||||
|
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||||
|
|
||||||
|
result = AnalyticsResult(
|
||||||
|
range="7d",
|
||||||
|
total_conversations=0,
|
||||||
|
resolution_rate=0.0,
|
||||||
|
escalation_rate=0.0,
|
||||||
|
avg_turns_per_conversation=0.0,
|
||||||
|
avg_cost_per_conversation_usd=0.0,
|
||||||
|
agent_usage=(),
|
||||||
|
interrupt_stats=InterruptStats(),
|
||||||
|
)
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
result.range = "30d" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_analytics_result_empty_agent_usage(self) -> None:
|
||||||
|
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||||
|
|
||||||
|
result = AnalyticsResult(
|
||||||
|
range="7d",
|
||||||
|
total_conversations=0,
|
||||||
|
resolution_rate=0.0,
|
||||||
|
escalation_rate=0.0,
|
||||||
|
avg_turns_per_conversation=0.0,
|
||||||
|
avg_cost_per_conversation_usd=0.0,
|
||||||
|
agent_usage=(),
|
||||||
|
interrupt_stats=InterruptStats(),
|
||||||
|
)
|
||||||
|
assert result.agent_usage == ()
|
||||||
249
backend/tests/unit/analytics/test_queries.py
Normal file
249
backend/tests/unit/analytics/test_queries.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""Unit tests for app.analytics.queries."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pool_with_fetchone(result: dict | None) -> MagicMock:
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.fetchone = AsyncMock(return_value=result)
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection.return_value = mock_ctx
|
||||||
|
return mock_pool
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pool_with_fetchall(result: list[dict]) -> MagicMock:
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.fetchall = AsyncMock(return_value=result)
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection.return_value = mock_ctx
|
||||||
|
return mock_pool
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolutionRate:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_float(self) -> None:
|
||||||
|
from app.analytics.queries import resolution_rate
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone({"rate": 0.85})
|
||||||
|
result = await resolution_rate(pool, range_days=7)
|
||||||
|
assert isinstance(result, float)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_state_returns_zero(self) -> None:
|
||||||
|
from app.analytics.queries import resolution_rate
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone(None)
|
||||||
|
result = await resolution_rate(pool, range_days=7)
|
||||||
|
assert result == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_correct_value(self) -> None:
|
||||||
|
from app.analytics.queries import resolution_rate
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone({"rate": 0.75})
|
||||||
|
result = await resolution_rate(pool, range_days=7)
|
||||||
|
assert result == 0.75
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentUsageQuery:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_tuple(self) -> None:
|
||||||
|
from app.analytics.queries import agent_usage
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchall([])
|
||||||
|
result = await agent_usage(pool, range_days=7)
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_state_returns_empty_tuple(self) -> None:
|
||||||
|
from app.analytics.queries import agent_usage
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchall([])
|
||||||
|
result = await agent_usage(pool, range_days=7)
|
||||||
|
assert result == ()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_maps_rows_to_agent_usage_objects(self) -> None:
|
||||||
|
from app.analytics.models import AgentUsage
|
||||||
|
from app.analytics.queries import agent_usage
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchall([
|
||||||
|
{"agent": "order_agent", "count": 10, "percentage": 66.7},
|
||||||
|
{"agent": "discount_agent", "count": 5, "percentage": 33.3},
|
||||||
|
])
|
||||||
|
result = await agent_usage(pool, range_days=7)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert isinstance(result[0], AgentUsage)
|
||||||
|
assert result[0].agent == "order_agent"
|
||||||
|
assert result[0].count == 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestEscalationRate:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_float(self) -> None:
|
||||||
|
from app.analytics.queries import escalation_rate
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone({"rate": 0.05})
|
||||||
|
result = await escalation_rate(pool, range_days=7)
|
||||||
|
assert isinstance(result, float)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_state_returns_zero(self) -> None:
|
||||||
|
from app.analytics.queries import escalation_rate
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone(None)
|
||||||
|
result = await escalation_rate(pool, range_days=7)
|
||||||
|
assert result == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCostPerConversation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_float(self) -> None:
|
||||||
|
from app.analytics.queries import cost_per_conversation
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone({"avg_cost": 0.03})
|
||||||
|
result = await cost_per_conversation(pool, range_days=7)
|
||||||
|
assert isinstance(result, float)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_state_returns_zero(self) -> None:
|
||||||
|
from app.analytics.queries import cost_per_conversation
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone(None)
|
||||||
|
result = await cost_per_conversation(pool, range_days=7)
|
||||||
|
assert result == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestInterruptStatsQuery:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_interrupt_stats(self) -> None:
|
||||||
|
from app.analytics.models import InterruptStats
|
||||||
|
from app.analytics.queries import interrupt_stats
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone(
|
||||||
|
{"total": 10, "approved": 7, "rejected": 2, "expired": 1}
|
||||||
|
)
|
||||||
|
result = await interrupt_stats(pool, range_days=7)
|
||||||
|
assert isinstance(result, InterruptStats)
|
||||||
|
assert result.total == 10
|
||||||
|
assert result.approved == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_state_returns_zeros(self) -> None:
|
||||||
|
from app.analytics.models import InterruptStats
|
||||||
|
from app.analytics.queries import interrupt_stats
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone(None)
|
||||||
|
result = await interrupt_stats(pool, range_days=7)
|
||||||
|
assert isinstance(result, InterruptStats)
|
||||||
|
assert result.total == 0
|
||||||
|
assert result.approved == 0
|
||||||
|
assert result.rejected == 0
|
||||||
|
assert result.expired == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTotalConversations:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_count(self) -> None:
|
||||||
|
from app.analytics.queries import _total_conversations
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone({"total": 42})
|
||||||
|
result = await _total_conversations(pool, range_days=7)
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_state_returns_zero(self) -> None:
|
||||||
|
from app.analytics.queries import _total_conversations
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone(None)
|
||||||
|
result = await _total_conversations(pool, range_days=7)
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAvgTurns:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_float(self) -> None:
|
||||||
|
from app.analytics.queries import _avg_turns
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone({"avg_turns": 3.5})
|
||||||
|
result = await _avg_turns(pool, range_days=7)
|
||||||
|
assert result == 3.5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_state_returns_zero(self) -> None:
|
||||||
|
from app.analytics.queries import _avg_turns
|
||||||
|
|
||||||
|
pool = _make_pool_with_fetchone(None)
|
||||||
|
result = await _avg_turns(pool, range_days=7)
|
||||||
|
assert result == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAnalytics:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_analytics_result(self) -> None:
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||||
|
from app.analytics.queries import get_analytics
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.analytics.queries.resolution_rate", return_value=0.85),
|
||||||
|
patch("app.analytics.queries.escalation_rate", return_value=0.05),
|
||||||
|
patch("app.analytics.queries.cost_per_conversation", return_value=0.03),
|
||||||
|
patch("app.analytics.queries.agent_usage", return_value=()),
|
||||||
|
patch(
|
||||||
|
"app.analytics.queries.interrupt_stats",
|
||||||
|
return_value=InterruptStats(),
|
||||||
|
),
|
||||||
|
patch("app.analytics.queries._total_conversations", return_value=100),
|
||||||
|
patch("app.analytics.queries._avg_turns", return_value=4.2),
|
||||||
|
):
|
||||||
|
result = await get_analytics(mock_pool, range_days=7)
|
||||||
|
|
||||||
|
assert isinstance(result, AnalyticsResult)
|
||||||
|
assert result.range == "7d"
|
||||||
|
assert result.total_conversations == 100
|
||||||
|
assert result.resolution_rate == 0.85
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_state_returns_zeros(self) -> None:
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||||
|
from app.analytics.queries import get_analytics
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.analytics.queries.resolution_rate", return_value=0.0),
|
||||||
|
patch("app.analytics.queries.escalation_rate", return_value=0.0),
|
||||||
|
patch("app.analytics.queries.cost_per_conversation", return_value=0.0),
|
||||||
|
patch("app.analytics.queries.agent_usage", return_value=()),
|
||||||
|
patch("app.analytics.queries.interrupt_stats", return_value=InterruptStats()),
|
||||||
|
patch("app.analytics.queries._total_conversations", return_value=0),
|
||||||
|
patch("app.analytics.queries._avg_turns", return_value=0.0),
|
||||||
|
):
|
||||||
|
result = await get_analytics(mock_pool, range_days=7)
|
||||||
|
|
||||||
|
assert isinstance(result, AnalyticsResult)
|
||||||
|
assert result.total_conversations == 0
|
||||||
|
assert result.resolution_rate == 0.0
|
||||||
|
assert result.agent_usage == ()
|
||||||
0
backend/tests/unit/openapi/__init__.py
Normal file
0
backend/tests/unit/openapi/__init__.py
Normal file
249
backend/tests/unit/openapi/test_classifier.py
Normal file
249
backend/tests/unit/openapi/test_classifier.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""Tests for OpenAPI endpoint classifier module.
|
||||||
|
|
||||||
|
RED phase: written before implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.openapi.models import EndpointInfo, ParameterInfo
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_endpoint(
|
||||||
|
path: str = "/items",
|
||||||
|
method: str = "GET",
|
||||||
|
operation_id: str = "list_items",
|
||||||
|
summary: str = "List items",
|
||||||
|
description: str = "",
|
||||||
|
parameters: tuple[ParameterInfo, ...] = (),
|
||||||
|
) -> EndpointInfo:
|
||||||
|
return EndpointInfo(
|
||||||
|
path=path,
|
||||||
|
method=method,
|
||||||
|
operation_id=operation_id,
|
||||||
|
summary=summary,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_ORDER_PARAM = ParameterInfo(
|
||||||
|
name="order_id", location="path", required=True, schema_type="string"
|
||||||
|
)
|
||||||
|
_CUSTOMER_PARAM = ParameterInfo(
|
||||||
|
name="customer_id", location="query", required=False, schema_type="string"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeuristicClassifier:
|
||||||
|
"""Tests for the rule-based HeuristicClassifier."""
|
||||||
|
|
||||||
|
async def test_get_classified_as_read(self) -> None:
|
||||||
|
"""GET endpoints are classified as read access."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint(method="GET")
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].access_type == "read"
|
||||||
|
|
||||||
|
async def test_post_classified_as_write(self) -> None:
|
||||||
|
"""POST endpoints are classified as write access."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint(method="POST")
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert results[0].access_type == "write"
|
||||||
|
|
||||||
|
async def test_post_needs_interrupt(self) -> None:
|
||||||
|
"""POST endpoints require interrupt/approval."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint(method="POST")
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert results[0].needs_interrupt is True
|
||||||
|
|
||||||
|
async def test_put_classified_as_write(self) -> None:
|
||||||
|
"""PUT endpoints are classified as write access."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint(method="PUT")
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert results[0].access_type == "write"
|
||||||
|
|
||||||
|
async def test_delete_classified_as_write_with_interrupt(self) -> None:
|
||||||
|
"""DELETE endpoints are classified as write and require interrupt."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint(method="DELETE")
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert results[0].access_type == "write"
|
||||||
|
assert results[0].needs_interrupt is True
|
||||||
|
|
||||||
|
async def test_get_does_not_need_interrupt(self) -> None:
|
||||||
|
"""GET endpoints do not require interrupt."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint(method="GET")
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert results[0].needs_interrupt is False
|
||||||
|
|
||||||
|
async def test_empty_endpoints_returns_empty_tuple(self) -> None:
|
||||||
|
"""Empty input yields empty output."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
results = await clf.classify(())
|
||||||
|
assert results == ()
|
||||||
|
|
||||||
|
async def test_customer_params_detected_order_id(self) -> None:
|
||||||
|
"""Parameters named order_id are recognized as customer params."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint(method="GET", parameters=(_ORDER_PARAM,))
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert "order_id" in results[0].customer_params
|
||||||
|
|
||||||
|
async def test_customer_params_detected_customer_id(self) -> None:
|
||||||
|
"""Parameters named customer_id are recognized as customer params."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint(method="GET", parameters=(_CUSTOMER_PARAM,))
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert "customer_id" in results[0].customer_params
|
||||||
|
|
||||||
|
async def test_result_is_tuple(self) -> None:
|
||||||
|
"""classify returns a tuple (immutable)."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint()
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert isinstance(results, tuple)
|
||||||
|
|
||||||
|
async def test_classification_has_confidence(self) -> None:
|
||||||
|
"""Heuristic results have a confidence value between 0 and 1."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint()
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert 0.0 <= results[0].confidence <= 1.0
|
||||||
|
|
||||||
|
async def test_patch_classified_as_write(self) -> None:
|
||||||
|
"""PATCH endpoints are classified as write access."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
ep = _make_endpoint(method="PATCH")
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert results[0].access_type == "write"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMClassifier:
|
||||||
|
"""Tests for the LLM-backed classifier."""
|
||||||
|
|
||||||
|
def _make_mock_llm(self, classifications: list[dict]) -> MagicMock:
|
||||||
|
"""Create a mock LLM that returns structured classification data."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = str(classifications)
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
return mock_llm
|
||||||
|
|
||||||
|
async def test_llm_classifier_classifies_endpoints(self) -> None:
|
||||||
|
"""LLM classifier returns ClassificationResult for each endpoint."""
|
||||||
|
from app.openapi.classifier import LLMClassifier
|
||||||
|
|
||||||
|
ep = _make_endpoint(method="GET")
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = (
|
||||||
|
'[{"access_type": "read", "agent_group": "support",'
|
||||||
|
' "needs_interrupt": false, "customer_params": [], "confidence": 0.9}]'
|
||||||
|
)
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
clf = LLMClassifier(llm=mock_llm)
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].access_type == "read"
|
||||||
|
|
||||||
|
async def test_llm_failure_falls_back_to_heuristic(self) -> None:
|
||||||
|
"""When LLM raises an exception, falls back to heuristic classifier."""
|
||||||
|
from app.openapi.classifier import LLMClassifier
|
||||||
|
|
||||||
|
ep = _make_endpoint(method="GET")
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
|
||||||
|
|
||||||
|
clf = LLMClassifier(llm=mock_llm)
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
# Falls back to heuristic: GET = read
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].access_type == "read"
|
||||||
|
|
||||||
|
async def test_llm_invalid_json_falls_back_to_heuristic(self) -> None:
|
||||||
|
"""When LLM returns unparseable output, falls back to heuristic."""
|
||||||
|
from app.openapi.classifier import LLMClassifier
|
||||||
|
|
||||||
|
ep = _make_endpoint(method="DELETE")
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = "this is not valid json at all"
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
clf = LLMClassifier(llm=mock_llm)
|
||||||
|
results = await clf.classify((ep,))
|
||||||
|
# Fallback: DELETE = write with interrupt
|
||||||
|
assert results[0].access_type == "write"
|
||||||
|
assert results[0].needs_interrupt is True
|
||||||
|
|
||||||
|
async def test_llm_empty_endpoints_returns_empty(self) -> None:
|
||||||
|
"""Empty input yields empty output without calling LLM."""
|
||||||
|
from app.openapi.classifier import LLMClassifier
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock()
|
||||||
|
|
||||||
|
clf = LLMClassifier(llm=mock_llm)
|
||||||
|
results = await clf.classify(())
|
||||||
|
assert results == ()
|
||||||
|
mock_llm.ainvoke.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassifierProtocol:
|
||||||
|
"""Verify both classifiers conform to ClassifierProtocol."""
|
||||||
|
|
||||||
|
def test_heuristic_has_classify_method(self) -> None:
|
||||||
|
"""HeuristicClassifier exposes classify method."""
|
||||||
|
from app.openapi.classifier import HeuristicClassifier
|
||||||
|
|
||||||
|
clf = HeuristicClassifier()
|
||||||
|
assert hasattr(clf, "classify")
|
||||||
|
assert callable(clf.classify)
|
||||||
|
|
||||||
|
def test_llm_has_classify_method(self) -> None:
|
||||||
|
"""LLMClassifier exposes classify method."""
|
||||||
|
from app.openapi.classifier import LLMClassifier
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
clf = LLMClassifier(llm=mock_llm)
|
||||||
|
assert hasattr(clf, "classify")
|
||||||
|
assert callable(clf.classify)
|
||||||
120
backend/tests/unit/openapi/test_fetcher.py
Normal file
120
backend/tests/unit/openapi/test_fetcher.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Tests for OpenAPI spec fetcher module.
|
||||||
|
|
||||||
|
RED phase: written before implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.openapi.ssrf import SSRFError
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
_SAMPLE_JSON = '{"openapi": "3.0.0", "info": {"title": "Test", "version": "1.0"}, "paths": {}}'
|
||||||
|
_SAMPLE_YAML = "openapi: '3.0.0'\ninfo:\n title: Test\n version: '1.0'\npaths: {}\n"
|
||||||
|
_PUBLIC_IP = "93.184.216.34"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _mock_public_dns():
|
||||||
|
with patch("app.openapi.ssrf.resolve_hostname", return_value=[_PUBLIC_IP]):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchSpec:
|
||||||
|
"""Tests for fetch_spec function."""
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_fetch_json_spec_succeeds(self, httpx_mock) -> None:
|
||||||
|
"""Fetch a JSON spec and return parsed dict."""
|
||||||
|
from app.openapi.fetcher import fetch_spec
|
||||||
|
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/spec.json",
|
||||||
|
text=_SAMPLE_JSON,
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
result = await fetch_spec("https://example.com/spec.json")
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["openapi"] == "3.0.0"
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_fetch_yaml_spec_succeeds(self, httpx_mock) -> None:
|
||||||
|
"""Fetch a YAML spec and return parsed dict."""
|
||||||
|
from app.openapi.fetcher import fetch_spec
|
||||||
|
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/spec.yaml",
|
||||||
|
text=_SAMPLE_YAML,
|
||||||
|
headers={"content-type": "application/x-yaml"},
|
||||||
|
)
|
||||||
|
result = await fetch_spec("https://example.com/spec.yaml")
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["openapi"] == "3.0.0"
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_fetch_yaml_by_url_extension(self, httpx_mock) -> None:
|
||||||
|
"""Auto-detect YAML format from .yaml URL extension."""
|
||||||
|
from app.openapi.fetcher import fetch_spec
|
||||||
|
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/api.yaml",
|
||||||
|
text=_SAMPLE_YAML,
|
||||||
|
headers={"content-type": "text/plain"},
|
||||||
|
)
|
||||||
|
result = await fetch_spec("https://example.com/api.yaml")
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["openapi"] == "3.0.0"
|
||||||
|
|
||||||
|
async def test_ssrf_blocked_url_raises(self) -> None:
|
||||||
|
"""SSRF-blocked URL raises SSRFError."""
|
||||||
|
from app.openapi.fetcher import fetch_spec
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]),
|
||||||
|
pytest.raises(SSRFError),
|
||||||
|
):
|
||||||
|
await fetch_spec("http://internal.corp/spec.json")
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_oversized_response_raises(self, httpx_mock) -> None:
|
||||||
|
"""Response exceeding 10MB raises ValueError."""
|
||||||
|
from app.openapi.fetcher import fetch_spec
|
||||||
|
|
||||||
|
big_content = "x" * (10 * 1024 * 1024 + 1)
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/huge.json",
|
||||||
|
text=big_content,
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="too large"):
|
||||||
|
await fetch_spec("https://example.com/huge.json")
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_invalid_json_raises(self, httpx_mock) -> None:
|
||||||
|
"""Non-parseable JSON raises ValueError."""
|
||||||
|
from app.openapi.fetcher import fetch_spec
|
||||||
|
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/bad.json",
|
||||||
|
text="not valid json {{{",
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="[Pp]arse|[Ii]nvalid|[Dd]ecode"):
|
||||||
|
await fetch_spec("https://example.com/bad.json")
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_public_dns")
|
||||||
|
async def test_invalid_yaml_raises(self, httpx_mock) -> None:
|
||||||
|
"""Non-parseable YAML raises ValueError."""
|
||||||
|
from app.openapi.fetcher import fetch_spec
|
||||||
|
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="https://example.com/bad.yaml",
|
||||||
|
text=": invalid: yaml: {\n",
|
||||||
|
headers={"content-type": "application/x-yaml"},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="[Pp]arse|[Ii]nvalid|[Yy]AML"):
|
||||||
|
await fetch_spec("https://example.com/bad.yaml")
|
||||||
258
backend/tests/unit/openapi/test_generator.py
Normal file
258
backend/tests/unit/openapi/test_generator.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
"""Tests for OpenAPI tool generator module.
|
||||||
|
|
||||||
|
RED phase: written before implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.openapi.models import ClassificationResult, EndpointInfo, ParameterInfo
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
_BASE_URL = "https://api.example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_endpoint(
|
||||||
|
path: str = "/items",
|
||||||
|
method: str = "GET",
|
||||||
|
operation_id: str = "list_items",
|
||||||
|
summary: str = "List items",
|
||||||
|
description: str = "Returns all items",
|
||||||
|
parameters: tuple[ParameterInfo, ...] = (),
|
||||||
|
request_body_schema: dict | None = None,
|
||||||
|
) -> EndpointInfo:
|
||||||
|
return EndpointInfo(
|
||||||
|
path=path,
|
||||||
|
method=method,
|
||||||
|
operation_id=operation_id,
|
||||||
|
summary=summary,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters,
|
||||||
|
request_body_schema=request_body_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_classification(
|
||||||
|
endpoint: EndpointInfo,
|
||||||
|
access_type: str = "read",
|
||||||
|
needs_interrupt: bool = False,
|
||||||
|
agent_group: str = "read_agent",
|
||||||
|
) -> ClassificationResult:
|
||||||
|
return ClassificationResult(
|
||||||
|
endpoint=endpoint,
|
||||||
|
access_type=access_type,
|
||||||
|
customer_params=(),
|
||||||
|
agent_group=agent_group,
|
||||||
|
confidence=0.9,
|
||||||
|
needs_interrupt=needs_interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PATH_PARAM = ParameterInfo(
|
||||||
|
name="item_id", location="path", required=True, schema_type="string"
|
||||||
|
)
|
||||||
|
_QUERY_PARAM = ParameterInfo(
|
||||||
|
name="filter", location="query", required=False, schema_type="string"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateToolCode:
|
||||||
|
"""Tests for generate_tool_code function."""
|
||||||
|
|
||||||
|
def test_generate_tool_for_get_endpoint(self) -> None:
|
||||||
|
"""Generated tool for GET endpoint is a GeneratedTool with non-empty code."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint(method="GET")
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
assert tool.function_name == "list_items"
|
||||||
|
assert tool.code != ""
|
||||||
|
assert "@tool" in tool.code
|
||||||
|
|
||||||
|
def test_generate_tool_contains_function_name(self) -> None:
|
||||||
|
"""Generated code contains the function name."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint(operation_id="get_order", method="GET")
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
assert "get_order" in tool.code
|
||||||
|
|
||||||
|
def test_generate_tool_contains_base_url(self) -> None:
|
||||||
|
"""Generated code contains the base URL."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint()
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
assert _BASE_URL in tool.code
|
||||||
|
|
||||||
|
def test_generate_tool_contains_http_method(self) -> None:
|
||||||
|
"""Generated code uses the correct HTTP method."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint(method="POST")
|
||||||
|
clf = _make_classification(ep, access_type="write")
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
assert "post" in tool.code.lower()
|
||||||
|
|
||||||
|
def test_generate_tool_for_post_with_body(self) -> None:
|
||||||
|
"""Generated tool for POST includes body parameter."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint(
|
||||||
|
method="POST",
|
||||||
|
request_body_schema={"type": "object", "properties": {"name": {"type": "string"}}},
|
||||||
|
)
|
||||||
|
clf = _make_classification(ep, access_type="write")
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
assert tool.code != ""
|
||||||
|
assert "POST" in tool.code or "post" in tool.code
|
||||||
|
|
||||||
|
def test_generate_tool_with_path_params(self) -> None:
|
||||||
|
"""Generated tool includes path parameter in function signature."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint(
|
||||||
|
path="/items/{item_id}",
|
||||||
|
operation_id="get_item",
|
||||||
|
parameters=(_PATH_PARAM,),
|
||||||
|
)
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
assert "item_id" in tool.code
|
||||||
|
|
||||||
|
def test_write_tool_includes_interrupt_marker(self) -> None:
|
||||||
|
"""Write tools that need interrupt include a marker comment."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint(method="DELETE", operation_id="delete_item")
|
||||||
|
clf = _make_classification(ep, access_type="write", needs_interrupt=True)
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
assert "interrupt" in tool.code.lower() or "approval" in tool.code.lower()
|
||||||
|
|
||||||
|
def test_generated_code_is_executable(self) -> None:
|
||||||
|
"""Generated code can be exec'd without syntax errors."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint(
|
||||||
|
path="/items/{item_id}",
|
||||||
|
operation_id="fetch_item",
|
||||||
|
parameters=(_PATH_PARAM,),
|
||||||
|
)
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
# Must be valid Python syntax
|
||||||
|
compile(tool.code, "<generated>", "exec")
|
||||||
|
|
||||||
|
def test_generated_tool_code_exec_imports(self) -> None:
|
||||||
|
"""Generated code exec'd with required imports does not raise."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint()
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
namespace: dict = {}
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool as lc_tool
|
||||||
|
|
||||||
|
namespace = {"httpx": httpx, "tool": lc_tool}
|
||||||
|
exec(tool.code, namespace) # noqa: S102
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("langchain_core not available for exec test")
|
||||||
|
|
||||||
|
def test_returns_generated_tool_instance(self) -> None:
|
||||||
|
"""generate_tool_code returns a GeneratedTool instance."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
from app.openapi.models import GeneratedTool
|
||||||
|
|
||||||
|
ep = _make_endpoint()
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
assert isinstance(tool, GeneratedTool)
|
||||||
|
|
||||||
|
def test_generated_tool_is_frozen(self) -> None:
|
||||||
|
"""GeneratedTool instance is immutable."""
|
||||||
|
from app.openapi.generator import generate_tool_code
|
||||||
|
|
||||||
|
ep = _make_endpoint()
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
tool = generate_tool_code(clf, _BASE_URL)
|
||||||
|
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
tool.code = "new code" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateAgentYaml:
|
||||||
|
"""Tests for generate_agent_yaml function."""
|
||||||
|
|
||||||
|
def test_generate_yaml_is_valid_string(self) -> None:
|
||||||
|
"""generate_agent_yaml returns a non-empty string."""
|
||||||
|
from app.openapi.generator import generate_agent_yaml
|
||||||
|
|
||||||
|
ep = _make_endpoint()
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
def test_generated_yaml_is_parseable(self) -> None:
|
||||||
|
"""Output can be parsed as YAML."""
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.openapi.generator import generate_agent_yaml
|
||||||
|
|
||||||
|
ep = _make_endpoint()
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||||
|
|
||||||
|
parsed = yaml.safe_load(result)
|
||||||
|
assert isinstance(parsed, dict)
|
||||||
|
|
||||||
|
def test_generated_yaml_contains_agents_key(self) -> None:
|
||||||
|
"""Generated YAML has an 'agents' key matching AgentConfig format."""
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.openapi.generator import generate_agent_yaml
|
||||||
|
|
||||||
|
ep = _make_endpoint()
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||||
|
|
||||||
|
parsed = yaml.safe_load(result)
|
||||||
|
assert "agents" in parsed
|
||||||
|
|
||||||
|
def test_generated_yaml_contains_tool_name(self) -> None:
|
||||||
|
"""Generated YAML references the tool function name."""
|
||||||
|
from app.openapi.generator import generate_agent_yaml
|
||||||
|
|
||||||
|
ep = _make_endpoint(operation_id="list_orders")
|
||||||
|
clf = _make_classification(ep)
|
||||||
|
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||||
|
|
||||||
|
assert "list_orders" in result
|
||||||
|
|
||||||
|
def test_empty_classifications_returns_empty_agents(self) -> None:
|
||||||
|
"""No classifications yields YAML with empty agents list."""
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.openapi.generator import generate_agent_yaml
|
||||||
|
|
||||||
|
result = generate_agent_yaml((), _BASE_URL)
|
||||||
|
parsed = yaml.safe_load(result)
|
||||||
|
assert parsed.get("agents") == [] or parsed.get("agents") is None
|
||||||
290
backend/tests/unit/openapi/test_parser.py
Normal file
290
backend/tests/unit/openapi/test_parser.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
"""Tests for OpenAPI endpoint parser module.
|
||||||
|
|
||||||
|
RED phase: written before implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
_MINIMAL_SPEC = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Test API", "version": "1.0.0"},
|
||||||
|
"paths": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
_GET_SPEC = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Orders API", "version": "1.0.0"},
|
||||||
|
"paths": {
|
||||||
|
"/orders/{order_id}": {
|
||||||
|
"get": {
|
||||||
|
"operationId": "get_order",
|
||||||
|
"summary": "Get an order",
|
||||||
|
"description": "Retrieves a single order by ID",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "order_id",
|
||||||
|
"in": "path",
|
||||||
|
"required": True,
|
||||||
|
"schema": {"type": "string"},
|
||||||
|
"description": "The order identifier",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Order found",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"id": {"type": "string"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_POST_SPEC = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Orders API", "version": "1.0.0"},
|
||||||
|
"paths": {
|
||||||
|
"/orders": {
|
||||||
|
"post": {
|
||||||
|
"operationId": "create_order",
|
||||||
|
"summary": "Create an order",
|
||||||
|
"description": "Creates a new order",
|
||||||
|
"requestBody": {
|
||||||
|
"required": True,
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"item": {"type": "string"},
|
||||||
|
"quantity": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"responses": {"201": {"description": "Created"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_MULTI_PARAM_SPEC = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Items API", "version": "1.0.0"},
|
||||||
|
"paths": {
|
||||||
|
"/items/{item_id}": {
|
||||||
|
"get": {
|
||||||
|
"operationId": "get_item",
|
||||||
|
"summary": "Get item",
|
||||||
|
"description": "",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "item_id",
|
||||||
|
"in": "path",
|
||||||
|
"required": True,
|
||||||
|
"schema": {"type": "integer"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "include_details",
|
||||||
|
"in": "query",
|
||||||
|
"required": False,
|
||||||
|
"schema": {"type": "boolean"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"responses": {"200": {"description": "OK"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_REF_SPEC = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Ref API", "version": "1.0.0"},
|
||||||
|
"components": {
|
||||||
|
"schemas": {
|
||||||
|
"Item": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"id": {"type": "string"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"paths": {
|
||||||
|
"/items": {
|
||||||
|
"get": {
|
||||||
|
"operationId": "list_items",
|
||||||
|
"summary": "List items",
|
||||||
|
"description": "",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {"$ref": "#/components/schemas/Item"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_MULTI_ENDPOINT_SPEC = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Multi API", "version": "1.0.0"},
|
||||||
|
"paths": {
|
||||||
|
"/users": {
|
||||||
|
"get": {
|
||||||
|
"operationId": "list_users",
|
||||||
|
"summary": "List users",
|
||||||
|
"description": "",
|
||||||
|
"responses": {"200": {"description": "OK"}},
|
||||||
|
},
|
||||||
|
"post": {
|
||||||
|
"operationId": "create_user",
|
||||||
|
"summary": "Create user",
|
||||||
|
"description": "",
|
||||||
|
"responses": {"201": {"description": "Created"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/users/{id}": {
|
||||||
|
"delete": {
|
||||||
|
"operationId": "delete_user",
|
||||||
|
"summary": "Delete user",
|
||||||
|
"description": "",
|
||||||
|
"parameters": [
|
||||||
|
{"name": "id", "in": "path", "required": True, "schema": {"type": "string"}}
|
||||||
|
],
|
||||||
|
"responses": {"204": {"description": "Deleted"}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseEndpoints:
|
||||||
|
"""Tests for parse_endpoints function."""
|
||||||
|
|
||||||
|
def test_empty_paths_returns_empty_tuple(self) -> None:
|
||||||
|
"""Spec with no paths yields no endpoints."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
result = parse_endpoints(_MINIMAL_SPEC)
|
||||||
|
assert result == ()
|
||||||
|
|
||||||
|
def test_parse_get_endpoint(self) -> None:
|
||||||
|
"""Parse a GET endpoint with path parameter."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
result = parse_endpoints(_GET_SPEC)
|
||||||
|
assert len(result) == 1
|
||||||
|
ep = result[0]
|
||||||
|
assert ep.path == "/orders/{order_id}"
|
||||||
|
assert ep.method == "GET"
|
||||||
|
assert ep.operation_id == "get_order"
|
||||||
|
assert ep.summary == "Get an order"
|
||||||
|
|
||||||
|
def test_parse_get_endpoint_parameters(self) -> None:
|
||||||
|
"""Path parameters are extracted correctly."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
result = parse_endpoints(_GET_SPEC)
|
||||||
|
ep = result[0]
|
||||||
|
assert len(ep.parameters) == 1
|
||||||
|
param = ep.parameters[0]
|
||||||
|
assert param.name == "order_id"
|
||||||
|
assert param.location == "path"
|
||||||
|
assert param.required is True
|
||||||
|
assert param.schema_type == "string"
|
||||||
|
|
||||||
|
def test_parse_post_with_request_body(self) -> None:
|
||||||
|
"""POST endpoint with request body is extracted."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
result = parse_endpoints(_POST_SPEC)
|
||||||
|
assert len(result) == 1
|
||||||
|
ep = result[0]
|
||||||
|
assert ep.method == "POST"
|
||||||
|
assert ep.request_body_schema is not None
|
||||||
|
assert ep.request_body_schema["type"] == "object"
|
||||||
|
|
||||||
|
def test_parse_path_and_query_params(self) -> None:
|
||||||
|
"""Both path and query parameters are extracted."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
result = parse_endpoints(_MULTI_PARAM_SPEC)
|
||||||
|
ep = result[0]
|
||||||
|
locations = {p.location for p in ep.parameters}
|
||||||
|
assert "path" in locations
|
||||||
|
assert "query" in locations
|
||||||
|
|
||||||
|
def test_autogenerate_operation_id_when_missing(self) -> None:
|
||||||
|
"""Auto-generate operation_id when not provided in spec."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Test", "version": "1.0"},
|
||||||
|
"paths": {
|
||||||
|
"/things/{id}": {
|
||||||
|
"get": {
|
||||||
|
"summary": "Get thing",
|
||||||
|
"description": "",
|
||||||
|
"responses": {"200": {"description": "OK"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = parse_endpoints(spec)
|
||||||
|
ep = result[0]
|
||||||
|
assert ep.operation_id != ""
|
||||||
|
assert len(ep.operation_id) > 0
|
||||||
|
|
||||||
|
def test_multiple_endpoints_extracted(self) -> None:
|
||||||
|
"""Multiple path+method combinations are all extracted."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
result = parse_endpoints(_MULTI_ENDPOINT_SPEC)
|
||||||
|
assert len(result) == 3
|
||||||
|
methods = {ep.method for ep in result}
|
||||||
|
assert "GET" in methods
|
||||||
|
assert "POST" in methods
|
||||||
|
assert "DELETE" in methods
|
||||||
|
|
||||||
|
def test_ref_in_response_schema_resolved(self) -> None:
|
||||||
|
"""$ref in response schema is resolved to the target schema."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
result = parse_endpoints(_REF_SPEC)
|
||||||
|
ep = result[0]
|
||||||
|
assert ep.response_schema is not None
|
||||||
|
# Resolved ref should contain the properties
|
||||||
|
assert "properties" in ep.response_schema or "$ref" not in ep.response_schema
|
||||||
|
|
||||||
|
def test_result_is_tuple(self) -> None:
|
||||||
|
"""parse_endpoints returns a tuple (immutable)."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
result = parse_endpoints(_GET_SPEC)
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
|
||||||
|
def test_endpoint_info_is_frozen(self) -> None:
|
||||||
|
"""EndpointInfo instances are frozen/immutable."""
|
||||||
|
from app.openapi.parser import parse_endpoints
|
||||||
|
|
||||||
|
result = parse_endpoints(_GET_SPEC)
|
||||||
|
ep = result[0]
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
ep.method = "POST" # type: ignore[misc]
|
||||||
219
backend/tests/unit/openapi/test_review_api.py
Normal file
219
backend/tests/unit/openapi/test_review_api.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""Tests for OpenAPI review API endpoints.
|
||||||
|
|
||||||
|
RED phase: written before implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
_SAMPLE_URL = "https://example.com/api/spec.json"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
"""Create TestClient for the review API app."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.openapi.review_api import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def job_id(client):
|
||||||
|
"""Create a job and return its ID."""
|
||||||
|
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||||
|
assert response.status_code == 202
|
||||||
|
return response.json()["job_id"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def job_with_classifications(client, job_id):
|
||||||
|
"""Return job_id for a job that has mock classifications injected."""
|
||||||
|
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||||
|
from app.openapi.review_api import _job_store
|
||||||
|
|
||||||
|
ep = EndpointInfo(
|
||||||
|
path="/orders",
|
||||||
|
method="GET",
|
||||||
|
operation_id="list_orders",
|
||||||
|
summary="List orders",
|
||||||
|
description="",
|
||||||
|
)
|
||||||
|
clf = ClassificationResult(
|
||||||
|
endpoint=ep,
|
||||||
|
access_type="read",
|
||||||
|
customer_params=(),
|
||||||
|
agent_group="read_agent",
|
||||||
|
confidence=0.9,
|
||||||
|
needs_interrupt=False,
|
||||||
|
)
|
||||||
|
# Inject classifications directly into the store
|
||||||
|
job = _job_store[job_id]
|
||||||
|
_job_store[job_id] = {**job, "classifications": [clf]}
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportEndpoint:
|
||||||
|
"""Tests for POST /api/v1/openapi/import."""
|
||||||
|
|
||||||
|
def test_post_import_returns_job_id(self, client) -> None:
|
||||||
|
"""POST /import returns 202 with a job_id."""
|
||||||
|
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||||
|
assert response.status_code == 202
|
||||||
|
data = response.json()
|
||||||
|
assert "job_id" in data
|
||||||
|
assert len(data["job_id"]) > 0
|
||||||
|
|
||||||
|
def test_post_import_empty_url_returns_422(self, client) -> None:
|
||||||
|
"""POST /import with empty URL returns 422 validation error."""
|
||||||
|
response = client.post("/api/v1/openapi/import", json={"url": ""})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_post_import_missing_url_returns_422(self, client) -> None:
|
||||||
|
"""POST /import with missing URL field returns 422."""
|
||||||
|
response = client.post("/api/v1/openapi/import", json={})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_post_import_invalid_scheme_returns_422(self, client) -> None:
|
||||||
|
"""POST /import with non-http URL returns 422."""
|
||||||
|
response = client.post("/api/v1/openapi/import", json={"url": "ftp://evil.com/spec"})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_post_import_returns_pending_status(self, client) -> None:
|
||||||
|
"""Newly created job has pending status."""
|
||||||
|
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "pending"
|
||||||
|
|
||||||
|
def test_post_import_returns_spec_url(self, client) -> None:
|
||||||
|
"""Response includes the original spec URL."""
|
||||||
|
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||||
|
data = response.json()
|
||||||
|
assert data["spec_url"] == _SAMPLE_URL
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetJobEndpoint:
|
||||||
|
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
|
||||||
|
|
||||||
|
def test_get_job_returns_status(self, client, job_id) -> None:
|
||||||
|
"""GET /jobs/{id} returns job status."""
|
||||||
|
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "status" in data
|
||||||
|
assert "job_id" in data
|
||||||
|
|
||||||
|
def test_get_unknown_job_returns_404(self, client) -> None:
|
||||||
|
"""GET /jobs/nonexistent returns 404."""
|
||||||
|
response = client.get("/api/v1/openapi/jobs/nonexistent-id")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_get_job_includes_spec_url(self, client, job_id) -> None:
|
||||||
|
"""Job response includes the spec URL."""
|
||||||
|
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
|
||||||
|
data = response.json()
|
||||||
|
assert data["spec_url"] == _SAMPLE_URL
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetClassificationsEndpoint:
|
||||||
|
"""Tests for GET /api/v1/openapi/jobs/{job_id}/classifications."""
|
||||||
|
|
||||||
|
def test_get_classifications_returns_list(self, client, job_with_classifications) -> None:
|
||||||
|
"""GET /classifications returns a list."""
|
||||||
|
response = client.get(
|
||||||
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 1
|
||||||
|
|
||||||
|
def test_get_classifications_unknown_job_returns_404(self, client) -> None:
|
||||||
|
"""GET /classifications for unknown job returns 404."""
|
||||||
|
response = client.get("/api/v1/openapi/jobs/unknown/classifications")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_classification_has_expected_fields(self, client, job_with_classifications) -> None:
|
||||||
|
"""Each classification item has access_type and endpoint fields."""
|
||||||
|
response = client.get(
|
||||||
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
|
||||||
|
)
|
||||||
|
item = response.json()[0]
|
||||||
|
assert "access_type" in item
|
||||||
|
assert "endpoint" in item
|
||||||
|
assert "needs_interrupt" in item
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateClassificationEndpoint:
|
||||||
|
"""Tests for PUT /api/v1/openapi/jobs/{job_id}/classifications/{idx}."""
|
||||||
|
|
||||||
|
def test_update_classification_succeeds(self, client, job_with_classifications) -> None:
|
||||||
|
"""PUT /classifications/0 updates the classification."""
|
||||||
|
response = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||||
|
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_update_unknown_job_returns_404(self, client) -> None:
|
||||||
|
"""PUT /classifications/0 for unknown job returns 404."""
|
||||||
|
response = client.put(
|
||||||
|
"/api/v1/openapi/jobs/unknown/classifications/0",
|
||||||
|
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
|
||||||
|
"""PUT /classifications/0 with invalid access_type returns 422."""
|
||||||
|
response = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||||
|
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
|
||||||
|
"""PUT /classifications/0 with invalid agent_group returns 422."""
|
||||||
|
response = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||||
|
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
|
||||||
|
"""PUT /classifications/999 returns 404 for out-of-range index."""
|
||||||
|
response = client.put(
|
||||||
|
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/999",
|
||||||
|
json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestApproveEndpoint:
|
||||||
|
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
|
||||||
|
|
||||||
|
def test_approve_job_succeeds(self, client, job_with_classifications) -> None:
|
||||||
|
"""POST /approve transitions job to approved status."""
|
||||||
|
response = client.post(
|
||||||
|
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_approve_unknown_job_returns_404(self, client) -> None:
|
||||||
|
"""POST /approve for unknown job returns 404."""
|
||||||
|
response = client.post("/api/v1/openapi/jobs/unknown/approve")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_approve_returns_job_status(self, client, job_with_classifications) -> None:
|
||||||
|
"""POST /approve returns updated job status."""
|
||||||
|
response = client.post(
|
||||||
|
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
assert "status" in data
|
||||||
93
backend/tests/unit/openapi/test_validator.py
Normal file
93
backend/tests/unit/openapi/test_validator.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Tests for OpenAPI spec validator module.
|
||||||
|
|
||||||
|
RED phase: written before implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
_VALID_SPEC = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Test API", "version": "1.0.0"},
|
||||||
|
"paths": {
|
||||||
|
"/items": {
|
||||||
|
"get": {
|
||||||
|
"summary": "List items",
|
||||||
|
"responses": {"200": {"description": "OK"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSpec:
|
||||||
|
"""Tests for validate_spec function."""
|
||||||
|
|
||||||
|
def test_valid_minimal_spec_passes(self) -> None:
|
||||||
|
"""A valid minimal spec returns empty error list."""
|
||||||
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
|
errors = validate_spec(_VALID_SPEC)
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
def test_missing_openapi_key_returns_error(self) -> None:
|
||||||
|
"""Missing 'openapi' field returns an error."""
|
||||||
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
|
spec = {k: v for k, v in _VALID_SPEC.items() if k != "openapi"}
|
||||||
|
errors = validate_spec(spec)
|
||||||
|
assert len(errors) > 0
|
||||||
|
assert any("openapi" in e.lower() for e in errors)
|
||||||
|
|
||||||
|
def test_missing_info_returns_error(self) -> None:
|
||||||
|
"""Missing 'info' field returns an error."""
|
||||||
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
|
spec = {k: v for k, v in _VALID_SPEC.items() if k != "info"}
|
||||||
|
errors = validate_spec(spec)
|
||||||
|
assert len(errors) > 0
|
||||||
|
assert any("info" in e.lower() for e in errors)
|
||||||
|
|
||||||
|
def test_missing_paths_returns_error(self) -> None:
|
||||||
|
"""Missing 'paths' field returns an error."""
|
||||||
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
|
spec = {k: v for k, v in _VALID_SPEC.items() if k != "paths"}
|
||||||
|
errors = validate_spec(spec)
|
||||||
|
assert len(errors) > 0
|
||||||
|
assert any("paths" in e.lower() for e in errors)
|
||||||
|
|
||||||
|
def test_non_dict_input_returns_error(self) -> None:
|
||||||
|
"""Non-dict input returns an error without raising."""
|
||||||
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
|
errors = validate_spec("not a dict") # type: ignore[arg-type]
|
||||||
|
assert len(errors) > 0
|
||||||
|
|
||||||
|
def test_empty_dict_returns_multiple_errors(self) -> None:
|
||||||
|
"""Empty dict returns errors for all required fields."""
|
||||||
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
|
errors = validate_spec({})
|
||||||
|
# Should have at least one error for each required field
|
||||||
|
assert len(errors) >= 3
|
||||||
|
|
||||||
|
def test_invalid_openapi_version_returns_error(self) -> None:
|
||||||
|
"""Unsupported openapi version string returns an error."""
|
||||||
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
|
spec = {**_VALID_SPEC, "openapi": "1.0.0"}
|
||||||
|
errors = validate_spec(spec)
|
||||||
|
assert len(errors) > 0
|
||||||
|
|
||||||
|
def test_errors_are_descriptive_strings(self) -> None:
|
||||||
|
"""All returned errors are non-empty strings."""
|
||||||
|
from app.openapi.validator import validate_spec
|
||||||
|
|
||||||
|
errors = validate_spec({})
|
||||||
|
for e in errors:
|
||||||
|
assert isinstance(e, str)
|
||||||
|
assert len(e) > 0
|
||||||
1
backend/tests/unit/replay/__init__.py
Normal file
1
backend/tests/unit/replay/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Unit tests for app.replay module."""
|
||||||
217
backend/tests/unit/replay/test_api.py
Normal file
217
backend/tests/unit/replay/test_api.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""Unit tests for app.replay.api."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app() -> FastAPI:
|
||||||
|
from app.replay.api import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def _http_exc(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_pool(
|
||||||
|
fetchall_result: list[dict],
|
||||||
|
*,
|
||||||
|
count: int | None = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Build a mock pool that returns the given rows from fetchall.
|
||||||
|
|
||||||
|
When *count* is provided, the first execute() call returns a cursor
|
||||||
|
whose fetchone() yields ``(count,)`` (for the COUNT query) and the
|
||||||
|
second call returns the rows via fetchall(). When *count* is None
|
||||||
|
(the default), a single cursor backed by *fetchall_result* is used
|
||||||
|
for all calls.
|
||||||
|
"""
|
||||||
|
if count is not None:
|
||||||
|
count_cursor = AsyncMock()
|
||||||
|
count_cursor.fetchone = AsyncMock(return_value=(count,))
|
||||||
|
|
||||||
|
rows_cursor = AsyncMock()
|
||||||
|
rows_cursor.fetchall = AsyncMock(return_value=fetchall_result)
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.execute = AsyncMock(side_effect=[count_cursor, rows_cursor])
|
||||||
|
else:
|
||||||
|
mock_cursor = AsyncMock()
|
||||||
|
mock_cursor.fetchall = AsyncMock(return_value=fetchall_result)
|
||||||
|
mock_cursor.fetchone = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||||
|
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection.return_value = mock_ctx
|
||||||
|
return mock_pool
|
||||||
|
|
||||||
|
|
||||||
|
class TestListConversations:
|
||||||
|
def test_returns_200_with_empty_list(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool([], count=0)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/conversations")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
data = body["data"]
|
||||||
|
assert isinstance(data["conversations"], list)
|
||||||
|
assert data["total"] == 0
|
||||||
|
assert data["page"] == 1
|
||||||
|
assert body["error"] is None
|
||||||
|
|
||||||
|
def test_returns_conversations_list(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
mock_rows = [
|
||||||
|
{
|
||||||
|
"thread_id": "t1",
|
||||||
|
"created_at": "2026-01-01T00:00:00",
|
||||||
|
"last_activity": "2026-01-01T00:01:00",
|
||||||
|
"status": "active",
|
||||||
|
"total_tokens": 100,
|
||||||
|
"total_cost_usd": 0.01,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
app.state.pool = _make_mock_pool(mock_rows, count=1)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/conversations")
|
||||||
|
body = resp.json()
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = body["data"]
|
||||||
|
assert len(data["conversations"]) == 1
|
||||||
|
assert data["conversations"][0]["thread_id"] == "t1"
|
||||||
|
assert data["total"] == 1
|
||||||
|
|
||||||
|
def test_pagination_defaults(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool([], count=0)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/conversations")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_pagination_custom_params(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool([], count=0)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/conversations?page=2&per_page=10")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_per_page_max_capped_at_100(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool([], count=0)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/conversations?per_page=200")
|
||||||
|
# FastAPI Query(le=100) rejects values > 100
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetReplay:
|
||||||
|
def test_thread_not_found_returns_404(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool([])
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/replay/nonexistent-thread")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_returns_replay_page_for_existing_thread(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
mock_rows = [
|
||||||
|
{
|
||||||
|
"thread_id": "thread-123",
|
||||||
|
"checkpoint_id": "cp-001",
|
||||||
|
"checkpoint": {
|
||||||
|
"channel_values": {
|
||||||
|
"messages": [{"type": "human", "content": "Hello"}]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
app.state.pool = _make_mock_pool(mock_rows)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/replay/thread-123")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["data"]["thread_id"] == "thread-123"
|
||||||
|
assert "steps" in body["data"]
|
||||||
|
assert "total_steps" in body["data"]
|
||||||
|
assert "page" in body["data"]
|
||||||
|
assert "per_page" in body["data"]
|
||||||
|
|
||||||
|
def test_replay_pagination_params(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
mock_rows = [
|
||||||
|
{
|
||||||
|
"thread_id": "t1",
|
||||||
|
"checkpoint_id": "cp-001",
|
||||||
|
"checkpoint": {
|
||||||
|
"channel_values": {"messages": [{"type": "human", "content": "Hi"}]}
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
app.state.pool = _make_mock_pool(mock_rows)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/replay/t1?page=1&per_page=5")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_error_response_has_envelope(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool([])
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/replay/missing")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] is not None
|
||||||
|
|
||||||
|
def test_invalid_thread_id_returns_400(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool([])
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/replay/id%20with%20spaces")
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_thread_id_special_chars_returns_400(self) -> None:
|
||||||
|
app = _build_app()
|
||||||
|
app.state.pool = _make_mock_pool([])
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/replay/id;DROP TABLE")
|
||||||
|
assert resp.status_code == 400
|
||||||
134
backend/tests/unit/replay/test_models.py
Normal file
134
backend/tests/unit/replay/test_models.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""Unit tests for app.replay.models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestStepType:
|
||||||
|
def test_all_step_types_exist(self) -> None:
|
||||||
|
from app.replay.models import StepType
|
||||||
|
|
||||||
|
assert StepType.user_message
|
||||||
|
assert StepType.supervisor_routing
|
||||||
|
assert StepType.tool_call
|
||||||
|
assert StepType.tool_result
|
||||||
|
assert StepType.agent_response
|
||||||
|
assert StepType.interrupt
|
||||||
|
|
||||||
|
def test_step_type_values(self) -> None:
|
||||||
|
from app.replay.models import StepType
|
||||||
|
|
||||||
|
assert StepType.user_message.value == "user_message"
|
||||||
|
assert StepType.tool_call.value == "tool_call"
|
||||||
|
assert StepType.agent_response.value == "agent_response"
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplayStep:
|
||||||
|
def test_minimal_replay_step(self) -> None:
|
||||||
|
from app.replay.models import ReplayStep, StepType
|
||||||
|
|
||||||
|
step = ReplayStep(step=1, type=StepType.user_message, timestamp="2026-01-01T00:00:00Z")
|
||||||
|
assert step.step == 1
|
||||||
|
assert step.type == StepType.user_message
|
||||||
|
assert step.timestamp == "2026-01-01T00:00:00Z"
|
||||||
|
assert step.content == ""
|
||||||
|
assert step.agent is None
|
||||||
|
assert step.tool is None
|
||||||
|
assert step.params is None
|
||||||
|
assert step.result is None
|
||||||
|
assert step.reasoning is None
|
||||||
|
assert step.tokens is None
|
||||||
|
assert step.duration_ms is None
|
||||||
|
|
||||||
|
def test_full_replay_step(self) -> None:
|
||||||
|
from app.replay.models import ReplayStep, StepType
|
||||||
|
|
||||||
|
step = ReplayStep(
|
||||||
|
step=2,
|
||||||
|
type=StepType.tool_call,
|
||||||
|
timestamp="2026-01-01T00:00:01Z",
|
||||||
|
content="calling get_order",
|
||||||
|
agent="order_agent",
|
||||||
|
tool="get_order_status",
|
||||||
|
params={"order_id": "ORD-123"},
|
||||||
|
result={"status": "shipped"},
|
||||||
|
reasoning="user asked about order",
|
||||||
|
tokens=50,
|
||||||
|
duration_ms=200,
|
||||||
|
)
|
||||||
|
assert step.step == 2
|
||||||
|
assert step.agent == "order_agent"
|
||||||
|
assert step.tool == "get_order_status"
|
||||||
|
assert step.params == {"order_id": "ORD-123"}
|
||||||
|
assert step.tokens == 50
|
||||||
|
|
||||||
|
def test_replay_step_is_frozen(self) -> None:
|
||||||
|
from app.replay.models import ReplayStep, StepType
|
||||||
|
|
||||||
|
step = ReplayStep(step=1, type=StepType.user_message, timestamp="2026-01-01T00:00:00Z")
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
step.step = 99 # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_replay_step_params_is_immutable_copy(self) -> None:
|
||||||
|
from app.replay.models import ReplayStep, StepType
|
||||||
|
|
||||||
|
params = {"key": "value"}
|
||||||
|
step = ReplayStep(
|
||||||
|
step=1,
|
||||||
|
type=StepType.tool_call,
|
||||||
|
timestamp="2026-01-01T00:00:00Z",
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
# Modifying original dict should not affect step
|
||||||
|
params["new_key"] = "new_value"
|
||||||
|
assert "new_key" not in (step.params or {})
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplayPage:
|
||||||
|
def test_replay_page_construction(self) -> None:
|
||||||
|
from app.replay.models import ReplayPage, ReplayStep, StepType
|
||||||
|
|
||||||
|
steps = (
|
||||||
|
ReplayStep(step=1, type=StepType.user_message, timestamp="2026-01-01T00:00:00Z"),
|
||||||
|
ReplayStep(step=2, type=StepType.agent_response, timestamp="2026-01-01T00:00:01Z"),
|
||||||
|
)
|
||||||
|
page = ReplayPage(
|
||||||
|
thread_id="thread-123",
|
||||||
|
total_steps=2,
|
||||||
|
page=1,
|
||||||
|
per_page=20,
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
assert page.thread_id == "thread-123"
|
||||||
|
assert page.total_steps == 2
|
||||||
|
assert page.page == 1
|
||||||
|
assert page.per_page == 20
|
||||||
|
assert len(page.steps) == 2
|
||||||
|
|
||||||
|
def test_replay_page_is_frozen(self) -> None:
|
||||||
|
from app.replay.models import ReplayPage
|
||||||
|
|
||||||
|
page = ReplayPage(
|
||||||
|
thread_id="t1",
|
||||||
|
total_steps=0,
|
||||||
|
page=1,
|
||||||
|
per_page=20,
|
||||||
|
steps=(),
|
||||||
|
)
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
page.page = 2 # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_replay_page_empty_steps(self) -> None:
|
||||||
|
from app.replay.models import ReplayPage
|
||||||
|
|
||||||
|
page = ReplayPage(
|
||||||
|
thread_id="t1",
|
||||||
|
total_steps=0,
|
||||||
|
page=1,
|
||||||
|
per_page=20,
|
||||||
|
steps=(),
|
||||||
|
)
|
||||||
|
assert page.steps == ()
|
||||||
257
backend/tests/unit/replay/test_transformer.py
Normal file
257
backend/tests/unit/replay/test_transformer.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
"""Unit tests for app.replay.transformer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_row(messages: list[dict], metadata: dict | None = None) -> dict:
|
||||||
|
"""Helper to build a checkpoint row with the given messages."""
|
||||||
|
return {
|
||||||
|
"thread_id": "thread-abc",
|
||||||
|
"checkpoint_id": "cp-001",
|
||||||
|
"checkpoint": {"channel_values": {"messages": messages}},
|
||||||
|
"metadata": metadata or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransformCheckpoints:
|
||||||
|
def test_empty_rows_returns_empty_list(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
result = transform_checkpoints([])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_human_message_produces_user_message_step(self) -> None:
|
||||||
|
from app.replay.models import StepType
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [_make_row([{"type": "human", "content": "Hello, I need help"}])]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].type == StepType.user_message
|
||||||
|
assert steps[0].content == "Hello, I need help"
|
||||||
|
assert steps[0].step == 1
|
||||||
|
|
||||||
|
def test_ai_message_with_content_produces_agent_response(self) -> None:
|
||||||
|
from app.replay.models import StepType
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[{"type": "ai", "content": "I can help you with that.", "tool_calls": []}],
|
||||||
|
metadata={"writes": {"some_agent": "response"}},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].type == StepType.agent_response
|
||||||
|
assert steps[0].content == "I can help you with that."
|
||||||
|
|
||||||
|
def test_ai_message_with_tool_calls_produces_tool_call_step(self) -> None:
|
||||||
|
from app.replay.models import StepType
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"name": "get_order_status",
|
||||||
|
"args": {"order_id": "ORD-123"},
|
||||||
|
"id": "call_abc",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].type == StepType.tool_call
|
||||||
|
assert steps[0].tool == "get_order_status"
|
||||||
|
assert steps[0].params == {"order_id": "ORD-123"}
|
||||||
|
|
||||||
|
def test_tool_message_produces_tool_result_step(self) -> None:
|
||||||
|
from app.replay.models import StepType
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"content": '{"status": "shipped"}',
|
||||||
|
"name": "get_order_status",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].type == StepType.tool_result
|
||||||
|
assert steps[0].tool == "get_order_status"
|
||||||
|
|
||||||
|
def test_multiple_messages_sequential_steps(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{"type": "human", "content": "Help"},
|
||||||
|
{"type": "ai", "content": "Sure!", "tool_calls": []},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 2
|
||||||
|
assert steps[0].step == 1
|
||||||
|
assert steps[1].step == 2
|
||||||
|
|
||||||
|
def test_unknown_message_type_skipped(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [_make_row([{"type": "unknown_type", "content": "test"}])]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
# Should not crash; unknown types may be skipped
|
||||||
|
assert isinstance(steps, list)
|
||||||
|
|
||||||
|
def test_row_missing_checkpoint_skipped(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [{"thread_id": "t1", "checkpoint_id": "cp1", "checkpoint": None, "metadata": {}}]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert isinstance(steps, list)
|
||||||
|
|
||||||
|
def test_row_missing_messages_key_skipped(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [{"thread_id": "t1", "checkpoint_id": "cp1", "checkpoint": {}, "metadata": {}}]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert isinstance(steps, list)
|
||||||
|
|
||||||
|
def test_multiple_rows_steps_are_continuous(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row([{"type": "human", "content": "Q1"}]),
|
||||||
|
_make_row([{"type": "ai", "content": "A1", "tool_calls": []}]),
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 2
|
||||||
|
assert steps[0].step == 1
|
||||||
|
assert steps[1].step == 2
|
||||||
|
|
||||||
|
def test_timestamps_are_strings(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [_make_row([{"type": "human", "content": "Hi"}])]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert isinstance(steps[0].timestamp, str)
|
||||||
|
|
||||||
|
def test_list_content_joined_to_string(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "human",
|
||||||
|
"content": [
|
||||||
|
{"text": "Hello"},
|
||||||
|
{"text": " world"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].content == "Hello world"
|
||||||
|
|
||||||
|
def test_checkpoint_as_string_skipped(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"thread_id": "t1",
|
||||||
|
"checkpoint_id": "cp1",
|
||||||
|
"checkpoint": "not-a-dict",
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert steps == []
|
||||||
|
|
||||||
|
def test_channel_values_not_dict_skipped(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"thread_id": "t1",
|
||||||
|
"checkpoint_id": "cp1",
|
||||||
|
"checkpoint": {"channel_values": "bad"},
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert steps == []
|
||||||
|
|
||||||
|
def test_tool_result_valid_json_parsed(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"content": '{"order_id": "123", "status": "shipped"}',
|
||||||
|
"name": "get_order_status",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].result == {"order_id": "123", "status": "shipped"}
|
||||||
|
|
||||||
|
def test_tool_result_invalid_json_wrapped(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"content": "not valid json",
|
||||||
|
"name": "some_tool",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].result == {"raw": "not valid json"}
|
||||||
|
|
||||||
|
def test_malformed_message_skipped_gracefully(self) -> None:
|
||||||
|
from app.replay.transformer import transform_checkpoints
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
_make_row(
|
||||||
|
[
|
||||||
|
{"type": "human", "content": "Good message"},
|
||||||
|
42, # not a dict -- will raise in _step_from_message
|
||||||
|
{"type": "ai", "content": "Response", "tool_calls": []},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
steps = transform_checkpoints(rows)
|
||||||
|
# The malformed message is skipped; the other two produce steps.
|
||||||
|
assert len(steps) == 2
|
||||||
|
assert steps[0].step == 1
|
||||||
|
assert steps[1].step == 2
|
||||||
@@ -7,10 +7,41 @@ import pytest
|
|||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
def _isolated_settings(**kwargs: object) -> Settings:
|
||||||
|
"""Create a Settings instance that ignores .env files and process env vars.
|
||||||
|
|
||||||
|
pydantic-settings reads from env_file and environment by default, which
|
||||||
|
causes test results to depend on the machine they run on. We override
|
||||||
|
model_config at the class level temporarily so that every test gets
|
||||||
|
deterministic results.
|
||||||
|
"""
|
||||||
|
# Build a throwaway subclass that disables env-file and env-var loading.
|
||||||
|
class _IsolatedSettings(Settings):
|
||||||
|
model_config = Settings.model_config.copy()
|
||||||
|
model_config["env_file"] = None # type: ignore[assignment]
|
||||||
|
model_config["env_ignore_empty"] = True
|
||||||
|
|
||||||
|
# _env_parse_none_str makes pydantic-settings treat missing env vars as
|
||||||
|
# absent rather than empty-string, so required fields will raise.
|
||||||
|
import os
|
||||||
|
|
||||||
|
env_backup = os.environ.copy()
|
||||||
|
# Strip all env vars that Settings knows about so they can't leak in.
|
||||||
|
settings_fields = set(Settings.model_fields)
|
||||||
|
for key in list(os.environ):
|
||||||
|
if key.lower() in settings_fields:
|
||||||
|
del os.environ[key]
|
||||||
|
try:
|
||||||
|
return _IsolatedSettings(**kwargs) # type: ignore[return-value]
|
||||||
|
finally:
|
||||||
|
os.environ.clear()
|
||||||
|
os.environ.update(env_backup)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
class TestSettings:
|
class TestSettings:
|
||||||
def test_default_values(self) -> None:
|
def test_default_values(self) -> None:
|
||||||
settings = Settings(
|
settings = _isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
anthropic_api_key="key",
|
anthropic_api_key="key",
|
||||||
)
|
)
|
||||||
@@ -20,7 +51,7 @@ class TestSettings:
|
|||||||
assert settings.interrupt_ttl_minutes == 30
|
assert settings.interrupt_ttl_minutes == 30
|
||||||
|
|
||||||
def test_custom_values(self) -> None:
|
def test_custom_values(self) -> None:
|
||||||
settings = Settings(
|
settings = _isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
llm_provider="openai",
|
llm_provider="openai",
|
||||||
llm_model="gpt-4o",
|
llm_model="gpt-4o",
|
||||||
@@ -33,18 +64,18 @@ class TestSettings:
|
|||||||
|
|
||||||
def test_invalid_provider_rejected(self) -> None:
|
def test_invalid_provider_rejected(self) -> None:
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
Settings(
|
_isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
llm_provider="invalid",
|
llm_provider="invalid",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_missing_database_url_rejected(self) -> None:
|
def test_missing_database_url_rejected(self) -> None:
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
Settings(anthropic_api_key="key")
|
_isolated_settings(anthropic_api_key="key")
|
||||||
|
|
||||||
def test_empty_api_key_for_provider_rejected(self) -> None:
|
def test_empty_api_key_for_provider_rejected(self) -> None:
|
||||||
with pytest.raises(ValueError, match="API key"):
|
with pytest.raises(ValueError, match="API key"):
|
||||||
Settings(
|
_isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
llm_provider="anthropic",
|
llm_provider="anthropic",
|
||||||
anthropic_api_key="",
|
anthropic_api_key="",
|
||||||
@@ -52,9 +83,27 @@ class TestSettings:
|
|||||||
|
|
||||||
def test_wrong_provider_key_rejected(self) -> None:
|
def test_wrong_provider_key_rejected(self) -> None:
|
||||||
with pytest.raises(ValueError, match="API key"):
|
with pytest.raises(ValueError, match="API key"):
|
||||||
Settings(
|
_isolated_settings(
|
||||||
database_url="postgresql://x:x@localhost/db",
|
database_url="postgresql://x:x@localhost/db",
|
||||||
llm_provider="openai",
|
llm_provider="openai",
|
||||||
anthropic_api_key="key",
|
anthropic_api_key="key",
|
||||||
openai_api_key="",
|
openai_api_key="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_azure_openai_missing_endpoint_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="AZURE_OPENAI_ENDPOINT"):
|
||||||
|
_isolated_settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="azure_openai",
|
||||||
|
azure_openai_api_key="key",
|
||||||
|
azure_openai_deployment="my-deploy",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_azure_openai_missing_deployment_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="AZURE_OPENAI_DEPLOYMENT"):
|
||||||
|
_isolated_settings(
|
||||||
|
database_url="postgresql://x:x@localhost/db",
|
||||||
|
llm_provider="azure_openai",
|
||||||
|
azure_openai_api_key="key",
|
||||||
|
azure_openai_endpoint="https://example.openai.azure.com",
|
||||||
|
)
|
||||||
|
|||||||
156
backend/tests/unit/test_conversation_tracker.py
Normal file
156
backend/tests/unit/test_conversation_tracker.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""Tests for app.conversation_tracker module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.conversation_tracker import (
|
||||||
|
ConversationTrackerProtocol,
|
||||||
|
NoOpConversationTracker,
|
||||||
|
PostgresConversationTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pool() -> AsyncMock:
|
||||||
|
"""Create a mock async connection pool."""
|
||||||
|
pool = AsyncMock()
|
||||||
|
conn = AsyncMock()
|
||||||
|
conn.execute = AsyncMock()
|
||||||
|
pool.connection = MagicMock(return_value=_AsyncContextManager(conn))
|
||||||
|
return pool, conn
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncContextManager:
|
||||||
|
"""Async context manager helper."""
|
||||||
|
|
||||||
|
def __init__(self, value: object) -> None:
|
||||||
|
self._value = value
|
||||||
|
|
||||||
|
async def __aenter__(self) -> object:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
async def __aexit__(self, *args: object) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestConversationTrackerProtocol:
|
||||||
|
def test_noop_satisfies_protocol(self) -> None:
|
||||||
|
tracker = NoOpConversationTracker()
|
||||||
|
assert isinstance(tracker, ConversationTrackerProtocol)
|
||||||
|
|
||||||
|
def test_postgres_satisfies_protocol(self) -> None:
|
||||||
|
tracker = PostgresConversationTracker()
|
||||||
|
assert isinstance(tracker, ConversationTrackerProtocol)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoOpConversationTracker:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_conversation_does_nothing(self) -> None:
|
||||||
|
tracker = NoOpConversationTracker()
|
||||||
|
pool = AsyncMock()
|
||||||
|
# Should not raise
|
||||||
|
await tracker.ensure_conversation(pool, "thread-1")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_turn_does_nothing(self) -> None:
|
||||||
|
tracker = NoOpConversationTracker()
|
||||||
|
pool = AsyncMock()
|
||||||
|
await tracker.record_turn(pool, "thread-1", "agent_a", 100, 0.05)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_does_nothing(self) -> None:
|
||||||
|
tracker = NoOpConversationTracker()
|
||||||
|
pool = AsyncMock()
|
||||||
|
await tracker.resolve(pool, "thread-1", "resolved")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_accepts_none_agent_name(self) -> None:
|
||||||
|
tracker = NoOpConversationTracker()
|
||||||
|
pool = AsyncMock()
|
||||||
|
await tracker.record_turn(pool, "thread-1", None, 0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPostgresConversationTracker:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_conversation_executes_insert(self) -> None:
|
||||||
|
tracker = PostgresConversationTracker()
|
||||||
|
pool, conn = _make_pool()
|
||||||
|
|
||||||
|
await tracker.ensure_conversation(pool, "thread-abc")
|
||||||
|
|
||||||
|
conn.execute.assert_awaited_once()
|
||||||
|
sql, params = conn.execute.call_args[0]
|
||||||
|
assert "INSERT" in sql
|
||||||
|
assert "ON CONFLICT" in sql
|
||||||
|
assert params["thread_id"] == "thread-abc"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_turn_executes_update(self) -> None:
|
||||||
|
tracker = PostgresConversationTracker()
|
||||||
|
pool, conn = _make_pool()
|
||||||
|
|
||||||
|
await tracker.record_turn(pool, "thread-abc", "order_agent", 250, 0.12)
|
||||||
|
|
||||||
|
conn.execute.assert_awaited_once()
|
||||||
|
sql, params = conn.execute.call_args[0]
|
||||||
|
assert "UPDATE" in sql
|
||||||
|
assert params["thread_id"] == "thread-abc"
|
||||||
|
assert params["agent_name"] == "order_agent"
|
||||||
|
assert params["tokens"] == 250
|
||||||
|
assert params["cost"] == 0.12
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_turn_accepts_none_agent_name(self) -> None:
|
||||||
|
tracker = PostgresConversationTracker()
|
||||||
|
pool, conn = _make_pool()
|
||||||
|
|
||||||
|
await tracker.record_turn(pool, "thread-abc", None, 0, 0.0)
|
||||||
|
|
||||||
|
conn.execute.assert_awaited_once()
|
||||||
|
sql, params = conn.execute.call_args[0]
|
||||||
|
assert params["agent_name"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_executes_update(self) -> None:
|
||||||
|
tracker = PostgresConversationTracker()
|
||||||
|
pool, conn = _make_pool()
|
||||||
|
|
||||||
|
await tracker.resolve(pool, "thread-abc", "resolved")
|
||||||
|
|
||||||
|
conn.execute.assert_awaited_once()
|
||||||
|
sql, params = conn.execute.call_args[0]
|
||||||
|
assert "UPDATE" in sql
|
||||||
|
assert params["thread_id"] == "thread-abc"
|
||||||
|
assert params["resolution_type"] == "resolved"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_sets_ended_at(self) -> None:
|
||||||
|
tracker = PostgresConversationTracker()
|
||||||
|
pool, conn = _make_pool()
|
||||||
|
|
||||||
|
await tracker.resolve(pool, "thread-abc", "escalated")
|
||||||
|
|
||||||
|
sql, params = conn.execute.call_args[0]
|
||||||
|
assert "ended_at" in sql.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_conversation_with_special_thread_id(self) -> None:
|
||||||
|
tracker = PostgresConversationTracker()
|
||||||
|
pool, conn = _make_pool()
|
||||||
|
|
||||||
|
await tracker.ensure_conversation(pool, "thread-123-abc-XYZ")
|
||||||
|
|
||||||
|
conn.execute.assert_awaited_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_turn_with_zero_cost(self) -> None:
|
||||||
|
tracker = PostgresConversationTracker()
|
||||||
|
pool, conn = _make_pool()
|
||||||
|
|
||||||
|
await tracker.record_turn(pool, "t1", "agent", 0, 0.0)
|
||||||
|
|
||||||
|
conn.execute.assert_awaited_once()
|
||||||
@@ -55,7 +55,7 @@ class TestDbModule:
|
|||||||
from app.db import setup_app_tables
|
from app.db import setup_app_tables
|
||||||
|
|
||||||
await setup_app_tables(mock_pool)
|
await setup_app_tables(mock_pool)
|
||||||
assert mock_conn.execute.await_count == 2
|
assert mock_conn.execute.await_count == 5
|
||||||
|
|
||||||
def test_ddl_statements_valid(self) -> None:
|
def test_ddl_statements_valid(self) -> None:
|
||||||
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
||||||
|
|||||||
55
backend/tests/unit/test_db_phase4.py
Normal file
55
backend/tests/unit/test_db_phase4.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Phase 4 DB migration tests -- analytics_events table and conversation columns."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsEventsDDL:
|
||||||
|
def test_analytics_events_ddl_exists(self) -> None:
|
||||||
|
from app.db import _ANALYTICS_EVENTS_DDL
|
||||||
|
|
||||||
|
assert "CREATE TABLE IF NOT EXISTS analytics_events" in _ANALYTICS_EVENTS_DDL
|
||||||
|
|
||||||
|
def test_analytics_events_ddl_has_required_columns(self) -> None:
|
||||||
|
from app.db import _ANALYTICS_EVENTS_DDL
|
||||||
|
|
||||||
|
assert "thread_id" in _ANALYTICS_EVENTS_DDL
|
||||||
|
assert "event_type" in _ANALYTICS_EVENTS_DDL
|
||||||
|
assert "agent_name" in _ANALYTICS_EVENTS_DDL
|
||||||
|
assert "tool_name" in _ANALYTICS_EVENTS_DDL
|
||||||
|
assert "tokens_used" in _ANALYTICS_EVENTS_DDL
|
||||||
|
assert "cost_usd" in _ANALYTICS_EVENTS_DDL
|
||||||
|
assert "duration_ms" in _ANALYTICS_EVENTS_DDL
|
||||||
|
assert "success" in _ANALYTICS_EVENTS_DDL
|
||||||
|
assert "error_message" in _ANALYTICS_EVENTS_DDL
|
||||||
|
assert "metadata" in _ANALYTICS_EVENTS_DDL
|
||||||
|
|
||||||
|
def test_conversations_migration_ddl_exists(self) -> None:
|
||||||
|
from app.db import _CONVERSATIONS_MIGRATION_DDL
|
||||||
|
|
||||||
|
assert "ALTER TABLE" in _CONVERSATIONS_MIGRATION_DDL
|
||||||
|
assert "resolution_type" in _CONVERSATIONS_MIGRATION_DDL
|
||||||
|
assert "agents_used" in _CONVERSATIONS_MIGRATION_DDL
|
||||||
|
assert "turn_count" in _CONVERSATIONS_MIGRATION_DDL
|
||||||
|
assert "ended_at" in _CONVERSATIONS_MIGRATION_DDL
|
||||||
|
assert "IF NOT EXISTS" in _CONVERSATIONS_MIGRATION_DDL
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_app_tables_executes_analytics_ddl(self) -> None:
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.connection.return_value = mock_ctx
|
||||||
|
|
||||||
|
from app.db import setup_app_tables
|
||||||
|
|
||||||
|
await setup_app_tables(mock_pool)
|
||||||
|
# Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations
|
||||||
|
assert mock_conn.execute.await_count == 5
|
||||||
79
backend/tests/unit/test_discount.py
Normal file
79
backend/tests/unit/test_discount.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Tests for app.agents.discount module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.discount import apply_discount, generate_coupon
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestApplyDiscount:
|
||||||
|
def test_invalid_discount_zero(self) -> None:
|
||||||
|
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 0})
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "Invalid" in result["message"]
|
||||||
|
|
||||||
|
def test_invalid_discount_over_100(self) -> None:
|
||||||
|
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 101})
|
||||||
|
assert result["status"] == "error"
|
||||||
|
|
||||||
|
def test_invalid_discount_negative(self) -> None:
|
||||||
|
result = apply_discount.invoke({"order_id": "1042", "discount_percent": -5})
|
||||||
|
assert result["status"] == "error"
|
||||||
|
|
||||||
|
@patch("app.agents.discount.interrupt", return_value=True)
|
||||||
|
def test_approved_discount(self, mock_interrupt) -> None:
|
||||||
|
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
|
||||||
|
assert result["status"] == "applied"
|
||||||
|
assert result["discount_percent"] == 10
|
||||||
|
assert "1042" in result["message"]
|
||||||
|
|
||||||
|
@patch("app.agents.discount.interrupt", return_value=False)
|
||||||
|
def test_rejected_discount(self, mock_interrupt) -> None:
|
||||||
|
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
|
||||||
|
assert result["status"] == "declined"
|
||||||
|
|
||||||
|
@patch("app.agents.discount.interrupt", return_value={"approved": True})
|
||||||
|
def test_approved_via_dict(self, mock_interrupt) -> None:
|
||||||
|
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
|
||||||
|
assert result["status"] == "applied"
|
||||||
|
|
||||||
|
@patch("app.agents.discount.interrupt", return_value={"approved": False})
|
||||||
|
def test_rejected_via_dict(self, mock_interrupt) -> None:
|
||||||
|
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
|
||||||
|
assert result["status"] == "declined"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestGenerateCoupon:
|
||||||
|
def test_valid_coupon(self) -> None:
|
||||||
|
result = generate_coupon.invoke({"discount_percent": 15, "expiry_days": 7})
|
||||||
|
assert result["status"] == "generated"
|
||||||
|
assert result["discount_percent"] == 15
|
||||||
|
assert result["expiry_days"] == 7
|
||||||
|
assert result["coupon_code"].startswith("SAVE15-")
|
||||||
|
|
||||||
|
def test_default_expiry(self) -> None:
|
||||||
|
result = generate_coupon.invoke({"discount_percent": 20})
|
||||||
|
assert result["status"] == "generated"
|
||||||
|
assert result["expiry_days"] == 30
|
||||||
|
|
||||||
|
def test_invalid_discount_zero(self) -> None:
|
||||||
|
result = generate_coupon.invoke({"discount_percent": 0})
|
||||||
|
assert result["status"] == "error"
|
||||||
|
|
||||||
|
def test_invalid_discount_over_100(self) -> None:
|
||||||
|
result = generate_coupon.invoke({"discount_percent": 101})
|
||||||
|
assert result["status"] == "error"
|
||||||
|
|
||||||
|
def test_invalid_expiry(self) -> None:
|
||||||
|
result = generate_coupon.invoke({"discount_percent": 10, "expiry_days": 0})
|
||||||
|
assert result["status"] == "error"
|
||||||
|
|
||||||
|
def test_coupon_codes_unique(self) -> None:
|
||||||
|
r1 = generate_coupon.invoke({"discount_percent": 10})
|
||||||
|
r2 = generate_coupon.invoke({"discount_percent": 10})
|
||||||
|
assert r1["coupon_code"] != r2["coupon_code"]
|
||||||
210
backend/tests/unit/test_edge_cases.py
Normal file
210
backend/tests/unit/test_edge_cases.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Edge case tests for ws_handler input validation and rate limiting."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.callbacks import TokenUsageCallbackHandler
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.session_manager import SessionManager
|
||||||
|
from app.ws_context import WebSocketContext
|
||||||
|
from app.ws_handler import dispatch_message
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws() -> AsyncMock:
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.send_json = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph() -> MagicMock:
|
||||||
|
graph = AsyncMock()
|
||||||
|
|
||||||
|
class AsyncIterHelper:
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
graph.astream = MagicMock(return_value=AsyncIterHelper())
|
||||||
|
state = MagicMock()
|
||||||
|
state.tasks = ()
|
||||||
|
graph.aget_state = AsyncMock(return_value=state)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
|
||||||
|
graph = _make_graph()
|
||||||
|
registry = MagicMock()
|
||||||
|
registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
|
||||||
|
return WebSocketContext(
|
||||||
|
graph_ctx=graph_ctx,
|
||||||
|
session_manager=sm or SessionManager(),
|
||||||
|
callback_handler=TokenUsageCallbackHandler(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestEmptyMessageHandling:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_message_content_returns_error(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
sm = SessionManager()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
|
sm.touch("t1")
|
||||||
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
|
||||||
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
msg_lower = call_data["message"].lower()
|
||||||
|
assert "content" in msg_lower or "missing" in msg_lower
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_whitespace_only_message_treated_as_empty(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
sm = SessionManager()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
|
sm.touch("t1")
|
||||||
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
|
||||||
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestOversizedMessageHandling:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_content_over_10000_chars_returns_error(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
sm = SessionManager()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
|
sm.touch("t1")
|
||||||
|
content = "x" * 10001
|
||||||
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||||
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "too long" in call_data["message"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
sm = SessionManager()
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
|
||||||
|
sm.touch("t1")
|
||||||
|
content = "x" * 10000
|
||||||
|
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||||
|
await dispatch_message(ws, ws_ctx, msg)
|
||||||
|
|
||||||
|
last_call = ws.send_json.call_args[0][0]
|
||||||
|
# Should be processed, not an error about length
|
||||||
|
msg_text = last_call.get("message", "").lower()
|
||||||
|
assert last_call["type"] != "error" or "too long" not in msg_text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raw_message_over_32kb_returns_error(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
ws_ctx = _make_ws_ctx()
|
||||||
|
|
||||||
|
large_msg = "x" * 40_000
|
||||||
|
await dispatch_message(ws, ws_ctx, large_msg)
|
||||||
|
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "too large" in call_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestInvalidJsonHandling:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_json_returns_error(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
ws_ctx = _make_ws_ctx()
|
||||||
|
|
||||||
|
await dispatch_message(ws, ws_ctx, "not valid json {{")
|
||||||
|
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
assert "invalid json" in call_data["message"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_string_returns_json_error(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
ws_ctx = _make_ws_ctx()
|
||||||
|
|
||||||
|
await dispatch_message(ws, ws_ctx, "")
|
||||||
|
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_json_array_not_object_returns_error(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
ws_ctx = _make_ws_ctx()
|
||||||
|
|
||||||
|
await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
|
||||||
|
|
||||||
|
call_data = ws.send_json.call_args[0][0]
|
||||||
|
assert call_data["type"] == "error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRateLimiting:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rapid_fire_messages_rate_limited(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
sm = SessionManager()
|
||||||
|
|
||||||
|
sm.touch("t1")
|
||||||
|
|
||||||
|
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
|
||||||
|
rate_limit_triggered = False
|
||||||
|
for i in range(11):
|
||||||
|
ws_ctx = _make_ws_ctx(sm=sm)
|
||||||
|
await dispatch_message(ws, ws_ctx, json.dumps({
|
||||||
|
"type": "message",
|
||||||
|
"thread_id": "t1",
|
||||||
|
"content": f"message {i}",
|
||||||
|
}))
|
||||||
|
last_call = ws.send_json.call_args[0][0]
|
||||||
|
if last_call["type"] == "error" and "rate" in last_call.get("message", "").lower():
|
||||||
|
rate_limit_triggered = True
|
||||||
|
break
|
||||||
|
|
||||||
|
assert rate_limit_triggered, "Rate limiting should trigger after 10 rapid messages"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_different_threads_have_separate_rate_limits(self) -> None:
|
||||||
|
ws = _make_ws()
|
||||||
|
sm = SessionManager()
|
||||||
|
|
||||||
|
sm.touch("t1")
|
||||||
|
sm.touch("t2")
|
||||||
|
|
||||||
|
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
|
||||||
|
for i in range(5):
|
||||||
|
ws_ctx1 = _make_ws_ctx(sm=sm)
|
||||||
|
ws_ctx2 = _make_ws_ctx(sm=sm)
|
||||||
|
await dispatch_message(ws, ws_ctx1, json.dumps({
|
||||||
|
"type": "message", "thread_id": "t1", "content": f"msg {i}",
|
||||||
|
}))
|
||||||
|
await dispatch_message(ws, ws_ctx2, json.dumps({
|
||||||
|
"type": "message", "thread_id": "t2", "content": f"msg {i}",
|
||||||
|
}))
|
||||||
|
|
||||||
|
last_call = ws.send_json.call_args[0][0]
|
||||||
|
assert "rate" not in last_call.get("message", "").lower()
|
||||||
175
backend/tests/unit/test_error_handler.py
Normal file
175
backend/tests/unit/test_error_handler.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""Tests for app.tools.error_handler module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.tools.error_handler import (
|
||||||
|
ErrorCategory,
|
||||||
|
classify_error,
|
||||||
|
with_retry,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorClassification:
|
||||||
|
def test_timeout_exception_is_timeout(self) -> None:
|
||||||
|
exc = httpx.TimeoutException("timed out")
|
||||||
|
assert classify_error(exc) == ErrorCategory.TIMEOUT
|
||||||
|
|
||||||
|
def test_connect_error_is_network(self) -> None:
|
||||||
|
exc = httpx.ConnectError("connection refused")
|
||||||
|
assert classify_error(exc) == ErrorCategory.NETWORK
|
||||||
|
|
||||||
|
def test_401_is_auth_failure(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(401, request=request)
|
||||||
|
exc = httpx.HTTPStatusError("401", request=request, response=response)
|
||||||
|
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
|
||||||
|
|
||||||
|
def test_403_is_auth_failure(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(403, request=request)
|
||||||
|
exc = httpx.HTTPStatusError("403", request=request, response=response)
|
||||||
|
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
|
||||||
|
|
||||||
|
def test_429_is_retryable(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(429, request=request)
|
||||||
|
exc = httpx.HTTPStatusError("429", request=request, response=response)
|
||||||
|
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||||
|
|
||||||
|
def test_500_is_retryable(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(500, request=request)
|
||||||
|
exc = httpx.HTTPStatusError("500", request=request, response=response)
|
||||||
|
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||||
|
|
||||||
|
def test_502_is_retryable(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(502, request=request)
|
||||||
|
exc = httpx.HTTPStatusError("502", request=request, response=response)
|
||||||
|
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||||
|
|
||||||
|
def test_503_is_retryable(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(503, request=request)
|
||||||
|
exc = httpx.HTTPStatusError("503", request=request, response=response)
|
||||||
|
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||||
|
|
||||||
|
def test_404_is_non_retryable(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(404, request=request)
|
||||||
|
exc = httpx.HTTPStatusError("404", request=request, response=response)
|
||||||
|
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||||
|
|
||||||
|
def test_400_is_non_retryable(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(400, request=request)
|
||||||
|
exc = httpx.HTTPStatusError("400", request=request, response=response)
|
||||||
|
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||||
|
|
||||||
|
def test_generic_exception_is_non_retryable(self) -> None:
|
||||||
|
exc = ValueError("bad value")
|
||||||
|
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||||
|
|
||||||
|
def test_runtime_error_is_non_retryable(self) -> None:
|
||||||
|
exc = RuntimeError("boom")
|
||||||
|
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithRetry:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_succeeds_on_first_try(self) -> None:
|
||||||
|
fn = AsyncMock(return_value="ok")
|
||||||
|
result = await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||||
|
assert result == "ok"
|
||||||
|
assert fn.call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retries_on_retryable_error(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(503, request=request)
|
||||||
|
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
|
||||||
|
|
||||||
|
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "success"])
|
||||||
|
|
||||||
|
with patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock):
|
||||||
|
result = await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||||
|
|
||||||
|
assert result == "success"
|
||||||
|
assert fn.call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_does_not_retry_non_retryable_error(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(404, request=request)
|
||||||
|
non_retryable_exc = httpx.HTTPStatusError("404", request=request, response=response)
|
||||||
|
|
||||||
|
fn = AsyncMock(side_effect=non_retryable_exc)
|
||||||
|
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||||
|
|
||||||
|
assert fn.call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_does_not_retry_auth_failure(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(401, request=request)
|
||||||
|
auth_exc = httpx.HTTPStatusError("401", request=request, response=response)
|
||||||
|
|
||||||
|
fn = AsyncMock(side_effect=auth_exc)
|
||||||
|
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||||
|
|
||||||
|
assert fn.call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raises_after_max_retries_exhausted(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(500, request=request)
|
||||||
|
retryable_exc = httpx.HTTPStatusError("500", request=request, response=response)
|
||||||
|
|
||||||
|
fn = AsyncMock(side_effect=retryable_exc)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
pytest.raises(httpx.HTTPStatusError),
|
||||||
|
):
|
||||||
|
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||||
|
|
||||||
|
assert fn.call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_does_not_retry_timeout(self) -> None:
|
||||||
|
"""TimeoutException is TIMEOUT category -- not retried by default."""
|
||||||
|
fn = AsyncMock(side_effect=httpx.TimeoutException("timed out"))
|
||||||
|
|
||||||
|
with pytest.raises(httpx.TimeoutException):
|
||||||
|
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||||
|
|
||||||
|
assert fn.call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exponential_backoff_increases_delay(self) -> None:
|
||||||
|
request = httpx.Request("GET", "http://example.com")
|
||||||
|
response = httpx.Response(503, request=request)
|
||||||
|
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
|
||||||
|
|
||||||
|
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "done"])
|
||||||
|
sleep_delays: list[float] = []
|
||||||
|
|
||||||
|
async def capture_sleep(delay: float) -> None:
|
||||||
|
sleep_delays.append(delay)
|
||||||
|
|
||||||
|
with patch("app.tools.error_handler.asyncio.sleep", side_effect=capture_sleep):
|
||||||
|
await with_retry(fn, max_retries=3, base_delay=1.0)
|
||||||
|
|
||||||
|
assert len(sleep_delays) == 2
|
||||||
|
assert sleep_delays[1] > sleep_delays[0]
|
||||||
142
backend/tests/unit/test_error_responses.py
Normal file
142
backend/tests/unit/test_error_responses.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""Tests for standardized error response envelope format."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.api_utils import envelope
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _build_test_app() -> FastAPI:
|
||||||
|
"""Build a minimal FastAPI app with the standard exception handlers."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=envelope(None, success=False, error=exc.detail),
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=envelope(None, success=False, error=str(exc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=envelope(None, success=False, error="Internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
class ItemRequest(BaseModel):
|
||||||
|
name: str = Field(..., min_length=1)
|
||||||
|
count: int = Field(..., gt=0)
|
||||||
|
|
||||||
|
@app.get("/items/{item_id}")
|
||||||
|
def get_item(item_id: int) -> dict:
|
||||||
|
if item_id == 0:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid item ID")
|
||||||
|
if item_id == 999:
|
||||||
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
|
if item_id == 401:
|
||||||
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
return envelope({"id": item_id, "name": "test"})
|
||||||
|
|
||||||
|
@app.post("/items")
|
||||||
|
def create_item(item: ItemRequest) -> dict:
|
||||||
|
return envelope({"id": 1, "name": item.name})
|
||||||
|
|
||||||
|
@app.get("/crash")
|
||||||
|
def crash() -> dict:
|
||||||
|
msg = "unexpected failure"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttpExceptionEnvelope:
|
||||||
|
"""HTTPException responses use the standard envelope format."""
|
||||||
|
|
||||||
|
def test_400_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.get("/items/0")
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] == "Invalid item ID"
|
||||||
|
|
||||||
|
def test_404_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.get("/items/999")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] == "Item not found"
|
||||||
|
|
||||||
|
def test_401_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.get("/items/401")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] == "Not authenticated"
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidationErrorEnvelope:
|
||||||
|
"""Validation errors return 422 with envelope format."""
|
||||||
|
|
||||||
|
def test_validation_error_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.post("/items", json={"name": "", "count": -1})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert isinstance(body["error"], str)
|
||||||
|
assert len(body["error"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeneralExceptionEnvelope:
|
||||||
|
"""Unhandled exceptions return 500 with safe envelope."""
|
||||||
|
|
||||||
|
def test_unhandled_exception_returns_500_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
|
resp = client.get("/crash")
|
||||||
|
assert resp.status_code == 500
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is False
|
||||||
|
assert body["data"] is None
|
||||||
|
assert body["error"] == "Internal server error"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSuccessResponseUnchanged:
|
||||||
|
"""Success responses still work normally."""
|
||||||
|
|
||||||
|
def test_success_returns_envelope(self) -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/items/42")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["data"]["id"] == 42
|
||||||
|
assert body["error"] is None
|
||||||
169
backend/tests/unit/test_escalation.py
Normal file
169
backend/tests/unit/test_escalation.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Tests for app.escalation module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.escalation import (
|
||||||
|
EscalationPayload,
|
||||||
|
EscalationResult,
|
||||||
|
NoOpEscalator,
|
||||||
|
WebhookEscalator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_payload(**kwargs) -> EscalationPayload:
|
||||||
|
defaults = {
|
||||||
|
"thread_id": "t1",
|
||||||
|
"reason": "Agent cannot resolve",
|
||||||
|
"conversation_summary": "User asked about refund policy",
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return EscalationPayload(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestEscalationPayload:
|
||||||
|
def test_frozen(self) -> None:
|
||||||
|
payload = _make_payload()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
payload.thread_id = "t2" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_default_metadata(self) -> None:
|
||||||
|
payload = _make_payload()
|
||||||
|
assert payload.metadata == {}
|
||||||
|
|
||||||
|
def test_model_dump(self) -> None:
|
||||||
|
payload = _make_payload(metadata={"key": "val"})
|
||||||
|
data = payload.model_dump()
|
||||||
|
assert data["thread_id"] == "t1"
|
||||||
|
assert data["metadata"] == {"key": "val"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestEscalationResult:
|
||||||
|
def test_frozen(self) -> None:
|
||||||
|
result = EscalationResult(success=True, status_code=200, attempts=1, error=None)
|
||||||
|
assert result.success
|
||||||
|
assert result.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestWebhookEscalator:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_url_returns_failure(self) -> None:
|
||||||
|
escalator = WebhookEscalator(url="", max_retries=3)
|
||||||
|
result = await escalator.escalate(_make_payload())
|
||||||
|
assert not result.success
|
||||||
|
assert result.attempts == 0
|
||||||
|
assert "not configured" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_post(self) -> None:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
|
||||||
|
escalator = WebhookEscalator(url="https://example.com/hook")
|
||||||
|
result = await escalator.escalate(_make_payload())
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.status_code == 200
|
||||||
|
assert result.attempts == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_on_server_error(self) -> None:
|
||||||
|
fail_response = AsyncMock()
|
||||||
|
fail_response.status_code = 500
|
||||||
|
success_response = AsyncMock()
|
||||||
|
success_response.status_code = 200
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(side_effect=[fail_response, fail_response, success_response])
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||||
|
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
|
||||||
|
result = await escalator.escalate(_make_payload())
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.attempts == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_retries_exhausted(self) -> None:
|
||||||
|
fail_response = AsyncMock()
|
||||||
|
fail_response.status_code = 500
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=fail_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||||
|
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
|
||||||
|
result = await escalator.escalate(_make_payload())
|
||||||
|
|
||||||
|
assert not result.success
|
||||||
|
assert result.attempts == 3
|
||||||
|
assert "500" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_error(self) -> None:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("timeout"))
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||||
|
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=2)
|
||||||
|
result = await escalator.escalate(_make_payload())
|
||||||
|
|
||||||
|
assert not result.success
|
||||||
|
assert "timed out" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_error(self) -> None:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(
|
||||||
|
side_effect=httpx.RequestError("connection refused")
|
||||||
|
)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||||
|
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=1)
|
||||||
|
result = await escalator.escalate(_make_payload())
|
||||||
|
|
||||||
|
assert not result.success
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestNoOpEscalator:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_disabled(self) -> None:
|
||||||
|
escalator = NoOpEscalator()
|
||||||
|
result = await escalator.escalate(_make_payload())
|
||||||
|
assert not result.success
|
||||||
|
assert result.attempts == 0
|
||||||
|
assert "disabled" in result.error.lower()
|
||||||
@@ -6,8 +6,11 @@ from typing import TYPE_CHECKING
|
|||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
from app.graph import SUPERVISOR_PROMPT, build_agent_nodes, build_graph
|
from app.graph import build_agent_nodes, build_graph
|
||||||
|
from app.graph_context import GraphContext
|
||||||
|
from app.intent import ClassificationResult, IntentTarget
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.registry import AgentRegistry
|
from app.registry import AgentRegistry
|
||||||
@@ -33,12 +36,59 @@ class TestBuildGraph:
|
|||||||
mock_llm = MagicMock()
|
mock_llm = MagicMock()
|
||||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||||
mock_checkpointer = AsyncMock()
|
checkpointer = InMemorySaver()
|
||||||
|
|
||||||
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
|
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||||
assert graph is not None
|
assert graph_ctx is not None
|
||||||
|
assert graph_ctx.graph is not None
|
||||||
|
|
||||||
def test_supervisor_prompt_contains_routing_info(self) -> None:
|
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
|
||||||
assert "order_lookup" in SUPERVISOR_PROMPT
|
mock_llm = MagicMock()
|
||||||
assert "order_actions" in SUPERVISOR_PROMPT
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||||
assert "fallback" in SUPERVISOR_PROMPT
|
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
mock_classifier = MagicMock()
|
||||||
|
|
||||||
|
graph_ctx = build_graph(
|
||||||
|
sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
|
||||||
|
)
|
||||||
|
assert graph_ctx.intent_classifier is mock_classifier
|
||||||
|
assert graph_ctx.registry is sample_registry
|
||||||
|
|
||||||
|
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
|
||||||
|
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||||
|
assert graph_ctx.intent_classifier is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestClassifyIntent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_without_classifier(self) -> None:
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None)
|
||||||
|
result = await graph_ctx.classify_intent("hello")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_classifier(self) -> None:
|
||||||
|
expected = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="test"),),
|
||||||
|
)
|
||||||
|
mock_classifier = AsyncMock()
|
||||||
|
mock_classifier.classify = AsyncMock(return_value=expected)
|
||||||
|
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_agents = MagicMock(return_value=())
|
||||||
|
graph_ctx = GraphContext(
|
||||||
|
graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await graph_ctx.classify_intent("check order")
|
||||||
|
assert result is not None
|
||||||
|
assert result.intents[0].agent_name == "order_lookup"
|
||||||
|
|||||||
177
backend/tests/unit/test_intent.py
Normal file
177
backend/tests/unit/test_intent.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""Tests for app.intent module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.intent import (
|
||||||
|
AMBIGUITY_THRESHOLD,
|
||||||
|
ClassificationResult,
|
||||||
|
IntentTarget,
|
||||||
|
LLMIntentClassifier,
|
||||||
|
_build_agent_list,
|
||||||
|
)
|
||||||
|
from app.registry import AgentConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _make_agent(name: str, desc: str = "test", perm: str = "read") -> AgentConfig:
|
||||||
|
return AgentConfig(
|
||||||
|
name=name, description=desc, permission=perm, tools=["fallback_respond"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestIntentModels:
|
||||||
|
def test_intent_target_frozen(self) -> None:
|
||||||
|
target = IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="order query")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
target.agent_name = "other" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_classification_result_frozen(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="a", confidence=0.9, reasoning="r"),),
|
||||||
|
)
|
||||||
|
assert not result.is_ambiguous
|
||||||
|
assert result.clarification_question is None
|
||||||
|
|
||||||
|
def test_classification_result_ambiguous(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(),
|
||||||
|
is_ambiguous=True,
|
||||||
|
clarification_question="What do you mean?",
|
||||||
|
)
|
||||||
|
assert result.is_ambiguous
|
||||||
|
|
||||||
|
def test_multi_intent(self) -> None:
|
||||||
|
result = ClassificationResult(
|
||||||
|
intents=(
|
||||||
|
IntentTarget(agent_name="order_actions", confidence=0.85, reasoning="cancel"),
|
||||||
|
IntentTarget(agent_name="discount", confidence=0.8, reasoning="discount"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert len(result.intents) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestBuildAgentList:
|
||||||
|
def test_formats_agents(self) -> None:
|
||||||
|
agents = (
|
||||||
|
_make_agent("order_lookup", "Looks up orders", "read"),
|
||||||
|
_make_agent("order_actions", "Modifies orders", "write"),
|
||||||
|
)
|
||||||
|
text = _build_agent_list(agents)
|
||||||
|
assert "order_lookup" in text
|
||||||
|
assert "order_actions" in text
|
||||||
|
assert "read" in text
|
||||||
|
assert "write" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestLLMIntentClassifier:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_intent_classification(self) -> None:
|
||||||
|
expected = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||||
|
)
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
|
||||||
|
|
||||||
|
result = await classifier.classify("What is order 1042 status?", agents)
|
||||||
|
assert len(result.intents) == 1
|
||||||
|
assert result.intents[0].agent_name == "order_lookup"
|
||||||
|
assert not result.is_ambiguous
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_intent_classification(self) -> None:
|
||||||
|
expected = ClassificationResult(
|
||||||
|
intents=(
|
||||||
|
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||||
|
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
agents = (_make_agent("order_actions"), _make_agent("discount"), _make_agent("fallback"))
|
||||||
|
|
||||||
|
result = await classifier.classify("Cancel order 1042 and give me 10% off", agents)
|
||||||
|
assert len(result.intents) == 2
|
||||||
|
assert not result.is_ambiguous
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ambiguous_classification(self) -> None:
|
||||||
|
expected = ClassificationResult(
|
||||||
|
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
|
||||||
|
is_ambiguous=False,
|
||||||
|
)
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
|
||||||
|
|
||||||
|
result = await classifier.classify("hmm", agents)
|
||||||
|
# Low confidence triggers ambiguity
|
||||||
|
assert result.is_ambiguous
|
||||||
|
assert result.clarification_question is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_error_returns_ambiguous(self) -> None:
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(side_effect=RuntimeError("LLM error"))
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
agents = (_make_agent("fallback"),)
|
||||||
|
|
||||||
|
result = await classifier.classify("test", agents)
|
||||||
|
assert result.is_ambiguous
|
||||||
|
assert result.clarification_question is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_result_type_returns_ambiguous(self) -> None:
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(return_value="not a ClassificationResult")
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
agents = (_make_agent("fallback"),)
|
||||||
|
|
||||||
|
result = await classifier.classify("test", agents)
|
||||||
|
assert result.is_ambiguous
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_high_confidence_not_ambiguous(self) -> None:
|
||||||
|
expected = ClassificationResult(
|
||||||
|
intents=(
|
||||||
|
IntentTarget(
|
||||||
|
agent_name="order_lookup",
|
||||||
|
confidence=AMBIGUITY_THRESHOLD + 0.1,
|
||||||
|
reasoning="clear",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_structured = MagicMock()
|
||||||
|
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||||
|
|
||||||
|
classifier = LLMIntentClassifier(mock_llm)
|
||||||
|
agents = (_make_agent("order_lookup"),)
|
||||||
|
|
||||||
|
result = await classifier.classify("order status 1042", agents)
|
||||||
|
assert not result.is_ambiguous
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user